主要思路
有意思。之前做过一题是用矩阵树+高斯消元插值求本题的一维情况,本打算用继续这么做,后面发现 Lagrange 插值能做到 \(\Theta(n^5)\),就学了一点新的东西。
首先,要认识到本题矩阵树出来之后可以得到一个二元多项式:
\[ \sum_{i = 0}^{n – 1} \sum_{j = 0}^{n – 1} a_{i, j} x^i y^j \]
那么我们的目标就是算 \(i \leq x, j \leq y\) 内的 \(a_{i, j}\),这个就是方案数。这个东西通过 \(\Theta(n^5)\) 的枚举+矩阵树可以算出点值。我们需要用二维 Lagrange 插值得到其系数。构造一维基函数:
\[ f_{x_i}(x) = \prod_{j \neq i} \frac{x – x_j}{x_i – x_j} \]
最后:
\[ f(x, y) = \sum_i \sum_j a_{i, j} f_{x_i}(x) f_{y_i}(y) \]
Lagrange 插值可以在 \(\Theta(n^2)\) 时间内算出这个多项式的系数。
代码
// FOJ6461.cpp
#include <bits/stdc++.h>
using namespace std;
void fileIO(string str)
{
freopen((str + ".in").c_str(), "r", stdin);
freopen((str + ".out").c_str(), "w", stdout);
}
const int MAX_N = 45, MAX_M = 1e5 + 200, mod = 1e9 + 7;
int n, m, x, y, mat[MAX_N][MAX_N], ai[MAX_N][MAX_N], org[MAX_N][MAX_N][3], poly[MAX_N][MAX_N], X[MAX_N], Y[MAX_N], inv[MAX_N];
int fpow(int bas, int tim)
{
int ret = 1;
while (tim)
{
if (tim & 1)
ret = 1LL * ret * bas % mod;
bas = 1LL * bas * bas % mod;
tim >>= 1;
}
return ret;
}
int det()
{
int ret = 1;
for (int i = 1; i <= n - 1; i++)
{
int key = 0;
for (int j = i; j <= n - 1; j++)
if (mat[j][i] > mat[key][i])
key = j;
if (key != i)
{
ret = (mod - ret) % mod;
if (mat[key][i] == 0)
return 0;
for (int j = 1; j <= n - 1; j++)
swap(mat[i][j], mat[key][j]);
}
int inv = fpow(mat[i][i], mod - 2);
for (int j = 1; j <= n - 1; j++)
if (i != j)
{
int rate = 1LL * mat[j][i] * inv % mod;
for (int k = 1; k <= n - 1; k++)
mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
}
}
for (int i = 1; i <= n - 1; i++)
ret = 1LL * ret * mat[i][i] % mod;
return ret;
}
void buildMatrix(int a, int b)
{
memset(mat, 0, sizeof(mat));
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
mat[i][j] = (0LL + mod - org[i][j][0] + mod - 1LL * org[i][j][1] * a % mod + mod - 1LL * org[i][j][2] * b % mod) % mod;
for (int i = 1; i <= n; i++)
{
mat[i][i] = 0;
for (int j = 1; j <= n; j++)
mat[i][i] = (0LL + mat[i][i] + org[i][j][0] + 1LL * org[i][j][1] * a % mod + 1LL * org[i][j][2] * b % mod) % mod;
}
}
int main()
{
// fileIO("tree");
scanf("%d%d%d%d", &n, &m, &x, &y);
for (int i = 1, u, v, c; i <= m; i++)
scanf("%d%d%d", &u, &v, &c), c--, org[u][v][c]++, org[v][u][c]++;
for (int i = 1; i <= n; i++)
inv[i] = fpow(i, mod - 2);
for (int a = 1; a <= n; a++)
for (int b = 1; b <= n; b++)
{
buildMatrix(a, b), ai[a][b] = det();
memset(X, 0, sizeof(X)), memset(Y, 0, sizeof(Y));
X[0] = Y[0] = 1;
int dotval = ai[a][b], len = 0;
for (int i = 1; i <= n; i++)
if (i != a)
{
dotval = 1LL * dotval * inv[abs(a - i)] % mod * ((a > i) ? 1 : mod - 1) % mod;
for (int j = ++len; j >= 1; j--)
X[j] = (0LL + mod - 1LL * X[j] * i % mod + X[j - 1]) % mod;
X[0] = (0LL + mod - 1LL * X[0] * i % mod) % mod;
}
len = 0;
for (int i = 1; i <= n; i++)
if (i != b)
{
dotval = 1LL * dotval * inv[abs(b - i)] % mod * ((b > i) ? 1 : mod - 1) % mod;
for (int j = ++len; j >= 1; j--)
Y[j] = (0LL + mod - 1LL * Y[j] * i % mod + Y[j - 1]) % mod;
Y[0] = (0LL + mod - 1LL * Y[0] * i % mod) % mod;
}
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
poly[i][j] = (0LL + poly[i][j] + 1LL * X[i] * Y[j] % mod * dotval % mod) % mod;
}
int ans = 0;
for (int i = 0; i <= x; i++)
for (int j = 0; j <= y; j++)
ans = (0LL + ans + poly[i][j]) % mod;
printf("%d\n", ans);
return 0;
}
// FOJ6461.cpp
#include <bits/stdc++.h>
using namespace std;
void fileIO(string str)
{
freopen((str + ".in").c_str(), "r", stdin);
freopen((str + ".out").c_str(), "w", stdout);
}
const int MAX_N = 45, MAX_M = 1e5 + 200, mod = 1e9 + 7;
int n, m, x, y, mat[MAX_N][MAX_N], ai[MAX_N][MAX_N], org[MAX_N][MAX_N][3], poly[MAX_N][MAX_N], X[MAX_N], Y[MAX_N], inv[MAX_N];
int fpow(int bas, int tim)
{
int ret = 1;
while (tim)
{
if (tim & 1)
ret = 1LL * ret * bas % mod;
bas = 1LL * bas * bas % mod;
tim >>= 1;
}
return ret;
}
int det()
{
int ret = 1;
for (int i = 1; i <= n - 1; i++)
{
int key = 0;
for (int j = i; j <= n - 1; j++)
if (mat[j][i] > mat[key][i])
key = j;
if (key != i)
{
ret = (mod - ret) % mod;
if (mat[key][i] == 0)
return 0;
for (int j = 1; j <= n - 1; j++)
swap(mat[i][j], mat[key][j]);
}
int inv = fpow(mat[i][i], mod - 2);
for (int j = 1; j <= n - 1; j++)
if (i != j)
{
int rate = 1LL * mat[j][i] * inv % mod;
for (int k = 1; k <= n - 1; k++)
mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod;
}
}
for (int i = 1; i <= n - 1; i++)
ret = 1LL * ret * mat[i][i] % mod;
return ret;
}
void buildMatrix(int a, int b)
{
memset(mat, 0, sizeof(mat));
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
mat[i][j] = (0LL + mod - org[i][j][0] + mod - 1LL * org[i][j][1] * a % mod + mod - 1LL * org[i][j][2] * b % mod) % mod;
for (int i = 1; i <= n; i++)
{
mat[i][i] = 0;
for (int j = 1; j <= n; j++)
mat[i][i] = (0LL + mat[i][i] + org[i][j][0] + 1LL * org[i][j][1] * a % mod + 1LL * org[i][j][2] * b % mod) % mod;
}
}
int main()
{
// fileIO("tree");
scanf("%d%d%d%d", &n, &m, &x, &y);
for (int i = 1, u, v, c; i <= m; i++)
scanf("%d%d%d", &u, &v, &c), c--, org[u][v][c]++, org[v][u][c]++;
for (int i = 1; i <= n; i++)
inv[i] = fpow(i, mod - 2);
for (int a = 1; a <= n; a++)
for (int b = 1; b <= n; b++)
{
buildMatrix(a, b), ai[a][b] = det();
memset(X, 0, sizeof(X)), memset(Y, 0, sizeof(Y));
X[0] = Y[0] = 1;
int dotval = ai[a][b], len = 0;
for (int i = 1; i <= n; i++)
if (i != a)
{
dotval = 1LL * dotval * inv[abs(a - i)] % mod * ((a > i) ? 1 : mod - 1) % mod;
for (int j = ++len; j >= 1; j--)
X[j] = (0LL + mod - 1LL * X[j] * i % mod + X[j - 1]) % mod;
X[0] = (0LL + mod - 1LL * X[0] * i % mod) % mod;
}
len = 0;
for (int i = 1; i <= n; i++)
if (i != b)
{
dotval = 1LL * dotval * inv[abs(b - i)] % mod * ((b > i) ? 1 : mod - 1) % mod;
for (int j = ++len; j >= 1; j--)
Y[j] = (0LL + mod - 1LL * Y[j] * i % mod + Y[j - 1]) % mod;
Y[0] = (0LL + mod - 1LL * Y[0] * i % mod) % mod;
}
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
poly[i][j] = (0LL + poly[i][j] + 1LL * X[i] * Y[j] % mod * dotval % mod) % mod;
}
int ans = 0;
for (int i = 0; i <= x; i++)
for (int j = 0; j <= y; j++)
ans = (0LL + ans + poly[i][j]) % mod;
printf("%d\n", ans);
return 0;
}
// FOJ6461.cpp #include <bits/stdc++.h> using namespace std; void fileIO(string str) { freopen((str + ".in").c_str(), "r", stdin); freopen((str + ".out").c_str(), "w", stdout); } const int MAX_N = 45, MAX_M = 1e5 + 200, mod = 1e9 + 7; int n, m, x, y, mat[MAX_N][MAX_N], ai[MAX_N][MAX_N], org[MAX_N][MAX_N][3], poly[MAX_N][MAX_N], X[MAX_N], Y[MAX_N], inv[MAX_N]; int fpow(int bas, int tim) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % mod; bas = 1LL * bas * bas % mod; tim >>= 1; } return ret; } int det() { int ret = 1; for (int i = 1; i <= n - 1; i++) { int key = 0; for (int j = i; j <= n - 1; j++) if (mat[j][i] > mat[key][i]) key = j; if (key != i) { ret = (mod - ret) % mod; if (mat[key][i] == 0) return 0; for (int j = 1; j <= n - 1; j++) swap(mat[i][j], mat[key][j]); } int inv = fpow(mat[i][i], mod - 2); for (int j = 1; j <= n - 1; j++) if (i != j) { int rate = 1LL * mat[j][i] * inv % mod; for (int k = 1; k <= n - 1; k++) mat[j][k] = (0LL + mat[j][k] + mod - 1LL * rate * mat[i][k] % mod) % mod; } } for (int i = 1; i <= n - 1; i++) ret = 1LL * ret * mat[i][i] % mod; return ret; } void buildMatrix(int a, int b) { memset(mat, 0, sizeof(mat)); for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) mat[i][j] = (0LL + mod - org[i][j][0] + mod - 1LL * org[i][j][1] * a % mod + mod - 1LL * org[i][j][2] * b % mod) % mod; for (int i = 1; i <= n; i++) { mat[i][i] = 0; for (int j = 1; j <= n; j++) mat[i][i] = (0LL + mat[i][i] + org[i][j][0] + 1LL * org[i][j][1] * a % mod + 1LL * org[i][j][2] * b % mod) % mod; } } int main() { // fileIO("tree"); scanf("%d%d%d%d", &n, &m, &x, &y); for (int i = 1, u, v, c; i <= m; i++) scanf("%d%d%d", &u, &v, &c), c--, org[u][v][c]++, org[v][u][c]++; for (int i = 1; i <= n; i++) inv[i] = fpow(i, mod - 2); for (int a = 1; a <= n; a++) for (int b = 1; b <= n; b++) { buildMatrix(a, b), ai[a][b] = det(); memset(X, 0, sizeof(X)), memset(Y, 0, sizeof(Y)); X[0] = Y[0] = 1; int dotval = ai[a][b], len = 0; for (int i = 1; i <= n; i++) if (i != a) { dotval = 1LL * dotval * inv[abs(a - i)] % mod * ((a > i) ? 1 : mod - 1) % mod; for (int j = ++len; j >= 1; j--) X[j] = (0LL + mod - 1LL * X[j] * i % mod + X[j - 1]) % mod; X[0] = (0LL + mod - 1LL * X[0] * i % mod) % mod; } len = 0; for (int i = 1; i <= n; i++) if (i != b) { dotval = 1LL * dotval * inv[abs(b - i)] % mod * ((b > i) ? 1 : mod - 1) % mod; for (int j = ++len; j >= 1; j--) Y[j] = (0LL + mod - 1LL * Y[j] * i % mod + Y[j - 1]) % mod; Y[0] = (0LL + mod - 1LL * Y[0] * i % mod) % mod; } for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) poly[i][j] = (0LL + poly[i][j] + 1LL * X[i] * Y[j] % mod * dotval % mod) % mod; } int ans = 0; for (int i = 0; i <= x; i++) for (int j = 0; j <= y; j++) ans = (0LL + ans + poly[i][j]) % mod; printf("%d\n", ans); return 0; }