后缀自动机 | SAM

概述

后缀自动机是处理字符串信息的有力工具。后缀自动机存储在 Trie 树上,配合 Link 指针就可以被认为是一张 DAG。任意一条从原点出发的路径都可以被认为是这个字符串的一个子串,且在后缀自动机上不会存在重复的子串信息(然而,我们可以进行一些扩展来维护子串位置信息)。

SAM 的构造

namespace SAM
{

struct node
{
    int dep, link, ch[26], pos;
} nodes[MAX_N << 1];

int ptot = 1, last_ptr = 1;

void insert(int c, int idx)
{
    int pre = last_ptr, p = last_ptr = ++ptot;
    nodes[p].dep = nodes[pre].dep + 1, nodes[p].pos = idx;
    while (pre && nodes[pre].ch[c] == 0)
        nodes[pre].ch[c] = p, pre = nodes[pre].link;
    if (pre == 0)
        nodes[p].link = 1;
    else
    {
        int q = nodes[pre].ch[c];
        if (nodes[q].dep == nodes[pre].dep + 1)
            nodes[p].link = q;
        else
        {
            int clone = ++ptot;
            nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1;
            nodes[p].link = nodes[q].link = clone;
            while (pre && nodes[pre].ch[c] == q)
                nodes[pre].ch[c] = clone, pre = nodes[pre].link;
        }
    }
}

} // namespace SAM

SAM 的构造是线性时间的,且可以动态加入字符。

在这里介绍以下广义后缀自动机的构造方式。这是离线的做法:

// P6139.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 2e6 + 200;

int n, pos[MAX_N];
char str[MAX_N];

namespace Trie
{

    int ch[MAX_N][26], ptot = 1, up[MAX_N], depot[MAX_N];

    void insert()
    {
        int p = 1;
        for (int i = 1; str[i]; i++)
        {
            if (ch[p][str[i] - 'a'] == 0)
                ch[p][str[i] - 'a'] = ++ptot, up[ptot] = p, depot[ptot] = str[i] - 'a';
            p = ch[p][str[i] - 'a'];
        }
    }

} // namespace Trie

namespace SAM
{

    struct node
    {
        int ch[26], link, dep;
    } nodes[MAX_N];

    int ptot = 1, last_ptr = 1;

    void insert(int c)
    {
        int pre = last_ptr, p = last_ptr = ++ptot;
        nodes[p].dep = nodes[pre].dep + 1;
        while (pre && nodes[pre].ch[c] == 0)
            nodes[pre].ch[c] = p, pre = nodes[pre].link;
        if (pre == 0)
            nodes[p].link = 1;
        else
        {
            int q = nodes[pre].ch[c];
            if (nodes[q].dep == nodes[pre].dep + 1)
                nodes[p].link = q;
            else
            {
                int clone = ++ptot;
                nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1;
                nodes[p].link = nodes[q].link = clone;
                while (pre && nodes[pre].ch[c] == q)
                    nodes[pre].ch[c] = clone, pre = nodes[pre].link;
            }
        }
    }

} // namespace SAM

int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%s", str + 1), Trie::insert();
    queue<int> q;
    for (int i = 0; i < 26; i++)
        if (Trie::ch[1][i])
            q.push(Trie::ch[1][i]);
    pos[1] = 1;
    while (!q.empty())
    {
        int u = q.front();
        q.pop();
        SAM::last_ptr = pos[Trie::up[u]];
        SAM::insert(Trie::depot[u]);
        pos[u] = SAM::last_ptr;
        for (int i = 0; i < 26; i++)
            if (Trie::ch[u][i])
                q.push(Trie::ch[u][i]);
    }
    long long ans = 0;
    for (int i = 2; i <= SAM::ptot; i++)
        ans += SAM::nodes[i].dep - SAM::nodes[SAM::nodes[i].link].dep;
    printf("%lld\n", ans);
    return 0;
}

在线的做法:

// CF316G3.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e6 + 200;

int n, lens[11], li[11], ri[11];
char str[11][MAX_N];

// SAM;

struct node
{
    int ch[26], dep, link;
} nodes[MAX_N];

int ptot = 1, last_ptr = 1, siz[11][MAX_N], rnk[MAX_N], bucket[MAX_N];
vector<int> G[MAX_N];

void insert(int c)
{
    int pre = last_ptr;
    if (nodes[pre].ch[c] != 0)
    {
        int q = nodes[pre].ch[c];
        if (nodes[q].dep == nodes[pre].dep + 1)
            last_ptr = q;
        else
        {
            int clone = ++ptot;
            nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1;
            last_ptr = clone, nodes[q].link = clone;
            while (pre && nodes[pre].ch[c] == q)
                nodes[pre].ch[c] = clone, pre = nodes[pre].link;
        }
    }
    else
    {
        int p = last_ptr = ++ptot;
        nodes[p].dep = nodes[pre].dep + 1;
        while (pre && nodes[pre].ch[c] == 0)
            nodes[pre].ch[c] = p, pre = nodes[pre].link;
        if (pre == 0)
            nodes[p].link = 1;
        else
        {
            int q = nodes[pre].ch[c];
            if (nodes[q].dep == nodes[pre].dep + 1)
                nodes[p].link = q;
            else
            {
                int clone = ++ptot;
                nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1;
                nodes[p].link = nodes[q].link = clone;
                while (pre && nodes[pre].ch[c] == q)
                    nodes[pre].ch[c] = clone, pre = nodes[pre].link;
            }
        }
    }
}

int main()
{
    scanf("%s%d", str[0] + 1, &n), lens[0] = strlen(str[0] + 1);
    for (int i = 1; str[0][i]; i++)
        insert(str[0][i] - 'a'), siz[0][last_ptr]++;
    for (int i = 1; i <= n; i++)
    {
        scanf("%s%d%d", str[i] + 1, &li[i], &ri[i]), lens[i] = strlen(str[i] + 1);
        last_ptr = 1;
        for (int j = 1; str[i][j]; j++)
            insert(str[i][j] - 'a'), siz[i][last_ptr]++;
    }
    for (int i = 1; i <= ptot; i++)
        bucket[nodes[i].dep]++;
    for (int i = 1; i <= ptot; i++)
        bucket[i] += bucket[i - 1];
    for (int i = 1; i <= ptot; i++)
        rnk[bucket[nodes[i].dep]--] = i;
    for (int i = ptot; i >= 2; i--)
        for (int j = 0; j <= n; j++)
            siz[j][nodes[rnk[i]].link] += siz[j][rnk[i]];
    long long ans = 0;
    for (int i = 2; i <= ptot; i++)
        if (siz[0][i])
        {
            bool flag = true;
            for (int j = 1; j <= n; j++)
                flag &= li[j] <= siz[j][i] && siz[j][i] <= ri[j];
            if (flag)
                ans += nodes[i].dep - nodes[nodes[i].link].dep;
        }
    printf("%lld\n", ans);
    return 0;
}

Link 树

Link 树是一个很妙的东西,将若干个 Endpos 等价类做成树形结构。它有以下性质:

  • \(dep[u] = dep[fa] + 1\),因为他们代表拥有相同的后缀,但是可以继续分化(前缀不同,但是存在规模更小的、但深度更大的 Endpos 等价类)。
  • \(endpos[u] \subset endpos[fa]\),利用这个性质我们可以套路地使用线段树合并来维护具体的等价类。

SAM 的应用

求本质不同的子串个数

发现子串的表现形式在 SAM 中是唯一的、不受位置所影响(也就是不会算重复),每一条从根出发的路径都是一个唯一的子串,所以求解这个问题我们直接对 DAG 做路径计数即可。


求一个串在母串中出现的次数

我们让串在母串的 SAM 中游走到点\(p\),然后我们回顾一下 Endpos 等价类的相关信息:这样的路径和母串一一对应,所以我们只需要用\(p\)的性质来计算个数,显然答案就是\(endpos[p]\)的大小。在构造初期打好标记之后,根据树形结构,我们 DFS 一遍求得大小即可。


求两个字符串的 LCP

对其中一个串建 SAM,然后暴力往上跳即可。暴力往上跳的具体操作就是像 AC 自动机一般,但是需要对串长进行注意,每一次向上跳 Link 父亲都代表着放弃一个长度的前缀,这一点和 AC 自动机并不一样。


CF1120C Compress String

这道题显然是一个 \(O(n^2)\) 的 DP。考虑设置状态\(f_i\)为到了第\(i\)为的最优代价。转移非常显然:

\[ f_i = \min \begin{cases} f_{i – 1} + A \\ f_j + b, \text{if } S[j + 1 \dots i] \subset S[1 \dots j] \end{cases} \]

考虑预处理出字串的最早节点\(pos[i][j]\)和每一个节点的最早出现位置,然后进行判断即可。

// CF1120C.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e5 + 200;

struct node
{
    int ch[26], link, dep, occur = 0x3f3f3f3f;
} nodes[MAX_N];

int n, ca, cb, ptot, last_cur, bucketA[MAX_N], bucketB[MAX_N];
int pos[5001][5001], dp[MAX_N];
char str[MAX_N];

void sam_initialize() { last_cur = ++ptot; }

void insert(int c, int pos)
{
    int cur = ++ptot;
    nodes[cur].dep = nodes[last_cur].dep + 1;
    nodes[cur].occur = pos;
    int p = last_cur;
    while (p && nodes[p].ch[c] == 0)
        nodes[p].ch[c] = cur, p = nodes[p].link;

    if (p == 0)
        nodes[cur].link = 1;
    else
    {
        int q = nodes[p].ch[c];

        if (nodes[p].dep + 1 == nodes[q].dep)
            nodes[cur].link = q;
        else
        {
            int clone = ++ptot;
            nodes[clone].dep = nodes[p].dep + 1;
            memcpy(nodes[clone].ch, nodes[q].ch, sizeof(nodes[q].ch));
            nodes[clone].link = nodes[q].link;
            while (p && nodes[p].ch[c] == q)
                nodes[p].ch[c] = clone, p = nodes[p].link;
            nodes[q].link = nodes[cur].link = clone;
        }
    }

    last_cur = cur;
}

void stringSort()
{
    for (int i = 1; i <= ptot; i++)
        bucketA[i] = 0;
    for (int i = 1; i <= ptot; i++)
        bucketA[nodes[i].dep]++;
    for (int i = 1; i <= ptot; i++)
        bucketA[i] += bucketA[i - 1];
    for (int i = 1; i <= ptot; i++)
        bucketB[bucketA[nodes[i].dep]--] = i;

    for (int i = ptot; i >= 1; i--)
        nodes[nodes[bucketB[i]].link].occur = min(nodes[nodes[bucketB[i]].link].occur, nodes[bucketB[i]].occur);
}

int main()
{
    sam_initialize();
    scanf("%d%d%d%s", &n, &ca, &cb, str + 1);
    for (int i = 1; i <= n; i++)
        insert(str[i] - 'a', i);
    for (int i = 1; i <= n; i++)
    {
        int p = 1;
        for (int j = i; j <= n; j++)
        {
            p = nodes[p].ch[str[j] - 'a'];
            pos[i][j] = p;
        }
    }

    stringSort();
    for (int i = 1; i <= n; i++)
    {
        dp[i] = dp[i - 1] + ca;
        for (int j = 1; j <= i - 1; j++)
            if (nodes[pos[j + 1][i]].occur <= j)
            {
                dp[i] = min(dp[i], dp[j] + cb);
                break;
            }
    }
    printf("%d", dp[n]);
    return 0;
}

[SDOI2016]生成魔咒

其实这就是统计不同字串的在线版本。我们考虑后缀会对答案造成哪些贡献:\(len(max\{endpos(p)\}) – len(\min\{endpos(p)\} + 1\),其实也就是\(maxLen(p) – maxLen(fa[p])\)。

// P4070.cpp
#include <bits/stdc++.h>
#define ll long long

using namespace std;

const int MAX_N = 1e6 + 200;

struct node
{
    map<int, int> ch;
    int link, len;
} nodes[MAX_N];

int n, ptot = 1, last_ptr = 1, opt;
ll ans;

void insert(int c)
{
    int pre = last_ptr, p = last_ptr = ++ptot;
    nodes[p].len = nodes[pre].len + 1;
    for (; pre != 0 && nodes[pre].ch[c] == 0; pre = nodes[pre].link)
        nodes[pre].ch[c] = p;
    if (pre == 0)
        nodes[p].link = 1;
    else
    {
        int q = nodes[pre].ch[c];
        if (nodes[q].len == nodes[pre].len + 1)
            nodes[p].link = q;
        else
        {
            int clone = ++ptot;
            nodes[clone] = nodes[q];
            nodes[clone].len = nodes[pre].len + 1;
            nodes[q].link = nodes[p].link = clone;
            for (; nodes[pre].ch[c] == q; pre = nodes[pre].link)
                nodes[pre].ch[c] = clone;
        }
    }
    ans += nodes[p].len - nodes[nodes[p].link].len;
}

int main()
{
    scanf("%d", &n);
    while (n--)
    {
        scanf("%d", &opt);
        insert(opt);
        printf("%lld\n", ans);
    }
    return 0;
}

[HEOI2016/TJOI2016]字符串

转换题意:问你子串\(S[a..b]\)的所有子串和\(S[c..d]\)的最长公共前缀的长度的最大值是多少?考虑把前缀转后缀(反转字符串)、对每一问进行二分,把\(S[c..c+mid – 1]\)抽出来并放到 SAM 中进行 check。其实我们只需要拿到这个串所在的\(endpos\)等价类,并查询类中有没有在区间\([a, b]\)(记得反转)中的即可。线段树合并 + SAM。

// P4094.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 2e5 + 200;

int n, q;
char str[MAX_N];

namespace SAM
{

struct node
{
    int link, dep, ch[26];
} nodes[MAX_N << 1];

int ptot = 1, last_ptr = 1;

void insert(int c)
{
    int pre = last_ptr, p = last_ptr = ++ptot;
    nodes[p].dep = nodes[pre].dep + 1;
    while (pre != 0 && nodes[pre].ch[c] == 0)
        nodes[pre].ch[c] = p, pre = nodes[pre].link;
    if (pre == 0)
        nodes[p].link = 1;
    else
    {
        int q = nodes[pre].ch[c];
        if (nodes[q].dep == nodes[pre].dep + 1)
            nodes[p].link = q;
        else
        {
            int clone = ++ptot;
            nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1;
            nodes[q].link = nodes[p].link = clone;
            while (pre != 0 && nodes[pre].ch[c] == q)
                nodes[pre].ch[c] = clone, pre = nodes[pre].link;
        }
    }
}

} // namespace SAM

namespace SegmentTree
{

struct node
{
    int lson, rson, sum;
} nodes[MAX_N << 6];

int ptot;

void pushup(int p) { nodes[p].sum = nodes[nodes[p].lson].sum + nodes[nodes[p].rson].sum; }

int merge(int x, int y, int l, int r)
{
    if (x == 0)
        return y;
    if (y == 0)
        return x;
    if (l == r)
    {
        nodes[x].sum += nodes[y].sum;
        return x;
    }
    int mid = (l + r) >> 1, p = ++ptot;
    nodes[p].lson = merge(nodes[x].lson, nodes[y].lson, l, mid);
    nodes[p].rson = merge(nodes[x].rson, nodes[y].rson, mid + 1, r);
    pushup(p);
    return p;
}

int update(int qx, int l, int r, int p)
{
    if (p == 0)
        p = ++ptot;
    if (l == r)
    {
        nodes[p].sum++;
        return p;
    }
    int mid = (l + r) >> 1;
    if (qx <= mid)
        nodes[p].lson = update(qx, l, mid, nodes[p].lson);
    else
        nodes[p].rson = update(qx, mid + 1, r, nodes[p].rson);
    pushup(p);
    return p;
}

int query(int ql, int qr, int l, int r, int p)
{
    if (p == 0)
        return 0;
    if (ql <= l && r <= qr)
        return nodes[p].sum;
    int mid = (l + r) >> 1, ret = 0;
    if (ql <= mid)
        ret += query(ql, qr, l, mid, nodes[p].lson);
    if (mid < qr)
        ret += query(ql, qr, mid + 1, r, nodes[p].rson);
    return ret;
}

} // namespace SegmentTree

int roots[MAX_N << 1], pos[MAX_N], perm[MAX_N << 1], fa[21][MAX_N << 1], bucket[MAX_N << 1];

bool compare(const int &rhs1, const int &rhs2)
{
    return SAM::nodes[rhs1].dep < SAM::nodes[rhs2].dep;
}

bool check(int mid, int p, int l, int r)
{
    for (int i = 20; i >= 0; i--)
        if (SAM::nodes[fa[i][p]].dep >= mid && fa[i][p])
            p = fa[i][p];
    return SegmentTree::query(l + mid - 1, r, 1, n, roots[p]) > 0;
}

int main()
{
    scanf("%d%d%s", &n, &q, str + 1);
    reverse(str + 1, str + 1 + n);
    for (int i = 1; i <= n; i++)
    {
        SAM::insert(str[i] - 'a');
        roots[SAM::last_ptr] = SegmentTree::update(i, 1, n, roots[SAM::last_ptr]);
        pos[i] = SAM::last_ptr;
    }
    for (int i = 1; i <= SAM::ptot; i++)
        perm[i] = i;
    for (int i = 1; i <= SAM::ptot; i++)
        bucket[SAM::nodes[i].dep]++;
    for (int i = 1; i <= SAM::ptot; i++)
        bucket[i] += bucket[i - 1];
    for (int i = 1; i <= SAM::ptot; i++)
        perm[bucket[SAM::nodes[i].dep]--] = i;
    for (int i = SAM::ptot; i >= 1; i--)
        if (SAM::nodes[perm[i]].link != 0)
        {
            roots[SAM::nodes[perm[i]].link] =
                SegmentTree::merge(roots[SAM::nodes[perm[i]].link], roots[perm[i]], 1, n);
        }
    for (int i = 1; i <= SAM::ptot; i++)
        fa[0][i] = SAM::nodes[i].link;
    for (int i = 1; i <= 20; i++)
        for (int j = 1; j <= SAM::ptot; j++)
            fa[i][j] = fa[i - 1][fa[i - 1][j]];
    while (q--)
    {
        int a, b, c, d;
        scanf("%d%d%d%d", &a, &b, &c, &d);
        a = n - a + 1, b = n - b + 1;
        c = n - c + 1, d = n - d + 1;
        int l = 0, r = min(a - b + 1, c - d + 1), ret = 0;
        while (l <= r)
        {
            int mid = (l + r) >> 1;
            if (check(mid, pos[c], b, a))
                l = mid + 1, ret = mid;
            else
                r = mid - 1;
        }
        printf("%d\n", ret);
    }
    return 0;
}

CF700E Cool Slogans

这道题还挺有意思。我们最终需要一个链式结构来表示题目中所描述的东西,并求出这个链式结构的长度。其实我们可以知道,这个链式结构上的节点可以直接利用 SAM 上的节点来构造。假设串\(T\)为根节点,我们会发现接下来的每个节点都被上级节点所包含,且都为\(link\)树的一个生成子链。这是很有意思的,所以我们可以直接在 Link 树上 DP:对于每一个节点而言,如果当前的「极大父亲」(这里的「极大父亲」指的是,某些情况下直接的父子关系并不能满足题目中出现两次的要求,所以需要「极大父亲」来表示最近能使其出现两次的祖先)。之后就正常 DP 并取最大值就好了。

// CF700E.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 401000;

int n;
char str[MAX_N];

namespace SAM
{

struct node
{
    int dep, link, ch[26], pos;
} nodes[MAX_N << 1];

int ptot = 1, last_ptr = 1;

void insert(int c, int idx)
{
    int pre = last_ptr, p = last_ptr = ++ptot;
    nodes[p].dep = nodes[pre].dep + 1, nodes[p].pos = idx;
    while (pre && nodes[pre].ch[c] == 0)
        nodes[pre].ch[c] = p, pre = nodes[pre].link;
    if (pre == 0)
        nodes[p].link = 1;
    else
    {
        int q = nodes[pre].ch[c];
        if (nodes[q].dep == nodes[pre].dep + 1)
            nodes[p].link = q;
        else
        {
            int clone = ++ptot;
            nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1;
            nodes[p].link = nodes[q].link = clone;
            while (pre && nodes[pre].ch[c] == q)
                nodes[pre].ch[c] = clone, pre = nodes[pre].link;
        }
    }
}

} // namespace SAM

namespace SegmentTree
{

struct node
{
    int sum, lson, rson;
} nodes[MAX_N * 25];

int ptot;

void pushup(int p) { nodes[p].sum = nodes[nodes[p].lson].sum + nodes[nodes[p].rson].sum; }

int update(int qx, int l, int r, int p)
{
    if (p == 0)
        p = ++ptot;
    if (l == r)
    {
        nodes[p].sum++;
        return p;
    }
    int mid = (l + r) >> 1;
    if (qx <= mid)
        nodes[p].lson = update(qx, l, mid, nodes[p].lson);
    else
        nodes[p].rson = update(qx, mid + 1, r, nodes[p].rson);
    pushup(p);
    return p;
}

int merge(int x, int y, int l, int r)
{
    if (x == 0 || y == 0)
        return x + y;
    if (l == r)
    {
        nodes[x].sum += nodes[y].sum;
        return x;
    }
    int mid = (l + r) >> 1, p = ++ptot;
    nodes[p].lson = merge(nodes[x].lson, nodes[y].lson, l, mid);
    nodes[p].rson = merge(nodes[x].rson, nodes[y].rson, mid + 1, r);
    pushup(p);
    return p;
}

int query(int ql, int qr, int l, int r, int p)
{
    if (p == 0)
        return 0;
    if (ql <= l && r <= qr)
        return nodes[p].sum;
    int mid = (l + r) >> 1, ret = 0;
    if (ql <= mid)
        ret += query(ql, qr, l, mid, nodes[p].lson);
    if (mid < qr)
        ret += query(ql, qr, mid + 1, r, nodes[p].rson);
    return ret;
}

} // namespace SegmentTree

int roots[MAX_N << 1], bucket[MAX_N << 1], rnk[MAX_N << 1], top[MAX_N << 1], dp[MAX_N << 1];

int main()
{
    scanf("%d%s", &n, str + 1);
    for (int i = 1; i <= n; i++)
        SAM::insert(str[i] - 'a', i), roots[SAM::last_ptr] = SegmentTree::update(i, 1, n, roots[SAM::last_ptr]);
    for (int i = 1; i <= SAM::ptot; i++)
        bucket[SAM::nodes[i].dep]++;
    for (int i = 1; i <= SAM::ptot; i++)
        bucket[i] += bucket[i - 1];
    for (int i = 1; i <= SAM::ptot; i++)
        rnk[bucket[SAM::nodes[i].dep]--] = i;
    for (int i = SAM::ptot; i >= 1; i--)
        if (SAM::nodes[rnk[i]].link != 0)
            roots[SAM::nodes[rnk[i]].link] = SegmentTree::merge(roots[SAM::nodes[rnk[i]].link], roots[rnk[i]], 1, n);
    int ans = 1;
    for (int i = 2; i <= SAM::ptot; i++)
    {
        int u = rnk[i], fa = SAM::nodes[u].link;
        if (fa == 1)
        {
            dp[u] = 1, top[u] = u;
            continue;
        }
        int l = SAM::nodes[u].pos - SAM::nodes[u].dep + SAM::nodes[top[fa]].dep;
        int r = SAM::nodes[u].pos - 1;
        int x = SegmentTree::query(l, r, 1, n, roots[top[fa]]);
        if (x > 0)
            dp[u] = dp[fa] + 1, top[u] = u;
        else
            dp[u] = dp[fa], top[u] = top[fa];
        ans = max(ans, dp[u]);
    }
    printf("%d\n", ans);
    return 0;
}

「NOI2018」你的名字

真的是很恶心的一道题目了。题意转换:在串\(S[l\dots r]\)中计算\(T\)串中不为\(S\)字串的字串的个数。先思考\(l = 1, r = n\)的做法:我们构建\(T\)的后缀自动机时,可以在\(S\)中跑,每次跑的时候记录下最长能与前缀匹配的长度,然后\(SAM_T\)中每一节点的贡献就是:

\[ \max(0, dep[p_T] – \max(match_i, dep[fa[p_T]]) \]

其中\(p_T\)是在\(SAM_T\)中的节点。

考虑正解,我们肯定需要一个线段树来维护集合信息,然后我们可以做类似的事情,只不过我们需要对\(match_i\)的处理进行改良兼容区间信息:LCP 暴力往上跳时判断是否在区间内。然后跟之前一样就可以做完了。

// LOJ2720.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e6 + 2000, MAX_M = 2e7 + 200;

int n, q, limit[MAX_N];
char S[MAX_N], T[MAX_N];

namespace SegmentTree
{

struct node
{
    int sum, lson, rson;
} nodes[MAX_M];

int ptot, roots[MAX_N];

int update(int qx, int l, int r, int p)
{
    if (p == 0)
        p = ++ptot;
    nodes[p].sum++;
    if (l == r)
        return p;
    int mid = (l + r) >> 1;
    if (qx <= mid)
        nodes[p].lson = update(qx, l, mid, nodes[p].lson);
    else
        nodes[p].rson = update(qx, mid + 1, r, nodes[p].rson);
    return p;
}

int query(int ql, int qr, int l, int r, int p)
{
    if (p == 0)
        return 0;
    if (ql <= l && r <= qr)
        return nodes[p].sum;
    int ret = 0, mid = (l + r) >> 1;
    if (ql <= mid)
        ret += query(ql, qr, l, mid, nodes[p].lson);
    if (mid < qr)
        ret += query(ql, qr, mid + 1, r, nodes[p].rson);
    return ret;
}

int merge(int x, int y, int l, int r)
{
    if (x == 0 || y == 0)
        return x + y;
    int p = ++ptot, mid = (l + r) >> 1;
    nodes[p].sum = nodes[x].sum + nodes[y].sum;
    if (l == r)
        return p;
    nodes[p].lson = merge(nodes[x].lson, nodes[y].lson, l, mid);
    nodes[p].rson = merge(nodes[x].rson, nodes[y].rson, mid + 1, r);
    return p;
}

} // namespace SegmentTree

namespace SAM
{

struct node
{
    int dep, ch[26], link;
} nodes[MAX_N];

int last_ptr = 1, ptot = 1, bucket[MAX_N], rnk[MAX_N];

int newnode() { return ++ptot; }

void initialize_collection()
{
    for (int i = 1; i <= ptot; i++)
        bucket[nodes[i].dep]++;
    for (int i = 1; i <= ptot; i++)
        bucket[i] += bucket[i - 1];
    for (int i = 1; i <= ptot; i++)
        rnk[bucket[nodes[i].dep]--] = i;
    for (int i = ptot; i >= 1; i--)
        if (nodes[rnk[i]].link != 0)
            SegmentTree::roots[nodes[rnk[i]].link] = SegmentTree::merge(SegmentTree::roots[nodes[rnk[i]].link], SegmentTree::roots[rnk[i]], 1, n);
}

} // namespace SAM

namespace SAM_T
{

SAM::node nodes[MAX_N];
int ptot = 1, last_ptr = 1, pos[MAX_N];

void clear() { ptot = last_ptr = 1, memset(nodes[1].ch, 0, sizeof(nodes[1].ch)), nodes[1].dep = nodes[1].link = 0; }

int newnode()
{
    int p = ++ptot;
    memset(nodes[p].ch, 0, sizeof(nodes[p].ch));
    nodes[p].dep = nodes[p].link = 0, pos[p] = 0;
    return p;
}

} // namespace SAM_T

int insert(int c, int &last_ptr, int (*newnode)(), SAM::node *nodes, bool toggle = false)
{
    int pre = last_ptr, p = last_ptr = newnode();
    nodes[p].dep = nodes[pre].dep + 1;
    if (toggle)
        SAM_T::pos[p] = nodes[p].dep;
    while (pre && nodes[pre].ch[c] == 0)
        nodes[pre].ch[c] = p, pre = nodes[pre].link;
    if (pre == 0)
        nodes[p].link = 1;
    else
    {
        int q = nodes[pre].ch[c];
        if (nodes[q].dep == nodes[pre].dep + 1)
            nodes[p].link = q;
        else
        {
            int clone = newnode();
            nodes[clone] = nodes[q];
            if (toggle)
                SAM_T::pos[clone] = SAM_T::pos[q];
            nodes[clone].dep = nodes[pre].dep + 1;
            nodes[q].link = nodes[p].link = clone;
            while (pre && nodes[pre].ch[c] == q)
                nodes[pre].ch[c] = clone, pre = nodes[pre].link;
        }
    }
    return p;
}

int main()
{
    freopen("name.in", "r", stdin);
    freopen("name.out", "w", stdout);
    scanf("%s", S + 1), n = strlen(S + 1);
    for (int i = 1; i <= n; i++)
    {
        insert(S[i] - 'a', SAM::last_ptr, SAM::newnode, SAM::nodes);
        SegmentTree::roots[SAM::last_ptr] = SegmentTree::update(i, 1, n, SegmentTree::roots[SAM::last_ptr]);
    }
    SAM::initialize_collection();
    scanf("%d", &q);
    while (q--)
    {
        int l, r, m;
        scanf("%s%d%d", T + 1, &l, &r);
        SAM_T::clear(), m = strlen(T + 1);
        for (int i = 1, p = 1, clen = 0; i <= m; i++)
        {
            int c = T[i] - 'a';
            insert(c, SAM_T::last_ptr, SAM_T::newnode, SAM_T::nodes, true);
            while (true)
            {
                if (SAM::nodes[p].ch[c] && SegmentTree::query(
                                               l + clen, r,
                                               1, n,
                                               SegmentTree::roots[SAM::nodes[p].ch[c]]))
                {
                    clen++, p = SAM::nodes[p].ch[c];
                    break;
                }
                if (clen == 0)
                    break;
                clen--;
                if (clen == SAM::nodes[SAM::nodes[p].link].dep)
                    p = SAM::nodes[p].link;
            }
            limit[i] = clen;
        }
        long long ans = 0;
        for (int i = 2; i <= SAM_T::ptot; i++)
            ans += max(0, SAM_T::nodes[i].dep - max(SAM_T::nodes[SAM_T::nodes[i].link].dep, limit[SAM_T::pos[i]]));
        printf("%lld\n", ans);
    }
    return 0;
}

「HAOI2016」找相同字符

这道题一眼看上去的做法:把两个串拼在一起,然后遍历每一个 Endpos 集合,在线段树里找前半串和后半串出现的次数再乘起来。

正解:其实不需要复杂的线段树,只需要把 Size 分开记即可,并且在插入第二个串的时候在 SAM 上游走,并把 Last_ptr 重置。

// LOJ2064.cpp
#include <bits/stdc++.h>
#define ll long long

using namespace std;

const int MAX_N = 5e5 + 200;

struct node
{
    int ch[26], siz[2], link, dep;
} nodes[MAX_N << 1];

int last_ptr = 1, ptot = 1, bucket[MAX_N << 1], n1, n2, rnk[MAX_N << 1];
ll ans;
char str1[MAX_N], str2[MAX_N];

void insert(int c, int typ)
{
    int pre = last_ptr, p = last_ptr = ++ptot;
    nodes[p].dep = nodes[pre].dep + 1;
    if (typ == 0)
        nodes[p].siz[0] = 1;
    while (pre && nodes[pre].ch[c] == 0)
        nodes[pre].ch[c] = p, pre = nodes[pre].link;
    if (pre == 0)
        nodes[p].link = 1;
    else
    {
        int q = nodes[pre].ch[c];
        if (nodes[q].dep == nodes[pre].dep + 1)
            nodes[p].link = q;
        else
        {
            int clone = ++ptot;
            nodes[clone] = nodes[q], nodes[clone].siz[0] = nodes[clone].siz[1] = 0;
            nodes[clone].dep = nodes[pre].dep + 1;
            nodes[p].link = nodes[q].link = clone;
            while (pre && nodes[pre].ch[c] == q)
                nodes[pre].ch[c] = clone, pre = nodes[pre].link;
        }
    }
}

void radixSort()
{
    for (int i = 1; i <= ptot; i++)
        bucket[nodes[i].dep]++;
    for (int i = 1; i <= ptot; i++)
        bucket[i] += bucket[i - 1];
    for (int i = ptot; i >= 1; i--)
        rnk[bucket[nodes[i].dep]--] = i;
    for (int i = ptot; i >= 1; i--)
        if (nodes[rnk[i]].link != 0)
        {
            nodes[nodes[rnk[i]].link].siz[0] += nodes[rnk[i]].siz[0];
            nodes[nodes[rnk[i]].link].siz[1] += nodes[rnk[i]].siz[1];
        }
    for (int i = 1; i <= ptot; i++)
        ans += 1LL * nodes[i].siz[0] * nodes[i].siz[1] * (nodes[i].dep - nodes[nodes[i].link].dep);
}

int main()
{
    scanf("%s%s", str1 + 1, str2 + 1);
    n1 = strlen(str1 + 1), n2 = strlen(str2 + 1);
    for (int i = 1; i <= n1; i++)
        insert(str1[i] - 'a', 0);
    last_ptr = 1;
    for (int i = 1, p = 1; i <= n2; i++)
    {
        insert(str2[i] - 'a', 1);
        p = nodes[p].ch[str2[i] - 'a'];
        nodes[p].siz[1] = 1;
    }
    radixSort();
    printf("%lld\n", ans);
    return 0;
}

「雅礼集训 2017 Day1」字符串

好难啊这道题。首先要介绍一种数据结构中出现过的套路:根号分类。对于某些乘积固定的两个参数,可能存在潜在的根号时间复杂度。比如:CF103D。

讲讲正解:先建出 SAM。如果字符串大小小于规定的块大小,那我们考虑下面这种处理方式:

\(k\)小代表着给定的\(m\)个区间的位置都比较集中(?),所以我们开一个三位空间来存区间的编号,再\(k^2\)枚举子串,在 SAM 上走。如果当前子串合规,那么就在三维空间里找有多少个囊括在询问\([l, r]\)中的区间序号。

如果不是:

我们考虑处理询问串每一位结尾向前的最大匹配长度和在 SAM 上的节点,然后对于询问我们只需要从\(a\)枚举到\(b\),并在最右点向上跳到最接近最左点的 SAM 节点。

// LOJ6031.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 2e6 + 200, block_size = 330;

int head[MAX_N], current, fa[20][MAX_N], li[MAX_N], ri[MAX_N], m, n, q, k;
int cpos[MAX_N], clen[MAX_N];
char str[MAX_N], w[MAX_N];

struct edge
{
    int to, nxt;
} edges[MAX_N << 1];

void addpath(int src, int dst)
{
    edges[current].to = dst, edges[current].nxt = head[src];
    head[src] = current++;
}

namespace SAM
{

struct node
{
    int ch[26], link, dep, siz;
} nodes[MAX_N << 1];

int ptot = 1, last_ptr = 1;

void insert(int c)
{
    int pre = last_ptr, p = last_ptr = ++ptot;
    nodes[p].dep = nodes[pre].dep + 1, nodes[p].siz = 1;
    while (pre && nodes[pre].ch[c] == 0)
        nodes[pre].ch[c] = p, pre = nodes[pre].link;
    if (pre == 0)
        nodes[p].link = 1;
    else
    {
        int q = nodes[pre].ch[c];
        if (nodes[q].dep == nodes[pre].dep + 1)
            nodes[p].link = q;
        else
        {
            int clone = ++ptot;
            nodes[clone] = nodes[q], nodes[clone].dep = nodes[pre].dep + 1;
            nodes[p].link = nodes[q].link = clone, nodes[clone].siz = 0;
            while (pre && nodes[pre].ch[c] == q)
                nodes[pre].ch[c] = clone, pre = nodes[pre].link;
        }
    }
}

void build_graph()
{
    for (int i = 1; i <= ptot; i++)
        if (nodes[i].link != 0)
            addpath(nodes[i].link, i), fa[0][i] = nodes[i].link;
}

void dfs(int u)
{
    for (int i = 1; i <= 19; i++)
        fa[i][u] = fa[i - 1][fa[i - 1][u]];
    for (int i = head[u]; i != -1; i = edges[i].nxt)
        dfs(edges[i].to), nodes[u].siz += nodes[edges[i].to].siz;
}

} // namespace SAM

void solve1()
{
    vector<int> buc[block_size][block_size];
    for (int i = 1; i <= m; i++)
        buc[li[i]][ri[i]].push_back(i);
    while (q--)
    {
        int a, b;
        long long ans = 0;
        scanf("%s%d%d", w + 1, &a, &b);
        a++, b++;
        for (int i = 1; i <= k; i++)
            for (int j = i, p = 1; j <= k; j++)
            {
                p = SAM::nodes[p].ch[w[j] - 'a'];
                if (p == 0)
                    break;
                int l = lower_bound(buc[i][j].begin(), buc[i][j].end(), a) - buc[i][j].begin();
                int r = upper_bound(buc[i][j].begin(), buc[i][j].end(), b) - buc[i][j].begin();
                ans += 1LL * (r - l) * SAM::nodes[p].siz;
            }
        printf("%lld\n", ans);
    }
}

void solve2()
{
    while (q--)
    {
        int a, b;
        long long ans = 0;
        scanf("%s%d%d", w + 1, &a, &b);
        a++, b++;
        for (int i = 1, p = 1, len = 0; i <= k; i++)
        {
            while (p && SAM::nodes[p].ch[w[i] - 'a'] == 0)
                p = SAM::nodes[p].link, len = SAM::nodes[p].dep;
            if (p == 0)
                p = 1;
            else
                p = SAM::nodes[p].ch[w[i] - 'a'];
            len += (p != 1);
            cpos[i] = p, clen[i] = len;
        }
        for (int idx = a; idx <= b; idx++)
        {
            int p = cpos[ri[idx]], dist = ri[idx] - li[idx] + 1;
            if (dist > clen[ri[idx]])
                continue;
            else if (dist == clen[ri[idx]])
            {
                ans += SAM::nodes[p].siz;
                continue;
            }
            for (int i = 19; i >= 0; i--)
                if (fa[i][p] && SAM::nodes[fa[i][p]].dep >= dist)
                    p = fa[i][p];
            ans += SAM::nodes[p].siz;
        }
        printf("%lld\n", ans);
    }
}

int main()
{
    memset(head, -1, sizeof(head));
    scanf("%d%d%d%d%s", &n, &m, &q, &k, str + 1);
    for (int i = 1; i <= n; i++)
        SAM::insert(str[i] - 'a');
    SAM::build_graph(), SAM::dfs(1);
    for (int i = 1; i <= m; i++)
        scanf("%d%d", &li[i], &ri[i]), li[i]++, ri[i]++;
    if (1LL * k <= block_size)
        solve1();
    else
        solve2();
    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *