令人窒息的字符串读入
这道题我要非常强调的是字符串的读入问题(我被卡了一下午和一晚上)。我们需要边读入边进行 trie 树的 insert 操作,不要存下字符子串!要不然你就会像我一样 MLE。具体代码:
// in function main();
int now = root;
for (int i = 1, l = strlen(buff + 1); i <= l; ++i)
{
if (buff[i] >= 'a' && buff[i] <= 'z')
{
if (!nodes[now].nxt[buff[i] - 'a'])
nodes[now].nxt[buff[i] - 'a'] = ++tot, nodes[tot].fa = now;
now = nodes[now].nxt[buff[i] - 'a'];
}
if (buff[i] == 'B')
now = nodes[now].fa;
if (buff[i] == 'P')
{
nds[++n] = now;
nodes[now].id = n;
}
}
正式思路
刚刚讲完了我的悲惨教训之后,我来正式阐述以下本题思路。我们要把这些字符串全部加入 Trie 树并且生成 Fail 失配指针。这些都是基本操作对吧。
void insert(string str, int id)
{
int p = root;
for (int i = 0; i < str.length(); i++)
{
int curt = str[i] - 'a', fa = p;
if (nodes[p].nxt[curt] == 0)
nodes[p].nxt[curt] = ++tot;
p = nodes[p].nxt[curt];
nodes[p].fa = fa;
}
nodes[p].id = id, nds[id] = p;
}
void bfs()
{
queue<int> q;
for (int i = 0; i < 26; i++)
if (nodes[root].nxt[i] != 0)
q.push(nodes[root].nxt[i]);
while (!q.empty())
{
int curt = q.front();
q.pop();
for (int i = 0; i < 26; i++)
if (nodes[curt].nxt[i] != 0)
nodes[nodes[curt].nxt[i]].fail = nodes[nodes[curt].fail].nxt[i],
q.push(nodes[curt].nxt[i]);
else
nodes[curt].nxt[i] = nodes[nodes[curt].fail].nxt[i];
}
}
之后我们会发现,对于每一次询问 \((x,y)\),我们所要求的就是在 Trie 树上,从字符串 \(y\) 的末尾节点开始沿着 fail 指针向上跳,每经历一个尾节点 \(x\) 时,答案计数加一。
当然,我做了一些小优化,离线处理每个询问,按关键字 \(y\) 进行从小到大的排序,然后对于每一个点\(x\)我们都用树状数组来记录能沿着 fail 指针树行进的(对了我们一定要用 fail 指针建一棵树来搞定这个)、能到达的以\(y\)结尾的点的个数,及维护前缀答案来搞定。
在获取答案之前,我们要写一个 DFS,来记录每个结点 low 和 dfn 的值。然后,统计答案时,因为答案分布在一整条链上,且链上的点的时间戳是由单调性的,所以可以用树状数组统计。
具体代码:
// P2414.cpp
#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
#include <queue>
#include <algorithm>
#define lowbit(num) (num & -num)
using namespace std;
const int MX_N = 200200;
int head[MX_N], current = 0;
struct egde
{
int to, nxt;
} edges[MX_N];
char buff[MX_N];
int T, anses[MX_N], root, nds[MX_N], tree[MX_N], tim, n, ql[MX_N], qr[MX_N], tot;
struct node
{
int nxt[26], fail, fa, pre[26], id, dfn, low;
} nodes[MX_N];
struct queryInfo
{
int x, y, ans, id;
bool operator<(const queryInfo &q) const { return y < q.y; }
} qi[MX_N];
inline int read()
{
int x = 0, t = 1;
char ch = getchar();
while ((ch < '0' || ch > '9') && ch != '-')
ch = getchar();
if (ch == '-')
t = -1, ch = getchar();
while (ch <= '9' && ch >= '0')
x = x * 10 + ch - 48, ch = getchar();
return x * t;
}
void addpath(int src, int dst) { edges[current].to = dst, edges[current].nxt = head[src], head[src] = current++; }
void update(int x, int c)
{
while (x <= tim)
tree[x] += c, x += lowbit(x);
}
int getsum(int x)
{
int ret = 0;
while (x > 0)
ret += tree[x], x -= lowbit(x);
return ret;
}
void insert(string str, int id)
{
int p = root;
for (int i = 0; i < str.length(); i++)
{
int curt = str[i] - 'a', fa = p;
if (nodes[p].nxt[curt] == 0)
nodes[p].nxt[curt] = ++tot;
p = nodes[p].nxt[curt];
nodes[p].fa = fa;
}
nodes[p].id = id, nds[id] = p;
}
void bfs()
{
queue<int> q;
for (int i = 0; i < 26; i++)
if (nodes[root].nxt[i] != 0)
q.push(nodes[root].nxt[i]);
while (!q.empty())
{
int curt = q.front();
q.pop();
for (int i = 0; i < 26; i++)
if (nodes[curt].nxt[i] != 0)
nodes[nodes[curt].nxt[i]].fail = nodes[nodes[curt].fail].nxt[i],
q.push(nodes[curt].nxt[i]);
else
nodes[curt].nxt[i] = nodes[nodes[curt].fail].nxt[i];
}
}
void dfs(int u)
{
nodes[u].dfn = ++tim;
for (int i = head[u]; i != -1; i = edges[i].nxt)
dfs(edges[i].to);
nodes[u].low = tim;
}
void dfsans(int u)
{
update(nodes[u].dfn, 1);
if (nodes[u].id != 0)
for (int i = ql[nodes[u].id]; i <= qr[nodes[u].id]; i++)
qi[i].ans = getsum(nodes[nds[qi[i].x]].low) - getsum(nodes[nds[qi[i].x]].dfn - 1);
for (int i = 0; i < 26; i++)
if (nodes[u].pre[i] != 0)
dfsans(nodes[u].nxt[i]);
update(nodes[u].dfn, -1);
}
int main()
{
memset(head, -1, sizeof(head));
root = 0;
scanf("%s", buff + 1);
int now = root;
for (int i = 1, l = strlen(buff + 1); i <= l; ++i)
{
if (buff[i] >= 'a' && buff[i] <= 'z')
{
if (!nodes[now].nxt[buff[i] - 'a'])
nodes[now].nxt[buff[i] - 'a'] = ++tot, nodes[tot].fa = now;
now = nodes[now].nxt[buff[i] - 'a'];
}
if (buff[i] == 'B')
now = nodes[now].fa;
if (buff[i] == 'P')
{
nds[++n] = now;
nodes[now].id = n;
}
}
for (int i = 0; i <= tot; i++)
for (int j = 0; j < 26; j++)
nodes[i].pre[j] = nodes[i].nxt[j];
bfs();
for (int i = 1; i <= tot; i++)
addpath(nodes[i].fail, i);
dfs(root);
T = read();
for (int i = 1; i <= T; i++)
qi[i].x = read(), qi[i].y = read(), qi[i].id = i;
sort(qi + 1, qi + 1 + T);
for (int i = 1, pos = 1; i <= T; i = pos)
{
ql[qi[i].y] = i;
while (qi[i].y == qi[pos].y)
pos++;
qr[qi[i].y] = pos - 1;
}
dfsans(root);
for (int i = 1; i <= T; i++)
anses[qi[i].id] = qi[i].ans;
for (int i = 1; i <= T; i++)
printf("%d\n", anses[i]);
return 0;
}