「LibreOJ」#2072「JSOI2016」独特的树叶 – 题解

主要思路

用换根的方法求出所有点做根的树的形态的 Hash 值,然后枚举 B 中的叶子,在 map 里查一查即可。

代码

// P4323.cpp
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 2e5 + 200, mod = 1e9 + 7, bitnum = 133;

typedef pair<int, int> pii;

int n, sub[2][MAX_N], g[2][MAX_N], bpow[MAX_N], cidx;
vector<int> G[2][MAX_N];
vector<pii> subs[2][MAX_N];

int fpow(int bas, int tim)
{
    int ret = 1;
    while (tim)
    {
        if (tim & 1)
            ret = 1LL * ret * bas % mod;
        bas = 1LL * bas * bas % mod;
        tim >>= 1;
    }
    return ret;
}

const int invBitnum = fpow(bitnum, mod - 2);

void dfs(int u, int fa, int idx)
{
    for (int v : G[idx][u])
        if (v != fa)
            dfs(v, u, idx);
    cidx = idx;
    sort(G[idx][u].begin(), G[idx][u].end(), [](const int &rhs1, const int &rhs2) { return sub[cidx][rhs1] < sub[cidx][rhs2]; });
    sub[idx][u] = 1LL * bitnum * bitnum % mod;
    int rk = 0;
    for (int v : G[idx][u])
        if (v != fa)
            sub[idx][u] = (0LL + sub[idx][u] + 1LL * bpow[rk] * sub[idx][v] % mod) % mod, subs[idx][u].push_back(make_pair(sub[idx][v], v)), rk++;
    sub[idx][u] = 1LL * sub[idx][u] * bitnum % mod * bitnum % mod * bitnum % mod;
}

const int endlInv = 1LL * invBitnum * invBitnum % mod * invBitnum % mod;

void collect(int u, int fa, int idx)
{
    if (fa != 0)
        subs[idx][u].push_back(make_pair(g[idx][u], fa));
    sort(subs[idx][u].begin(), subs[idx][u].end());
    int rk = 0, acc = 0;
    deque<int> pre, suf;
    pre.push_back(0), suf.push_back(0);
    for (pii v : subs[idx][u])
    {
        acc = (0LL + acc + 1LL * bpow[rk] * v.first % mod) % mod;
        pre.push_back(acc), suf.push_back(1LL * bpow[rk] * v.first % mod), rk++;
    }
    pre.push_back(0), suf.push_back(0);
    for (int i = rk; i >= 0; i--)
        suf[i] = (0LL + suf[i] + suf[i + 1]) % mod;
    rk = 1;
    for (pii v : subs[idx][u])
    {
        if (v.second != fa)
        {
            g[idx][v.second] = (1LL * bitnum * bitnum % mod + 0LL + pre[rk - 1] + 1LL * invBitnum * suf[rk + 1] % mod) % mod * bitnum % mod * bitnum % mod * bitnum % mod;
            collect(v.second, u, idx);
        }
        rk++;
    }
}

int main()
{
    scanf("%d", &n);
    for (int i = bpow[0] = 1; i <= n + 10; i++)
        bpow[i] = 1LL * bpow[i - 1] * bitnum % mod;
    for (int i = 1, u, v; i <= n - 1; i++)
        scanf("%d%d", &u, &v), G[0][u].push_back(v), G[0][v].push_back(u);
    for (int i = 1, u, v; i <= n; i++)
        scanf("%d%d", &u, &v), G[1][u].push_back(v), G[1][v].push_back(u);
    dfs(1, 0, 0), dfs(1, 0, 1), collect(1, 0, 0), collect(1, 0, 1);
    map<int, int> mp;
    for (int i = 1; i <= n; i++)
    {
        sort(subs[0][i].begin(), subs[0][i].end());
        int acc = 0;
        for (pii v : subs[0][i])
            acc = (1LL * acc * bitnum % mod + v.first) % mod;
        mp[acc]++;
    }
    for (int i = 1; i <= n + 1; i++)
        if (G[1][i].size() == 1)
        {
            int cpt = G[1][i][0];
            sort(subs[1][cpt].begin(), subs[1][cpt].end());
            int acc = 0;
            for (pii v : subs[1][cpt])
                if (v.second != i)
                    acc = (1LL * acc * bitnum % mod + v.first) % mod;
            if (mp[acc])
                printf("%d\n", i), exit(0);
        }
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *