简述
树的结构让某些数据难以直接快速合并,所以就有了树上启发式合并,一种把合并时间为\(O(n)\)优化成\(O(\log n)\)的神奇方式。如果您先前学习过树链剖分,那么学习树上启发式合并就非常简单了。
运作原理
我们以一道例题作为引入:Codeforces 600E
一棵树有\(n\)个结点,每个结点都是一种颜色,每个颜色有一个编号。求树中每个子树的「最多颜色的编号」的和。一颗子树中可能存在多个「最多颜色的编号」,都要计入答案。
那么这道题\(O(n^2)\)很好做,直接枚举每一个点作为子树根,然后再\(O(n)\)扫一遍即可。那么,如果我们用树上启发式合并来做,如何把这个问题解决呢?
我们发现暴力的做法并没有完全利用好子树内的信息。我们考虑在处理时把所有链都缩在一起并上传,避免浪费资源。
但是,我们发现树形的结构无法完美地将链合并。启发式合并的原理在于保留重儿子的链,然后以重儿子为主,把轻儿子的信息合并到重儿子上并且按情况上传。
我们一般写成这样:
// CF600E.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 1e5 + 200; int head[MAX_N], current, col[MAX_N], n, siz[MAX_N], son[MAX_N], lft[MAX_N], rig[MAX_N]; int anti[MAX_N], bucket[MAX_N], ptot; ll most[MAX_N], answer[MAX_N]; struct edge { int to, nxt; } edges[MAX_N << 1]; void addpath(int src, int dst) { edges[current].to = dst, edges[current].nxt = head[src]; head[src] = current++; } void predfs(int u, int fa) { lft[u] = ++ptot, anti[ptot] = u, siz[u] = 1; for (int i = head[u]; i != -1; i = edges[i].nxt) if (edges[i].to != fa) { predfs(edges[i].to, u), siz[u] += siz[edges[i].to]; son[u] = (siz[son[u]] < siz[edges[i].to]) ? edges[i].to : son[u]; } rig[u] = ptot; } void dfs(int u, int fa, bool save) { for (int i = head[u]; i != -1; i = edges[i].nxt) if (edges[i].to != fa && edges[i].to != son[u]) dfs(edges[i].to, u, false); ll acc = 0; if (son[u] != 0) dfs(son[u], u, true), most[u] = most[son[u]], acc = answer[son[u]]; if (++bucket[col[u]] > most[u]) most[u] = bucket[col[u]], acc = col[u]; else if (bucket[col[u]] == most[u]) acc += col[u]; for (int i = head[u]; i != -1; i = edges[i].nxt) if (edges[i].to != fa && edges[i].to != son[u]) { for (int id = lft[edges[i].to]; id <= rig[edges[i].to]; id++) if (++bucket[col[anti[id]]] > most[u]) most[u] = bucket[col[anti[id]]], acc = col[anti[id]]; else if (bucket[col[anti[id]]] == most[u]) acc += col[anti[id]]; } if (save == false) for (int id = lft[u]; id <= rig[u]; id++) --bucket[col[anti[id]]]; answer[u] = acc; } int main() { memset(head, -1, sizeof(head)); scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%d", &col[i]); for (int i = 1, u, v; i <= n - 1; i++) scanf("%d%d", &u, &v), addpath(u, v), addpath(v, u); predfs(1, 0), dfs(1, 0, true); for (int i = 1; i <= n; i++) printf("%lld ", answer[i]); return 0; }
具体证明见:树上启发式合并 – OI Wiki
又一道例题:Codeforces 741D
我们发现,如果子树内存在一个这样的可重组回文串链,那么其状态压缩之后的异或和要么等于零、要么等于\(2^k\)(奇数个的情况)。所以,我们先对子树内的答案去个最大值,然后在以根节点,用桶来拼接半径即可。具体见代码:
// CF741D.cpp // dsu on tree; #include <bits/stdc++.h> using namespace std; const int MAX_N = 5e5 + 200; int head[MAX_N], current, n, fa[MAX_N], siz[MAX_N], lft[MAX_N], rig[MAX_N], ptot; int son[MAX_N], dep[MAX_N], anti[MAX_N], book[1 << 23], answer[MAX_N], dist[MAX_N]; char opt[10]; struct edge { int to, nxt, weight; } edges[MAX_N << 1]; void addpath(int src, int dst, int weight) { edges[current].to = dst, edges[current].nxt = head[src]; edges[current].weight = weight, head[src] = current++; } void predfs(int u) { lft[u] = ++ptot, siz[u] = 1, dep[u] = dep[fa[u]] + 1, anti[ptot] = u; for (int i = head[u]; i != -1; i = edges[i].nxt) { dist[edges[i].to] = dist[u] ^ edges[i].weight, predfs(edges[i].to), siz[u] += siz[edges[i].to]; if (siz[son[u]] < siz[edges[i].to]) son[u] = edges[i].to; } rig[u] = ptot; } void dfs(int u, bool save) { for (int i = head[u]; i != -1; i = edges[i].nxt) if (edges[i].to != son[u]) dfs(edges[i].to, false), answer[u] = max(answer[u], answer[edges[i].to]); if (son[u] != 0) dfs(son[u], true), answer[u] = max(answer[u], answer[son[u]]); if (book[dist[u]] != 0) answer[u] = max(answer[u], book[dist[u]] - dep[u]); // iterate the mid char; for (int ch = 0; ch <= 21; ch++) if (book[dist[u] ^ (1 << ch)]) answer[u] = max(answer[u], book[dist[u] ^ (1 << ch)] - dep[u]); book[dist[u]] = max(book[dist[u]], dep[u]); for (int i = head[u]; i != -1; i = edges[i].nxt) if (edges[i].to != son[u]) { for (int id = lft[edges[i].to]; id <= rig[edges[i].to]; id++) { int curt = anti[id]; if (book[dist[curt]]) answer[u] = max(answer[u], dep[curt] + book[dist[curt]] - (dep[u] << 1)); // iterate the mid char; for (int ch = 0; ch <= 21; ch++) if (book[dist[curt] ^ (1 << ch)]) answer[u] = max(answer[u], book[dist[curt] ^ (1 << ch)] + dep[curt] - (dep[u] << 1)); } for (int id = lft[edges[i].to]; id <= rig[edges[i].to]; id++) book[dist[anti[id]]] = max(book[dist[anti[id]]], dep[anti[id]]); } if (save == false) for (int id = lft[u]; id <= rig[u]; id++) book[dist[anti[id]]] = 0; } int main() { memset(head, -1, sizeof(head)); scanf("%d", &n); for (int i = 2; i <= n; i++) scanf("%d%s", &fa[i], opt + 1), addpath(fa[i], i, 1 << (opt[1] - 'a')); predfs(1), dfs(1, true); for (int i = 1; i <= n; i++) printf("%d ", answer[i]); return 0; }
经典例题:「十二省联考 2019」春节十二响
这道题就是经典的树上启发式合并裸题。
#include <bits/stdc++.h> using namespace std; const int MAX_N = 2e5 + 2000; int fa[MAX_N], head[MAX_N], current, weight[MAX_N], n, size[MAX_N]; int dfn[MAX_N], tot, tmp[MAX_N]; priority_queue<int> queues[MAX_N]; struct edge { int to, nxt; } edges[MAX_N << 1]; void addpath(int src, int dst) { edges[current].to = dst, edges[current].nxt = head[src]; head[src] = current++; } void dfs(int u) { dfn[u] = ++tot; for (int i = head[u]; i != -1; i = edges[i].nxt) if (edges[i].to != fa[u]) { dfs(edges[i].to); if (queues[dfn[edges[i].to]].size() > queues[dfn[u]].size()) swap(dfn[edges[i].to], dfn[u]); int m = queues[dfn[edges[i].to]].size(); for (int pt = 1; pt <= m; pt++) { tmp[pt] = max(queues[dfn[u]].top(), queues[dfn[edges[i].to]].top()); queues[dfn[edges[i].to]].pop(); queues[dfn[u]].pop(); } for (int pt = 1; pt <= m; pt++) queues[dfn[u]].push(tmp[pt]); } queues[dfn[u]].push(weight[u]); } int main() { memset(head, -1, sizeof(head)); scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%d", &weight[i]); for (int i = 2; i <= n; i++) scanf("%d", &fa[i]), addpath(fa[i], i), addpath(i, fa[i]); dfs(1); long long ans = 0; while (!queues[dfn[1]].empty()) ans += queues[dfn[1]].top(), queues[dfn[1]].pop(); printf("%lld", ans); return 0; }
[…] 考虑每一次在子树内枚举,损失了太多已知信息。我们考虑进行启发式合并,然后就可以优化到$Theta(n log n)$。具体也可以看这篇博客。 […]