单位根反演

简述

在计数题解题中遇到要求 \(x \bmod m = 0\) 的条件时,如果式子的计算复杂度比较高,但是式子里含有二项式系数的时候,就可以考虑用单位根反演。

原理

长这个样子:

\[ \frac{1}{m} \sum_{i = 0}^{m – 1} (\omega_m^k)^i = [m | k] \]

证明不难,留作习题。

例题 A:LibreOJ #6358. 前夕

这道题题意:给 \(n\) 个元素,然后选出若干个这些元素的集合,使得这些集合的交集是 \(4\) 的倍数。

可以考虑先试试暴力容斥。枚举交集大小,然后乘上容斥系数 \(f(i)\):

\[ \sum_{k = 0}^n f(i) {n \choose k} (2^{2^{n – k}} – 1) \]

考虑计算这个容斥系数,然后就可以算出结果了。这个容斥系数需要对前 \(n\) 个满足:

\[ \sum_{i = 0}^n {n \choose i} f(i) = [4 | n] \]

二项式反演一下:

\[ f(n) = \sum_{i = 0}^n (-1)^{n – i} {n \choose i} [4 | i] \]

正常计算要 \(\Theta(n^2)\) 的时间,肯定超时。设 \(m = 4\),我们尝试套单位根反演进去:

\[ \begin{aligned} f(k) &= \sum_{i = 0}^k (-1)^{k – i} {k \choose i} [m | i] \\ &= \sum_{i = 0}^k (-1)^{k – i} {k \choose i} \frac{1}{m} \sum_{j = 0}^{m – 1} (\omega_m^i)^j \\ &= \frac{1}{m} \sum_{i = 0}^k \sum_{j = 0}^{m – 1} (-1)^{k – i} {k \choose i} (\omega_m^i)^j \\ &= \frac{1}{m} \sum_{i = 0}^{m – 1} \sum_{j = 0}^{k} {k \choose j} (\omega_m^i)^j (-1)^{k – j} \end{aligned} \]

然后就是喜闻乐见的二项式定理合并:

\[ \begin{aligned} f(k) &= \frac{1}{m} \sum_{i = 0}^{m – 1} \sum_{j = 0}^{k} {k \choose j} (\omega_m^i)^j (-1)^{k – j} \\ &= \frac{1}{m} \sum_{i = 0}^{m – 1} (\omega_m^i -1)^k \end{aligned} \]

然后直接算就完事了。最后,在这里简单说明一下单位根一般的写法:找到模数 \(P\) 的原根 \(g\) 之后,那么 \(\omega_m = \frac{g^{P – 1}}{m}\)。

// LOJ6358.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e7 + 200, m = 4, mod = 998244353, G = 3;

int f[MAX_N], n, fac[MAX_N], fac_inv[MAX_N];

int quick_pow(int bas, int tim, int cmod)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % cmod;
        bas = 1LL * bas * bas % cmod;
        tim >>= 1;
    }
    return ret;
}

const int wn = quick_pow(G, (mod - 1) / 4, mod);
int wns[4], wn_org[4];

int binomial(int n_, int k_) { return 1LL * fac[n_] * fac_inv[k_] % mod * fac_inv[n_ - k_] % mod; }

int main()
{
    scanf("%d", &n);
    for (int i = fac[0] = fac_inv[0] = 1; i <= n; i++)
        fac[i] = 1LL * fac[i - 1] * i % mod;
    fac_inv[n] = quick_pow(fac[n], mod - 2, mod);
    for (int i = n - 1; i >= 1; i--)
        fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod;
    for (int i = 0; i < 4; i++)
        wn_org[i] = quick_pow(wn, i, mod) - 1, wns[i] = 1;
    int m_inv = quick_pow(m, mod - 2, mod);
    for (int k = 0; k <= n; k++)
    {
        for (int i = 0; i < m; i++)
            f[k] = (1LL * f[k] + wns[i]) % mod;
        f[k] = 1LL * f[k] * m_inv % mod;
        for (int i = 0; i < m; i++)
            wns[i] = 1LL * wns[i] * wn_org[i] % mod;
    }
    int ans = 0;
    for (int k = n, pow_2 = 2; k >= 0; k--, pow_2 = 1LL * pow_2 * pow_2 % mod)
    {
        // subset is k;
        int tmp = 1LL * f[k] * binomial(n, k) % mod * ((pow_2 + mod - 1) % mod) % mod;
        ans = (0LL + ans + tmp) % mod;
    }
    printf("%d\n", ans + 1);
    return 0;
}

例题 B:BZOJ 3328 – PYXFIB

其实可以考虑放在 Fibonacci 的矩阵上做:矩阵满足单位根反演的规则。那么我们可以把式子写成:

\[ \begin{aligned} & \ \ \ \ \ \sum_{i = 0}^n {n \choose i} F_i [k | i] \\ &= \sum_{i = 0}^n {n \choose i} F_i \cdot \frac{1}{k} \sum_{j = 0}^{k – 1} (\omega_k^i)^j \\ &= \frac{1}{k} \sum_{i = 0}^n \sum_{j = 0}^{k – 1} {n \choose i} F_i (\omega_k^i)^j \\ &= \frac{1}{k} \sum_{i = 0}^{k – 1} \sum_{j = 0}^{n} {n \choose j} F_j \omega_k^{ij} \end{aligned} \]

然后我们可以把 Fibonacci 的部分换成转移矩阵 \(A\):

\[ \begin{aligned} & \ \ \ \ \ \frac{1}{k} \sum_{i = 0}^{k – 1} \sum_{j = 0}^{n} {n \choose j} A^j \omega_k^{ij} \\ &= \frac{1}{k} \sum_{i = 0}^{k – 1} \sum_{j = 0}^{n} {n \choose j}I^{n – j} (A \omega_k^i)^j \\ &= \frac{1}{k} \sum_{i = 0}^{k – 1} (A \omega_k^i + I)^n \end{aligned} \]

然后就完事了。

// BZ3328.cpp
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAX_N = 1e5 + 200;

int T, m, mod, ptot, primes[MAX_N], g, wns[MAX_N];
ll n;

struct matrix
{
    int mat[2][2];

    void clear() { memset(mat, 0, sizeof(mat)); }

    int *operator[](const int &rhs) { return mat[rhs]; }

    matrix operator+(const matrix &rhs)
    {
        matrix ret;
        for (int i = 0; i < 2; i++)
            for (int j = 0; j < 2; j++)
                ret[i][j] = (1LL * mat[i][j] + rhs.mat[i][j]) % mod;
        return ret;
    }

    matrix operator*(const matrix &rhs)
    {
        matrix ret;
        ret.clear();
        for (int i = 0; i < 2; i++)
            for (int j = 0; j < 2; j++)
                for (int k = 0; k < 2; k++)
                    ret[i][j] = (1LL * ret[i][j] + 1LL * mat[i][k] * rhs.mat[k][j] % mod) % mod;
        return ret;
    }

    matrix operator*(const int &rhs)
    {
        matrix ret;
        for (int i = 0; i < 2; i++)
            for (int j = 0; j < 2; j++)
                ret[i][j] = 1LL * mat[i][j] * rhs % mod;
        return ret;
    }

    matrix operator^(const ll &rhs);
} eps, fib;

matrix matrix::operator^(const ll &rhs)
{
    ll tim = rhs;
    matrix ret = eps, bas = *this;
    while (tim)
    {
        if (tim & 1LL)
            ret = ret * bas;
        bas = bas * bas;
        tim >>= 1;
    }
    return ret;
}

int quick_pow(int bas, int tim, int cmod)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % cmod;
        bas = 1LL * bas * bas % cmod;
        tim >>= 1;
    }
    return ret;
}

void find_root()
{
    ptot = 0;
    int x = mod - 1;
    for (int i = 2; 1LL * i * i <= x; i++)
        if (x % i == 0)
        {
            primes[++ptot] = i;
            while (x % i == 0)
                x /= i;
        }
    if (x > 1)
        primes[++ptot] = x;
    for (int i = 2; i <= mod - 1; i++)
    {
        bool flag = true;
        for (int k = 1; flag && k <= ptot; k++)
            if (quick_pow(i, (mod - 1) / primes[k], mod) == 1)
                flag = false;
        if (flag)
        {
            g = i;
            break;
        }
    }
    int wn = quick_pow(g, (mod - 1) / m, mod);
    wns[0] = 1;
    for (int i = 1; i < m; i++)
        wns[i] = 1LL * wns[i - 1] * wn % mod;
}

int main()
{
    for (int i = 0; i < 2; i++)
        eps.mat[i][i] = 1;
    fib[0][0] = fib[0][1] = fib[1][0] = 1;
    scanf("%d", &T);
    while (T--)
    {
        scanf("%lld%d%d", &n, &m, &mod), find_root();
        matrix ans;
        ans.clear();
        for (int i = 0; i < m; i++)
            ans = ans + ((fib * wns[i] + eps) ^ n);
        printf("%lld\n", 1LL * ans[0][0] * quick_pow(m, mod - 2, mod) % mod);
    }
    return 0;
}

 

Leave a Reply

Your email address will not be published. Required fields are marked *