OI

「GDOI2020模拟02.05」生成树 – 题解

主要思路

有意思。之前做过一题是用矩阵树+高斯消元插值求本题的一维情况,本打算用继续这么做,后面发现 Lagrange 插值能做到 \(\Theta(n^5)\),就学了一点新的东西。

首先,要认识到本题矩阵树出来之后可以得到一个二元多项式:

\[ \sum_{i = 0}^{n – 1} \sum_{j = 0}^{n – 1} a_{i, j} x^i y^j \]

那么我们的目标就是算 \(i \leq x, j \leq y\) 内的 \(a_{i, j}\),这个就是方案数。这个东西通过 \(\Theta(n^5)\) 的枚举+矩阵树可以算出点值。我们需要用二维 Lagrange 插值得到其系数。构造一维基函数:

\[ f_{x_i}(x) = \prod_{j \neq i} \frac{x – x_j}{x_i – x_j} \]

最后:

\[ f(x, y) = \sum_i \sum_j a_{i, j} f_{x_i}(x) f_{y_i}(y) \]

Lagrange 插值可以在 \(\Theta(n^2)\) 时间内算出这个多项式的系数。

代码

// FOJ6461.cpp
#include <bits/stdc++.h>

using namespace std;

void fileIO(string str)
{
    freopen((str + ".in").c_str(), "r", stdin);
    freopen((str + ".out").c_str(), "w", stdout);
}

const int MAX_N = 45, MAX_M = 1e5 + 200, mod = 1e9 + 7;

int n, m, x, y, mat[MAX_N][MAX_N], ai[MAX_N][MAX_N], org[MAX_N][MAX_N][3], poly[MAX_N][MAX_N], X[MAX_N], Y[MAX_N], inv[MAX_N];

int fpow(int bas, int tim)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % mod;
        bas = 1LL * bas * bas % mod;
        tim >>= 1;
    }
    return ret;
}

int det()
{
    int ret = 1;
    for (int i = 1; i <= n - 1; i++)
    {
        int key = 0;
        for (int j = i; j <= n - 1; j++)
            if (mat[j][i] > mat[key][i])
                key = j;
        if (key != i)
        {
            ret = (mod - ret) % mod;
            if (mat[key][i] == 0)
                return 0;
            for (int j = 1; j <= n - 1; j++)
                swap(mat[i][j], mat[key][j]);
        }
        int inv = fpow(mat[i][i], mod - 2);
        for (int j = 1; j <= n - 1; j++)
            if (i != j)
            {
                int rate = 1LL * mat[j][i] * inv % mod;
                for (int k = 1; k <= n - 1; k++)
                    mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
            }
    }
    for (int i = 1; i <= n - 1; i++)
        ret = 1LL * ret * mat[i][i] % mod;
    return ret;
}

void buildMatrix(int a, int b)
{
    memset(mat, 0, sizeof(mat));
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= n; j++)
            mat[i][j] = (0LL + mod - org[i][j][0] + mod - 1LL * org[i][j][1] * a % mod + mod - 1LL * org[i][j][2] * b % mod) % mod;
    for (int i = 1; i <= n; i++)
    {
        mat[i][i] = 0;
        for (int j = 1; j <= n; j++)
            mat[i][i] = (0LL + mat[i][i] + org[i][j][0] + 1LL * org[i][j][1] * a % mod + 1LL * org[i][j][2] * b % mod) % mod;
    }
}

int main()
{
    // fileIO("tree");
    scanf("%d%d%d%d", &n, &m, &x, &y);
    for (int i = 1, u, v, c; i <= m; i++)
        scanf("%d%d%d", &u, &v, &c), c--, org[u][v][c]++, org[v][u][c]++;
    for (int i = 1; i <= n; i++)
        inv[i] = fpow(i, mod - 2);
    for (int a = 1; a <= n; a++)
        for (int b = 1; b <= n; b++)
        {
            buildMatrix(a, b), ai[a][b] = det();
            memset(X, 0, sizeof(X)), memset(Y, 0, sizeof(Y));
            X[0] = Y[0] = 1;
            int dotval = ai[a][b], len = 0;
            for (int i = 1; i <= n; i++)
                if (i != a)
                {
                    dotval = 1LL * dotval * inv[abs(a - i)] % mod * ((a > i) ? 1 : mod - 1) % mod;
                    for (int j = ++len; j >= 1; j--)
                        X[j] = (0LL + mod - 1LL * X[j] * i % mod + X[j - 1]) % mod;
                    X[0] = (0LL + mod - 1LL * X[0] * i % mod) % mod;
                }
            len = 0;
            for (int i = 1; i <= n; i++)
                if (i != b)
                {
                    dotval = 1LL * dotval * inv[abs(b - i)] % mod * ((b > i) ? 1 : mod - 1) % mod;
                    for (int j = ++len; j >= 1; j--)
                        Y[j] = (0LL + mod - 1LL * Y[j] * i % mod + Y[j - 1]) % mod;
                    Y[0] = (0LL + mod - 1LL * Y[0] * i % mod) % mod;
                }
            for (int i = 0; i < n; i++)
                for (int j = 0; j < n; j++)
                    poly[i][j] = (0LL + poly[i][j] + 1LL * X[i] * Y[j] % mod * dotval % mod) % mod;
        }
    int ans = 0;
    for (int i = 0; i <= x; i++)
        for (int j = 0; j <= y; j++)
            ans = (0LL + ans + poly[i][j]) % mod;
    printf("%d\n", ans);
    return 0;
}

kal0rona

http://kaloronahuang.com

江西师大附中全机房最弱