主要思路
真你妈毒瘤。
原式子是长这个样子的:
\[ f(n) = \sum_{i = 0}^n \sum_{j = 0}^i \begin{Bmatrix} i \\ j \end{Bmatrix} 2^j (j!) \]
首先发现,把\(j\)的上界调整到\(n\)是没有问题的,因为在\(j > i\)时,\( \begin{Bmatrix} i \\ j \end{Bmatrix} = 0 \),然后再交换一些和式:
\[ \begin{align} f(n) &= \sum_{i = 0}^n \sum_{j = 0}^n \begin{Bmatrix} i \\ j \end{Bmatrix} 2^j (j!) \\ &= \sum_{j = 0}^n 2^j (j!) \sum_{i = 0}^n \begin{Bmatrix} i \\ j \end{Bmatrix} \end{align} \]
然后把后面的斯特林数展开成卷积的形式:
\[ \begin{align} f(n) &= \sum_{j = 0}^n 2^j (j!) \sum_{i = 0}^n \begin{Bmatrix} i \\ j \end{Bmatrix} \\ &= \sum_{j = 0}^n 2^j (j!) \sum_{i = 0}^n \sum_{k = 0}^j \frac{(-1)^{j – k}}{(j – k)!} \cdot \frac{k^i}{k!} \\ &= \sum_{j = 0}^n 2^j (j!) \sum_{i = 0}^n \sum_{k = 0}^j \frac{(-1)^{k}}{(k)!} \cdot \frac{(j – k)^i}{(j – k)!} \end{align} \]
有点困难?再推一推。
\[ \begin{align} f(n) &= \sum_{j = 0}^n 2^j (j!) \sum_{k = 0}^j \frac{(-1)^{k}}{(k)!} \sum_{i = 0}^n \frac{(j – k)^i}{(j – k)!} \\ \text{设} calc(x) &= \sum_{i = 0}^n \frac{x^i}{x!} = \frac{1}{x!} \cdot \frac{(1 – x^{n + 1})}{1 – x} = \frac{x^{n + 1} – 1}{x!(x – 1)} \\ f(n) &= \sum_{j = 0}^n 2^j (j!) \sum_{k = 0}^j \frac{(-1)^{k}}{(k)!} calc(j – k) \end{align} \]
发现明显的卷积\( \sum_{k = 0}^j \frac{(-1)^{k}}{(k)!} calc(j – k) \)。用 NTT 搞就行了。
代码
// P4091.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 3e5 + 200, mod = 998244353, mod_g = 3, mod_gi = 332748118; int n, mx_pow, mx_bit, rev[MAX_N], fac[MAX_N], fac_inv[MAX_N], gx[MAX_N]; ll poly_A[MAX_N], poly_B[MAX_N]; int quick_pow(int bas, int tim) { int ans = 1; bas %= mod; while (tim) { if (tim & 1) ans = 1LL * bas * ans % mod; bas = 1LL * bas * bas % mod; tim >>= 1; } return ans; } void ntt_initialize() { mx_pow = 2, mx_bit = 1; while ((n << 1) >= mx_pow) mx_pow <<= 1, mx_bit++; for (int i = 0; i < mx_pow; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (mx_bit - 1)); } inline void ntt(ll *arr, int dft) { for (int i = 0; i < mx_pow; i++) if (i < rev[i]) swap(arr[i], arr[rev[i]]); for (int step = 1; step < mx_pow; step <<= 1) { ll omega_n = quick_pow(dft == 1 ? mod_g : mod_gi, (mod - 1) / (step << 1)); for (int j = 0; j < mx_pow; j += (step << 1)) { ll omega_nk = 1; for (int k = j; k < j + step; k++, omega_nk = (omega_nk * omega_n) % mod) { ll x = arr[k], y = omega_nk * arr[k + step] % mod; arr[k] = (x + y) % mod; arr[k + step] = (x - y + mod) % mod; } } } if (dft != 1) { ll inv = quick_pow(mx_pow, mod - 2); for (int i = 0; i < mx_pow; i++) arr[i] = arr[i] * inv % mod; } } int main() { scanf("%d", &n); // process the factors; for (int i = fac[0] = 1; i <= n; i++) fac[i] = 1LL * fac[i - 1] * i % mod; fac_inv[n] = quick_pow(fac[n], mod - 2); for (int i = n - 1; i >= 0; i--) fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod; // process the poly; for (int i = 0, flag = 1; i <= n; i++, flag ^= 1) if (flag == 1) poly_A[i] = fac_inv[i]; else poly_A[i] = (mod - fac_inv[i]) % mod; poly_B[0] = 1, poly_B[1] = n + 1; for (int i = 2; i <= n; i++) poly_B[i] = 1LL * fac_inv[i] * quick_pow(i - 1, mod - 2) % mod * ((1LL * quick_pow(i, n + 1) - 1 + mod) % mod) % mod; // the calculation; ntt_initialize(); ntt(poly_A, 1), ntt(poly_B, 1); for (int i = 0; i < mx_pow; i++) poly_A[i] = 1LL * poly_A[i] * poly_B[i] % mod; ntt(poly_A, -1); // final calc; ll ans = 0, pow_2 = 1; for (int i = 0; i <= n; i++) ans = (ans + 1LL * pow_2 * fac[i] % mod * poly_A[i] % mod) % mod, pow_2 = (pow_2 << 1) % mod; printf("%lld", ans); return 0; }