主要思路
首先仔细剖析出题目意思:我们要规定一个路径,要被经过\(k\),然后再到路径的两头延伸出不大于 \(k\) 的分支,且每个分支独占一颗子树,求方案数。
那么我们可以先把 \(1\) 定为根,然后算出以点 \(u\) 为端点、分支在子树内的方案数 \(f_u\),然后再算向父亲侧的方案数 \(g_u\),再枚举 \((u, v)\):
- 如果 \((u, v)\) 的 LCA 为第三点,那么贡献显然就是 \(f_u \cdot f_v\)。
- 如果不为第三点,就让深度小的点的 \(g_u\) 乘上深度大的点的 \(f_v\) 即可。
考虑如何计算 \(f_u\),发现我们可以先定形状再染色(也就是乘上 \(P(k, i)\) ),显然答案等于:
\[ f_u = \sum_{V \subset son(u), 0 \leq |V| \leq k} \prod_{v \in V} (size_v + 1) \]
\(g_u\) 的计算就可以仿照 Up & Down 的做法进行计算。这样的计算方式非常的慢,所以我们可以把目光转向生成函数。在节点 \(u\) 计算时,我们可以发现生成函数为:
\[ P(x) = \prod_{v \in son(u)} (size_v \cdot x + 1) \]
我们可以取这个函数的前\(k\)项和作为\(f_u\)。这个时候可以用分治 NTT 来做掉这个事情。复杂度是对的,原因是每次 NTT 的长度都为 \(\Theta(deg_u)\) 级别的,所以整个长度为 \(\Theta(n)\) 级别的。
我们算完了 \(f_u\) 之后还要考虑计算 \(g_u\),并且还要解决掉 \(\Theta(n^2)\) 的路径枚举的开销。其实这两个事情可以一起做掉,我们先固定 \((u, v), dep_u < dep_v\),进行 \(\Theta(n^2)\) 的枚举,那么对上面的生成函数动些手脚就可以拿到答案:
\[ P(x) \cdot \frac{(n – size_u) \cdot x + 1}{size_v \cdot x + 1} \]
取前 \(k\) 项即可,然后再乘上一个染色系数 \(P(k, i)\)。乘法和除法的时间需要 \(\Theta(n)\),所以看样子这个时间是 \(\Theta(n^2 deg_u)\)。然而,我们注意到这个函数只跟 \(size_v\) 有关,所以我们可以考虑枚举 \(size_v\)。\(size_v\) 的个数为 \(\Theta(\sqrt{n})\) 级别的,所以可以在根号时间内做掉。
所以就大概结束了,细节很多。
代码
// CF981H.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const ll MAX_N = 2e5 + 200, mod = 998244353; int n, k, head[MAX_N], current, siz[MAX_N], fac[MAX_N], fac_inv[MAX_N]; int f[MAX_N], g[MAX_N], dp[MAX_N], gpans; vector<int> ans[MAX_N]; struct edge { int to, nxt; } edges[MAX_N << 1]; int quick_pow(int bas, int tim) { int ret = 1; while (tim) { if (tim & 1) ret = 1LL * ret * bas % mod; bas = 1LL * bas * bas % mod; tim >>= 1; } return ret; } const int G = 3, Gi = quick_pow(G, mod - 2); int poly_bit, poly_siz, rev[MAX_N], ui[MAX_N], vi[MAX_N]; void addpath(int src, int dst) { edges[current].to = dst, edges[current].nxt = head[src]; head[src] = current++; } int getRev(int x, int bit) { int curt = 0; for (int i = 0; i < bit; i++) curt |= (((x >> i) & 1) << (bit - 1 - i)); return curt; } void ntt(vector<int> &arr, int dft) { for (int i = 0; i < poly_siz; i++) if (i < (rev[i] = getRev(i, poly_bit))) swap(arr[i], arr[rev[i]]); for (int step = 1; step < poly_siz; step <<= 1) { int omega_n = quick_pow(dft == 1 ? G : Gi, (mod - 1) / (step << 1)); for (int j = 0; j < poly_siz; j += (step << 1)) { int omega_nk = 1; for (int k = j; k < j + step; k++, omega_nk = 1LL * omega_nk * omega_n % mod) { int t = 1LL * omega_nk * arr[k + step] % mod; arr[k + step] = (1LL * arr[k] + mod - t) % mod; arr[k] = (1LL * arr[k] + t) % mod; } } } if (dft == -1) { int inv = quick_pow(poly_siz, mod - 2); for (int i = 0; i < poly_siz; i++) arr[i] = 1LL * arr[i] * inv % mod; } } vector<int> polyMultiply(vector<int> A, vector<int> B) { int n = A.size(), m = B.size(), len = n + m; poly_bit = poly_siz = 0; while ((1 << poly_bit) < len) poly_bit++; poly_siz = (1 << poly_bit); while (A.size() < poly_siz) A.push_back(0); while (B.size() < poly_siz) B.push_back(0); ntt(A, 1), ntt(B, 1); for (int i = 0; i < poly_siz; i++) A[i] = 1LL * A[i] * B[i] % mod; ntt(A, -1); while (!A.empty() && A.back() == 0) A.pop_back(); return A; } vector<int> solve(int l, int r, const vector<int> &curt) { if (l == r - 1) return vector<int>{1, curt[l]}; else { int mid = (l + r) >> 1; return polyMultiply(solve(l, mid, curt), solve(mid, r, curt)); } } void fac_init() { for (int i = fac[0] = 1; i <= MAX_N - 1; i++) fac[i] = 1LL * fac[i - 1] * i % mod; fac_inv[MAX_N - 1] = quick_pow(fac[MAX_N - 1], mod - 2); for (int i = MAX_N - 2; i >= 0; i--) fac_inv[i] = 1LL * fac_inv[i + 1] * (i + 1) % mod; } void dfs(int u, int fa) { siz[u] = 1; vector<int> vec; for (int i = head[u]; i != -1; i = edges[i].nxt) if (edges[i].to != fa) { dfs(edges[i].to, u), siz[u] += siz[edges[i].to]; vec.push_back(siz[edges[i].to]); } int siz = vec.size(); if (vec.empty()) vec = vector<int>{1}; else vec = solve(0, siz, vec); ans[u] = vec; for (int i = 0, siz = vec.size(); i < siz && i <= k; i++) dp[u] = (1LL * dp[u] + 1LL * fac[k] * vec[i] % mod * fac_inv[k - i] % mod) % mod; return; } void dfs_collect(int u, int fa, int have, int del_sum) { gpans = (1LL * gpans + mod - 1LL * dp[u] * del_sum % mod) % mod; gpans = (1LL * gpans + 1LL * dp[u] * have % mod) % mod; vector<int> n_poly = ans[u]; n_poly.push_back(0); for (int i = n_poly.size() - 2; i >= 0; i--) n_poly[i + 1] = (1LL * n_poly[i + 1] + 1LL * n_poly[i] * (n - siz[u]) % mod) % mod; map<int, int> press; for (int i = head[u]; i != -1; i = edges[i].nxt) if (edges[i].to != fa && press.count(siz[edges[i].to]) == 0) { int calc_dp = 0; vector<int> curt = n_poly, nn_poly = curt; for (int j = 0, siz_ = nn_poly.size(); j < siz_; j++) if (j == 0) nn_poly[j] = curt[j]; else nn_poly[j] = (1LL * curt[j] + mod - (1LL * nn_poly[j - 1] * siz[edges[i].to] % mod)) % mod; for (int j = 0; j < (int)nn_poly.size() && j <= k; j++) { int vert = k - j; int ans = nn_poly[j]; int cnt = 1LL * fac[k] * ans % mod * fac_inv[vert] % mod; calc_dp = (1LL * calc_dp + 1LL * cnt) % mod; } press[siz[edges[i].to]] = calc_dp; } puts(""); for (int i = head[u]; i != -1; i = edges[i].nxt) if (edges[i].to != fa) dfs_collect(edges[i].to, u, (1LL * have + 1LL * press[siz[edges[i].to]]) % mod, (1LL * del_sum + dp[u]) % mod); } int main() { memset(head, -1, sizeof(head)); scanf("%d%d", &n, &k), fac_init(); if (k == 1) printf("%lld\n", (1LL * (n - 1) * n / 2) % mod), exit(0); for (int i = 1; i <= n - 1; i++) scanf("%d%d", &ui[i], &vi[i]); for (int i = n - 1; i >= 1; i--) addpath(ui[i], vi[i]), addpath(vi[i], ui[i]); dfs(1, 0); int curt = 0; for (int i = 1; i <= n; i++) { gpans = (1LL * gpans + 1LL * curt * dp[i] % mod) % mod; curt = (1LL * curt + 1LL * dp[i]) % mod; } dfs_collect(1, 0, 0, 0); printf("%lld\n", gpans % mod); return 0; }