原理推导
按照多项式求逆中倍增的思想,可以写出这样的推导:假设我们已知在\(\pmod {x^{\lceil \frac{n}{2} \rceil}}\)求得\(B'(x)\)为当前的解,现在要步进至\(\pmod x^n\),那么我们可以写出
\[ \begin{aligned} B’^2(x) &\equiv A(x) \pmod {x^{\lceil \frac{n}{2} \rceil}} \\ B’^2(x) &\equiv B^2(x) \pmod {x^{\lceil \frac{n}{2} \rceil}} \\ (B'(x) – B(x))(B'(x) + B(x)) &\equiv 0 \pmod {x^{\lceil \frac{n}{2} \rceil}} \\ (B'(x) – B(x)) &\equiv 0 \pmod {x^{\lceil \frac{n}{2} \rceil}} \\ B’^2(x) + B^2(x) &\equiv 2B'(x)B(x) \pmod {x^n} \\ B’^2(x) + A(x) &\equiv 2B'(x)B(x) \pmod {x^n} \\ \frac{B'(x)}{2} + \frac{A(x)}{2B'(x)} &\equiv B(x) \pmod {x^n} \end{aligned} \]
每次倍增的时候需要做一次多项式求逆和 4 次 NTT。
代码
// P5205.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 4e5 + 200, mod = 998244353, G = 3; int n, seq[MAX_N], rev[MAX_N], tmp[MAX_N], tmpA[MAX_N], ans[MAX_N]; int quick_pow(int bas, int tim) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % mod; bas = 1LL * bas * bas % mod; tim >>= 1; } return ret; } const int Gi = quick_pow(G, mod - 2), inv2 = quick_pow(2, mod - 2); void ntt_initialize(int poly_bit) { for (int i = 1; i < (1 << poly_bit); i++) rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (poly_bit - 1))); } void ntt(int *arr, int dft, int poly_siz) { for (int i = 0; i < poly_siz; i++) if (i < rev[i]) swap(arr[i], arr[rev[i]]); for (int step = 1; step < poly_siz; step <<= 1) { int omega_n = quick_pow(dft == 1 ? G : Gi, (mod - 1) / (step << 1)); for (int j = 0; j < poly_siz; j += (step << 1)) { int omega_nk = 1; for (int k = j; k < j + step; k++, omega_nk = 1LL * omega_nk * omega_n % mod) { int t = 1LL * arr[k + step] * omega_nk % mod; arr[k + step] = (0LL + arr[k] + mod - t) % mod, arr[k] = (1LL * arr[k] + t) % mod; } } } if (dft == -1) { int inv_n = quick_pow(poly_siz, mod - 2); for (int i = 0; i < poly_siz; i++) arr[i] = 1LL * arr[i] * inv_n % mod; } } void poly_inverse(int *src, int *dst, int deg) { if (deg == 1) return (void)(dst[0] = quick_pow(src[0], mod - 2)); poly_inverse(src, dst, (deg + 1) >> 1); int poly_bit = 0; while ((deg << 1) > (1 << poly_bit)) poly_bit++; int poly_siz = (1 << poly_bit); ntt_initialize(poly_bit); for (int i = 0; i < deg; i++) tmp[i] = src[i]; for (int i = deg; i < poly_siz; i++) tmp[i] = 0; ntt(tmp, 1, poly_siz), ntt(dst, 1, poly_siz); for (int i = 0; i < poly_siz; i++) dst[i] = 1LL * dst[i] * ((2LL + mod - 1LL * dst[i] * tmp[i] % mod) % mod) % mod; ntt(dst, -1, poly_siz); for (int i = deg; i < poly_siz; i++) dst[i] = 0; } void poly_sqrt(int *src, int *dst, int deg) { if (deg == 1) return (void)(dst[0] = 1); poly_sqrt(src, dst, (deg + 1) >> 1); for (int i = 0; i <= (deg << 1); i++) tmpA[i] = 0; poly_inverse(dst, tmpA, deg); int poly_bit = 0, poly_siz = 1; while ((1 << poly_bit) < (deg << 1)) poly_bit++, poly_siz <<= 1; ntt_initialize(poly_bit); for (int i = 0; i < deg; i++) tmp[i] = src[i]; for (int i = deg; i < poly_siz; i++) tmp[i] = 0; ntt(tmpA, 1, poly_siz), ntt(tmp, 1, poly_siz), ntt(dst, 1, poly_siz); for (int i = 0; i < poly_siz; i++) dst[i] = 1LL * inv2 * ((0LL + dst[i] + 1LL * tmpA[i] * tmp[i] % mod) % mod) % mod; ntt(dst, -1, poly_siz); for (int i = deg; i < poly_siz; i++) dst[i] = 0; } int main() { scanf("%d", &n); for (int i = 0; i < n; i++) scanf("%d", &seq[i]); poly_sqrt(seq, ans, n); for (int i = 0; i < n; i++) printf("%d ", ans[i]); return 0; }