概述
AC 自动机是一种有限状态自动机,常用于多模式串的字符串匹配。
这道题我要非常强调的是字符串的读入问题(我被卡了一下午和一晚上)。我们需要边读入边进行 trie 树的 insert 操作,不要存下字符子串!要不然你就会像我一样 MLE。具体代码:
// in function main(); int now = root; for (int i = 1, l = strlen(buff + 1); i <= l; ++i) { if (buff[i] >= 'a' && buff[i] <= 'z') { if (!nodes[now].nxt[buff[i] - 'a']) nodes[now].nxt[buff[i] - 'a'] = ++tot, nodes[tot].fa = now; now = nodes[now].nxt[buff[i] - 'a']; } if (buff[i] == 'B') now = nodes[now].fa; if (buff[i] == 'P') { nds[++n] = now; nodes[now].id = n; } }
刚刚讲完了我的悲惨教训之后,我来正式阐述以下本题思路。我们要把这些字符串全部加入 Trie 树并且生成 Fail 失配指针。这些都是基本操作对吧。
void insert(string str, int id) { int p = root; for (int i = 0; i < str.length(); i++) { int curt = str[i] - 'a', fa = p; if (nodes[p].nxt[curt] == 0) nodes[p].nxt[curt] = ++tot; p = nodes[p].nxt[curt]; nodes[p].fa = fa; } nodes[p].id = id, nds[id] = p; } void bfs() { queue<int> q; for (int i = 0; i < 26; i++) if (nodes[root].nxt[i] != 0) q.push(nodes[root].nxt[i]); while (!q.empty()) { int curt = q.front(); q.pop(); for (int i = 0; i < 26; i++) if (nodes[curt].nxt[i] != 0) nodes[nodes[curt].nxt[i]].fail = nodes[nodes[curt].fail].nxt[i], q.push(nodes[curt].nxt[i]); else nodes[curt].nxt[i] = nodes[nodes[curt].fail].nxt[i]; } }
之后我们会发现,对于每一次询问 \((x,y)\),我们所要求的就是在 Trie 树上,从字符串 \(y\) 的末尾节点开始沿着 fail 指针向上跳,每经历一个尾节点 \(x\) 时,答案计数加一。
当然,我做了一些小优化,离线处理每个询问,按关键字 \(y\) 进行从小到大的排序,然后对于每一个点\(x\)我们都用树状数组来记录能沿着 fail 指针树行进的(对了我们一定要用 fail 指针建一棵树来搞定这个)、能到达的以\(y\)结尾的点的个数,及维护前缀答案来搞定。
在获取答案之前,我们要写一个 DFS,来记录每个结点 low 和 dfn 的值。然后,统计答案时,因为答案分布在一整条链上,且链上的点的时间戳是由单调性的,所以可以用树状数组统计。
具体代码:
// P2414.cpp #include <iostream> #include <cstdio> #include <cstring> #include <vector> #include <queue> #include <algorithm> #define lowbit(num) (num & -num) using namespace std; const int MX_N = 200200; int head[MX_N], current = 0; struct egde { int to, nxt; } edges[MX_N]; char buff[MX_N]; int T, anses[MX_N], root, nds[MX_N], tree[MX_N], tim, n, ql[MX_N], qr[MX_N], tot; struct node { int nxt[26], fail, fa, pre[26], id, dfn, low; } nodes[MX_N]; struct queryInfo { int x, y, ans, id; bool operator<(const queryInfo &q) const { return y < q.y; } } qi[MX_N]; inline int read() { int x = 0, t = 1; char ch = getchar(); while ((ch < '0' || ch > '9') && ch != '-') ch = getchar(); if (ch == '-') t = -1, ch = getchar(); while (ch <= '9' && ch >= '0') x = x * 10 + ch - 48, ch = getchar(); return x * t; } void addpath(int src, int dst) { edges[current].to = dst, edges[current].nxt = head[src], head[src] = current++; } void update(int x, int c) { while (x <= tim) tree[x] += c, x += lowbit(x); } int getsum(int x) { int ret = 0; while (x > 0) ret += tree[x], x -= lowbit(x); return ret; } void insert(string str, int id) { int p = root; for (int i = 0; i < str.length(); i++) { int curt = str[i] - 'a', fa = p; if (nodes[p].nxt[curt] == 0) nodes[p].nxt[curt] = ++tot; p = nodes[p].nxt[curt]; nodes[p].fa = fa; } nodes[p].id = id, nds[id] = p; } void bfs() { queue<int> q; for (int i = 0; i < 26; i++) if (nodes[root].nxt[i] != 0) q.push(nodes[root].nxt[i]); while (!q.empty()) { int curt = q.front(); q.pop(); for (int i = 0; i < 26; i++) if (nodes[curt].nxt[i] != 0) nodes[nodes[curt].nxt[i]].fail = nodes[nodes[curt].fail].nxt[i], q.push(nodes[curt].nxt[i]); else nodes[curt].nxt[i] = nodes[nodes[curt].fail].nxt[i]; } } void dfs(int u) { nodes[u].dfn = ++tim; for (int i = head[u]; i != -1; i = edges[i].nxt) dfs(edges[i].to); nodes[u].low = tim; } void dfsans(int u) { update(nodes[u].dfn, 1); if (nodes[u].id != 0) for (int i = ql[nodes[u].id]; i <= qr[nodes[u].id]; i++) qi[i].ans = getsum(nodes[nds[qi[i].x]].low) - getsum(nodes[nds[qi[i].x]].dfn - 1); for (int i = 0; i < 26; i++) if (nodes[u].pre[i] != 0) dfsans(nodes[u].nxt[i]); update(nodes[u].dfn, -1); } int main() { memset(head, -1, sizeof(head)); root = 0; scanf("%s", buff + 1); int now = root; for (int i = 1, l = strlen(buff + 1); i <= l; ++i) { if (buff[i] >= 'a' && buff[i] <= 'z') { if (!nodes[now].nxt[buff[i] - 'a']) nodes[now].nxt[buff[i] - 'a'] = ++tot, nodes[tot].fa = now; now = nodes[now].nxt[buff[i] - 'a']; } if (buff[i] == 'B') now = nodes[now].fa; if (buff[i] == 'P') { nds[++n] = now; nodes[now].id = n; } } for (int i = 0; i <= tot; i++) for (int j = 0; j < 26; j++) nodes[i].pre[j] = nodes[i].nxt[j]; bfs(); for (int i = 1; i <= tot; i++) addpath(nodes[i].fail, i); dfs(root); T = read(); for (int i = 1; i <= T; i++) qi[i].x = read(), qi[i].y = read(), qi[i].id = i; sort(qi + 1, qi + 1 + T); for (int i = 1, pos = 1; i <= T; i = pos) { ql[qi[i].y] = i; while (qi[i].y == qi[pos].y) pos++; qr[qi[i].y] = pos - 1; } dfsans(root); for (int i = 1; i <= T; i++) anses[qi[i].id] = qi[i].ans; for (int i = 1; i <= T; i++) printf("%d\n", anses[i]); return 0; }
这道题铁定是要用 AC 自动机来进行实现的。我们把这些字符串全部通过 insert 操作放入了自动机之后,怎样改写 build_AC_automation 操作来搞定这道题呢?
首先,我们发现 fail 指针可以把整个 trie 树变成一张图(每一个点都加了一条虚边)。所以我们会发现,只要我们找到这张图的一个环,且这个环上所有的点都不是某个不安全代码的结尾点,就可以认定有无限长的代码安全。原因解析起来非常的简单,因为如果有一个这样安全的环,那么无限长的代码势必在这个环内无限循环,且因为这是安全的,所以也可以保证这一长串的代码是安全的。
详见代码。
// P2444.cpp #include <iostream> #include <queue> #include <cmath> #include <cstring> #include <cstdio> using namespace std; const int MX_N = 30020; struct node { node *nxt[2], *fail; bool sum, ins, vis; node() { nxt[0] = nxt[1] = fail = NULL, sum = 0, ins = vis = false; } }; node *root; int n, longest; void insert(string str) { int siz = str.length(); node *p = root; for (int i = 0; i < siz; i++) p = (p->nxt[str[i] - '0'] == NULL) ? (p->nxt[str[i] - '0'] = new node()) : p->nxt[str[i] - '0']; p->sum = true; } void build_ac_automation() { queue<node *> q; q.push(root); while (!q.empty()) { node *curt = q.front(); q.pop(); for (int i = 0; i < 2; i++) if (curt->nxt[i] != NULL) { if (curt == root) curt->nxt[i]->fail = root; else { node *p = curt->fail; while (p) if (p->nxt[i] != NULL) { curt->nxt[i]->fail = p->nxt[i]; curt->nxt[i]->sum |= p->nxt[i]->sum; break; } else p = p->fail; if (p == NULL) curt->nxt[i]->fail = root; } q.push(curt->nxt[i]); } else if (curt->fail != NULL) curt->nxt[i] = curt->fail->nxt[i]; } } bool dfs(node *curt) { if (curt->ins) return true; if (curt->vis || curt->sum > 0) return false; curt->vis = curt->ins = true; for (int i = 0; i < 2; i++) if (curt->nxt[i] != NULL && dfs(curt->nxt[i])) return true; curt->ins = false; return false; } char buff[30020]; int main() { scanf("%d", &n); root = new node(); for (int i = 1; i <= n; i++) scanf("%s", buff), insert(buff); build_ac_automation(); if (dfs(root)) printf("TAK"); else printf("NIE"); return 0; }
这道题就是一道 AC 自动机的模板题。AC 自动机适用于多个模式串匹配的情景,实际上,它就是在 Trie 树上进行的一个匹配算法。这里有一篇很优秀的讲稿,不了解 AC 自动机的读者可以自行阅读:https://blog.csdn.net/creatorx/article/details/71100840
这道题我们只需要在 Trie 树上增加一个关键字 \(sum\) 即可。
// HDU-2222.cpp #include <iostream> #include <queue> #include <cstdio> #include <cstring> using namespace std; struct node { node *nxt[26], *fail; int sum; ~node() { for (int i = 0; i < 26; i++) if (nxt[i] != NULL) delete nxt[i]; } }; node *root; int cnt; char str[1000020]; void insert(string str) { int siz = str.length(); node *p = root; for (int i = 0; i < siz; i++) { int curt = str[i] - 'a'; if (p->nxt[curt] == NULL) { p->nxt[curt] = new node(); p->nxt[curt]->fail = NULL; for (int i = 0; i < 26; i++) p->nxt[curt]->nxt[i] = NULL; p->nxt[curt]->sum = 0; } p = p->nxt[curt]; } p->sum++; } void build_AC_automation() { queue<node *> q; q.push(root); node *p, *tmp; while (!q.empty()) { tmp = q.front(); q.pop(); for (int i = 0; i < 26; i++) if (tmp->nxt[i]) { if (tmp == root) tmp->nxt[i]->fail = root; else { p = tmp->fail; while (p) if (p->nxt[i]) { tmp->nxt[i]->fail = p->nxt[i]; break; } else p = p->fail; if (p == NULL) tmp->nxt[i]->fail = root; } q.push(tmp->nxt[i]); } } } void ac_automation(string passage) { int siz = passage.length(); node *p = root; for (int i = 0; i < siz; i++) { int curt = passage[i] - 'a'; while (!p->nxt[curt] && p != root) p = p->fail; p = p->nxt[curt]; if (p == NULL) p = root; node *tmp = p; while (tmp != root) { if (tmp->sum >= 0) { cnt += tmp->sum; tmp->sum = -1; } else break; tmp = tmp->fail; } } } int main() { int T; scanf("%d", &T); while (T--) { int n; scanf("%d", &n); root = new node(); for (int i = 1; i <= n; i++) scanf("%s", str), insert(str); build_AC_automation(); scanf("%s", str); ac_automation(str); printf("%d\n", cnt); cnt = 0; } return 0; }