简述
如果你手上有一堆散点,然后你可以用这个公式拟合出一个函数:
\[ f(x) = \sum_{i = 1}^n y_i \prod_{j \neq i}^n \frac{x – x_j}{x_i – x_j} \]
模版可以这么写:
int evaluate(int x, int *xi, int *yi) { int ret = 0; for (int i = 1; i <= n; i++) { int pans = yi[i]; for (int j = 1; j <= n; j++) if (i != j) pans = 1LL * pans * ((0LL + x + mod - xi[j]) % mod) % mod * fpow((0LL + xi[i] + mod - xi[j]) % mod, mod - 2) % mod; ret = (0LL + ret + pans) % mod; } return ret; }
Lagrange 插值的一些优化
\(x_1\) 连续的情况
如果 \(x_1\) 出现连续的情况,那么会变成这个样子:
\[ f(x) = \sum_{i = 1}^n y_i \prod_{j \neq i}^n \frac{x – j}{i – j} \]
那么上下两部分都可以预处理。需要注意的是,当 \(i < j\) 时,你需要判断阶乘长度来决定是否赋予一个负号。这样整个过程就变成 \(\Theta(n)\) 的了。
重心 Lagrange 插值
如果你现在需要对一堆点多次插值,如何做到优秀的复杂的呢?
考虑原来的式子:
\[ \begin{aligned} f(x) &= \sum_{i = 1}^n y_i \prod_{j \neq i}^n \frac{x – x_j}{x_i – x_j} \\ &= \prod_{i = 1}^n (x – x_i) \sum_{i = 1}^n \frac{y_i}{x – x_i} \prod_{j \neq i} \frac{1}{x_i – x_j} \end{aligned} \]
设下面几个东西:
\[ l(x) = \prod_{i = 1}^n (x – x_i), \omega_i = \prod_{j \neq i} \frac{1}{x_i – x_j} \]
原式变成:
\[ f(x) = l(x) \sum_{i = 1}^n y_i \frac{\omega_i}{x – x_i} \]
单次估值降至 \(\Theta(n)\)。
例题
常用 Lagrange 插值优化的场景有以下几种:
- 枚举项较高、但次数不高的项
- DP 本身就是一个多项式的系数
- 矩阵树生成出来的计数多项式(虽然用高斯消元会更好)
BZOJ2655 – Calc
考虑与 \(A\) 相关的一个 DP:设 \(dp[i][j]\) 为从前 \(i\) 个数里去出 \(j\) 个的答案。这个东西转移很简单:
\[ dp[i][j] = dp[i – 1][j] + dp[i – 1][j – 1] \times i \]
并不显然的,这个东西最终是一个 \(2n+1\) 的多项式。所以插值即可。
// BZ2655.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 1010; int A, n, mod, dp[MAX_N][MAX_N]; int fpow(int bas, int tim) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % mod; bas = 1LL * bas * bas % mod; tim >>= 1; } return ret; } int main() { scanf("%d%d%d", &A, &n, &mod), dp[0][0] = 1; for (int i = 1; i <= 2 * n + 1; i++) { dp[i][0] = dp[i - 1][0]; for (int j = 1; j <= n; j++) dp[i][j] = (0LL + dp[i - 1][j] + 1LL * dp[i - 1][j - 1] * i % mod) % mod; } int ans = 0; if (A <= 2 * n + 1) ans = dp[A][n]; else for (int i = 1; i <= 2 * n + 1; i++) { // yi[i] = f[i][n]; int pans = 1; for (int j = 1; j <= 2 * n + 1; j++) if (i != j) pans = 1LL * pans * ((0LL + A + mod - j) % mod) % mod * fpow((0LL + i + mod - j) % mod, mod - 2) % mod; ans = (0LL + ans + 1LL * dp[i][n] * pans % mod) % mod; } for (int i = 1; i <= n; i++) ans = 1LL * ans * i % mod; printf("%lld\n", ans); return 0; }