多项式开根

原理推导

按照多项式求逆中倍增的思想,可以写出这样的推导:假设我们已知在\(\pmod {x^{\lceil \frac{n}{2} \rceil}}\)求得\(B'(x)\)为当前的解,现在要步进至\(\pmod x^n\),那么我们可以写出

\[ \begin{aligned} B’^2(x) &\equiv A(x) \pmod {x^{\lceil \frac{n}{2} \rceil}} \\ B’^2(x) &\equiv B^2(x) \pmod {x^{\lceil \frac{n}{2} \rceil}} \\ (B'(x) – B(x))(B'(x) + B(x)) &\equiv 0 \pmod {x^{\lceil \frac{n}{2} \rceil}} \\ (B'(x) – B(x)) &\equiv 0 \pmod {x^{\lceil \frac{n}{2} \rceil}} \\ B’^2(x) + B^2(x) &\equiv 2B'(x)B(x) \pmod {x^n} \\ B’^2(x) + A(x) &\equiv 2B'(x)B(x) \pmod {x^n} \\ \frac{B'(x)}{2} + \frac{A(x)}{2B'(x)} &\equiv B(x) \pmod {x^n} \end{aligned} \]

每次倍增的时候需要做一次多项式求逆和 4 次 NTT。

代码

// P5205.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 4e5 + 200, mod = 998244353, G = 3;

int n, seq[MAX_N], rev[MAX_N], tmp[MAX_N], tmpA[MAX_N], ans[MAX_N];

int quick_pow(int bas, int tim)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % mod;
        bas = 1LL * bas * bas % mod;
        tim >>= 1;
    }
    return ret;
}

const int Gi = quick_pow(G, mod - 2), inv2 = quick_pow(2, mod - 2);

void ntt_initialize(int poly_bit)
{
    for (int i = 1; i < (1 << poly_bit); i++)
        rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (poly_bit - 1)));
}

void ntt(int *arr, int dft, int poly_siz)
{
    for (int i = 0; i < poly_siz; i++)
        if (i < rev[i])
            swap(arr[i], arr[rev[i]]);
    for (int step = 1; step < poly_siz; step <<= 1)
    {
        int omega_n = quick_pow(dft == 1 ? G : Gi, (mod - 1) / (step << 1));
        for (int j = 0; j < poly_siz; j += (step << 1))
        {
            int omega_nk = 1;
            for (int k = j; k < j + step; k++, omega_nk = 1LL * omega_nk * omega_n % mod)
            {
                int t = 1LL * arr[k + step] * omega_nk % mod;
                arr[k + step] = (0LL + arr[k] + mod - t) % mod, arr[k] = (1LL * arr[k] + t) % mod;
            }
        }
    }
    if (dft == -1)
    {
        int inv_n = quick_pow(poly_siz, mod - 2);
        for (int i = 0; i < poly_siz; i++)
            arr[i] = 1LL * arr[i] * inv_n % mod;
    }
}

void poly_inverse(int *src, int *dst, int deg)
{
    if (deg == 1)
        return (void)(dst[0] = quick_pow(src[0], mod - 2));
    poly_inverse(src, dst, (deg + 1) >> 1);
    int poly_bit = 0;
    while ((deg << 1) > (1 << poly_bit))
        poly_bit++;
    int poly_siz = (1 << poly_bit);
    ntt_initialize(poly_bit);
    for (int i = 0; i < deg; i++)
        tmp[i] = src[i];
    for (int i = deg; i < poly_siz; i++)
        tmp[i] = 0;
    ntt(tmp, 1, poly_siz), ntt(dst, 1, poly_siz);
    for (int i = 0; i < poly_siz; i++)
        dst[i] = 1LL * dst[i] * ((2LL + mod - 1LL * dst[i] * tmp[i] % mod) % mod) % mod;
    ntt(dst, -1, poly_siz);
    for (int i = deg; i < poly_siz; i++)
        dst[i] = 0;
}

void poly_sqrt(int *src, int *dst, int deg)
{
    if (deg == 1)
        return (void)(dst[0] = 1);
    poly_sqrt(src, dst, (deg + 1) >> 1);
    for (int i = 0; i <= (deg << 1); i++)
        tmpA[i] = 0;
    poly_inverse(dst, tmpA, deg);
    int poly_bit = 0, poly_siz = 1;
    while ((1 << poly_bit) < (deg << 1))
        poly_bit++, poly_siz <<= 1;
    ntt_initialize(poly_bit);
    for (int i = 0; i < deg; i++)
        tmp[i] = src[i];
    for (int i = deg; i < poly_siz; i++)
        tmp[i] = 0;
    ntt(tmpA, 1, poly_siz), ntt(tmp, 1, poly_siz), ntt(dst, 1, poly_siz);
    for (int i = 0; i < poly_siz; i++)
        dst[i] = 1LL * inv2 * ((0LL + dst[i] + 1LL * tmpA[i] * tmp[i] % mod) % mod) % mod;
    ntt(dst, -1, poly_siz);
    for (int i = deg; i < poly_siz; i++)
        dst[i] = 0;
}

int main()
{
    scanf("%d", &n);
    for (int i = 0; i < n; i++)
        scanf("%d", &seq[i]);
    poly_sqrt(seq, ans, n);
    for (int i = 0; i < n; i++)
        printf("%d ", ans[i]);
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *