P5327:「ZJOI2019」语言 – 题解

主要思路

好题!好题!

这个题首先需要了解到一个性质:每个点所能到达的、进行贸易的国家的集合可以形成一个连通的生成子图。所以,我们最后需要求的就是每个国家对应的集合的大小。

如何去快速求这个东西呢?我们需要把 \(m\) 个路径中每个跨过当前点 \(u\) 的路径给拿出来,然后算这些路径的交集。这样是一种计算的方法,但是确实不快速。

我们尝试在这个方法上进行优化得到正解。考虑一个静态问题:给定一个点集 \(S\),算其生成树大小。其实这个就是「异象石」,我们需要求两两的 LCA 然后再用深度信息算大小。问题是我们这里没法去搞这个东西,要不然还是超时。

假设我们可以用树上差分来做批量修改点集就好了。这里其实我们就可以用线段树合并去维护这个信息:每个线段树的叶子代表一个差分点,然后我们可以用 \(DFS\) 序做下标,然后两个 DFS 序连续的区间进行合并时,我们就可以直接用异象石里的方法去搞了。

代码

// P5327.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 2e5 + 200;

typedef long long ll;

int n, head[MAX_N], current, dfn[MAX_N], st[20][MAX_N], stot, ptot;
int log2_[MAX_N], dep[MAX_N], m, roots[MAX_N], up[MAX_N];
vector<int> remTag[MAX_N];
ll ans;

struct edge
{
    int to, nxt;
} edges[MAX_N << 1];

struct node
{
    int val, lson, rson, s, t, sum;
} nodes[MAX_N * 30];

void addpath(int src, int dst)
{
    edges[current].to = dst, edges[current].nxt = head[src];
    head[src] = current++;
}

void dfs(int u, int fa)
{
    up[u] = fa, st[0][++stot] = u, dfn[u] = stot, dep[u] = dep[fa] + 1;
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (edges[i].to != fa)
            dfs(edges[i].to, u), st[0][++stot] = u;
}

int gmin(int x, int y) { return dep[x] < dep[y] ? x : y; }

int getLCA(int x, int y)
{
    if (x == 0 || y == 0)
        return 0;
    if (dfn[x] > dfn[y])
        swap(x, y);
    int d = log2_[dfn[y] - dfn[x] + 1];
    return gmin(st[d][dfn[x]], st[d][dfn[y] - (1 << d) + 1]);
}

void pushup(int p)
{
    nodes[p].sum = nodes[nodes[p].lson].sum + nodes[nodes[p].rson].sum - dep[getLCA(nodes[nodes[p].lson].t, nodes[nodes[p].rson].s)];
    nodes[p].s = (nodes[nodes[p].lson].s ? nodes[nodes[p].lson].s : nodes[nodes[p].rson].s);
    nodes[p].t = (nodes[nodes[p].rson].t ? nodes[nodes[p].rson].t : nodes[nodes[p].lson].t);
}

#define mid ((l + r) >> 1)

int update(int qx, int l, int r, int p, int val)
{
    if (p == 0)
        p = ++ptot;
    if (l == r)
    {
        nodes[p].val += val;
        if (nodes[p].val > 0)
            nodes[p].s = nodes[p].t = qx, nodes[p].sum = dep[qx];
        else
            nodes[p].s = nodes[p].t = nodes[p].sum = 0;
        return p;
    }
    if (dfn[qx] <= mid)
        nodes[p].lson = update(qx, l, mid, nodes[p].lson, val);
    else
        nodes[p].rson = update(qx, mid + 1, r, nodes[p].rson, val);
    pushup(p);
    return p;
}

int query(int p) { return nodes[p].sum - dep[getLCA(nodes[p].s, nodes[p].t)]; }

int merge(int x, int y, int l, int r)
{
    if (x == 0 || y == 0)
        return x + y;
    if (l == r)
    {
        nodes[x].val += nodes[y].val, nodes[x].sum |= nodes[y].sum;
        nodes[x].s |= nodes[y].s, nodes[x].t |= nodes[y].t;
        return x;
    }
    nodes[x].lson = merge(nodes[x].lson, nodes[y].lson, l, mid);
    nodes[x].rson = merge(nodes[x].rson, nodes[y].rson, mid + 1, r);
    pushup(x);
    return x;
}

void solve(int u, int fa)
{
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (edges[i].to != fa)
            solve(edges[i].to, u);
    for (int v : remTag[u])
        roots[u] = update(v, 1, stot, roots[u], -1);
    ans += query(roots[u]), roots[up[u]] = merge(roots[up[u]], roots[u], 1, stot);
}

int main()
{
    memset(head, -1, sizeof(head));
    scanf("%d%d", &n, &m);
    for (int i = 1, u, v; i <= n - 1; i++)
        scanf("%d%d", &u, &v), addpath(u, v), addpath(v, u);
    dfs(1, 0);
    for (int i = 2; i <= stot; i++)
        log2_[i] = log2_[i >> 1] + 1;
    for (int i = 1; i < 20; i++)
        for (int j = 1; j + (1 << i) - 1 <= stot; j++)
            st[i][j] = gmin(st[i - 1][j], st[i - 1][j + (1 << (i - 1))]);
    for (int i = 1, u, v, lca; i <= m; i++)
    {
        scanf("%d%d", &u, &v), lca = getLCA(u, v);
        roots[u] = update(u, 1, stot, roots[u], 1), roots[u] = update(v, 1, stot, roots[u], 1);
        roots[v] = update(u, 1, stot, roots[v], 1), roots[v] = update(v, 1, stot, roots[v], 1);
        remTag[lca].push_back(u), remTag[lca].push_back(v);
        remTag[up[lca]].push_back(u), remTag[up[lca]].push_back(v);
    }
    solve(1, 0), printf("%lld\n", ans >> 1);
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *