Loading [MathJax]/extensions/tex2jax.js

AC 自动机 | Aho-Corasick Automaton

概述

AC 自动机是一种有限状态自动机,常用于多模式串的字符串匹配。

模版

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// P3808.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 2e6 + 200;
int nodes[MAX_N][26], tag[MAX_N], fail[MAX_N], ptot, n;
char opt[MAX_N];
void insert(char *str)
{
int p = 0;
for (int i = 1; str[i] != '\0'; i++)
{
int digit = str[i] - 'a';
if (nodes[p][digit] == 0)
nodes[p][digit] = ++ptot;
p = nodes[p][digit];
}
tag[p]++;
}
void inti_automaton()
{
queue<int> q;
for (int i = 0; i < 26; i++)
if (nodes[0][i] != 0)
fail[nodes[0][i]] = 0, q.push(nodes[0][i]);
while (!q.empty())
{
int u = q.front();
q.pop();
for (int i = 0; i < 26; i++)
if (nodes[u][i] != 0)
fail[nodes[u][i]] = nodes[fail[u]][i], q.push(nodes[u][i]);
else
nodes[u][i] = nodes[fail[u]][i];
}
}
int query(char *str)
{
int p = 0, ret = 0;
for (int i = 1; str[i] != '\0'; i++)
{
p = nodes[p][str[i] - 'a'];
for (int pt = p; pt != 0 && tag[pt] != -1; pt = fail[pt])
ret += tag[pt], tag[pt] = -1;
}
return ret;
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%s", opt + 1), insert(opt);
inti_automaton();
scanf("%s", opt + 1);
printf("%d\n", query(opt));
return 0;
}
// P3808.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 2e6 + 200; int nodes[MAX_N][26], tag[MAX_N], fail[MAX_N], ptot, n; char opt[MAX_N]; void insert(char *str) { int p = 0; for (int i = 1; str[i] != '\0'; i++) { int digit = str[i] - 'a'; if (nodes[p][digit] == 0) nodes[p][digit] = ++ptot; p = nodes[p][digit]; } tag[p]++; } void inti_automaton() { queue<int> q; for (int i = 0; i < 26; i++) if (nodes[0][i] != 0) fail[nodes[0][i]] = 0, q.push(nodes[0][i]); while (!q.empty()) { int u = q.front(); q.pop(); for (int i = 0; i < 26; i++) if (nodes[u][i] != 0) fail[nodes[u][i]] = nodes[fail[u]][i], q.push(nodes[u][i]); else nodes[u][i] = nodes[fail[u]][i]; } } int query(char *str) { int p = 0, ret = 0; for (int i = 1; str[i] != '\0'; i++) { p = nodes[p][str[i] - 'a']; for (int pt = p; pt != 0 && tag[pt] != -1; pt = fail[pt]) ret += tag[pt], tag[pt] = -1; } return ret; } int main() { scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%s", opt + 1), insert(opt); inti_automaton(); scanf("%s", opt + 1); printf("%d\n", query(opt)); return 0; }
// P3808.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 2e6 + 200;

int nodes[MAX_N][26], tag[MAX_N], fail[MAX_N], ptot, n;
char opt[MAX_N];

void insert(char *str)
{
    int p = 0;
    for (int i = 1; str[i] != '\0'; i++)
    {
        int digit = str[i] - 'a';
        if (nodes[p][digit] == 0)
            nodes[p][digit] = ++ptot;
        p = nodes[p][digit];
    }
    tag[p]++;
}

void inti_automaton()
{
    queue<int> q;
    for (int i = 0; i < 26; i++)
        if (nodes[0][i] != 0)
            fail[nodes[0][i]] = 0, q.push(nodes[0][i]);
    while (!q.empty())
    {
        int u = q.front();
        q.pop();
        for (int i = 0; i < 26; i++)
            if (nodes[u][i] != 0)
                fail[nodes[u][i]] = nodes[fail[u]][i], q.push(nodes[u][i]);
            else
                nodes[u][i] = nodes[fail[u]][i];
    }
}

int query(char *str)
{
    int p = 0, ret = 0;
    for (int i = 1; str[i] != '\0'; i++)
    {
        p = nodes[p][str[i] - 'a'];
        for (int pt = p; pt != 0 && tag[pt] != -1; pt = fail[pt])
            ret += tag[pt], tag[pt] = -1;
    }
    return ret;
}

int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%s", opt + 1), insert(opt);
    inti_automaton();
    scanf("%s", opt + 1);
    printf("%d\n", query(opt));
    return 0;
}

用法

A – 普通字符串匹配

直接将字符串放入 AC 自动机中,走入 Trie 树中,判断 Tag 即可。

B – Fail 树上 DP

分析 Fail 树的性质:父节点一定是当前字符串的一个后缀。

C – 可持久化 AC 自动机

考虑维护一个栈来记录走过的节点,如果需要回溯直接减去字符串长度进行回弹。例题:[USACO15FEB]审查(Gold)

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// P3121.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 1e5 + 200;
int n, nodes[MAX_N][26], ptot, stk[MAX_N], tag[MAX_N], top, fail[MAX_N];
char passage[MAX_N], opt[MAX_N], cstk[MAX_N];
void insert(char *str)
{
int p = 0;
for (int i = 1; str[i] != '\0'; i++)
{
if (nodes[p][str[i] - 'a'] == 0)
nodes[p][str[i] - 'a'] = ++ptot;
p = nodes[p][str[i] - 'a'];
}
tag[p] = strlen(str + 1);
}
void build_automaton()
{
queue<int> q;
for (int i = 0; i < 26; i++)
if (nodes[0][i] != 0)
q.push(nodes[0][i]), fail[nodes[0][i]] = 0;
while (!q.empty())
{
int u = q.front();
q.pop();
for (int i = 0; i < 26; i++)
if (nodes[u][i] != 0)
fail[nodes[u][i]] = nodes[fail[u]][i], q.push(nodes[u][i]);
else
nodes[u][i] = nodes[fail[u]][i];
}
}
int main()
{
scanf("%s%d", passage + 1, &n);
for (int i = 1; i <= n; i++)
scanf("%s", opt + 1), insert(opt);
build_automaton();
for (int i = 1, p = 0; passage[i] != '\0'; i++)
{
p = nodes[p][passage[i] - 'a'], stk[++top] = p, cstk[top] = passage[i];
if (tag[p])
{
top -= tag[p];
if (top == 0)
p = 0;
else
p = stk[top];
}
}
for (int i = 1; i <= top; i++)
putchar(cstk[i]);
putchar('\n');
return 0;
}
// P3121.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 1e5 + 200; int n, nodes[MAX_N][26], ptot, stk[MAX_N], tag[MAX_N], top, fail[MAX_N]; char passage[MAX_N], opt[MAX_N], cstk[MAX_N]; void insert(char *str) { int p = 0; for (int i = 1; str[i] != '\0'; i++) { if (nodes[p][str[i] - 'a'] == 0) nodes[p][str[i] - 'a'] = ++ptot; p = nodes[p][str[i] - 'a']; } tag[p] = strlen(str + 1); } void build_automaton() { queue<int> q; for (int i = 0; i < 26; i++) if (nodes[0][i] != 0) q.push(nodes[0][i]), fail[nodes[0][i]] = 0; while (!q.empty()) { int u = q.front(); q.pop(); for (int i = 0; i < 26; i++) if (nodes[u][i] != 0) fail[nodes[u][i]] = nodes[fail[u]][i], q.push(nodes[u][i]); else nodes[u][i] = nodes[fail[u]][i]; } } int main() { scanf("%s%d", passage + 1, &n); for (int i = 1; i <= n; i++) scanf("%s", opt + 1), insert(opt); build_automaton(); for (int i = 1, p = 0; passage[i] != '\0'; i++) { p = nodes[p][passage[i] - 'a'], stk[++top] = p, cstk[top] = passage[i]; if (tag[p]) { top -= tag[p]; if (top == 0) p = 0; else p = stk[top]; } } for (int i = 1; i <= top; i++) putchar(cstk[i]); putchar('\n'); return 0; }
// P3121.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 1e5 + 200;

int n, nodes[MAX_N][26], ptot, stk[MAX_N], tag[MAX_N], top, fail[MAX_N];
char passage[MAX_N], opt[MAX_N], cstk[MAX_N];

void insert(char *str)
{
    int p = 0;
    for (int i = 1; str[i] != '\0'; i++)
    {
        if (nodes[p][str[i] - 'a'] == 0)
            nodes[p][str[i] - 'a'] = ++ptot;
        p = nodes[p][str[i] - 'a'];
    }
    tag[p] = strlen(str + 1);
}

void build_automaton()
{
    queue<int> q;
    for (int i = 0; i < 26; i++)
        if (nodes[0][i] != 0)
            q.push(nodes[0][i]), fail[nodes[0][i]] = 0;
    while (!q.empty())
    {
        int u = q.front();
        q.pop();
        for (int i = 0; i < 26; i++)
            if (nodes[u][i] != 0)
                fail[nodes[u][i]] = nodes[fail[u]][i], q.push(nodes[u][i]);
            else
                nodes[u][i] = nodes[fail[u]][i];
    }
}

int main()
{
    scanf("%s%d", passage + 1, &n);
    for (int i = 1; i <= n; i++)
        scanf("%s", opt + 1), insert(opt);
    build_automaton();
    for (int i = 1, p = 0; passage[i] != '\0'; i++)
    {
        p = nodes[p][passage[i] - 'a'], stk[++top] = p, cstk[top] = passage[i];
        if (tag[p])
        {
            top -= tag[p];
            if (top == 0)
                p = 0;
            else
                p = stk[top];
        }
    }
    for (int i = 1; i <= top; i++)
        putchar(cstk[i]);
    putchar('\n');
    return 0;
}

D – Trie 上 DP

一般状态可以设置成\(dp[i][j]\)为进行到\(i\)且当前自动机节点为\(j\)。例题:[JSOI2007]文本生成器

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
// P4052.cpp
#include <bits/stdc++.h>
using namespace std;
const int MAX_N = 10010, mod = 10007;
int nodes[MAX_N][26], ptot, fail[MAX_N], n, m, dp[105][MAX_N];
bool tag[MAX_N];
char opt[MAX_N];
void insert(char *str)
{
int p = 0;
for (int i = 1; str[i] != '\0'; i++)
{
if (nodes[p][str[i] - 'A'] == 0)
nodes[p][str[i] - 'A'] = ++ptot;
p = nodes[p][str[i] - 'A'];
}
tag[p] |= true;
}
void build_fail()
{
queue<int> q;
for (int i = 0; i < 26; i++)
if (nodes[0][i] != 0)
q.push(nodes[0][i]), fail[nodes[0][i]] = 0;
while (!q.empty())
{
int u = q.front();
q.pop();
for (int i = 0; i < 26; i++)
if (nodes[u][i] == 0)
nodes[u][i] = nodes[fail[u]][i];
else
{
fail[nodes[u][i]] = nodes[fail[u]][i];
tag[nodes[u][i]] |= tag[nodes[fail[u]][i]];
q.push(nodes[u][i]);
}
}
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
scanf("%s", opt + 1), insert(opt);
build_fail(), dp[0][0] = 1;
for (int i = 1; i <= m; i++)
for (int j = 0; j <= ptot; j++)
for (int bit = 0; bit < 26; bit++)
if (tag[nodes[j][bit]] == false)
dp[i][nodes[j][bit]] = (1LL * dp[i][nodes[j][bit]] + dp[i - 1][j]) % mod;
int ans = 0;
for (int i = 0; i <= ptot; i++)
ans = (1LL * ans + dp[m][i]) % mod;
int sum = 1;
for (int i = 1; i <= m; i++)
sum = 1LL * sum * 26 % mod;
printf("%d\n", (sum - ans + mod) % mod);
return 0;
}
// P4052.cpp #include <bits/stdc++.h> using namespace std; const int MAX_N = 10010, mod = 10007; int nodes[MAX_N][26], ptot, fail[MAX_N], n, m, dp[105][MAX_N]; bool tag[MAX_N]; char opt[MAX_N]; void insert(char *str) { int p = 0; for (int i = 1; str[i] != '\0'; i++) { if (nodes[p][str[i] - 'A'] == 0) nodes[p][str[i] - 'A'] = ++ptot; p = nodes[p][str[i] - 'A']; } tag[p] |= true; } void build_fail() { queue<int> q; for (int i = 0; i < 26; i++) if (nodes[0][i] != 0) q.push(nodes[0][i]), fail[nodes[0][i]] = 0; while (!q.empty()) { int u = q.front(); q.pop(); for (int i = 0; i < 26; i++) if (nodes[u][i] == 0) nodes[u][i] = nodes[fail[u]][i]; else { fail[nodes[u][i]] = nodes[fail[u]][i]; tag[nodes[u][i]] |= tag[nodes[fail[u]][i]]; q.push(nodes[u][i]); } } } int main() { scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++) scanf("%s", opt + 1), insert(opt); build_fail(), dp[0][0] = 1; for (int i = 1; i <= m; i++) for (int j = 0; j <= ptot; j++) for (int bit = 0; bit < 26; bit++) if (tag[nodes[j][bit]] == false) dp[i][nodes[j][bit]] = (1LL * dp[i][nodes[j][bit]] + dp[i - 1][j]) % mod; int ans = 0; for (int i = 0; i <= ptot; i++) ans = (1LL * ans + dp[m][i]) % mod; int sum = 1; for (int i = 1; i <= m; i++) sum = 1LL * sum * 26 % mod; printf("%d\n", (sum - ans + mod) % mod); return 0; }
// P4052.cpp
#include <bits/stdc++.h>

using namespace std;

const int MAX_N = 10010, mod = 10007;

int nodes[MAX_N][26], ptot, fail[MAX_N], n, m, dp[105][MAX_N];
bool tag[MAX_N];
char opt[MAX_N];

void insert(char *str)
{
    int p = 0;
    for (int i = 1; str[i] != '\0'; i++)
    {
        if (nodes[p][str[i] - 'A'] == 0)
            nodes[p][str[i] - 'A'] = ++ptot;
        p = nodes[p][str[i] - 'A'];
    }
    tag[p] |= true;
}

void build_fail()
{
    queue<int> q;
    for (int i = 0; i < 26; i++)
        if (nodes[0][i] != 0)
            q.push(nodes[0][i]), fail[nodes[0][i]] = 0;
    while (!q.empty())
    {
        int u = q.front();
        q.pop();
        for (int i = 0; i < 26; i++)
            if (nodes[u][i] == 0)
                nodes[u][i] = nodes[fail[u]][i];
            else
            {
                fail[nodes[u][i]] = nodes[fail[u]][i];
                tag[nodes[u][i]] |= tag[nodes[fail[u]][i]];
                q.push(nodes[u][i]);
            }
    }
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++)
        scanf("%s", opt + 1), insert(opt);
    build_fail(), dp[0][0] = 1;
    for (int i = 1; i <= m; i++)
        for (int j = 0; j <= ptot; j++)
            for (int bit = 0; bit < 26; bit++)
                if (tag[nodes[j][bit]] == false)
                    dp[i][nodes[j][bit]] = (1LL * dp[i][nodes[j][bit]] + dp[i - 1][j]) % mod;
    int ans = 0;
    for (int i = 0; i <= ptot; i++)
        ans = (1LL * ans + dp[m][i]) % mod;
    int sum = 1;
    for (int i = 1; i <= m; i++)
        sum = 1LL * sum * 26 % mod;
    printf("%d\n", (sum - ans + mod) % mod);
    return 0;
}

 

Leave a Reply

Your email address will not be published. Required fields are marked *