Loading [MathJax]/extensions/tex2jax.js

单位根反演

简述

在计数题解题中遇到要求 \(x \bmod m = 0\) 的条件时,如果式子的计算复杂度比较高,但是式子里含有二项式系数的时候,就可以考虑用单位根反演。

原理

长这个样子:

\[ \frac{1}{m} \sum_{i = 0}^{m – 1} (\omega_m^k)^i = [m | k] \]

证明不难,留作习题。

例题 A:LibreOJ #6358. 前夕

这道题题意:给 \(n\) 个元素,然后选出若干个这些元素的集合,使得这些集合的交集是 \(4\) 的倍数。

可以考虑先试试暴力容斥。枚举交集大小,然后乘上容斥系数 \(f(i)\):

\[ \sum_{k = 0}^n f(i) {n \choose k} (2^{2^{n – k}} – 1) \]

考虑计算这个容斥系数,然后就可以算出结果了。这个容斥系数需要对前 \(n\) 个满足:

\[ \sum_{i = 0}^n {n \choose i} f(i) = [4 | n] \]

二项式反演一下:

\[ f(n) = \sum_{i = 0}^n (-1)^{n – i} {n \choose i} [4 | i] \]

正常计算要 \(\Theta(n^2)\) 的时间,肯定超时。设 \(m = 4\),我们尝试套单位根反演进去:

\[ \begin{aligned} f(k) &= \sum_{i = 0}^k (-1)^{k – i} {k \choose i} [m | i] \\ &= \sum_{i = 0}^k (-1)^{k – i} {k \choose i} \frac{1}{m} \sum_{j = 0}^{m – 1} (\omega_m^i)^j \\ &= \frac{1}{m} \sum_{i = 0}^k \sum_{j = 0}^{m – 1} (-1)^{k – i} {k \choose i} (\omega_m^i)^j \\ &= \frac{1}{m} \sum_{i = 0}^{m – 1} \sum_{j = 0}^{k} {k \choose j} (\omega_m^i)^j (-1)^{k – j} \end{aligned} \]

然后就是喜闻乐见的二项式定理合并:

\[ \begin{aligned} f(k) &= \frac{1}{m} \sum_{i = 0}^{m – 1} \sum_{j = 0}^{k} {k \choose j} (\omega_m^i)^j (-1)^{k – j} \\ &= \frac{1}{m} \sum_{i = 0}^{m – 1} (\omega_m^i -1)^k \end{aligned} \]

然后直接算就完事了。最后,在这里简单说明一下单位根一般的写法:找到模数 \(P\) 的原根 \(g\) 之后,那么 \(\omega_m = \frac{g^{P – 1}}{m}\)。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// LOJ6358.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1e7 + 200, m = 4, mod = 998244353, G = 3;
int f[MAX_N], n, fac[MAX_N], fac_inv[MAX_N];
int quick_pow(int bas, int tim, int cmod)
{
int ret = 1;
while (tim)
{
if (tim & 1)
ret = 1LL * ret * bas % cmod;
bas = 1LL * bas * bas % cmod;
tim >>= 1;
}
return ret;
}
const int wn = quick_pow(G, (mod - 1) / 4, mod);
int wns[4], wn_org[4];
int binomial(int n_, int k_) { return 1LL * fac[n_] * fac_inv[k_] % mod * fac_inv[n_ - k_] % mod; }
int main()
{
scanf("%d", &n);
for (int i = fac[0] = fac_inv[0] = 1; i <= n; i++)
fac[i] = 1LL * fac[i - 1] * i % mod;
fac_inv[n] = quick_pow(fac[n], mod - 2, mod);
for (int i = n - 1; i >= 1; i--)
fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod;
for (int i = 0; i < 4; i++)
wn_org[i] = quick_pow(wn, i, mod) - 1, wns[i] = 1;
int m_inv = quick_pow(m, mod - 2, mod);
for (int k = 0; k <= n; k++)
{
for (int i = 0; i < m; i++)
f[k] = (1LL * f[k] + wns[i]) % mod;
f[k] = 1LL * f[k] * m_inv % mod;
for (int i = 0; i < m; i++)
wns[i] = 1LL * wns[i] * wn_org[i] % mod;
}
int ans = 0;
for (int k = n, pow_2 = 2; k >= 0; k--, pow_2 = 1LL * pow_2 * pow_2 % mod)
{
// subset is k;
int tmp = 1LL * f[k] * binomial(n, k) % mod * ((pow_2 + mod - 1) % mod) % mod;
ans = (0LL + ans + tmp) % mod;
}
printf("%d\n", ans + 1);
return 0;
}
// LOJ6358.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 1e7 + 200, m = 4, mod = 998244353, G = 3; int f[MAX_N], n, fac[MAX_N], fac_inv[MAX_N]; int quick_pow(int bas, int tim, int cmod) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % cmod; bas = 1LL * bas * bas % cmod; tim >>= 1; } return ret; } const int wn = quick_pow(G, (mod - 1) / 4, mod); int wns[4], wn_org[4]; int binomial(int n_, int k_) { return 1LL * fac[n_] * fac_inv[k_] % mod * fac_inv[n_ - k_] % mod; } int main() { scanf("%d", &n); for (int i = fac[0] = fac_inv[0] = 1; i <= n; i++) fac[i] = 1LL * fac[i - 1] * i % mod; fac_inv[n] = quick_pow(fac[n], mod - 2, mod); for (int i = n - 1; i >= 1; i--) fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod; for (int i = 0; i < 4; i++) wn_org[i] = quick_pow(wn, i, mod) - 1, wns[i] = 1; int m_inv = quick_pow(m, mod - 2, mod); for (int k = 0; k <= n; k++) { for (int i = 0; i < m; i++) f[k] = (1LL * f[k] + wns[i]) % mod; f[k] = 1LL * f[k] * m_inv % mod; for (int i = 0; i < m; i++) wns[i] = 1LL * wns[i] * wn_org[i] % mod; } int ans = 0; for (int k = n, pow_2 = 2; k >= 0; k--, pow_2 = 1LL * pow_2 * pow_2 % mod) { // subset is k; int tmp = 1LL * f[k] * binomial(n, k) % mod * ((pow_2 + mod - 1) % mod) % mod; ans = (0LL + ans + tmp) % mod; } printf("%d\n", ans + 1); return 0; }
// LOJ6358.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e7 + 200, m = 4, mod = 998244353, G = 3;

int f[MAX_N], n, fac[MAX_N], fac_inv[MAX_N];

int quick_pow(int bas, int tim, int cmod)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % cmod;
        bas = 1LL * bas * bas % cmod;
        tim >>= 1;
    }
    return ret;
}

const int wn = quick_pow(G, (mod - 1) / 4, mod);
int wns[4], wn_org[4];

int binomial(int n_, int k_) { return 1LL * fac[n_] * fac_inv[k_] % mod * fac_inv[n_ - k_] % mod; }

int main()
{
    scanf("%d", &n);
    for (int i = fac[0] = fac_inv[0] = 1; i <= n; i++)
        fac[i] = 1LL * fac[i - 1] * i % mod;
    fac_inv[n] = quick_pow(fac[n], mod - 2, mod);
    for (int i = n - 1; i >= 1; i--)
        fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod;
    for (int i = 0; i < 4; i++)
        wn_org[i] = quick_pow(wn, i, mod) - 1, wns[i] = 1;
    int m_inv = quick_pow(m, mod - 2, mod);
    for (int k = 0; k <= n; k++)
    {
        for (int i = 0; i < m; i++)
            f[k] = (1LL * f[k] + wns[i]) % mod;
        f[k] = 1LL * f[k] * m_inv % mod;
        for (int i = 0; i < m; i++)
            wns[i] = 1LL * wns[i] * wn_org[i] % mod;
    }
    int ans = 0;
    for (int k = n, pow_2 = 2; k >= 0; k--, pow_2 = 1LL * pow_2 * pow_2 % mod)
    {
        // subset is k;
        int tmp = 1LL * f[k] * binomial(n, k) % mod * ((pow_2 + mod - 1) % mod) % mod;
        ans = (0LL + ans + tmp) % mod;
    }
    printf("%d\n", ans + 1);
    return 0;
}

例题 B:BZOJ 3328 – PYXFIB

其实可以考虑放在 Fibonacci 的矩阵上做:矩阵满足单位根反演的规则。那么我们可以把式子写成:

\[ \begin{aligned} & \ \ \ \ \ \sum_{i = 0}^n {n \choose i} F_i [k | i] \\ &= \sum_{i = 0}^n {n \choose i} F_i \cdot \frac{1}{k} \sum_{j = 0}^{k – 1} (\omega_k^i)^j \\ &= \frac{1}{k} \sum_{i = 0}^n \sum_{j = 0}^{k – 1} {n \choose i} F_i (\omega_k^i)^j \\ &= \frac{1}{k} \sum_{i = 0}^{k – 1} \sum_{j = 0}^{n} {n \choose j} F_j \omega_k^{ij} \end{aligned} \]

然后我们可以把 Fibonacci 的部分换成转移矩阵 \(A\):

\[ \begin{aligned} & \ \ \ \ \ \frac{1}{k} \sum_{i = 0}^{k – 1} \sum_{j = 0}^{n} {n \choose j} A^j \omega_k^{ij} \\ &= \frac{1}{k} \sum_{i = 0}^{k – 1} \sum_{j = 0}^{n} {n \choose j}I^{n – j} (A \omega_k^i)^j \\ &= \frac{1}{k} \sum_{i = 0}^{k – 1} (A \omega_k^i + I)^n \end{aligned} \]

然后就完事了。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// BZ3328.cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAX_N = 1e5 + 200;
int T, m, mod, ptot, primes[MAX_N], g, wns[MAX_N];
ll n;
struct matrix
{
int mat[2][2];
void clear() { memset(mat, 0, sizeof(mat)); }
int *operator[](const int &rhs) { return mat[rhs]; }
matrix operator+(const matrix &rhs)
{
matrix ret;
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
ret[i][j] = (1LL * mat[i][j] + rhs.mat[i][j]) % mod;
return ret;
}
matrix operator*(const matrix &rhs)
{
matrix ret;
ret.clear();
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
for (int k = 0; k < 2; k++)
ret[i][j] = (1LL * ret[i][j] + 1LL * mat[i][k] * rhs.mat[k][j] % mod) % mod;
return ret;
}
matrix operator*(const int &rhs)
{
matrix ret;
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
ret[i][j] = 1LL * mat[i][j] * rhs % mod;
return ret;
}
matrix operator^(const ll &rhs);
} eps, fib;
matrix matrix::operator^(const ll &rhs)
{
ll tim = rhs;
matrix ret = eps, bas = *this;
while (tim)
{
if (tim & 1LL)
ret = ret * bas;
bas = bas * bas;
tim >>= 1;
}
return ret;
}
int quick_pow(int bas, int tim, int cmod)
{
int ret = 1;
while (tim)
{
if (tim & 1)
ret = 1LL * ret * bas % cmod;
bas = 1LL * bas * bas % cmod;
tim >>= 1;
}
return ret;
}
void find_root()
{
ptot = 0;
int x = mod - 1;
for (int i = 2; 1LL * i * i <= x; i++)
if (x % i == 0)
{
primes[++ptot] = i;
while (x % i == 0)
x /= i;
}
if (x > 1)
primes[++ptot] = x;
for (int i = 2; i <= mod - 1; i++)
{
bool flag = true;
for (int k = 1; flag && k <= ptot; k++)
if (quick_pow(i, (mod - 1) / primes[k], mod) == 1)
flag = false;
if (flag)
{
g = i;
break;
}
}
int wn = quick_pow(g, (mod - 1) / m, mod);
wns[0] = 1;
for (int i = 1; i < m; i++)
wns[i] = 1LL * wns[i - 1] * wn % mod;
}
int main()
{
for (int i = 0; i < 2; i++)
eps.mat[i][i] = 1;
fib[0][0] = fib[0][1] = fib[1][0] = 1;
scanf("%d", &T);
while (T--)
{
scanf("%lld%d%d", &n, &m, &mod), find_root();
matrix ans;
ans.clear();
for (int i = 0; i < m; i++)
ans = ans + ((fib * wns[i] + eps) ^ n);
printf("%lld\n", 1LL * ans[0][0] * quick_pow(m, mod - 2, mod) % mod);
}
return 0;
}
// BZ3328.cpp #include <bits/stdc++.h> using namespace std; typedef long long ll; const int MAX_N = 1e5 + 200; int T, m, mod, ptot, primes[MAX_N], g, wns[MAX_N]; ll n; struct matrix { int mat[2][2]; void clear() { memset(mat, 0, sizeof(mat)); } int *operator[](const int &rhs) { return mat[rhs]; } matrix operator+(const matrix &rhs) { matrix ret; for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) ret[i][j] = (1LL * mat[i][j] + rhs.mat[i][j]) % mod; return ret; } matrix operator*(const matrix &rhs) { matrix ret; ret.clear(); for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) for (int k = 0; k < 2; k++) ret[i][j] = (1LL * ret[i][j] + 1LL * mat[i][k] * rhs.mat[k][j] % mod) % mod; return ret; } matrix operator*(const int &rhs) { matrix ret; for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) ret[i][j] = 1LL * mat[i][j] * rhs % mod; return ret; } matrix operator^(const ll &rhs); } eps, fib; matrix matrix::operator^(const ll &rhs) { ll tim = rhs; matrix ret = eps, bas = *this; while (tim) { if (tim & 1LL) ret = ret * bas; bas = bas * bas; tim >>= 1; } return ret; } int quick_pow(int bas, int tim, int cmod) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % cmod; bas = 1LL * bas * bas % cmod; tim >>= 1; } return ret; } void find_root() { ptot = 0; int x = mod - 1; for (int i = 2; 1LL * i * i <= x; i++) if (x % i == 0) { primes[++ptot] = i; while (x % i == 0) x /= i; } if (x > 1) primes[++ptot] = x; for (int i = 2; i <= mod - 1; i++) { bool flag = true; for (int k = 1; flag && k <= ptot; k++) if (quick_pow(i, (mod - 1) / primes[k], mod) == 1) flag = false; if (flag) { g = i; break; } } int wn = quick_pow(g, (mod - 1) / m, mod); wns[0] = 1; for (int i = 1; i < m; i++) wns[i] = 1LL * wns[i - 1] * wn % mod; } int main() { for (int i = 0; i < 2; i++) eps.mat[i][i] = 1; fib[0][0] = fib[0][1] = fib[1][0] = 1; scanf("%d", &T); while (T--) { scanf("%lld%d%d", &n, &m, &mod), find_root(); matrix ans; ans.clear(); for (int i = 0; i < m; i++) ans = ans + ((fib * wns[i] + eps) ^ n); printf("%lld\n", 1LL * ans[0][0] * quick_pow(m, mod - 2, mod) % mod); } return 0; }
// BZ3328.cpp
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAX_N = 1e5 + 200;

int T, m, mod, ptot, primes[MAX_N], g, wns[MAX_N];
ll n;

struct matrix
{
    int mat[2][2];

    void clear() { memset(mat, 0, sizeof(mat)); }

    int *operator[](const int &rhs) { return mat[rhs]; }

    matrix operator+(const matrix &rhs)
    {
        matrix ret;
        for (int i = 0; i < 2; i++)
            for (int j = 0; j < 2; j++)
                ret[i][j] = (1LL * mat[i][j] + rhs.mat[i][j]) % mod;
        return ret;
    }

    matrix operator*(const matrix &rhs)
    {
        matrix ret;
        ret.clear();
        for (int i = 0; i < 2; i++)
            for (int j = 0; j < 2; j++)
                for (int k = 0; k < 2; k++)
                    ret[i][j] = (1LL * ret[i][j] + 1LL * mat[i][k] * rhs.mat[k][j] % mod) % mod;
        return ret;
    }

    matrix operator*(const int &rhs)
    {
        matrix ret;
        for (int i = 0; i < 2; i++)
            for (int j = 0; j < 2; j++)
                ret[i][j] = 1LL * mat[i][j] * rhs % mod;
        return ret;
    }

    matrix operator^(const ll &rhs);
} eps, fib;

matrix matrix::operator^(const ll &rhs)
{
    ll tim = rhs;
    matrix ret = eps, bas = *this;
    while (tim)
    {
        if (tim & 1LL)
            ret = ret * bas;
        bas = bas * bas;
        tim >>= 1;
    }
    return ret;
}

int quick_pow(int bas, int tim, int cmod)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % cmod;
        bas = 1LL * bas * bas % cmod;
        tim >>= 1;
    }
    return ret;
}

void find_root()
{
    ptot = 0;
    int x = mod - 1;
    for (int i = 2; 1LL * i * i <= x; i++)
        if (x % i == 0)
        {
            primes[++ptot] = i;
            while (x % i == 0)
                x /= i;
        }
    if (x > 1)
        primes[++ptot] = x;
    for (int i = 2; i <= mod - 1; i++)
    {
        bool flag = true;
        for (int k = 1; flag && k <= ptot; k++)
            if (quick_pow(i, (mod - 1) / primes[k], mod) == 1)
                flag = false;
        if (flag)
        {
            g = i;
            break;
        }
    }
    int wn = quick_pow(g, (mod - 1) / m, mod);
    wns[0] = 1;
    for (int i = 1; i < m; i++)
        wns[i] = 1LL * wns[i - 1] * wn % mod;
}

int main()
{
    for (int i = 0; i < 2; i++)
        eps.mat[i][i] = 1;
    fib[0][0] = fib[0][1] = fib[1][0] = 1;
    scanf("%d", &T);
    while (T--)
    {
        scanf("%lld%d%d", &n, &m, &mod), find_root();
        matrix ans;
        ans.clear();
        for (int i = 0; i < m; i++)
            ans = ans + ((fib * wns[i] + eps) ^ n);
        printf("%lld\n", 1LL * ans[0][0] * quick_pow(m, mod - 2, mod) % mod);
    }
    return 0;
}

 

Leave a Reply

Your email address will not be published. Required fields are marked *