多项式求逆

前置技能

FFT、NTT、多项式乘法。

原理推理

现在我们有一个这样的式子:

\[ A(x)B(x) \equiv C(x) \ (mod \ x^n) \]

现在我们已知\(B(x), C(x)\)的信息,而我们计算答案的时候需要用\(A(x)\)进行运算。这个时候,我们需要求出\(B(x)\)的逆元。多项式求逆元,我们一般是用倍增的方式来求解。假如我们有\(B(x)\)在模\(x^{\lceil \frac{x}{2} \rceil}\)意义下的逆元\(D(x)\):

\[ B(x)D(x) \equiv 1 \ (mod \ x^{\lceil \frac{x}{2} \rceil}) \]

现在我们需要求\( x^n \)意义下的逆元,我们现在有两个式子:

\[ \begin{cases} B(x)D(x) \equiv 1 \ (mod \ x^{\lceil \frac{x}{2} \rceil}) \\ B(x)B^{-1}(x) \equiv 1 \ (mod \ x^{\lceil \frac{x}{2} \rceil}) \end{cases} \]

上下相减:

\[ B(x)(D(x) – B^{-1}(x)) \equiv 0 \ (mod \ x^{\lceil \frac{x}{2} \rceil}) \\ D(x) – B^{-1}(x) \equiv 0 \  (mod \ x^{\lceil \frac{x}{2} \rceil}) \]

同时平方:

\[ D^2(x) + B^{-2}(x) – 2D(x)B^{-1}(x) \equiv 0 \ (mod \ n) \]

其中,这里为什么模意义突然变成了\(n\)的原因如下:

我可能会这样讲:根据\(B(x) – B^{-1}(x) \equiv 0 \pmod {x^{\lceil n/2 \rceil}}\),左边一定是一个形如\(k\cdot x^{\lceil n/2 \rceil}\)的多项式,它的平方(即\((B(x) – B^{-1}(x))^2\))为\(k^2 x^n\)(\(n\)为偶数)或\(k^2 x^{n+1}\)(\(n\)为奇数) ,显然二者均能被\(x_n\)整除

src: http://blog.miskcoo.com/2015/05/polynomial-inverse, PLANET6174

剩下的这个乘法用 NTT 解决就行了。

代码

// P4238.cpp
#include <bits/stdc++.h>
#define ll long long

using namespace std;

const int MAX_N = 3e5 + 2000, mod = 998244353, G = 3, Gi = 332748118;

int n, rev[MAX_N], mx_pow, mx_bit;
ll ai[MAX_N], ans[MAX_N], tmp[MAX_N];

ll quick_pow(ll bas, ll tim)
{
    ll ans = 1;
    bas %= mod;
    while (tim)
    {
        if (tim & 1)
            ans = ans * bas % mod;
        bas = bas * bas % mod;
        tim >>= 1;
    }
    return ans;
}

inline void ntt(ll *arr, int limit, int dft)
{
    for (int i = 0; i < limit; i++)
        if (i < rev[i])
            swap(arr[i], arr[rev[i]]);
    for (int step = 1; step < limit; step <<= 1)
    {
        ll omega_n = quick_pow(dft == 1 ? G : Gi, (mod - 1) / (step << 1));
        for (int j = 0; j < limit; j += (step << 1))
        {
            ll omega_nk = 1;
            for (int k = j; k < j + step; k++, omega_nk = (omega_nk * omega_n % mod))
            {
                ll x = arr[k], y = omega_nk * arr[k + step] % mod;
                arr[k] = (x + y) % mod, arr[k + step] = (x - y + mod) % mod;
            }
        }
    }
    if (dft != 1)
    {
        ll inv = quick_pow(limit, mod - 2);
        for (int i = 0; i < limit; i++)
            arr[i] = (arr[i] * inv % mod);
    }
}

void poly_inverse(int deg, ll *a, ll *b)
{
    if (deg == 1)
        return (void)(b[0] = quick_pow(a[0], mod - 2));
    poly_inverse((deg + 1) >> 1, a, b);

    ll limit = 2, mx_bit = 1;
    while ((deg << 1) > limit)
        limit <<= 1, mx_bit++;
    for (int i = 1; i < limit; i++)
        rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (mx_bit - 1)));

    for (int i = 0; i < deg; i++)
        tmp[i] = a[i];
    for (int i = deg; i < limit; i++)
        tmp[i] = 0;
    ntt(tmp, limit, 1), ntt(b, limit, 1);

    for (int i = 0; i < limit; i++)
        b[i] = 1LL * ((2LL - tmp[i] * b[i] % mod + mod) % mod) * b[i] % mod;

    ntt(b, limit, -1);
    for (int i = deg; i < limit; i++)
        b[i] = 0;
}

int main()
{
    scanf("%d", &n);
    for (int i = 0; i <= n - 1; i++)
        scanf("%lld", &ai[i]);
    poly_inverse(n, ai, ans);
    for (int i = 0; i < n; i++)
        printf("%lld ", ans[i]);
    return 0;
}

 

Leave a Reply

Your email address will not be published. Required fields are marked *