BZOJ4231:回忆树 – 题解

主要思路

这也太他妈毒瘤了吧。

这个题首先肯定是要离线的,在线肯定没发做。事实上,我们可以对这个匹配分为三类:在上升链上、过 LCA 或者是在下降链上。这三者可以分开进行统计。上升链和下降链可以通过对询问字符串正反建 AC 自动机、然后在原树上做 DFS 即可;经过 LCA 的链有点难处理,发现能参与匹配的字符个数是与 \(|\text{询问串}|\) 相关的,所以我们把左右链的两个部分拼起来,然后单独跑 KMP 即可。

代码好长,狗屎一般。

代码

// BZ4231.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 6e5 + 200;

typedef pair<int, int> pii;

int n, m, head[MAX_N], current;
int ansBox[MAX_N], dep[MAX_N], up[20][MAX_N], upweight[MAX_N];
char cmdlet[MAX_N], S[MAX_N];
vector<pii> tags[2][MAX_N];

struct AC_Automaton
{
	int ch[MAX_N][26], lft[MAX_N], rig[MAX_N], nodes[MAX_N], fail[MAX_N], dtot, ptot = 1;
	vector<int> G[MAX_N];

	void dfs(int u)
	{
		lft[u] = ++dtot;
		for (int v : G[u])
			dfs(v);
		rig[u] = dtot;
	}

	void build()
	{
		queue<int> q;
		for (int i = 0; i < 26; i++)
			if (ch[1][i])
				q.push(ch[1][i]), fail[ch[1][i]] = 1;
			else
				ch[1][i] = 1;
		while (!q.empty())
		{
			int u = q.front();
			q.pop();
			for (int i = 0; i < 26; i++)
				if (ch[u][i])
					fail[ch[u][i]] = ch[fail[u]][i], q.push(ch[u][i]);
				else
					ch[u][i] = ch[fail[u]][i];
		}
		for (int i = 2; i <= ptot; i++)
			G[fail[i]].push_back(i);
		dfs(1);
	}

	inline int lowbit(int x) { return x & (-x); }

	void update(int x, int d)
	{
		for (; x <= dtot; x += lowbit(x))
			nodes[x] += d;
	}

	int query(int x, int ret = 0)
	{
		for (; x; x -= lowbit(x))
			ret += nodes[x];
		return ret;
	}

} tr[2];

struct edge
{
	int to, nxt, weight;
} edges[MAX_N << 1];

struct queryInfo
{
	int u, v, rt, pos, rev;
} queries[MAX_N];

void addpath(int src, int dst, int weight)
{
	edges[current].to = dst, edges[current].nxt = head[src];
	edges[current].weight = weight, head[src] = current++;
}

void dfs(int u, int fa)
{
	up[0][u] = fa, dep[u] = dep[fa] + 1;
	for (int i = head[u]; i != -1; i = edges[i].nxt)
		if (edges[i].to != fa)
			upweight[edges[i].to] = edges[i].weight, dfs(edges[i].to, u);
}

int getLCA(int x, int y)
{
	if (dep[x] < dep[y])
		swap(x, y);
	for (int i = 19; i >= 0; i--)
		if (dep[up[i][x]] >= dep[y])
			x = up[i][x];
	if (x == y)
		return x;
	for (int i = 19; i >= 0; i--)
		if (up[i][x] != up[i][y])
			x = up[i][x], y = up[i][y];
	return up[0][x];
}

int kmp()
{
	static int nxt[MAX_N];
	int ret = 0, tLen = strlen(cmdlet + 1), i = 0, j = -1;
	char *T = cmdlet + 1;
	nxt[0] = -1;
	while (i < tLen)
	{
		if (j == -1 || T[i] == T[j])
			nxt[++i] = ++j;
		else
			j = nxt[j];
	}
	int curt = 0, pre = 0;
	while (S[curt])
	{
		if (pre == -1 || S[curt] == T[pre])
			curt++, pre++;
		else
			pre = nxt[pre];
		if (pre == tLen)
			ret++;
	}
	return ret;
}

int jump(int u, int delta)
{
	for (int i = 19; i >= 0; i--)
		if (delta & (1 << i))
			u = up[i][u];
	return u;
}

void collect(int u, int p1, int p2, int fa)
{
	tr[0].update(tr[0].lft[p1], 1), tr[1].update(tr[1].lft[p2], 1);
	for (pii x : tags[0][u])
	{
		int org = x.first;
		ansBox[org] += x.second * (tr[0].query(tr[0].rig[queries[x.first].pos]) - tr[0].query(tr[0].lft[queries[x.first].pos] - 1));
	}
	for (pii x : tags[1][u])
	{
		int org = x.first;
		ansBox[org] += x.second * (tr[1].query(tr[1].rig[queries[x.first].rev]) - tr[1].query(tr[1].lft[queries[x.first].rev] - 1));
	}
	for (int i = head[u]; i != -1; i = edges[i].nxt)
		if (edges[i].to != fa)
			collect(edges[i].to, tr[0].ch[p1][edges[i].weight], tr[1].ch[p2][edges[i].weight], u);
	tr[0].update(tr[0].lft[p1], -1), tr[1].update(tr[1].lft[p2], -1);
}

int main()
{
	memset(head, -1, sizeof(head));
	scanf("%d%d", &n, &m);
	for (int i = 1, u, v; i <= n - 1; i++)
	{
		scanf("%d%d%s", &u, &v, cmdlet + 1);
		addpath(u, v, cmdlet[1] - 'a'), addpath(v, u, cmdlet[1] - 'a');
	}
	dfs(1, 0);
	for (int i = 1; i <= 19; i++)
		for (int j = 1; j <= n; j++)
			up[i][j] = up[i - 1][up[i - 1][j]];
	for (int qptr = 1, u, v; qptr <= m; qptr++)
	{
		scanf("%d%d%s", &u, &v, cmdlet + 1);
		queries[qptr].u = u, queries[qptr].v = v;
		queries[qptr].rt = getLCA(u, v);
		int p = 1, len = strlen(cmdlet + 1);
		for (int i = 1; i <= len; i++)
		{
			if (tr[0].ch[p][cmdlet[i] - 'a'] == 0)
				tr[0].ch[p][cmdlet[i] - 'a'] = ++tr[0].ptot;
			p = tr[0].ch[p][cmdlet[i] - 'a'];
		}
		queries[qptr].pos = p, p = 1;
		for (int i = len; i >= 1; i--)
		{
			if (tr[1].ch[p][cmdlet[i] - 'a'] == 0)
				tr[1].ch[p][cmdlet[i] - 'a'] = ++tr[1].ptot;
			p = tr[1].ch[p][cmdlet[i] - 'a'];
		}
		queries[qptr].rev = p;
		if (!((queries[qptr].rt == u) || (queries[qptr].rt == v)))
		{
			int b = min(len - 1, dep[queries[qptr].u] - dep[queries[qptr].rt]), sublen = b, a = 0, j;
			for (a = 0, j = jump(queries[qptr].u, dep[queries[qptr].u] - dep[queries[qptr].rt] - b); j != queries[qptr].rt; a++, j = up[0][j])
				S[a] = upweight[j] + 'a';
			b = min(len - 1, dep[queries[qptr].v] - dep[queries[qptr].rt]), sublen += b;
			for (a = sublen - 1, j = jump(queries[qptr].v, dep[queries[qptr].v] - dep[queries[qptr].rt] - b); j != queries[qptr].rt; a--, j = up[0][j])
				S[a] = upweight[j] + 'a';
			S[sublen] = 0;
			ansBox[qptr] = kmp();
		}
		if (dep[queries[qptr].v] - dep[queries[qptr].rt] >= len)
		{
			tags[0][queries[qptr].v].push_back(make_pair(qptr, 1));
			tags[0][jump(queries[qptr].v, dep[queries[qptr].v] - dep[queries[qptr].rt] - len + 1)].push_back(make_pair(qptr, -1));
		}
		if (dep[queries[qptr].u] - dep[queries[qptr].rt] >= len)
		{
			tags[1][queries[qptr].u].push_back(make_pair(qptr, 1));
			tags[1][jump(queries[qptr].u, dep[queries[qptr].u] - dep[queries[qptr].rt] - len + 1)].push_back(make_pair(qptr, -1));
		}
	}
	tr[0].build(), tr[1].build(), collect(1, 1, 1, 0);
	for (int i = 1; i <= m; i++)
		printf("%d\n", ansBox[i]);
	return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *