Loading [MathJax]/extensions/tex2jax.js

Codeforces 1106F:Lunar New Year and a Recursive Sequence 题解

主要思路

哇这道题还是很神仙的。

首先,有递推式:

\[ f_i = \left(\prod_{j = 1}^{k} f_{i – j}^{b_j}\right) \bmod p \]

题目会给出一个\(f_n=m\),且这个数列前面的项都是\(1\),可看作次数为\(0\)的常数项。我们会发现,对于\(f_j,j>k\),都可以写成\(f_k^{C}\),其中\(C\)是一个多项式。这个多项式可以通过线性递推得到:

\[ C_i = (\sum_{j = 1}^{k} b_jC_{i-j}) \mod (p-1) \]

看到数据范围,考虑用矩阵乘法在\(O(n^3 \log n)\)的时间内得到\(C_n\)。所以现在我们有:

\[ f_k^{C_k} \equiv m \;(mod \; p) \]

我们现在已知\(k,m,p,C_n\),我们现在要求\(f_k\)。考虑使用原根来搞。众所周知,998244353 的原根是 3。原根的幂可以填充整个模\(p\)剩余系,所以可以考虑把这个式子写成:

\[ (3^t)^{C_k} \equiv 3^s \;(mod \; p), \text{其中设}m = 3^s, f_k = 3^t \]

我们把离散对数搞下来,变成:

\[ t*C_k \equiv s \; (mod \; p-1) \]

这个用 BSGS 搞下就可以得出结果\(s\)。然后用 exgcd 算出同余方程的解(顺便判别有无解)。算出\(t\)之后快速幂一下输出。

代码

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// CF1106F.cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
// TODO: Make it fit to the matrix power;
const int MAX_MAT = 150, mod = 998244353;
int n, m, k, ki[MAX_MAT];
struct matrix
{
ll mat[MAX_MAT][MAX_MAT];
void basify()
{
for (int i = 1; i <= k; i++)
mat[i][i] = 1;
}
ll *operator[](const int &id) { return mat[id]; }
matrix operator*(const matrix &mt) const
{
matrix ans;
memset(ans.mat, 0, sizeof(ans.mat));
for (int i = 1; i <= k; i++)
for (int j = 1; j <= k; j++)
for (int mid = 1; mid <= k; mid++)
ans.mat[i][j] = (ans.mat[i][j] + mat[i][mid] * mt.mat[mid][j] % (mod - 1)) % (mod - 1);
return ans;
}
matrix operator^(const int &tim) const
{
matrix ans, bas = *this;
ans.basify();
int ti = tim;
while (ti)
{
if (ti & 1)
ans = ans * bas;
bas = bas * bas;
ti >>= 1;
}
return ans;
}
};
ll quick_pow(ll bas, ll tim)
{
ll ans = 1;
while (tim)
{
if (tim & 1)
ans = ans * bas % mod;
bas = bas * bas % mod;
tim >>= 1;
}
return ans;
}
ll bsgs(ll a, ll y)
{
if (a == 0 && y == 0)
return 1;
if (a == 0 && y != 0)
return -1;
map<ll, ll> hsh;
ll u = ceil(sqrt(mod));
for (int i = 0, x = 1; i <= u; i++, x = x * a % mod)
hsh[y * x % mod] = i;
ll unit = quick_pow(a, u);
for (int i = 1, x = unit; i <= u; i++, x = x * unit % mod)
if (hsh.count(x))
return i * u - hsh[x];
return -1;
}
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if (b == 0)
{
x = 1, y = 0;
return a;
}
ll d = exgcd(b, a % b, x, y), z = x;
x = y, y = z - (a / b) * y;
return d;
}
ll gcd(ll a, ll b) { return (b == 0) ? a : gcd(b, a % b); }
ll solve(ll a, ll b, ll c)
{
if (b == 0)
return 0;
ll d = gcd(a, b);
a /= d, b /= d;
if (gcd(a, c) != 1)
return -1;
ll res, tmp;
exgcd(a, c, res, tmp);
res = res * b % c;
res = (res + c) % c;
return res;
}
int main()
{
scanf("%d", &k);
for (int i = 1; i <= k; i++)
scanf("%d", &ki[i]);
scanf("%d%d", &n, &m);
matrix B;
for (int i = 1; i <= k; i++)
B[1][i] = ki[i];
for (int i = 2; i <= k; i++)
B[i][i - 1] = 1;
B = B ^ (n - k);
ll res = B[1][1], s = bsgs(3, m), ans = solve(res, s, mod - 1);
if (ans == -1)
puts("-1");
else
printf("%lld", quick_pow(3, ans));
return 0;
}
// CF1106F.cpp #include <bits/stdc++.h> #define ll long long using namespace std; // TODO: Make it fit to the matrix power; const int MAX_MAT = 150, mod = 998244353; int n, m, k, ki[MAX_MAT]; struct matrix { ll mat[MAX_MAT][MAX_MAT]; void basify() { for (int i = 1; i <= k; i++) mat[i][i] = 1; } ll *operator[](const int &id) { return mat[id]; } matrix operator*(const matrix &mt) const { matrix ans; memset(ans.mat, 0, sizeof(ans.mat)); for (int i = 1; i <= k; i++) for (int j = 1; j <= k; j++) for (int mid = 1; mid <= k; mid++) ans.mat[i][j] = (ans.mat[i][j] + mat[i][mid] * mt.mat[mid][j] % (mod - 1)) % (mod - 1); return ans; } matrix operator^(const int &tim) const { matrix ans, bas = *this; ans.basify(); int ti = tim; while (ti) { if (ti & 1) ans = ans * bas; bas = bas * bas; ti >>= 1; } return ans; } }; ll quick_pow(ll bas, ll tim) { ll ans = 1; while (tim) { if (tim & 1) ans = ans * bas % mod; bas = bas * bas % mod; tim >>= 1; } return ans; } ll bsgs(ll a, ll y) { if (a == 0 && y == 0) return 1; if (a == 0 && y != 0) return -1; map<ll, ll> hsh; ll u = ceil(sqrt(mod)); for (int i = 0, x = 1; i <= u; i++, x = x * a % mod) hsh[y * x % mod] = i; ll unit = quick_pow(a, u); for (int i = 1, x = unit; i <= u; i++, x = x * unit % mod) if (hsh.count(x)) return i * u - hsh[x]; return -1; } ll exgcd(ll a, ll b, ll &x, ll &y) { if (b == 0) { x = 1, y = 0; return a; } ll d = exgcd(b, a % b, x, y), z = x; x = y, y = z - (a / b) * y; return d; } ll gcd(ll a, ll b) { return (b == 0) ? a : gcd(b, a % b); } ll solve(ll a, ll b, ll c) { if (b == 0) return 0; ll d = gcd(a, b); a /= d, b /= d; if (gcd(a, c) != 1) return -1; ll res, tmp; exgcd(a, c, res, tmp); res = res * b % c; res = (res + c) % c; return res; } int main() { scanf("%d", &k); for (int i = 1; i <= k; i++) scanf("%d", &ki[i]); scanf("%d%d", &n, &m); matrix B; for (int i = 1; i <= k; i++) B[1][i] = ki[i]; for (int i = 2; i <= k; i++) B[i][i - 1] = 1; B = B ^ (n - k); ll res = B[1][1], s = bsgs(3, m), ans = solve(res, s, mod - 1); if (ans == -1) puts("-1"); else printf("%lld", quick_pow(3, ans)); return 0; }
// CF1106F.cpp
#include <bits/stdc++.h>
#define ll long long
using namespace std;
// TODO: Make it fit to the matrix power;
const int MAX_MAT = 150, mod = 998244353;
int n, m, k, ki[MAX_MAT];
struct matrix
{
    ll mat[MAX_MAT][MAX_MAT];
    void basify()
    {
        for (int i = 1; i <= k; i++)
            mat[i][i] = 1;
    }
    ll *operator[](const int &id) { return mat[id]; }
    matrix operator*(const matrix &mt) const
    {
        matrix ans;
        memset(ans.mat, 0, sizeof(ans.mat));
        for (int i = 1; i <= k; i++)
            for (int j = 1; j <= k; j++)
                for (int mid = 1; mid <= k; mid++)
                    ans.mat[i][j] = (ans.mat[i][j] + mat[i][mid] * mt.mat[mid][j] % (mod - 1)) % (mod - 1);
        return ans;
    }
    matrix operator^(const int &tim) const
    {
        matrix ans, bas = *this;
        ans.basify();
        int ti = tim;
        while (ti)
        {
            if (ti & 1)
                ans = ans * bas;
            bas = bas * bas;
            ti >>= 1;
        }
        return ans;
    }
};
ll quick_pow(ll bas, ll tim)
{
    ll ans = 1;
    while (tim)
    {
        if (tim & 1)
            ans = ans * bas % mod;
        bas = bas * bas % mod;
        tim >>= 1;
    }
    return ans;
}
ll bsgs(ll a, ll y)
{
    if (a == 0 && y == 0)
        return 1;
    if (a == 0 && y != 0)
        return -1;
    map<ll, ll> hsh;
    ll u = ceil(sqrt(mod));
    for (int i = 0, x = 1; i <= u; i++, x = x * a % mod)
        hsh[y * x % mod] = i;
    ll unit = quick_pow(a, u);
    for (int i = 1, x = unit; i <= u; i++, x = x * unit % mod)
        if (hsh.count(x))
            return i * u - hsh[x];
    return -1;
}
ll exgcd(ll a, ll b, ll &x, ll &y)
{
    if (b == 0)
    {
        x = 1, y = 0;
        return a;
    }
    ll d = exgcd(b, a % b, x, y), z = x;
    x = y, y = z - (a / b) * y;
    return d;
}
ll gcd(ll a, ll b) { return (b == 0) ? a : gcd(b, a % b); }
ll solve(ll a, ll b, ll c)
{
    if (b == 0)
        return 0;
    ll d = gcd(a, b);
    a /= d, b /= d;
    if (gcd(a, c) != 1)
        return -1;
    ll res, tmp;
    exgcd(a, c, res, tmp);
    res = res * b % c;
    res = (res + c) % c;
    return res;
}
int main()
{
    scanf("%d", &k);
    for (int i = 1; i <= k; i++)
        scanf("%d", &ki[i]);
    scanf("%d%d", &n, &m);
    matrix B;
    for (int i = 1; i <= k; i++)
        B[1][i] = ki[i];
    for (int i = 2; i <= k; i++)
        B[i][i - 1] = 1;
    B = B ^ (n - k);
    ll res = B[1][1], s = bsgs(3, m), ans = solve(res, s, mod - 1);
    if (ans == -1)
        puts("-1");
    else
        printf("%lld", quick_pow(3, ans));
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *