简述
在计数题解题中遇到要求 \(x \bmod m = 0\) 的条件时,如果式子的计算复杂度比较高,但是式子里含有二项式系数的时候,就可以考虑用单位根反演。
原理
长这个样子:
\[ \frac{1}{m} \sum_{i = 0}^{m – 1} (\omega_m^k)^i = [m | k] \]
证明不难,留作习题。
例题 A:LibreOJ #6358. 前夕
这道题题意:给 \(n\) 个元素,然后选出若干个这些元素的集合,使得这些集合的交集是 \(4\) 的倍数。
可以考虑先试试暴力容斥。枚举交集大小,然后乘上容斥系数 \(f(i)\):
\[ \sum_{k = 0}^n f(i) {n \choose k} (2^{2^{n – k}} – 1) \]
考虑计算这个容斥系数,然后就可以算出结果了。这个容斥系数需要对前 \(n\) 个满足:
\[ \sum_{i = 0}^n {n \choose i} f(i) = [4 | n] \]
二项式反演一下:
\[ f(n) = \sum_{i = 0}^n (-1)^{n – i} {n \choose i} [4 | i] \]
正常计算要 \(\Theta(n^2)\) 的时间,肯定超时。设 \(m = 4\),我们尝试套单位根反演进去:
\[ \begin{aligned} f(k) &= \sum_{i = 0}^k (-1)^{k – i} {k \choose i} [m | i] \\ &= \sum_{i = 0}^k (-1)^{k – i} {k \choose i} \frac{1}{m} \sum_{j = 0}^{m – 1} (\omega_m^i)^j \\ &= \frac{1}{m} \sum_{i = 0}^k \sum_{j = 0}^{m – 1} (-1)^{k – i} {k \choose i} (\omega_m^i)^j \\ &= \frac{1}{m} \sum_{i = 0}^{m – 1} \sum_{j = 0}^{k} {k \choose j} (\omega_m^i)^j (-1)^{k – j} \end{aligned} \]
然后就是喜闻乐见的二项式定理合并:
\[ \begin{aligned} f(k) &= \frac{1}{m} \sum_{i = 0}^{m – 1} \sum_{j = 0}^{k} {k \choose j} (\omega_m^i)^j (-1)^{k – j} \\ &= \frac{1}{m} \sum_{i = 0}^{m – 1} (\omega_m^i -1)^k \end{aligned} \]
然后直接算就完事了。最后,在这里简单说明一下单位根一般的写法:找到模数 \(P\) 的原根 \(g\) 之后,那么 \(\omega_m = \frac{g^{P – 1}}{m}\)。
// LOJ6358.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 1e7 + 200, m = 4, mod = 998244353, G = 3; int f[MAX_N], n, fac[MAX_N], fac_inv[MAX_N]; int quick_pow(int bas, int tim, int cmod) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % cmod; bas = 1LL * bas * bas % cmod; tim >>= 1; } return ret; } const int wn = quick_pow(G, (mod - 1) / 4, mod); int wns[4], wn_org[4]; int binomial(int n_, int k_) { return 1LL * fac[n_] * fac_inv[k_] % mod * fac_inv[n_ - k_] % mod; } int main() { scanf("%d", &n); for (int i = fac[0] = fac_inv[0] = 1; i <= n; i++) fac[i] = 1LL * fac[i - 1] * i % mod; fac_inv[n] = quick_pow(fac[n], mod - 2, mod); for (int i = n - 1; i >= 1; i--) fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod; for (int i = 0; i < 4; i++) wn_org[i] = quick_pow(wn, i, mod) - 1, wns[i] = 1; int m_inv = quick_pow(m, mod - 2, mod); for (int k = 0; k <= n; k++) { for (int i = 0; i < m; i++) f[k] = (1LL * f[k] + wns[i]) % mod; f[k] = 1LL * f[k] * m_inv % mod; for (int i = 0; i < m; i++) wns[i] = 1LL * wns[i] * wn_org[i] % mod; } int ans = 0; for (int k = n, pow_2 = 2; k >= 0; k--, pow_2 = 1LL * pow_2 * pow_2 % mod) { // subset is k; int tmp = 1LL * f[k] * binomial(n, k) % mod * ((pow_2 + mod - 1) % mod) % mod; ans = (0LL + ans + tmp) % mod; } printf("%d\n", ans + 1); return 0; }
例题 B:BZOJ 3328 – PYXFIB
其实可以考虑放在 Fibonacci 的矩阵上做:矩阵满足单位根反演的规则。那么我们可以把式子写成:
\[ \begin{aligned} & \ \ \ \ \ \sum_{i = 0}^n {n \choose i} F_i [k | i] \\ &= \sum_{i = 0}^n {n \choose i} F_i \cdot \frac{1}{k} \sum_{j = 0}^{k – 1} (\omega_k^i)^j \\ &= \frac{1}{k} \sum_{i = 0}^n \sum_{j = 0}^{k – 1} {n \choose i} F_i (\omega_k^i)^j \\ &= \frac{1}{k} \sum_{i = 0}^{k – 1} \sum_{j = 0}^{n} {n \choose j} F_j \omega_k^{ij} \end{aligned} \]
然后我们可以把 Fibonacci 的部分换成转移矩阵 \(A\):
\[ \begin{aligned} & \ \ \ \ \ \frac{1}{k} \sum_{i = 0}^{k – 1} \sum_{j = 0}^{n} {n \choose j} A^j \omega_k^{ij} \\ &= \frac{1}{k} \sum_{i = 0}^{k – 1} \sum_{j = 0}^{n} {n \choose j}I^{n – j} (A \omega_k^i)^j \\ &= \frac{1}{k} \sum_{i = 0}^{k – 1} (A \omega_k^i + I)^n \end{aligned} \]
然后就完事了。
// BZ3328.cpp #include <bits/stdc++.h> using namespace std; typedef long long ll; const int MAX_N = 1e5 + 200; int T, m, mod, ptot, primes[MAX_N], g, wns[MAX_N]; ll n; struct matrix { int mat[2][2]; void clear() { memset(mat, 0, sizeof(mat)); } int *operator[](const int &rhs) { return mat[rhs]; } matrix operator+(const matrix &rhs) { matrix ret; for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) ret[i][j] = (1LL * mat[i][j] + rhs.mat[i][j]) % mod; return ret; } matrix operator*(const matrix &rhs) { matrix ret; ret.clear(); for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) for (int k = 0; k < 2; k++) ret[i][j] = (1LL * ret[i][j] + 1LL * mat[i][k] * rhs.mat[k][j] % mod) % mod; return ret; } matrix operator*(const int &rhs) { matrix ret; for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) ret[i][j] = 1LL * mat[i][j] * rhs % mod; return ret; } matrix operator^(const ll &rhs); } eps, fib; matrix matrix::operator^(const ll &rhs) { ll tim = rhs; matrix ret = eps, bas = *this; while (tim) { if (tim & 1LL) ret = ret * bas; bas = bas * bas; tim >>= 1; } return ret; } int quick_pow(int bas, int tim, int cmod) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % cmod; bas = 1LL * bas * bas % cmod; tim >>= 1; } return ret; } void find_root() { ptot = 0; int x = mod - 1; for (int i = 2; 1LL * i * i <= x; i++) if (x % i == 0) { primes[++ptot] = i; while (x % i == 0) x /= i; } if (x > 1) primes[++ptot] = x; for (int i = 2; i <= mod - 1; i++) { bool flag = true; for (int k = 1; flag && k <= ptot; k++) if (quick_pow(i, (mod - 1) / primes[k], mod) == 1) flag = false; if (flag) { g = i; break; } } int wn = quick_pow(g, (mod - 1) / m, mod); wns[0] = 1; for (int i = 1; i < m; i++) wns[i] = 1LL * wns[i - 1] * wn % mod; } int main() { for (int i = 0; i < 2; i++) eps.mat[i][i] = 1; fib[0][0] = fib[0][1] = fib[1][0] = 1; scanf("%d", &T); while (T--) { scanf("%lld%d%d", &n, &m, &mod), find_root(); matrix ans; ans.clear(); for (int i = 0; i < m; i++) ans = ans + ((fib * wns[i] + eps) ^ n); printf("%lld\n", 1LL * ans[0][0] * quick_pow(m, mod - 2, mod) % mod); } return 0; }