树上启发式合并 | DSU on tree

简述

树的结构让某些数据难以直接快速合并,所以就有了树上启发式合并,一种把合并时间为\(O(n)\)优化成\(O(\log n)\)的神奇方式。如果您先前学习过树链剖分,那么学习树上启发式合并就非常简单了。

运作原理

我们以一道例题作为引入:Codeforces 600E

一棵树有\(n\)个结点,每个结点都是一种颜色,每个颜色有一个编号。求树中每个子树的「最多颜色的编号」的和。一颗子树中可能存在多个「最多颜色的编号」,都要计入答案。

那么这道题\(O(n^2)\)很好做,直接枚举每一个点作为子树根,然后再\(O(n)\)扫一遍即可。那么,如果我们用树上启发式合并来做,如何把这个问题解决呢?

我们发现暴力的做法并没有完全利用好子树内的信息。我们考虑在处理时把所有链都缩在一起并上传,避免浪费资源。

但是,我们发现树形的结构无法完美地将链合并。启发式合并的原理在于保留重儿子的链,然后以重儿子为主,把轻儿子的信息合并到重儿子上并且按情况上传。

我们一般写成这样:

// CF600E.cpp
#include <bits/stdc++.h>
#define ll long long

using namespace std;

const int MAX_N = 1e5 + 200;

int head[MAX_N], current, col[MAX_N], n, siz[MAX_N], son[MAX_N], lft[MAX_N], rig[MAX_N];
int anti[MAX_N], bucket[MAX_N], ptot;
ll most[MAX_N], answer[MAX_N];

struct edge
{
    int to, nxt;
} edges[MAX_N << 1];

void addpath(int src, int dst)
{
    edges[current].to = dst, edges[current].nxt = head[src];
    head[src] = current++;
}

void predfs(int u, int fa)
{
    lft[u] = ++ptot, anti[ptot] = u, siz[u] = 1;
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (edges[i].to != fa)
        {
            predfs(edges[i].to, u), siz[u] += siz[edges[i].to];
            son[u] = (siz[son[u]] < siz[edges[i].to]) ? edges[i].to : son[u];
        }
    rig[u] = ptot;
}

void dfs(int u, int fa, bool save)
{
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (edges[i].to != fa && edges[i].to != son[u])
            dfs(edges[i].to, u, false);
    ll acc = 0;
    if (son[u] != 0)
        dfs(son[u], u, true), most[u] = most[son[u]], acc = answer[son[u]];

    if (++bucket[col[u]] > most[u])
        most[u] = bucket[col[u]], acc = col[u];
    else if (bucket[col[u]] == most[u])
        acc += col[u];

    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (edges[i].to != fa && edges[i].to != son[u])
        {
            for (int id = lft[edges[i].to]; id <= rig[edges[i].to]; id++)
                if (++bucket[col[anti[id]]] > most[u])
                    most[u] = bucket[col[anti[id]]], acc = col[anti[id]];
                else if (bucket[col[anti[id]]] == most[u])
                    acc += col[anti[id]];
        }
    if (save == false)
        for (int id = lft[u]; id <= rig[u]; id++)
            --bucket[col[anti[id]]];
    answer[u] = acc;
}

int main()
{
    memset(head, -1, sizeof(head));
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%d", &col[i]);
    for (int i = 1, u, v; i <= n - 1; i++)
        scanf("%d%d", &u, &v), addpath(u, v), addpath(v, u);
    predfs(1, 0), dfs(1, 0, true);
    for (int i = 1; i <= n; i++)
        printf("%lld ", answer[i]);
    return 0;
}

具体证明见:树上启发式合并 – OI Wiki


又一道例题:Codeforces 741D

我们发现,如果子树内存在一个这样的可重组回文串链,那么其状态压缩之后的异或和要么等于零、要么等于\(2^k\)(奇数个的情况)。所以,我们先对子树内的答案去个最大值,然后在以根节点,用桶来拼接半径即可。具体见代码:

// CF741D.cpp
// dsu on tree;
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 5e5 + 200;

int head[MAX_N], current, n, fa[MAX_N], siz[MAX_N], lft[MAX_N], rig[MAX_N], ptot;
int son[MAX_N], dep[MAX_N], anti[MAX_N], book[1 << 23], answer[MAX_N], dist[MAX_N];
char opt[10];

struct edge
{
    int to, nxt, weight;
} edges[MAX_N << 1];

void addpath(int src, int dst, int weight)
{
    edges[current].to = dst, edges[current].nxt = head[src];
    edges[current].weight = weight, head[src] = current++;
}

void predfs(int u)
{
    lft[u] = ++ptot, siz[u] = 1, dep[u] = dep[fa[u]] + 1, anti[ptot] = u;
    for (int i = head[u]; i != -1; i = edges[i].nxt)
    {
        dist[edges[i].to] = dist[u] ^ edges[i].weight, predfs(edges[i].to), siz[u] += siz[edges[i].to];
        if (siz[son[u]] < siz[edges[i].to])
            son[u] = edges[i].to;
    }
    rig[u] = ptot;
}

void dfs(int u, bool save)
{
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (edges[i].to != son[u])
            dfs(edges[i].to, false), answer[u] = max(answer[u], answer[edges[i].to]);
    if (son[u] != 0)
        dfs(son[u], true), answer[u] = max(answer[u], answer[son[u]]);
    if (book[dist[u]] != 0)
        answer[u] = max(answer[u], book[dist[u]] - dep[u]);
    // iterate the mid char;
    for (int ch = 0; ch <= 21; ch++)
        if (book[dist[u] ^ (1 << ch)])
            answer[u] = max(answer[u], book[dist[u] ^ (1 << ch)] - dep[u]);
    book[dist[u]] = max(book[dist[u]], dep[u]);
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (edges[i].to != son[u])
        {
            for (int id = lft[edges[i].to]; id <= rig[edges[i].to]; id++)
            {
                int curt = anti[id];
                if (book[dist[curt]])
                    answer[u] = max(answer[u], dep[curt] + book[dist[curt]] - (dep[u] << 1));
                // iterate the mid char;
                for (int ch = 0; ch <= 21; ch++)
                    if (book[dist[curt] ^ (1 << ch)])
                        answer[u] = max(answer[u], book[dist[curt] ^ (1 << ch)] + dep[curt] - (dep[u] << 1));
            }
            for (int id = lft[edges[i].to]; id <= rig[edges[i].to]; id++)
                book[dist[anti[id]]] = max(book[dist[anti[id]]], dep[anti[id]]);
        }
    if (save == false)
        for (int id = lft[u]; id <= rig[u]; id++)
            book[dist[anti[id]]] = 0;
}

int main()
{
    memset(head, -1, sizeof(head));
    scanf("%d", &n);
    for (int i = 2; i <= n; i++)
        scanf("%d%s", &fa[i], opt + 1), addpath(fa[i], i, 1 << (opt[1] - 'a'));
    predfs(1), dfs(1, true);
    for (int i = 1; i <= n; i++)
        printf("%d ", answer[i]);
    return 0;
}

经典例题:「十二省联考 2019」春节十二响

这道题就是经典的树上启发式合并裸题。

#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 2e5 + 2000;
int fa[MAX_N], head[MAX_N], current, weight[MAX_N], n, size[MAX_N];
int dfn[MAX_N], tot, tmp[MAX_N];
priority_queue<int> queues[MAX_N];
struct edge
{
    int to, nxt;
} edges[MAX_N << 1];
void addpath(int src, int dst)
{
    edges[current].to = dst, edges[current].nxt = head[src];
    head[src] = current++;
}
void dfs(int u)
{
    dfn[u] = ++tot;
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (edges[i].to != fa[u])
        {
            dfs(edges[i].to);
            if (queues[dfn[edges[i].to]].size() > queues[dfn[u]].size())
                swap(dfn[edges[i].to], dfn[u]);
            int m = queues[dfn[edges[i].to]].size();
            for (int pt = 1; pt <= m; pt++)
            {
                tmp[pt] = max(queues[dfn[u]].top(),
                              queues[dfn[edges[i].to]].top());
                queues[dfn[edges[i].to]].pop();
                queues[dfn[u]].pop();
            }
            for (int pt = 1; pt <= m; pt++)
                queues[dfn[u]].push(tmp[pt]);
        }
    queues[dfn[u]].push(weight[u]);
}
int main()
{
    memset(head, -1, sizeof(head));
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%d", &weight[i]);
    for (int i = 2; i <= n; i++)
        scanf("%d", &fa[i]), addpath(fa[i], i), addpath(i, fa[i]);
    dfs(1);
    long long ans = 0;
    while (!queues[dfn[1]].empty())
        ans += queues[dfn[1]].top(), queues[dfn[1]].pop();
    printf("%lld", ans);
    return 0;
}

One Comment

Leave a Reply

Your email address will not be published. Required fields are marked *