目录
算法能解决的问题
给定 n n n个模式串, 给定主串 s s s, 非常高的效率检查主串 s s s中是否存在某个模式串
算法原理
AC自动机本质是在Trie树上用KMP算法的思想的结果
假设在Trie 插入如下模式串she, he, say, shr, her

将KMP算法的ne数组类比到AC自动机中

假设考虑当前子串she, 节点 e e e存储的是以 e e e结尾的某个后缀(后缀有e, he), 和某个最长真前缀 匹配的最长前缀的结尾下标

在上述Trie树中, she的最长匹配真前缀 在 h e he he的结尾

上图是样例构建出AC自动机后的情况
问题就变成了如何计算这些指针 ?
因为KMP算法计算ne数组的过程是根据 [ 0 , i − 1 ] [0, i - 1] [0,i−1]的ne值递推出来的
类比于KMP算法计算ne数组的过程
- 计算AC自动机的
fail数组, 用 b f s bfs bfs搜索 , 将 u u u视为上一层 , 当前层是 t r [ u ] [ i ] tr[u][i] tr[u][i] j = fail[u]就类比于KMP算法的j = ne[i - 1]- 然后向前找, 如果失败
j = fail[j], 类比到KMP算法,j = ne[j - 1]或者j = ne[j], 两种写法都是对的, 只不过初始下标不同 - 如果找到了匹配的位置
if (tr[j][i]), 那么j = tr[j][i], 类比到KMP算法if (s[j] == s[i]) j++或者if (s[j + 1] == s[i]) j++ - 然后记录当前
fail值,fail[c] = j, 类比于KMP算法ne[i] = j

伪代码如下
q是队列, h是队头, t是队尾, tr是Trie树, c是当前位置i, u是上一层位置i - 1, j = fail[u]
因为对于Trie树来说, 节点 0 0 0是根节点, 因此插入的时候是从节点 1 1 1开始的, 因此AC自动机向前寻找
fail指针的过程是j = fail[j], 而不是j = fail[j - 1]
cpp
while (h <= t) {
// 类比为i - 1
int u = q[h++];
for (int i = 0; i < 26; ++i) {
int c = tr[u][i];
// j = ne[i - 1]
int j = fail[u];
while (j && !tr[j][i]) j = fail[j];
if (tr[j][i]) j = tr[j][i];
fail[c] = j;
q[++t] = c;
}
}
类比于KMP算法, AC自动机的匹配方式和构建AC自动机类似
模板代码实现

使用cnt[p]记录如果当前位置的节点编号是单词结尾 , 将该位置 + 1 +1 +1, 因此在统计单次数量的时候, 需要将指针 p p p能到达的字符串都遍历一遍, 具体的来说

假设当前指针指向了she的结尾e, p p p指针需要移动到右上方, 将he的cnt值进行累加, 因为she出现了, he也必定出现
代码实现
cpp
#include <bits/stdc++.h>
using namespace std;
const int N = 1e4 + 10, M = 1e6 + 10, S = 55;
int n;
int tr[N * S][26], idx, cnt[N * S];
int fail[N * S];
int q[N * S], h, t;
void insert(string &s) {
int p = 0;
for (int i = 0; s[i]; ++i) {
int c = s[i] - 'a';
if (!tr[p][c]) tr[p][c] = ++idx;
p = tr[p][c];
}
cnt[p]++;
}
void build() {
h = 0, t = -1;
for (int i = 0; i < 26; ++i) {
if (tr[0][i]) q[++t] = tr[0][i];
}
while (h <= t) {
int u = q[h++];
for (int i = 0; i < 26; ++i) {
int v = tr[u][i];
if (!v) continue;
int j = fail[u];
while (j && !tr[j][i]) j = fail[j];
if (tr[j][i]) j = tr[j][i];
fail[v] = j;
q[++t] = v;
}
}
}
void solve() {
memset(tr, 0, sizeof tr);
memset(cnt, 0, sizeof cnt);
memset(fail, 0, sizeof fail);
idx = 0;
cin >> n;
for (int i = 0; i < n; ++i) {
string p;
cin >> p;
insert(p);
}
build();
string s;
cin >> s;
int ans = 0, j = 0;
// 注意在计算主串的时候子节点下标是c = s[i] - 'a';
for (int i = 0; s[i]; ++i) {
int c = s[i] - 'a';
while (j && !tr[j][c]) j = fail[j];
if (tr[j][c]) j = tr[j][c];
int p = j;
while (p) {
ans += cnt[p];
cnt[p] = 0;
p = fail[p];
}
}
cout << ans << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int T;
cin >> T;
while (T--) solve();
return 0;
}
Trier图
将代码的向前寻找fail指针的过程进行优化, 具体的来说

对于当前儿子 v v v
- 如果当前儿子(是父节点的第 i i i个儿子)不存在, 那么当前儿子节点指向父节点的失败指针的第 i i i个儿子上
- 如果当前儿子存在, 将当前儿子节点的失败指针指向父节点的失败指针的第 i i i个儿子上 , 然后将当前节点 v v v入队
因为所有的失败指针都指向某个最终节点 , 因此在查询过程中, 可以直接查询 , 具体的来说

这一部分, 可以直接写成j = tr[j][c]
代码优化后
cpp
#include <bits/stdc++.h>
using namespace std;
const int N = 1e4 + 10, M = 1e6 + 10, S = 55;
int n;
int tr[N * S][26], idx, cnt[N * S];
int fail[N * S];
int q[N * S], h, t;
void insert(string &s) {
int p = 0;
for (int i = 0; s[i]; ++i) {
int c = s[i] - 'a';
if (!tr[p][c]) tr[p][c] = ++idx;
p = tr[p][c];
}
cnt[p]++;
}
void build() {
h = 0, t = -1;
for (int i = 0; i < 26; ++i) {
if (tr[0][i]) q[++t] = tr[0][i];
}
while (h <= t) {
int u = q[h++];
for (int i = 0; i < 26; ++i) {
int v = tr[u][i];
if (!v) tr[u][i] = tr[fail[u]][i];
else {
fail[v] = tr[fail[u]][i];
q[++t] = v;
}
}
}
}
void solve() {
memset(tr, 0, sizeof tr);
memset(cnt, 0, sizeof cnt);
memset(fail, 0, sizeof fail);
idx = 0;
cin >> n;
for (int i = 0; i < n; ++i) {
string p;
cin >> p;
insert(p);
}
build();
string s;
cin >> s;
int ans = 0, j = 0;
// 注意在计算主串的时候子节点下标是c = s[i] - 'a';
for (int i = 0; s[i]; ++i) {
int c = s[i] - 'a';
j = tr[j][c];
int p = j;
while (p) {
ans += cnt[p];
cnt[p] = 0;
p = fail[p];
}
}
cout << ans << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int T;
cin >> T;
while (T--) solve();
return 0;
}
例题

定义状态表示 f ( i , j ) f(i, j) f(i,j)表示考虑生成前 i i i个字符, 并且当前指向了AC自动机的第 j j j个位置 , 的所有方案中需要改变的字符数量最少的值
AC自动机朴素写法
cpp
#include <bits/stdc++.h>
using namespace std;
const int N = 55, M = 1010, K = 30, INF = 0x3f3f3f3f;
int T, n, m;
int tr[N * K][4], idx, cnt[N * K];
string s;
int q[N * K], h, t;
int fail[N * K];
int f[M][N * K];
int get(char c) {
if (c == 'A') return 0;
else if (c == 'G') return 1;
else if (c == 'C') return 2;
else return 3;
}
void insert(string &s) {
int p = 0;
for (int i = 0; s[i]; ++i) {
int t = get(s[i]);
if (!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
}
cnt[p] = 1;
}
void build() {
h = 0, t = -1;
for (int i = 0; i < 4; ++i) {
if (tr[0][i]) q[++t] = tr[0][i];
}
while (h <= t) {
int u = q[h++];
for (int i = 0; i < 4; ++i) {
int v = tr[u][i];
if (!v) continue;
int j = fail[u];
while (j && !tr[j][i]) j = fail[j];
if (tr[j][i]) j = tr[j][i];
fail[v] = j;
q[++t] = v;
}
}
}
void solve() {
memset(tr, 0, sizeof tr);
memset(cnt, 0, sizeof cnt);
memset(fail, 0, sizeof fail);
memset(f, 0x3f, sizeof f);
idx = 0;
f[0][0] = 0;
for (int i = 0; i < n; ++i) {
string s;
cin >> s;
insert(s);
}
build();
cin >> s;
m = s.size();
s = ' ' + s;
for (int i = 0; i < m; ++i) {
for (int j = 0; j <= idx; ++j) {
for (int k = 0; k < 4; ++k) {
int cost = get(s[i + 1]) != k;
int p = j;
while (p && !tr[p][k]) p = fail[p];
if (tr[p][k]) p = tr[p][k];
else p = 0;
// 关键点, 检查在后缀匹配的前缀中是否有某个前缀是致病片段
int x = p;
bool flag = true;
while (x) {
if (cnt[x]) {
flag = false;
break;
}
x = fail[x];
}
if (flag) f[i + 1][p] = min(f[i + 1][p], f[i][j] + cost);
}
}
}
int ans = INF;
for (int i = 0; i <= idx; ++i) ans = min(ans, f[m][i]);
if (ans == INF) ans = -1;
printf("Case %d: %d\n", ++T, ans);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
while (cin >> n, n) solve();
return 0;
}
Trie图 + 路径优化写法
如果当前后缀所匹配的某个前缀 有致病片段, 因为是匹配的, 当前后缀也包含那个前缀, 因此当前后缀也有致病片段, 因此可以做cnt[v] |= cnt[fail[v]]这样的优化
cpp
#include <bits/stdc++.h>
using namespace std;
const int N = 55, M = 1010, K = 30, INF = 0x3f3f3f3f;
int T, n, m;
int tr[N * K][4], idx, cnt[N * K];
string s;
int q[N * K], h, t;
int fail[N * K];
int f[M][N * K];
int get(char c) {
if (c == 'A') return 0;
else if (c == 'G') return 1;
else if (c == 'C') return 2;
else return 3;
}
void insert(string &s) {
int p = 0;
for (int i = 0; s[i]; ++i) {
int t = get(s[i]);
if (!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
}
cnt[p] = 1;
}
void build() {
h = 0, t = -1;
for (int i = 0; i < 4; ++i) {
if (tr[0][i]) q[++t] = tr[0][i];
}
while (h <= t) {
int u = q[h++];
for (int i = 0; i < 4; ++i) {
int v = tr[u][i];
if (!v) tr[u][i] = tr[fail[u]][i];
else {
fail[v] = tr[fail[u]][i];
// 能够跳转到的某个前缀是否含有致病片段, 如果某个前缀有, 当前位置也有
cnt[v] |= cnt[fail[v]];
q[++t] = v;
}
}
}
}
void solve() {
memset(tr, 0, sizeof tr);
memset(cnt, 0, sizeof cnt);
memset(fail, 0, sizeof fail);
memset(f, 0x3f, sizeof f);
idx = 0;
f[0][0] = 0;
for (int i = 0; i < n; ++i) {
string s;
cin >> s;
insert(s);
}
build();
cin >> s;
m = s.size();
s = ' ' + s;
for (int i = 0; i < m; ++i) {
for (int j = 0; j <= idx; ++j) {
for (int k = 0; k < 4; ++k) {
int cost = get(s[i + 1]) != k;
int p = tr[j][k];
if (!cnt[p]) f[i + 1][p] = min(f[i + 1][p], f[i][j] + cost);
}
}
}
int ans = INF;
for (int i = 0; i <= idx; ++i) ans = min(ans, f[m][i]);
if (ans == INF) ans = -1;
printf("Case %d: %d\n", ++T, ans);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
while (cin >> n, n) solve();
return 0;
}