AC 自动机 | Aho-Corasick Automaton

概述

AC 自动机是一种有限状态自动机,常用于多模式串的字符串匹配。

模版

// P3808.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 2e6 + 200;

int nodes[MAX_N][26], tag[MAX_N], fail[MAX_N], ptot, n;
char opt[MAX_N];

void insert(char *str)
{
    int p = 0;
    for (int i = 1; str[i] != '\0'; i++)
    {
        int digit = str[i] - 'a';
        if (nodes[p][digit] == 0)
            nodes[p][digit] = ++ptot;
        p = nodes[p][digit];
    }
    tag[p]++;
}

void inti_automaton()
{
    queue<int> q;
    for (int i = 0; i < 26; i++)
        if (nodes[0][i] != 0)
            fail[nodes[0][i]] = 0, q.push(nodes[0][i]);
    while (!q.empty())
    {
        int u = q.front();
        q.pop();
        for (int i = 0; i < 26; i++)
            if (nodes[u][i] != 0)
                fail[nodes[u][i]] = nodes[fail[u]][i], q.push(nodes[u][i]);
            else
                nodes[u][i] = nodes[fail[u]][i];
    }
}

int query(char *str)
{
    int p = 0, ret = 0;
    for (int i = 1; str[i] != '\0'; i++)
    {
        p = nodes[p][str[i] - 'a'];
        for (int pt = p; pt != 0 && tag[pt] != -1; pt = fail[pt])
            ret += tag[pt], tag[pt] = -1;
    }
    return ret;
}

int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%s", opt + 1), insert(opt);
    inti_automaton();
    scanf("%s", opt + 1);
    printf("%d\n", query(opt));
    return 0;
}

用法

A – 普通字符串匹配

直接将字符串放入 AC 自动机中,走入 Trie 树中,判断 Tag 即可。

B – Fail 树上 DP

分析 Fail 树的性质:父节点一定是当前字符串的一个后缀。

C – 可持久化 AC 自动机

考虑维护一个栈来记录走过的节点,如果需要回溯直接减去字符串长度进行回弹。例题:[USACO15FEB]审查(Gold)

// P3121.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e5 + 200;

int n, nodes[MAX_N][26], ptot, stk[MAX_N], tag[MAX_N], top, fail[MAX_N];
char passage[MAX_N], opt[MAX_N], cstk[MAX_N];

void insert(char *str)
{
    int p = 0;
    for (int i = 1; str[i] != '\0'; i++)
    {
        if (nodes[p][str[i] - 'a'] == 0)
            nodes[p][str[i] - 'a'] = ++ptot;
        p = nodes[p][str[i] - 'a'];
    }
    tag[p] = strlen(str + 1);
}

void build_automaton()
{
    queue<int> q;
    for (int i = 0; i < 26; i++)
        if (nodes[0][i] != 0)
            q.push(nodes[0][i]), fail[nodes[0][i]] = 0;
    while (!q.empty())
    {
        int u = q.front();
        q.pop();
        for (int i = 0; i < 26; i++)
            if (nodes[u][i] != 0)
                fail[nodes[u][i]] = nodes[fail[u]][i], q.push(nodes[u][i]);
            else
                nodes[u][i] = nodes[fail[u]][i];
    }
}

int main()
{
    scanf("%s%d", passage + 1, &n);
    for (int i = 1; i <= n; i++)
        scanf("%s", opt + 1), insert(opt);
    build_automaton();
    for (int i = 1, p = 0; passage[i] != '\0'; i++)
    {
        p = nodes[p][passage[i] - 'a'], stk[++top] = p, cstk[top] = passage[i];
        if (tag[p])
        {
            top -= tag[p];
            if (top == 0)
                p = 0;
            else
                p = stk[top];
        }
    }
    for (int i = 1; i <= top; i++)
        putchar(cstk[i]);
    putchar('\n');
    return 0;
}

D – Trie 上 DP

一般状态可以设置成\(dp[i][j]\)为进行到\(i\)且当前自动机节点为\(j\)。例题:[JSOI2007]文本生成器

// P4052.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 10010, mod = 10007;

int nodes[MAX_N][26], ptot, fail[MAX_N], n, m, dp[105][MAX_N];
bool tag[MAX_N];
char opt[MAX_N];

void insert(char *str)
{
    int p = 0;
    for (int i = 1; str[i] != '\0'; i++)
    {
        if (nodes[p][str[i] - 'A'] == 0)
            nodes[p][str[i] - 'A'] = ++ptot;
        p = nodes[p][str[i] - 'A'];
    }
    tag[p] |= true;
}

void build_fail()
{
    queue<int> q;
    for (int i = 0; i < 26; i++)
        if (nodes[0][i] != 0)
            q.push(nodes[0][i]), fail[nodes[0][i]] = 0;
    while (!q.empty())
    {
        int u = q.front();
        q.pop();
        for (int i = 0; i < 26; i++)
            if (nodes[u][i] == 0)
                nodes[u][i] = nodes[fail[u]][i];
            else
            {
                fail[nodes[u][i]] = nodes[fail[u]][i];
                tag[nodes[u][i]] |= tag[nodes[fail[u]][i]];
                q.push(nodes[u][i]);
            }
    }
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++)
        scanf("%s", opt + 1), insert(opt);
    build_fail(), dp[0][0] = 1;
    for (int i = 1; i <= m; i++)
        for (int j = 0; j <= ptot; j++)
            for (int bit = 0; bit < 26; bit++)
                if (tag[nodes[j][bit]] == false)
                    dp[i][nodes[j][bit]] = (1LL * dp[i][nodes[j][bit]] + dp[i - 1][j]) % mod;
    int ans = 0;
    for (int i = 0; i <= ptot; i++)
        ans = (1LL * ans + dp[m][i]) % mod;
    int sum = 1;
    for (int i = 1; i <= m; i++)
        sum = 1LL * sum * 26 % mod;
    printf("%d\n", (sum - ans + mod) % mod);
    return 0;
}

 

Leave a Reply

Your email address will not be published. Required fields are marked *