P3346:「ZJOI2015」诸神眷顾的幻想乡 – 题解

主要思路

这算是广义后缀自动机的一个通用(general)运用。我们可以像统计多个串一样来做这个,只不过我们需要在树上 DFS 时记录分岔点的 SAM 节点,然后进行回溯。剩下计算本质不同字串个数的方法就很套路了。

代码

// P3346.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 4e6 + 200;

int head[MAX_N], current, n, ctot, val[MAX_N], deg[MAX_N];

struct edge
{
    int to, nxt;
} edges[MAX_N << 1];

void addpath(int src, int dst)
{
    edges[current].to = dst, edges[current].nxt = head[src];
    head[src] = current++;
}

namespace SAM
{

struct node
{
    int ch[10], dep, link;
} nodes[MAX_N];

int ptot = 1, last_ptr = 1;

int insert(int pre, int c)
{
    if (nodes[pre].ch[c] == 0)
    {
        int p = last_ptr = ++ptot;
        nodes[p].dep = nodes[pre].dep + 1;
        while (pre && nodes[pre].ch[c] == 0)
            nodes[pre].ch[c] = p, pre = nodes[pre].link;
        if (pre == 0)
            nodes[p].link = 1;
        else
        {
            int q = nodes[pre].ch[c];
            if (nodes[q].dep == nodes[pre].dep + 1)
                nodes[p].link = q;
            else
            {
                int clone = ++ptot;
                nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1;
                nodes[q].link = nodes[p].link = clone;
                while (pre && nodes[pre].ch[c] == q)
                    nodes[pre].ch[c] = clone, pre = nodes[pre].link;
            }
        }
    }
    else
    {
        int q = nodes[pre].ch[c];
        if (nodes[q].dep == nodes[pre].dep + 1)
            last_ptr = q;
        else
        {
            int clone = last_ptr = ++ptot;
            nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1;
            nodes[q].link = clone;
            while (pre && nodes[pre].ch[c] == q)
                nodes[pre].ch[c] = clone, pre = nodes[pre].link;
        }
    }

    return last_ptr;
}

void dfs(int u, int fa, int p)
{
    int pt = insert(p, val[u]);
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        if (edges[i].to != fa)
            dfs(edges[i].to, u, pt);
}

long long calc()
{
    long long ans = 0;
    for (int i = 1; i <= ptot; i++)
        ans += nodes[i].dep - nodes[nodes[i].link].dep;
    return ans;
}

} // namespace SAM

int main()
{
    memset(head, -1, sizeof(head));
    scanf("%d%d", &n, &ctot);
    for (int i = 1; i <= n; i++)
        scanf("%d", &val[i]);
    for (int i = 1, u, v; i <= n - 1; i++)
        scanf("%d%d", &u, &v), addpath(u, v), addpath(v, u), deg[u]++, deg[v]++;
    for (int i = 1; i <= n; i++)
        if (deg[i] == 1)
            SAM::dfs(i, 0, 1);
    printf("%lld\n", SAM::calc());
    return 0;
}

 

Leave a Reply

Your email address will not be published. Required fields are marked *