前置技能
FFT、NTT、多项式乘法。
原理推理
现在我们有一个这样的式子:
\[ A(x)B(x) \equiv C(x) \ (mod \ x^n) \]
现在我们已知\(B(x), C(x)\)的信息,而我们计算答案的时候需要用\(A(x)\)进行运算。这个时候,我们需要求出\(B(x)\)的逆元。多项式求逆元,我们一般是用倍增的方式来求解。假如我们有\(B(x)\)在模\(x^{\lceil \frac{x}{2} \rceil}\)意义下的逆元\(D(x)\):
\[ B(x)D(x) \equiv 1 \ (mod \ x^{\lceil \frac{x}{2} \rceil}) \]
现在我们需要求\( x^n \)意义下的逆元,我们现在有两个式子:
\[ \begin{cases} B(x)D(x) \equiv 1 \ (mod \ x^{\lceil \frac{x}{2} \rceil}) \\ B(x)B^{-1}(x) \equiv 1 \ (mod \ x^{\lceil \frac{x}{2} \rceil}) \end{cases} \]
上下相减:
\[ B(x)(D(x) – B^{-1}(x)) \equiv 0 \ (mod \ x^{\lceil \frac{x}{2} \rceil}) \\ D(x) – B^{-1}(x) \equiv 0 \ (mod \ x^{\lceil \frac{x}{2} \rceil}) \]
同时平方:
\[ D^2(x) + B^{-2}(x) – 2D(x)B^{-1}(x) \equiv 0 \ (mod \ n) \]
其中,这里为什么模意义突然变成了\(n\)的原因如下:
我可能会这样讲:根据\(B(x) – B^{-1}(x) \equiv 0 \pmod {x^{\lceil n/2 \rceil}}\),左边一定是一个形如\(k\cdot x^{\lceil n/2 \rceil}\)的多项式,它的平方(即\((B(x) – B^{-1}(x))^2\))为\(k^2 x^n\)(\(n\)为偶数)或\(k^2 x^{n+1}\)(\(n\)为奇数) ,显然二者均能被\(x_n\)整除
src: http://blog.miskcoo.com/2015/05/polynomial-inverse, PLANET6174
剩下的这个乘法用 NTT 解决就行了。
代码
// P4238.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 3e5 + 2000, mod = 998244353, G = 3, Gi = 332748118; int n, rev[MAX_N], mx_pow, mx_bit; ll ai[MAX_N], ans[MAX_N], tmp[MAX_N]; ll quick_pow(ll bas, ll tim) { ll ans = 1; bas %= mod; while (tim) { if (tim & 1) ans = ans * bas % mod; bas = bas * bas % mod; tim >>= 1; } return ans; } inline void ntt(ll *arr, int limit, int dft) { for (int i = 0; i < limit; i++) if (i < rev[i]) swap(arr[i], arr[rev[i]]); for (int step = 1; step < limit; step <<= 1) { ll omega_n = quick_pow(dft == 1 ? G : Gi, (mod - 1) / (step << 1)); for (int j = 0; j < limit; j += (step << 1)) { ll omega_nk = 1; for (int k = j; k < j + step; k++, omega_nk = (omega_nk * omega_n % mod)) { ll x = arr[k], y = omega_nk * arr[k + step] % mod; arr[k] = (x + y) % mod, arr[k + step] = (x - y + mod) % mod; } } } if (dft != 1) { ll inv = quick_pow(limit, mod - 2); for (int i = 0; i < limit; i++) arr[i] = (arr[i] * inv % mod); } } void poly_inverse(int deg, ll *a, ll *b) { if (deg == 1) return (void)(b[0] = quick_pow(a[0], mod - 2)); poly_inverse((deg + 1) >> 1, a, b); ll limit = 2, mx_bit = 1; while ((deg << 1) > limit) limit <<= 1, mx_bit++; for (int i = 1; i < limit; i++) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (mx_bit - 1))); for (int i = 0; i < deg; i++) tmp[i] = a[i]; for (int i = deg; i < limit; i++) tmp[i] = 0; ntt(tmp, limit, 1), ntt(b, limit, 1); for (int i = 0; i < limit; i++) b[i] = 1LL * ((2LL - tmp[i] * b[i] % mod + mod) % mod) * b[i] % mod; ntt(b, limit, -1); for (int i = deg; i < limit; i++) b[i] = 0; } int main() { scanf("%d", &n); for (int i = 0; i <= n - 1; i++) scanf("%lld", &ai[i]); poly_inverse(n, ai, ans); for (int i = 0; i < n; i++) printf("%lld ", ans[i]); return 0; }