Loading [MathJax]/extensions/tex2jax.js

「GDOI2020模拟02.05」生成树 – 题解

主要思路

有意思。之前做过一题是用矩阵树+高斯消元插值求本题的一维情况,本打算用继续这么做,后面发现 Lagrange 插值能做到 \(\Theta(n^5)\),就学了一点新的东西。

首先,要认识到本题矩阵树出来之后可以得到一个二元多项式:

\[ \sum_{i = 0}^{n – 1} \sum_{j = 0}^{n – 1} a_{i, j} x^i y^j \]

那么我们的目标就是算 \(i \leq x, j \leq y\) 内的 \(a_{i, j}\),这个就是方案数。这个东西通过 \(\Theta(n^5)\) 的枚举+矩阵树可以算出点值。我们需要用二维 Lagrange 插值得到其系数。构造一维基函数:

\[ f_{x_i}(x) = \prod_{j \neq i} \frac{x – x_j}{x_i – x_j} \]

最后:

\[ f(x, y) = \sum_i \sum_j a_{i, j} f_{x_i}(x) f_{y_i}(y) \]

Lagrange 插值可以在 \(\Theta(n^2)\) 时间内算出这个多项式的系数。

代码

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// FOJ6461.cpp
#include <bits/stdc++.h>
using namespace std;
void fileIO(string str)
{
freopen((str + ".in").c_str(), "r", stdin);
freopen((str + ".out").c_str(), "w", stdout);
}
const int MAX_N = 45, MAX_M = 1e5 + 200, mod = 1e9 + 7;
int n, m, x, y, mat[MAX_N][MAX_N], ai[MAX_N][MAX_N], org[MAX_N][MAX_N][3], poly[MAX_N][MAX_N], X[MAX_N], Y[MAX_N], inv[MAX_N];
int fpow(int bas, int tim)
{
int ret = 1;
while (tim)
{
if (tim & 1)
ret = 1LL * ret * bas % mod;
bas = 1LL * bas * bas % mod;
tim >>= 1;
}
return ret;
}
int det()
{
int ret = 1;
for (int i = 1; i <= n - 1; i++)
{
int key = 0;
for (int j = i; j <= n - 1; j++)
if (mat[j][i] > mat[key][i])
key = j;
if (key != i)
{
ret = (mod - ret) % mod;
if (mat[key][i] == 0)
return 0;
for (int j = 1; j <= n - 1; j++)
swap(mat[i][j], mat[key][j]);
}
int inv = fpow(mat[i][i], mod - 2);
for (int j = 1; j <= n - 1; j++)
if (i != j)
{
int rate = 1LL * mat[j][i] * inv % mod;
for (int k = 1; k <= n - 1; k++)
mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
}
}
for (int i = 1; i <= n - 1; i++)
ret = 1LL * ret * mat[i][i] % mod;
return ret;
}
void buildMatrix(int a, int b)
{
memset(mat, 0, sizeof(mat));
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
mat[i][j] = (0LL + mod - org[i][j][0] + mod - 1LL * org[i][j][1] * a % mod + mod - 1LL * org[i][j][2] * b % mod) % mod;
for (int i = 1; i <= n; i++)
{
mat[i][i] = 0;
for (int j = 1; j <= n; j++)
mat[i][i] = (0LL + mat[i][i] + org[i][j][0] + 1LL * org[i][j][1] * a % mod + 1LL * org[i][j][2] * b % mod) % mod;
}
}
int main()
{
// fileIO("tree");
scanf("%d%d%d%d", &n, &m, &x, &y);
for (int i = 1, u, v, c; i <= m; i++)
scanf("%d%d%d", &u, &v, &c), c--, org[u][v][c]++, org[v][u][c]++;
for (int i = 1; i <= n; i++)
inv[i] = fpow(i, mod - 2);
for (int a = 1; a <= n; a++)
for (int b = 1; b <= n; b++)
{
buildMatrix(a, b), ai[a][b] = det();
memset(X, 0, sizeof(X)), memset(Y, 0, sizeof(Y));
X[0] = Y[0] = 1;
int dotval = ai[a][b], len = 0;
for (int i = 1; i <= n; i++)
if (i != a)
{
dotval = 1LL * dotval * inv[abs(a - i)] % mod * ((a > i) ? 1 : mod - 1) % mod;
for (int j = ++len; j >= 1; j--)
X[j] = (0LL + mod - 1LL * X[j] * i % mod + X[j - 1]) % mod;
X[0] = (0LL + mod - 1LL * X[0] * i % mod) % mod;
}
len = 0;
for (int i = 1; i <= n; i++)
if (i != b)
{
dotval = 1LL * dotval * inv[abs(b - i)] % mod * ((b > i) ? 1 : mod - 1) % mod;
for (int j = ++len; j >= 1; j--)
Y[j] = (0LL + mod - 1LL * Y[j] * i % mod + Y[j - 1]) % mod;
Y[0] = (0LL + mod - 1LL * Y[0] * i % mod) % mod;
}
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
poly[i][j] = (0LL + poly[i][j] + 1LL * X[i] * Y[j] % mod * dotval % mod) % mod;
}
int ans = 0;
for (int i = 0; i <= x; i++)
for (int j = 0; j <= y; j++)
ans = (0LL + ans + poly[i][j]) % mod;
printf("%d\n", ans);
return 0;
}
// FOJ6461.cpp #include <bits/stdc++.h> using namespace std; void fileIO(string str) { freopen((str + ".in").c_str(), "r", stdin); freopen((str + ".out").c_str(), "w", stdout); } const int MAX_N = 45, MAX_M = 1e5 + 200, mod = 1e9 + 7; int n, m, x, y, mat[MAX_N][MAX_N], ai[MAX_N][MAX_N], org[MAX_N][MAX_N][3], poly[MAX_N][MAX_N], X[MAX_N], Y[MAX_N], inv[MAX_N]; int fpow(int bas, int tim) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % mod; bas = 1LL * bas * bas % mod; tim >>= 1; } return ret; } int det() { int ret = 1; for (int i = 1; i <= n - 1; i++) { int key = 0; for (int j = i; j <= n - 1; j++) if (mat[j][i] > mat[key][i]) key = j; if (key != i) { ret = (mod - ret) % mod; if (mat[key][i] == 0) return 0; for (int j = 1; j <= n - 1; j++) swap(mat[i][j], mat[key][j]); } int inv = fpow(mat[i][i], mod - 2); for (int j = 1; j <= n - 1; j++) if (i != j) { int rate = 1LL * mat[j][i] * inv % mod; for (int k = 1; k <= n - 1; k++) mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod; } } for (int i = 1; i <= n - 1; i++) ret = 1LL * ret * mat[i][i] % mod; return ret; } void buildMatrix(int a, int b) { memset(mat, 0, sizeof(mat)); for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) mat[i][j] = (0LL + mod - org[i][j][0] + mod - 1LL * org[i][j][1] * a % mod + mod - 1LL * org[i][j][2] * b % mod) % mod; for (int i = 1; i <= n; i++) { mat[i][i] = 0; for (int j = 1; j <= n; j++) mat[i][i] = (0LL + mat[i][i] + org[i][j][0] + 1LL * org[i][j][1] * a % mod + 1LL * org[i][j][2] * b % mod) % mod; } } int main() { // fileIO("tree"); scanf("%d%d%d%d", &n, &m, &x, &y); for (int i = 1, u, v, c; i <= m; i++) scanf("%d%d%d", &u, &v, &c), c--, org[u][v][c]++, org[v][u][c]++; for (int i = 1; i <= n; i++) inv[i] = fpow(i, mod - 2); for (int a = 1; a <= n; a++) for (int b = 1; b <= n; b++) { buildMatrix(a, b), ai[a][b] = det(); memset(X, 0, sizeof(X)), memset(Y, 0, sizeof(Y)); X[0] = Y[0] = 1; int dotval = ai[a][b], len = 0; for (int i = 1; i <= n; i++) if (i != a) { dotval = 1LL * dotval * inv[abs(a - i)] % mod * ((a > i) ? 1 : mod - 1) % mod; for (int j = ++len; j >= 1; j--) X[j] = (0LL + mod - 1LL * X[j] * i % mod + X[j - 1]) % mod; X[0] = (0LL + mod - 1LL * X[0] * i % mod) % mod; } len = 0; for (int i = 1; i <= n; i++) if (i != b) { dotval = 1LL * dotval * inv[abs(b - i)] % mod * ((b > i) ? 1 : mod - 1) % mod; for (int j = ++len; j >= 1; j--) Y[j] = (0LL + mod - 1LL * Y[j] * i % mod + Y[j - 1]) % mod; Y[0] = (0LL + mod - 1LL * Y[0] * i % mod) % mod; } for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) poly[i][j] = (0LL + poly[i][j] + 1LL * X[i] * Y[j] % mod * dotval % mod) % mod; } int ans = 0; for (int i = 0; i <= x; i++) for (int j = 0; j <= y; j++) ans = (0LL + ans + poly[i][j]) % mod; printf("%d\n", ans); return 0; }
// FOJ6461.cpp
#include <bits/stdc++.h>

using namespace std;

void fileIO(string str)
{
    freopen((str + ".in").c_str(), "r", stdin);
    freopen((str + ".out").c_str(), "w", stdout);
}

const int MAX_N = 45, MAX_M = 1e5 + 200, mod = 1e9 + 7;

int n, m, x, y, mat[MAX_N][MAX_N], ai[MAX_N][MAX_N], org[MAX_N][MAX_N][3], poly[MAX_N][MAX_N], X[MAX_N], Y[MAX_N], inv[MAX_N];

int fpow(int bas, int tim)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % mod;
        bas = 1LL * bas * bas % mod;
        tim >>= 1;
    }
    return ret;
}

int det()
{
    int ret = 1;
    for (int i = 1; i <= n - 1; i++)
    {
        int key = 0;
        for (int j = i; j <= n - 1; j++)
            if (mat[j][i] > mat[key][i])
                key = j;
        if (key != i)
        {
            ret = (mod - ret) % mod;
            if (mat[key][i] == 0)
                return 0;
            for (int j = 1; j <= n - 1; j++)
                swap(mat[i][j], mat[key][j]);
        }
        int inv = fpow(mat[i][i], mod - 2);
        for (int j = 1; j <= n - 1; j++)
            if (i != j)
            {
                int rate = 1LL * mat[j][i] * inv % mod;
                for (int k = 1; k <= n - 1; k++)
                    mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
            }
    }
    for (int i = 1; i <= n - 1; i++)
        ret = 1LL * ret * mat[i][i] % mod;
    return ret;
}

void buildMatrix(int a, int b)
{
    memset(mat, 0, sizeof(mat));
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= n; j++)
            mat[i][j] = (0LL + mod - org[i][j][0] + mod - 1LL * org[i][j][1] * a % mod + mod - 1LL * org[i][j][2] * b % mod) % mod;
    for (int i = 1; i <= n; i++)
    {
        mat[i][i] = 0;
        for (int j = 1; j <= n; j++)
            mat[i][i] = (0LL + mat[i][i] + org[i][j][0] + 1LL * org[i][j][1] * a % mod + 1LL * org[i][j][2] * b % mod) % mod;
    }
}

int main()
{
    // fileIO("tree");
    scanf("%d%d%d%d", &n, &m, &x, &y);
    for (int i = 1, u, v, c; i <= m; i++)
        scanf("%d%d%d", &u, &v, &c), c--, org[u][v][c]++, org[v][u][c]++;
    for (int i = 1; i <= n; i++)
        inv[i] = fpow(i, mod - 2);
    for (int a = 1; a <= n; a++)
        for (int b = 1; b <= n; b++)
        {
            buildMatrix(a, b), ai[a][b] = det();
            memset(X, 0, sizeof(X)), memset(Y, 0, sizeof(Y));
            X[0] = Y[0] = 1;
            int dotval = ai[a][b], len = 0;
            for (int i = 1; i <= n; i++)
                if (i != a)
                {
                    dotval = 1LL * dotval * inv[abs(a - i)] % mod * ((a > i) ? 1 : mod - 1) % mod;
                    for (int j = ++len; j >= 1; j--)
                        X[j] = (0LL + mod - 1LL * X[j] * i % mod + X[j - 1]) % mod;
                    X[0] = (0LL + mod - 1LL * X[0] * i % mod) % mod;
                }
            len = 0;
            for (int i = 1; i <= n; i++)
                if (i != b)
                {
                    dotval = 1LL * dotval * inv[abs(b - i)] % mod * ((b > i) ? 1 : mod - 1) % mod;
                    for (int j = ++len; j >= 1; j--)
                        Y[j] = (0LL + mod - 1LL * Y[j] * i % mod + Y[j - 1]) % mod;
                    Y[0] = (0LL + mod - 1LL * Y[0] * i % mod) % mod;
                }
            for (int i = 0; i < n; i++)
                for (int j = 0; j < n; j++)
                    poly[i][j] = (0LL + poly[i][j] + 1LL * X[i] * Y[j] % mod * dotval % mod) % mod;
        }
    int ans = 0;
    for (int i = 0; i <= x; i++)
        for (int j = 0; j <= y; j++)
            ans = (0LL + ans + poly[i][j]) % mod;
    printf("%d\n", ans);
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *