Lagrange 插值

简述

如果你手上有一堆散点,然后你可以用这个公式拟合出一个函数:

\[ f(x) = \sum_{i = 1}^n y_i \prod_{j \neq i}^n \frac{x – x_j}{x_i – x_j} \]

模版可以这么写:

int evaluate(int x, int *xi, int *yi)
{
    int ret = 0;
    for (int i = 1; i <= n; i++)
    {
        int pans = yi[i];
        for (int j = 1; j <= n; j++)
            if (i != j)
                pans = 1LL * pans * ((0LL + x + mod - xi[j]) % mod) % mod * fpow((0LL + xi[i] + mod - xi[j]) % mod, mod - 2) % mod;
        ret = (0LL + ret + pans) % mod;
    }
    return ret;
}

Lagrange 插值的一些优化

\(x_1\) 连续的情况

如果 \(x_1\) 出现连续的情况,那么会变成这个样子:

\[ f(x) = \sum_{i = 1}^n y_i \prod_{j \neq i}^n \frac{x – j}{i – j} \]

那么上下两部分都可以预处理。需要注意的是,当 \(i < j\) 时,你需要判断阶乘长度来决定是否赋予一个负号。这样整个过程就变成 \(\Theta(n)\) 的了。

重心 Lagrange 插值

如果你现在需要对一堆点多次插值,如何做到优秀的复杂的呢?

考虑原来的式子:

\[ \begin{aligned} f(x) &= \sum_{i = 1}^n y_i \prod_{j \neq i}^n \frac{x – x_j}{x_i – x_j} \\ &= \prod_{i = 1}^n (x – x_i) \sum_{i = 1}^n \frac{y_i}{x – x_i} \prod_{j \neq i} \frac{1}{x_i – x_j} \end{aligned} \]

设下面几个东西:

\[ l(x) = \prod_{i = 1}^n (x – x_i), \omega_i = \prod_{j \neq i} \frac{1}{x_i – x_j} \]

原式变成:

\[ f(x) = l(x) \sum_{i = 1}^n y_i \frac{\omega_i}{x – x_i} \]

单次估值降至 \(\Theta(n)\)。

例题

常用 Lagrange 插值优化的场景有以下几种:

  • 枚举项较高、但次数不高的项
  • DP 本身就是一个多项式的系数
  • 矩阵树生成出来的计数多项式(虽然用高斯消元会更好)

BZOJ2655 – Calc

考虑与 \(A\) 相关的一个 DP:设 \(dp[i][j]\) 为从前 \(i\) 个数里去出 \(j\) 个的答案。这个东西转移很简单:

\[ dp[i][j] = dp[i – 1][j] + dp[i – 1][j – 1] \times i \]

并不显然的,这个东西最终是一个 \(2n+1\) 的多项式。所以插值即可。

// BZ2655.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1010;

int A, n, mod, dp[MAX_N][MAX_N];

int fpow(int bas, int tim)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % mod;
        bas = 1LL * bas * bas % mod;
        tim >>= 1;
    }
    return ret;
}

int main()
{
    scanf("%d%d%d", &A, &n, &mod), dp[0][0] = 1;
    for (int i = 1; i <= 2 * n + 1; i++)
    {
        dp[i][0] = dp[i - 1][0];
        for (int j = 1; j <= n; j++)
            dp[i][j] = (0LL + dp[i - 1][j] + 1LL * dp[i - 1][j - 1] * i % mod) % mod;
    }
    int ans = 0;
    if (A <= 2 * n + 1)
        ans = dp[A][n];
    else
        for (int i = 1; i <= 2 * n + 1; i++)
        {
            // yi[i] = f[i][n];
            int pans = 1;
            for (int j = 1; j <= 2 * n + 1; j++)
                if (i != j)
                    pans = 1LL * pans * ((0LL + A + mod - j) % mod) % mod * fpow((0LL + i + mod - j) % mod, mod - 2) % mod;
            ans = (0LL + ans + 1LL * dp[i][n] * pans % mod) % mod;
        }
    for (int i = 1; i <= n; i++)
        ans = 1LL * ans * i % mod;
    printf("%lld\n", ans);
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *