Loading [MathJax]/extensions/tex2jax.js

Lagrange 插值

简述

如果你手上有一堆散点,然后你可以用这个公式拟合出一个函数:

\[ f(x) = \sum_{i = 1}^n y_i \prod_{j \neq i}^n \frac{x – x_j}{x_i – x_j} \]

模版可以这么写:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
int evaluate(int x, int *xi, int *yi)
{
int ret = 0;
for (int i = 1; i <= n; i++)
{
int pans = yi[i];
for (int j = 1; j <= n; j++)
if (i != j)
pans = 1LL * pans * ((0LL + x + mod - xi[j]) % mod) % mod * fpow((0LL + xi[i] + mod - xi[j]) % mod, mod - 2) % mod;
ret = (0LL + ret + pans) % mod;
}
return ret;
}
int evaluate(int x, int *xi, int *yi) { int ret = 0; for (int i = 1; i <= n; i++) { int pans = yi[i]; for (int j = 1; j <= n; j++) if (i != j) pans = 1LL * pans * ((0LL + x + mod - xi[j]) % mod) % mod * fpow((0LL + xi[i] + mod - xi[j]) % mod, mod - 2) % mod; ret = (0LL + ret + pans) % mod; } return ret; }
int evaluate(int x, int *xi, int *yi)
{
    int ret = 0;
    for (int i = 1; i <= n; i++)
    {
        int pans = yi[i];
        for (int j = 1; j <= n; j++)
            if (i != j)
                pans = 1LL * pans * ((0LL + x + mod - xi[j]) % mod) % mod * fpow((0LL + xi[i] + mod - xi[j]) % mod, mod - 2) % mod;
        ret = (0LL + ret + pans) % mod;
    }
    return ret;
}

Lagrange 插值的一些优化

\(x_1\) 连续的情况

如果 \(x_1\) 出现连续的情况,那么会变成这个样子:

\[ f(x) = \sum_{i = 1}^n y_i \prod_{j \neq i}^n \frac{x – j}{i – j} \]

那么上下两部分都可以预处理。需要注意的是,当 \(i < j\) 时,你需要判断阶乘长度来决定是否赋予一个负号。这样整个过程就变成 \(\Theta(n)\) 的了。

重心 Lagrange 插值

如果你现在需要对一堆点多次插值,如何做到优秀的复杂的呢?

考虑原来的式子:

\[ \begin{aligned} f(x) &= \sum_{i = 1}^n y_i \prod_{j \neq i}^n \frac{x – x_j}{x_i – x_j} \\ &= \prod_{i = 1}^n (x – x_i) \sum_{i = 1}^n \frac{y_i}{x – x_i} \prod_{j \neq i} \frac{1}{x_i – x_j} \end{aligned} \]

设下面几个东西:

\[ l(x) = \prod_{i = 1}^n (x – x_i), \omega_i = \prod_{j \neq i} \frac{1}{x_i – x_j} \]

原式变成:

\[ f(x) = l(x) \sum_{i = 1}^n y_i \frac{\omega_i}{x – x_i} \]

单次估值降至 \(\Theta(n)\)。

例题

常用 Lagrange 插值优化的场景有以下几种:

  • 枚举项较高、但次数不高的项
  • DP 本身就是一个多项式的系数
  • 矩阵树生成出来的计数多项式(虽然用高斯消元会更好)

BZOJ2655 – Calc

考虑与 \(A\) 相关的一个 DP:设 \(dp[i][j]\) 为从前 \(i\) 个数里去出 \(j\) 个的答案。这个东西转移很简单:

\[ dp[i][j] = dp[i – 1][j] + dp[i – 1][j – 1] \times i \]

并不显然的,这个东西最终是一个 \(2n+1\) 的多项式。所以插值即可。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// BZ2655.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1010;
int A, n, mod, dp[MAX_N][MAX_N];
int fpow(int bas, int tim)
{
int ret = 1;
while (tim)
{
if (tim & 1)
ret = 1LL * ret * bas % mod;
bas = 1LL * bas * bas % mod;
tim >>= 1;
}
return ret;
}
int main()
{
scanf("%d%d%d", &A, &n, &mod), dp[0][0] = 1;
for (int i = 1; i <= 2 * n + 1; i++)
{
dp[i][0] = dp[i - 1][0];
for (int j = 1; j <= n; j++)
dp[i][j] = (0LL + dp[i - 1][j] + 1LL * dp[i - 1][j - 1] * i % mod) % mod;
}
int ans = 0;
if (A <= 2 * n + 1)
ans = dp[A][n];
else
for (int i = 1; i <= 2 * n + 1; i++)
{
// yi[i] = f[i][n];
int pans = 1;
for (int j = 1; j <= 2 * n + 1; j++)
if (i != j)
pans = 1LL * pans * ((0LL + A + mod - j) % mod) % mod * fpow((0LL + i + mod - j) % mod, mod - 2) % mod;
ans = (0LL + ans + 1LL * dp[i][n] * pans % mod) % mod;
}
for (int i = 1; i <= n; i++)
ans = 1LL * ans * i % mod;
printf("%lld\n", ans);
return 0;
}
// BZ2655.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 1010; int A, n, mod, dp[MAX_N][MAX_N]; int fpow(int bas, int tim) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % mod; bas = 1LL * bas * bas % mod; tim >>= 1; } return ret; } int main() { scanf("%d%d%d", &A, &n, &mod), dp[0][0] = 1; for (int i = 1; i <= 2 * n + 1; i++) { dp[i][0] = dp[i - 1][0]; for (int j = 1; j <= n; j++) dp[i][j] = (0LL + dp[i - 1][j] + 1LL * dp[i - 1][j - 1] * i % mod) % mod; } int ans = 0; if (A <= 2 * n + 1) ans = dp[A][n]; else for (int i = 1; i <= 2 * n + 1; i++) { // yi[i] = f[i][n]; int pans = 1; for (int j = 1; j <= 2 * n + 1; j++) if (i != j) pans = 1LL * pans * ((0LL + A + mod - j) % mod) % mod * fpow((0LL + i + mod - j) % mod, mod - 2) % mod; ans = (0LL + ans + 1LL * dp[i][n] * pans % mod) % mod; } for (int i = 1; i <= n; i++) ans = 1LL * ans * i % mod; printf("%lld\n", ans); return 0; }
// BZ2655.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1010;

int A, n, mod, dp[MAX_N][MAX_N];

int fpow(int bas, int tim)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % mod;
        bas = 1LL * bas * bas % mod;
        tim >>= 1;
    }
    return ret;
}

int main()
{
    scanf("%d%d%d", &A, &n, &mod), dp[0][0] = 1;
    for (int i = 1; i <= 2 * n + 1; i++)
    {
        dp[i][0] = dp[i - 1][0];
        for (int j = 1; j <= n; j++)
            dp[i][j] = (0LL + dp[i - 1][j] + 1LL * dp[i - 1][j - 1] * i % mod) % mod;
    }
    int ans = 0;
    if (A <= 2 * n + 1)
        ans = dp[A][n];
    else
        for (int i = 1; i <= 2 * n + 1; i++)
        {
            // yi[i] = f[i][n];
            int pans = 1;
            for (int j = 1; j <= 2 * n + 1; j++)
                if (i != j)
                    pans = 1LL * pans * ((0LL + A + mod - j) % mod) % mod * fpow((0LL + i + mod - j) % mod, mod - 2) % mod;
            ans = (0LL + ans + 1LL * dp[i][n] * pans % mod) % mod;
        }
    for (int i = 1; i <= n; i++)
        ans = 1LL * ans * i % mod;
    printf("%lld\n", ans);
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *