概述
后缀自动机是处理字符串信息的有力工具。后缀自动机存储在 Trie 树上,配合 Link 指针就可以被认为是一张 DAG。任意一条从原点出发的路径都可以被认为是这个字符串的一个子串,且在后缀自动机上不会存在重复的子串信息(然而,我们可以进行一些扩展来维护子串位置信息)。
SAM 的构造
namespace SAM { struct node { int dep, link, ch[26], pos; } nodes[MAX_N << 1]; int ptot = 1, last_ptr = 1; void insert(int c, int idx) { int pre = last_ptr, p = last_ptr = ++ptot; nodes[p].dep = nodes[pre].dep + 1, nodes[p].pos = idx; while (pre && nodes[pre].ch[c] == 0) nodes[pre].ch[c] = p, pre = nodes[pre].link; if (pre == 0) nodes[p].link = 1; else { int q = nodes[pre].ch[c]; if (nodes[q].dep == nodes[pre].dep + 1) nodes[p].link = q; else { int clone = ++ptot; nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1; nodes[p].link = nodes[q].link = clone; while (pre && nodes[pre].ch[c] == q) nodes[pre].ch[c] = clone, pre = nodes[pre].link; } } } } // namespace SAM
SAM 的构造是线性时间的,且可以动态加入字符。
在这里介绍以下广义后缀自动机的构造方式。这是离线的做法:
// P6139.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 2e6 + 200; int n, pos[MAX_N]; char str[MAX_N]; namespace Trie { int ch[MAX_N][26], ptot = 1, up[MAX_N], depot[MAX_N]; void insert() { int p = 1; for (int i = 1; str[i]; i++) { if (ch[p][str[i] - 'a'] == 0) ch[p][str[i] - 'a'] = ++ptot, up[ptot] = p, depot[ptot] = str[i] - 'a'; p = ch[p][str[i] - 'a']; } } } // namespace Trie namespace SAM { struct node { int ch[26], link, dep; } nodes[MAX_N]; int ptot = 1, last_ptr = 1; void insert(int c) { int pre = last_ptr, p = last_ptr = ++ptot; nodes[p].dep = nodes[pre].dep + 1; while (pre && nodes[pre].ch[c] == 0) nodes[pre].ch[c] = p, pre = nodes[pre].link; if (pre == 0) nodes[p].link = 1; else { int q = nodes[pre].ch[c]; if (nodes[q].dep == nodes[pre].dep + 1) nodes[p].link = q; else { int clone = ++ptot; nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1; nodes[p].link = nodes[q].link = clone; while (pre && nodes[pre].ch[c] == q) nodes[pre].ch[c] = clone, pre = nodes[pre].link; } } } } // namespace SAM int main() { scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%s", str + 1), Trie::insert(); queue<int> q; for (int i = 0; i < 26; i++) if (Trie::ch[1][i]) q.push(Trie::ch[1][i]); pos[1] = 1; while (!q.empty()) { int u = q.front(); q.pop(); SAM::last_ptr = pos[Trie::up[u]]; SAM::insert(Trie::depot[u]); pos[u] = SAM::last_ptr; for (int i = 0; i < 26; i++) if (Trie::ch[u][i]) q.push(Trie::ch[u][i]); } long long ans = 0; for (int i = 2; i <= SAM::ptot; i++) ans += SAM::nodes[i].dep - SAM::nodes[SAM::nodes[i].link].dep; printf("%lld\n", ans); return 0; }
在线的做法:
// CF316G3.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 1e6 + 200; int n, lens[11], li[11], ri[11]; char str[11][MAX_N]; // SAM; struct node { int ch[26], dep, link; } nodes[MAX_N]; int ptot = 1, last_ptr = 1, siz[11][MAX_N], rnk[MAX_N], bucket[MAX_N]; vector<int> G[MAX_N]; void insert(int c) { int pre = last_ptr; if (nodes[pre].ch[c] != 0) { int q = nodes[pre].ch[c]; if (nodes[q].dep == nodes[pre].dep + 1) last_ptr = q; else { int clone = ++ptot; nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1; last_ptr = clone, nodes[q].link = clone; while (pre && nodes[pre].ch[c] == q) nodes[pre].ch[c] = clone, pre = nodes[pre].link; } } else { int p = last_ptr = ++ptot; nodes[p].dep = nodes[pre].dep + 1; while (pre && nodes[pre].ch[c] == 0) nodes[pre].ch[c] = p, pre = nodes[pre].link; if (pre == 0) nodes[p].link = 1; else { int q = nodes[pre].ch[c]; if (nodes[q].dep == nodes[pre].dep + 1) nodes[p].link = q; else { int clone = ++ptot; nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1; nodes[p].link = nodes[q].link = clone; while (pre && nodes[pre].ch[c] == q) nodes[pre].ch[c] = clone, pre = nodes[pre].link; } } } } int main() { scanf("%s%d", str[0] + 1, &n), lens[0] = strlen(str[0] + 1); for (int i = 1; str[0][i]; i++) insert(str[0][i] - 'a'), siz[0][last_ptr]++; for (int i = 1; i <= n; i++) { scanf("%s%d%d", str[i] + 1, &li[i], &ri[i]), lens[i] = strlen(str[i] + 1); last_ptr = 1; for (int j = 1; str[i][j]; j++) insert(str[i][j] - 'a'), siz[i][last_ptr]++; } for (int i = 1; i <= ptot; i++) bucket[nodes[i].dep]++; for (int i = 1; i <= ptot; i++) bucket[i] += bucket[i - 1]; for (int i = 1; i <= ptot; i++) rnk[bucket[nodes[i].dep]--] = i; for (int i = ptot; i >= 2; i--) for (int j = 0; j <= n; j++) siz[j][nodes[rnk[i]].link] += siz[j][rnk[i]]; long long ans = 0; for (int i = 2; i <= ptot; i++) if (siz[0][i]) { bool flag = true; for (int j = 1; j <= n; j++) flag &= li[j] <= siz[j][i] && siz[j][i] <= ri[j]; if (flag) ans += nodes[i].dep - nodes[nodes[i].link].dep; } printf("%lld\n", ans); return 0; }
Link 树
Link 树是一个很妙的东西,将若干个 Endpos 等价类做成树形结构。它有以下性质:
- \(dep[u] = dep[fa] + 1\),因为他们代表拥有相同的后缀,但是可以继续分化(前缀不同,但是存在规模更小的、但深度更大的 Endpos 等价类)。
- \(endpos[u] \subset endpos[fa]\),利用这个性质我们可以套路地使用线段树合并来维护具体的等价类。
SAM 的应用
求本质不同的子串个数
发现子串的表现形式在 SAM 中是唯一的、不受位置所影响(也就是不会算重复),每一条从根出发的路径都是一个唯一的子串,所以求解这个问题我们直接对 DAG 做路径计数即可。
求一个串在母串中出现的次数
我们让串在母串的 SAM 中游走到点\(p\),然后我们回顾一下 Endpos 等价类的相关信息:这样的路径和母串一一对应,所以我们只需要用\(p\)的性质来计算个数,显然答案就是\(endpos[p]\)的大小。在构造初期打好标记之后,根据树形结构,我们 DFS 一遍求得大小即可。
求两个字符串的 LCP
对其中一个串建 SAM,然后暴力往上跳即可。暴力往上跳的具体操作就是像 AC 自动机一般,但是需要对串长进行注意,每一次向上跳 Link 父亲都代表着放弃一个长度的前缀,这一点和 AC 自动机并不一样。
CF1120C Compress String
这道题显然是一个 \(O(n^2)\) 的 DP。考虑设置状态\(f_i\)为到了第\(i\)为的最优代价。转移非常显然:
\[ f_i = \min \begin{cases} f_{i – 1} + A \\ f_j + b, \text{if } S[j + 1 \dots i] \subset S[1 \dots j] \end{cases} \]
考虑预处理出字串的最早节点\(pos[i][j]\)和每一个节点的最早出现位置,然后进行判断即可。
// CF1120C.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 1e5 + 200; struct node { int ch[26], link, dep, occur = 0x3f3f3f3f; } nodes[MAX_N]; int n, ca, cb, ptot, last_cur, bucketA[MAX_N], bucketB[MAX_N]; int pos[5001][5001], dp[MAX_N]; char str[MAX_N]; void sam_initialize() { last_cur = ++ptot; } void insert(int c, int pos) { int cur = ++ptot; nodes[cur].dep = nodes[last_cur].dep + 1; nodes[cur].occur = pos; int p = last_cur; while (p && nodes[p].ch[c] == 0) nodes[p].ch[c] = cur, p = nodes[p].link; if (p == 0) nodes[cur].link = 1; else { int q = nodes[p].ch[c]; if (nodes[p].dep + 1 == nodes[q].dep) nodes[cur].link = q; else { int clone = ++ptot; nodes[clone].dep = nodes[p].dep + 1; memcpy(nodes[clone].ch, nodes[q].ch, sizeof(nodes[q].ch)); nodes[clone].link = nodes[q].link; while (p && nodes[p].ch[c] == q) nodes[p].ch[c] = clone, p = nodes[p].link; nodes[q].link = nodes[cur].link = clone; } } last_cur = cur; } void stringSort() { for (int i = 1; i <= ptot; i++) bucketA[i] = 0; for (int i = 1; i <= ptot; i++) bucketA[nodes[i].dep]++; for (int i = 1; i <= ptot; i++) bucketA[i] += bucketA[i - 1]; for (int i = 1; i <= ptot; i++) bucketB[bucketA[nodes[i].dep]--] = i; for (int i = ptot; i >= 1; i--) nodes[nodes[bucketB[i]].link].occur = min(nodes[nodes[bucketB[i]].link].occur, nodes[bucketB[i]].occur); } int main() { sam_initialize(); scanf("%d%d%d%s", &n, &ca, &cb, str + 1); for (int i = 1; i <= n; i++) insert(str[i] - 'a', i); for (int i = 1; i <= n; i++) { int p = 1; for (int j = i; j <= n; j++) { p = nodes[p].ch[str[j] - 'a']; pos[i][j] = p; } } stringSort(); for (int i = 1; i <= n; i++) { dp[i] = dp[i - 1] + ca; for (int j = 1; j <= i - 1; j++) if (nodes[pos[j + 1][i]].occur <= j) { dp[i] = min(dp[i], dp[j] + cb); break; } } printf("%d", dp[n]); return 0; }
[SDOI2016]生成魔咒
其实这就是统计不同字串的在线版本。我们考虑后缀会对答案造成哪些贡献:\(len(max\{endpos(p)\}) – len(\min\{endpos(p)\} + 1\),其实也就是\(maxLen(p) – maxLen(fa[p])\)。
// P4070.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 1e6 + 200; struct node { map<int, int> ch; int link, len; } nodes[MAX_N]; int n, ptot = 1, last_ptr = 1, opt; ll ans; void insert(int c) { int pre = last_ptr, p = last_ptr = ++ptot; nodes[p].len = nodes[pre].len + 1; for (; pre != 0 && nodes[pre].ch[c] == 0; pre = nodes[pre].link) nodes[pre].ch[c] = p; if (pre == 0) nodes[p].link = 1; else { int q = nodes[pre].ch[c]; if (nodes[q].len == nodes[pre].len + 1) nodes[p].link = q; else { int clone = ++ptot; nodes[clone] = nodes[q]; nodes[clone].len = nodes[pre].len + 1; nodes[q].link = nodes[p].link = clone; for (; nodes[pre].ch[c] == q; pre = nodes[pre].link) nodes[pre].ch[c] = clone; } } ans += nodes[p].len - nodes[nodes[p].link].len; } int main() { scanf("%d", &n); while (n--) { scanf("%d", &opt); insert(opt); printf("%lld\n", ans); } return 0; }
[HEOI2016/TJOI2016]字符串
转换题意:问你子串\(S[a..b]\)的所有子串和\(S[c..d]\)的最长公共前缀的长度的最大值是多少?考虑把前缀转后缀(反转字符串)、对每一问进行二分,把\(S[c..c+mid – 1]\)抽出来并放到 SAM 中进行 check。其实我们只需要拿到这个串所在的\(endpos\)等价类,并查询类中有没有在区间\([a, b]\)(记得反转)中的即可。线段树合并 + SAM。
// P4094.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 2e5 + 200; int n, q; char str[MAX_N]; namespace SAM { struct node { int link, dep, ch[26]; } nodes[MAX_N << 1]; int ptot = 1, last_ptr = 1; void insert(int c) { int pre = last_ptr, p = last_ptr = ++ptot; nodes[p].dep = nodes[pre].dep + 1; while (pre != 0 && nodes[pre].ch[c] == 0) nodes[pre].ch[c] = p, pre = nodes[pre].link; if (pre == 0) nodes[p].link = 1; else { int q = nodes[pre].ch[c]; if (nodes[q].dep == nodes[pre].dep + 1) nodes[p].link = q; else { int clone = ++ptot; nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1; nodes[q].link = nodes[p].link = clone; while (pre != 0 && nodes[pre].ch[c] == q) nodes[pre].ch[c] = clone, pre = nodes[pre].link; } } } } // namespace SAM namespace SegmentTree { struct node { int lson, rson, sum; } nodes[MAX_N << 6]; int ptot; void pushup(int p) { nodes[p].sum = nodes[nodes[p].lson].sum + nodes[nodes[p].rson].sum; } int merge(int x, int y, int l, int r) { if (x == 0) return y; if (y == 0) return x; if (l == r) { nodes[x].sum += nodes[y].sum; return x; } int mid = (l + r) >> 1, p = ++ptot; nodes[p].lson = merge(nodes[x].lson, nodes[y].lson, l, mid); nodes[p].rson = merge(nodes[x].rson, nodes[y].rson, mid + 1, r); pushup(p); return p; } int update(int qx, int l, int r, int p) { if (p == 0) p = ++ptot; if (l == r) { nodes[p].sum++; return p; } int mid = (l + r) >> 1; if (qx <= mid) nodes[p].lson = update(qx, l, mid, nodes[p].lson); else nodes[p].rson = update(qx, mid + 1, r, nodes[p].rson); pushup(p); return p; } int query(int ql, int qr, int l, int r, int p) { if (p == 0) return 0; if (ql <= l && r <= qr) return nodes[p].sum; int mid = (l + r) >> 1, ret = 0; if (ql <= mid) ret += query(ql, qr, l, mid, nodes[p].lson); if (mid < qr) ret += query(ql, qr, mid + 1, r, nodes[p].rson); return ret; } } // namespace SegmentTree int roots[MAX_N << 1], pos[MAX_N], perm[MAX_N << 1], fa[21][MAX_N << 1], bucket[MAX_N << 1]; bool compare(const int &rhs1, const int &rhs2) { return SAM::nodes[rhs1].dep < SAM::nodes[rhs2].dep; } bool check(int mid, int p, int l, int r) { for (int i = 20; i >= 0; i--) if (SAM::nodes[fa[i][p]].dep >= mid && fa[i][p]) p = fa[i][p]; return SegmentTree::query(l + mid - 1, r, 1, n, roots[p]) > 0; } int main() { scanf("%d%d%s", &n, &q, str + 1); reverse(str + 1, str + 1 + n); for (int i = 1; i <= n; i++) { SAM::insert(str[i] - 'a'); roots[SAM::last_ptr] = SegmentTree::update(i, 1, n, roots[SAM::last_ptr]); pos[i] = SAM::last_ptr; } for (int i = 1; i <= SAM::ptot; i++) perm[i] = i; for (int i = 1; i <= SAM::ptot; i++) bucket[SAM::nodes[i].dep]++; for (int i = 1; i <= SAM::ptot; i++) bucket[i] += bucket[i - 1]; for (int i = 1; i <= SAM::ptot; i++) perm[bucket[SAM::nodes[i].dep]--] = i; for (int i = SAM::ptot; i >= 1; i--) if (SAM::nodes[perm[i]].link != 0) { roots[SAM::nodes[perm[i]].link] = SegmentTree::merge(roots[SAM::nodes[perm[i]].link], roots[perm[i]], 1, n); } for (int i = 1; i <= SAM::ptot; i++) fa[0][i] = SAM::nodes[i].link; for (int i = 1; i <= 20; i++) for (int j = 1; j <= SAM::ptot; j++) fa[i][j] = fa[i - 1][fa[i - 1][j]]; while (q--) { int a, b, c, d; scanf("%d%d%d%d", &a, &b, &c, &d); a = n - a + 1, b = n - b + 1; c = n - c + 1, d = n - d + 1; int l = 0, r = min(a - b + 1, c - d + 1), ret = 0; while (l <= r) { int mid = (l + r) >> 1; if (check(mid, pos[c], b, a)) l = mid + 1, ret = mid; else r = mid - 1; } printf("%d\n", ret); } return 0; }
CF700E Cool Slogans
这道题还挺有意思。我们最终需要一个链式结构来表示题目中所描述的东西,并求出这个链式结构的长度。其实我们可以知道,这个链式结构上的节点可以直接利用 SAM 上的节点来构造。假设串\(T\)为根节点,我们会发现接下来的每个节点都被上级节点所包含,且都为\(link\)树的一个生成子链。这是很有意思的,所以我们可以直接在 Link 树上 DP:对于每一个节点而言,如果当前的「极大父亲」(这里的「极大父亲」指的是,某些情况下直接的父子关系并不能满足题目中出现两次的要求,所以需要「极大父亲」来表示最近能使其出现两次的祖先)。之后就正常 DP 并取最大值就好了。
// CF700E.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 401000; int n; char str[MAX_N]; namespace SAM { struct node { int dep, link, ch[26], pos; } nodes[MAX_N << 1]; int ptot = 1, last_ptr = 1; void insert(int c, int idx) { int pre = last_ptr, p = last_ptr = ++ptot; nodes[p].dep = nodes[pre].dep + 1, nodes[p].pos = idx; while (pre && nodes[pre].ch[c] == 0) nodes[pre].ch[c] = p, pre = nodes[pre].link; if (pre == 0) nodes[p].link = 1; else { int q = nodes[pre].ch[c]; if (nodes[q].dep == nodes[pre].dep + 1) nodes[p].link = q; else { int clone = ++ptot; nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1; nodes[p].link = nodes[q].link = clone; while (pre && nodes[pre].ch[c] == q) nodes[pre].ch[c] = clone, pre = nodes[pre].link; } } } } // namespace SAM namespace SegmentTree { struct node { int sum, lson, rson; } nodes[MAX_N * 25]; int ptot; void pushup(int p) { nodes[p].sum = nodes[nodes[p].lson].sum + nodes[nodes[p].rson].sum; } int update(int qx, int l, int r, int p) { if (p == 0) p = ++ptot; if (l == r) { nodes[p].sum++; return p; } int mid = (l + r) >> 1; if (qx <= mid) nodes[p].lson = update(qx, l, mid, nodes[p].lson); else nodes[p].rson = update(qx, mid + 1, r, nodes[p].rson); pushup(p); return p; } int merge(int x, int y, int l, int r) { if (x == 0 || y == 0) return x + y; if (l == r) { nodes[x].sum += nodes[y].sum; return x; } int mid = (l + r) >> 1, p = ++ptot; nodes[p].lson = merge(nodes[x].lson, nodes[y].lson, l, mid); nodes[p].rson = merge(nodes[x].rson, nodes[y].rson, mid + 1, r); pushup(p); return p; } int query(int ql, int qr, int l, int r, int p) { if (p == 0) return 0; if (ql <= l && r <= qr) return nodes[p].sum; int mid = (l + r) >> 1, ret = 0; if (ql <= mid) ret += query(ql, qr, l, mid, nodes[p].lson); if (mid < qr) ret += query(ql, qr, mid + 1, r, nodes[p].rson); return ret; } } // namespace SegmentTree int roots[MAX_N << 1], bucket[MAX_N << 1], rnk[MAX_N << 1], top[MAX_N << 1], dp[MAX_N << 1]; int main() { scanf("%d%s", &n, str + 1); for (int i = 1; i <= n; i++) SAM::insert(str[i] - 'a', i), roots[SAM::last_ptr] = SegmentTree::update(i, 1, n, roots[SAM::last_ptr]); for (int i = 1; i <= SAM::ptot; i++) bucket[SAM::nodes[i].dep]++; for (int i = 1; i <= SAM::ptot; i++) bucket[i] += bucket[i - 1]; for (int i = 1; i <= SAM::ptot; i++) rnk[bucket[SAM::nodes[i].dep]--] = i; for (int i = SAM::ptot; i >= 1; i--) if (SAM::nodes[rnk[i]].link != 0) roots[SAM::nodes[rnk[i]].link] = SegmentTree::merge(roots[SAM::nodes[rnk[i]].link], roots[rnk[i]], 1, n); int ans = 1; for (int i = 2; i <= SAM::ptot; i++) { int u = rnk[i], fa = SAM::nodes[u].link; if (fa == 1) { dp[u] = 1, top[u] = u; continue; } int l = SAM::nodes[u].pos - SAM::nodes[u].dep + SAM::nodes[top[fa]].dep; int r = SAM::nodes[u].pos - 1; int x = SegmentTree::query(l, r, 1, n, roots[top[fa]]); if (x > 0) dp[u] = dp[fa] + 1, top[u] = u; else dp[u] = dp[fa], top[u] = top[fa]; ans = max(ans, dp[u]); } printf("%d\n", ans); return 0; }
「NOI2018」你的名字
真的是很恶心的一道题目了。题意转换:在串\(S[l\dots r]\)中计算\(T\)串中不为\(S\)字串的字串的个数。先思考\(l = 1, r = n\)的做法:我们构建\(T\)的后缀自动机时,可以在\(S\)中跑,每次跑的时候记录下最长能与前缀匹配的长度,然后\(SAM_T\)中每一节点的贡献就是:
\[ \max(0, dep[p_T] – \max(match_i, dep[fa[p_T]]) \]
其中\(p_T\)是在\(SAM_T\)中的节点。
考虑正解,我们肯定需要一个线段树来维护集合信息,然后我们可以做类似的事情,只不过我们需要对\(match_i\)的处理进行改良兼容区间信息:LCP 暴力往上跳时判断是否在区间内。然后跟之前一样就可以做完了。
// LOJ2720.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 1e6 + 2000, MAX_M = 2e7 + 200; int n, q, limit[MAX_N]; char S[MAX_N], T[MAX_N]; namespace SegmentTree { struct node { int sum, lson, rson; } nodes[MAX_M]; int ptot, roots[MAX_N]; int update(int qx, int l, int r, int p) { if (p == 0) p = ++ptot; nodes[p].sum++; if (l == r) return p; int mid = (l + r) >> 1; if (qx <= mid) nodes[p].lson = update(qx, l, mid, nodes[p].lson); else nodes[p].rson = update(qx, mid + 1, r, nodes[p].rson); return p; } int query(int ql, int qr, int l, int r, int p) { if (p == 0) return 0; if (ql <= l && r <= qr) return nodes[p].sum; int ret = 0, mid = (l + r) >> 1; if (ql <= mid) ret += query(ql, qr, l, mid, nodes[p].lson); if (mid < qr) ret += query(ql, qr, mid + 1, r, nodes[p].rson); return ret; } int merge(int x, int y, int l, int r) { if (x == 0 || y == 0) return x + y; int p = ++ptot, mid = (l + r) >> 1; nodes[p].sum = nodes[x].sum + nodes[y].sum; if (l == r) return p; nodes[p].lson = merge(nodes[x].lson, nodes[y].lson, l, mid); nodes[p].rson = merge(nodes[x].rson, nodes[y].rson, mid + 1, r); return p; } } // namespace SegmentTree namespace SAM { struct node { int dep, ch[26], link; } nodes[MAX_N]; int last_ptr = 1, ptot = 1, bucket[MAX_N], rnk[MAX_N]; int newnode() { return ++ptot; } void initialize_collection() { for (int i = 1; i <= ptot; i++) bucket[nodes[i].dep]++; for (int i = 1; i <= ptot; i++) bucket[i] += bucket[i - 1]; for (int i = 1; i <= ptot; i++) rnk[bucket[nodes[i].dep]--] = i; for (int i = ptot; i >= 1; i--) if (nodes[rnk[i]].link != 0) SegmentTree::roots[nodes[rnk[i]].link] = SegmentTree::merge(SegmentTree::roots[nodes[rnk[i]].link], SegmentTree::roots[rnk[i]], 1, n); } } // namespace SAM namespace SAM_T { SAM::node nodes[MAX_N]; int ptot = 1, last_ptr = 1, pos[MAX_N]; void clear() { ptot = last_ptr = 1, memset(nodes[1].ch, 0, sizeof(nodes[1].ch)), nodes[1].dep = nodes[1].link = 0; } int newnode() { int p = ++ptot; memset(nodes[p].ch, 0, sizeof(nodes[p].ch)); nodes[p].dep = nodes[p].link = 0, pos[p] = 0; return p; } } // namespace SAM_T int insert(int c, int &last_ptr, int (*newnode)(), SAM::node *nodes, bool toggle = false) { int pre = last_ptr, p = last_ptr = newnode(); nodes[p].dep = nodes[pre].dep + 1; if (toggle) SAM_T::pos[p] = nodes[p].dep; while (pre && nodes[pre].ch[c] == 0) nodes[pre].ch[c] = p, pre = nodes[pre].link; if (pre == 0) nodes[p].link = 1; else { int q = nodes[pre].ch[c]; if (nodes[q].dep == nodes[pre].dep + 1) nodes[p].link = q; else { int clone = newnode(); nodes[clone] = nodes[q]; if (toggle) SAM_T::pos[clone] = SAM_T::pos[q]; nodes[clone].dep = nodes[pre].dep + 1; nodes[q].link = nodes[p].link = clone; while (pre && nodes[pre].ch[c] == q) nodes[pre].ch[c] = clone, pre = nodes[pre].link; } } return p; } int main() { freopen("name.in", "r", stdin); freopen("name.out", "w", stdout); scanf("%s", S + 1), n = strlen(S + 1); for (int i = 1; i <= n; i++) { insert(S[i] - 'a', SAM::last_ptr, SAM::newnode, SAM::nodes); SegmentTree::roots[SAM::last_ptr] = SegmentTree::update(i, 1, n, SegmentTree::roots[SAM::last_ptr]); } SAM::initialize_collection(); scanf("%d", &q); while (q--) { int l, r, m; scanf("%s%d%d", T + 1, &l, &r); SAM_T::clear(), m = strlen(T + 1); for (int i = 1, p = 1, clen = 0; i <= m; i++) { int c = T[i] - 'a'; insert(c, SAM_T::last_ptr, SAM_T::newnode, SAM_T::nodes, true); while (true) { if (SAM::nodes[p].ch[c] && SegmentTree::query( l + clen, r, 1, n, SegmentTree::roots[SAM::nodes[p].ch[c]])) { clen++, p = SAM::nodes[p].ch[c]; break; } if (clen == 0) break; clen--; if (clen == SAM::nodes[SAM::nodes[p].link].dep) p = SAM::nodes[p].link; } limit[i] = clen; } long long ans = 0; for (int i = 2; i <= SAM_T::ptot; i++) ans += max(0, SAM_T::nodes[i].dep - max(SAM_T::nodes[SAM_T::nodes[i].link].dep, limit[SAM_T::pos[i]])); printf("%lld\n", ans); } return 0; }
「HAOI2016」找相同字符
这道题一眼看上去的做法:把两个串拼在一起,然后遍历每一个 Endpos 集合,在线段树里找前半串和后半串出现的次数再乘起来。
正解:其实不需要复杂的线段树,只需要把 Size 分开记即可,并且在插入第二个串的时候在 SAM 上游走,并把 Last_ptr 重置。
// LOJ2064.cpp #include <bits/stdc++.h> #define ll long long using namespace std; const int MAX_N = 5e5 + 200; struct node { int ch[26], siz[2], link, dep; } nodes[MAX_N << 1]; int last_ptr = 1, ptot = 1, bucket[MAX_N << 1], n1, n2, rnk[MAX_N << 1]; ll ans; char str1[MAX_N], str2[MAX_N]; void insert(int c, int typ) { int pre = last_ptr, p = last_ptr = ++ptot; nodes[p].dep = nodes[pre].dep + 1; if (typ == 0) nodes[p].siz[0] = 1; while (pre && nodes[pre].ch[c] == 0) nodes[pre].ch[c] = p, pre = nodes[pre].link; if (pre == 0) nodes[p].link = 1; else { int q = nodes[pre].ch[c]; if (nodes[q].dep == nodes[pre].dep + 1) nodes[p].link = q; else { int clone = ++ptot; nodes[clone] = nodes[q], nodes[clone].siz[0] = nodes[clone].siz[1] = 0; nodes[clone].dep = nodes[pre].dep + 1; nodes[p].link = nodes[q].link = clone; while (pre && nodes[pre].ch[c] == q) nodes[pre].ch[c] = clone, pre = nodes[pre].link; } } } void radixSort() { for (int i = 1; i <= ptot; i++) bucket[nodes[i].dep]++; for (int i = 1; i <= ptot; i++) bucket[i] += bucket[i - 1]; for (int i = ptot; i >= 1; i--) rnk[bucket[nodes[i].dep]--] = i; for (int i = ptot; i >= 1; i--) if (nodes[rnk[i]].link != 0) { nodes[nodes[rnk[i]].link].siz[0] += nodes[rnk[i]].siz[0]; nodes[nodes[rnk[i]].link].siz[1] += nodes[rnk[i]].siz[1]; } for (int i = 1; i <= ptot; i++) ans += 1LL * nodes[i].siz[0] * nodes[i].siz[1] * (nodes[i].dep - nodes[nodes[i].link].dep); } int main() { scanf("%s%s", str1 + 1, str2 + 1); n1 = strlen(str1 + 1), n2 = strlen(str2 + 1); for (int i = 1; i <= n1; i++) insert(str1[i] - 'a', 0); last_ptr = 1; for (int i = 1, p = 1; i <= n2; i++) { insert(str2[i] - 'a', 1); p = nodes[p].ch[str2[i] - 'a']; nodes[p].siz[1] = 1; } radixSort(); printf("%lld\n", ans); return 0; }
「雅礼集训 2017 Day1」字符串
好难啊这道题。首先要介绍一种数据结构中出现过的套路:根号分类。对于某些乘积固定的两个参数,可能存在潜在的根号时间复杂度。比如:CF103D。
讲讲正解:先建出 SAM。如果字符串大小小于规定的块大小,那我们考虑下面这种处理方式:
\(k\)小代表着给定的\(m\)个区间的位置都比较集中(?),所以我们开一个三位空间来存区间的编号,再\(k^2\)枚举子串,在 SAM 上走。如果当前子串合规,那么就在三维空间里找有多少个囊括在询问\([l, r]\)中的区间序号。
如果不是:
我们考虑处理询问串每一位结尾向前的最大匹配长度和在 SAM 上的节点,然后对于询问我们只需要从\(a\)枚举到\(b\),并在最右点向上跳到最接近最左点的 SAM 节点。
// LOJ6031.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 2e6 + 200, block_size = 330; int head[MAX_N], current, fa[20][MAX_N], li[MAX_N], ri[MAX_N], m, n, q, k; int cpos[MAX_N], clen[MAX_N]; char str[MAX_N], w[MAX_N]; struct edge { int to, nxt; } edges[MAX_N << 1]; void addpath(int src, int dst) { edges[current].to = dst, edges[current].nxt = head[src]; head[src] = current++; } namespace SAM { struct node { int ch[26], link, dep, siz; } nodes[MAX_N << 1]; int ptot = 1, last_ptr = 1; void insert(int c) { int pre = last_ptr, p = last_ptr = ++ptot; nodes[p].dep = nodes[pre].dep + 1, nodes[p].siz = 1; while (pre && nodes[pre].ch[c] == 0) nodes[pre].ch[c] = p, pre = nodes[pre].link; if (pre == 0) nodes[p].link = 1; else { int q = nodes[pre].ch[c]; if (nodes[q].dep == nodes[pre].dep + 1) nodes[p].link = q; else { int clone = ++ptot; nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1; nodes[p].link = nodes[q].link = clone, nodes[clone].siz = 0; while (pre && nodes[pre].ch[c] == q) nodes[pre].ch[c] = clone, pre = nodes[pre].link; } } } void build_graph() { for (int i = 1; i <= ptot; i++) if (nodes[i].link != 0) addpath(nodes[i].link, i), fa[0][i] = nodes[i].link; } void dfs(int u) { for (int i = 1; i <= 19; i++) fa[i][u] = fa[i - 1][fa[i - 1][u]]; for (int i = head[u]; i != -1; i = edges[i].nxt) dfs(edges[i].to), nodes[u].siz += nodes[edges[i].to].siz; } } // namespace SAM void solve1() { vector<int> buc[block_size][block_size]; for (int i = 1; i <= m; i++) buc[li[i]][ri[i]].push_back(i); while (q--) { int a, b; long long ans = 0; scanf("%s%d%d", w + 1, &a, &b); a++, b++; for (int i = 1; i <= k; i++) for (int j = i, p = 1; j <= k; j++) { p = SAM::nodes[p].ch[w[j] - 'a']; if (p == 0) break; int l = lower_bound(buc[i][j].begin(), buc[i][j].end(), a) - buc[i][j].begin(); int r = upper_bound(buc[i][j].begin(), buc[i][j].end(), b) - buc[i][j].begin(); ans += 1LL * (r - l) * SAM::nodes[p].siz; } printf("%lld\n", ans); } } void solve2() { while (q--) { int a, b; long long ans = 0; scanf("%s%d%d", w + 1, &a, &b); a++, b++; for (int i = 1, p = 1, len = 0; i <= k; i++) { while (p && SAM::nodes[p].ch[w[i] - 'a'] == 0) p = SAM::nodes[p].link, len = SAM::nodes[p].dep; if (p == 0) p = 1; else p = SAM::nodes[p].ch[w[i] - 'a']; len += (p != 1); cpos[i] = p, clen[i] = len; } for (int idx = a; idx <= b; idx++) { int p = cpos[ri[idx]], dist = ri[idx] - li[idx] + 1; if (dist > clen[ri[idx]]) continue; else if (dist == clen[ri[idx]]) { ans += SAM::nodes[p].siz; continue; } for (int i = 19; i >= 0; i--) if (fa[i][p] && SAM::nodes[fa[i][p]].dep >= dist) p = fa[i][p]; ans += SAM::nodes[p].siz; } printf("%lld\n", ans); } } int main() { memset(head, -1, sizeof(head)); scanf("%d%d%d%d%s", &n, &m, &q, &k, str + 1); for (int i = 1; i <= n; i++) SAM::insert(str[i] - 'a'); SAM::build_graph(), SAM::dfs(1); for (int i = 1; i <= m; i++) scanf("%d%d", &li[i], &ri[i]), li[i]++, ri[i]++; if (1LL * k <= block_size) solve1(); else solve2(); return 0; }