简述
如果你手上有一堆散点,然后你可以用这个公式拟合出一个函数:
\[ f(x) = \sum_{i = 1}^n y_i \prod_{j \neq i}^n \frac{x – x_j}{x_i – x_j} \]
模版可以这么写:
int evaluate(int x, int *xi, int *yi)
{
int ret = 0;
for (int i = 1; i <= n; i++)
{
int pans = yi[i];
for (int j = 1; j <= n; j++)
if (i != j)
pans = 1LL * pans * ((0LL + x + mod - xi[j]) % mod) % mod * fpow((0LL + xi[i] + mod - xi[j]) % mod, mod - 2) % mod;
ret = (0LL + ret + pans) % mod;
}
return ret;
}
Lagrange 插值的一些优化
\(x_1\) 连续的情况
如果 \(x_1\) 出现连续的情况,那么会变成这个样子:
\[ f(x) = \sum_{i = 1}^n y_i \prod_{j \neq i}^n \frac{x – j}{i – j} \]
那么上下两部分都可以预处理。需要注意的是,当 \(i < j\) 时,你需要判断阶乘长度来决定是否赋予一个负号。这样整个过程就变成 \(\Theta(n)\) 的了。
重心 Lagrange 插值
如果你现在需要对一堆点多次插值,如何做到优秀的复杂的呢?
考虑原来的式子:
\[ \begin{aligned} f(x) &= \sum_{i = 1}^n y_i \prod_{j \neq i}^n \frac{x – x_j}{x_i – x_j} \\ &= \prod_{i = 1}^n (x – x_i) \sum_{i = 1}^n \frac{y_i}{x – x_i} \prod_{j \neq i} \frac{1}{x_i – x_j} \end{aligned} \]
设下面几个东西:
\[ l(x) = \prod_{i = 1}^n (x – x_i), \omega_i = \prod_{j \neq i} \frac{1}{x_i – x_j} \]
原式变成:
\[ f(x) = l(x) \sum_{i = 1}^n y_i \frac{\omega_i}{x – x_i} \]
单次估值降至 \(\Theta(n)\)。
例题
常用 Lagrange 插值优化的场景有以下几种:
- 枚举项较高、但次数不高的项
- DP 本身就是一个多项式的系数
- 矩阵树生成出来的计数多项式(虽然用高斯消元会更好)
BZOJ2655 – Calc
考虑与 \(A\) 相关的一个 DP:设 \(dp[i][j]\) 为从前 \(i\) 个数里去出 \(j\) 个的答案。这个东西转移很简单:
\[ dp[i][j] = dp[i – 1][j] + dp[i – 1][j – 1] \times i \]
并不显然的,这个东西最终是一个 \(2n+1\) 的多项式。所以插值即可。
// BZ2655.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1010;
int A, n, mod, dp[MAX_N][MAX_N];
int fpow(int bas, int tim)
{
int ret = 1;
while (tim)
{
if (tim & 1)
ret = 1LL * ret * bas % mod;
bas = 1LL * bas * bas % mod;
tim >>= 1;
}
return ret;
}
int main()
{
scanf("%d%d%d", &A, &n, &mod), dp[0][0] = 1;
for (int i = 1; i <= 2 * n + 1; i++)
{
dp[i][0] = dp[i - 1][0];
for (int j = 1; j <= n; j++)
dp[i][j] = (0LL + dp[i - 1][j] + 1LL * dp[i - 1][j - 1] * i % mod) % mod;
}
int ans = 0;
if (A <= 2 * n + 1)
ans = dp[A][n];
else
for (int i = 1; i <= 2 * n + 1; i++)
{
// yi[i] = f[i][n];
int pans = 1;
for (int j = 1; j <= 2 * n + 1; j++)
if (i != j)
pans = 1LL * pans * ((0LL + A + mod - j) % mod) % mod * fpow((0LL + i + mod - j) % mod, mod - 2) % mod;
ans = (0LL + ans + 1LL * dp[i][n] * pans % mod) % mod;
}
for (int i = 1; i <= n; i++)
ans = 1LL * ans * i % mod;
printf("%lld\n", ans);
return 0;
}