题目:
给定一个长度为N的字符串a,一个大小为M的字符串集合s(最长20)
需获取a中所有在s中存在的子串下标
如:a=中国人民银行今日发表新闻 s=[人民银行, 新闻]
返回:【2,"人民银行"】, 【0, "新闻"】
解:
我们需要使用AC自动机算法来解决该问题,学过编译原理的应该都能看出来AC自动机是一种DFA,它的特殊在于,在字典树的基础上,结合了kmp算法的思路,在匹配失效的时候,高效跳转到一个新的状态继续匹配,不需要走回头路重复匹配。
具体思路很多博文都说了,细节我就不重复说了,大概思路就是,使用bfs方式在子串字典树上构建fail树,核心公式:p.next(c).fail=p.fail.next(c)
在构建的过程中,最开始所有节点默认的fial都指向root节点,然后bfs开始遍历,对于每个节点,需要遍历所有的符号,对于可达的符号,则将(当前节点+该可达符号)指向的子节点加入bfsList,并将该子节点的fail指向当前节点的 fail.next(符号);对于当前节点不可达的符号,则直接将(当前节点+不可达符号)指向当前节点的fail.next(符号)
代码:
package learning;
import java.io.*;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiConsumer;
/**
* @AUTHOR Linhui
* @DATE 2023-11-17
*/
public class AC {
public static class ACNode {
private int id;
private final int deep;
private final ACNode father;
private final char current;
private ACNode fail;
private Map<Character, ACNode> childMap;
private String word;
private ACNode(ACNode father, String word) {
this.father = father;
this.deep = father.deep + 1;
this.current = word.charAt(deep - 1);
}
public ACNode() {
this.father = null;
this.deep = 0;
this.current = 0;
}
public void addWord(String word) {
if (deep > 0) return;
ACNode p = this;
for (int i = 0; i < word.length(); i++) {
char c = word.charAt(i);
ACNode finalP = p;
if (p.childMap == null) p.childMap = new HashMap<>();
ACNode q = p.childMap.computeIfAbsent(c, k -> new ACNode(finalP, word));
p = q;
}
p.word = word;
}
@Override
public String toString() {
return String.valueOf(id);
}
private void build() {
if (deep > 0) return;
Deque<ACNode> bfsList = new ArrayDeque<>();
if (childMap != null) {
for (ACNode acNode : childMap.values()) {
acNode.fail = this;
bfsList.addLast(acNode);
}
}
int i = 0;
while (!bfsList.isEmpty()) {
ACNode point = bfsList.removeFirst();
ACNode fail = point.fail;
point.id = ++i;
if (point.childMap != null) {
for (ACNode acNode : point.childMap.values()) {
bfsList.addLast(acNode);
// p.next(c).fail=p.fail.next(c)
while (true) {
/*说明找到了根结点还没有找到*/
if (fail == null) {
acNode.fail = this;
break;
}
/*说明有公共前缀*/
ACNode next = fail.next(acNode.current);
if (next != null) {
acNode.fail = next;
break;
} else {/*继续向上寻找*/
fail = fail.fail;
}
}
// acNode.fail = point.fail == null ? this.next(acNode.current) : point.fail.next(acNode.current);
}
}
}
}
/**
* 1)当前指针curr指向AC自动机的根节点:curr=root。
* 2)从文本串中读取(下)一个字符。
* 3)从当前节点的所有孩子节点中寻找与该字符匹配的节点:
* <p>
* 如果成功:判断当前节点以及当前节点fail指向的节点是否表示字符串结束,则将匹配的字符串(从根节点到结束节点)保存。curr指向孩子节点,继续执行步骤2。
* 如果失败执行步骤4
* 4)若fail == null,则说明没有任何子串为输入字符串的前缀,这时设置curr = root,执行步骤2.
* 若fail != null,则将curr指向 fail节点,执行步骤3(fail.next(c))。
*/
public ACNode next(char c) {
ACNode next = childMap == null ? null : childMap.get(c);
if (next != null) return next;
if (fail != null) return fail.next(c);
return null;
}
public static class Result {
private int index;
private String word;
public Result(int index, String word) {
this.index = index;
this.word = word;
}
}
public void match(String str, BiConsumer<Integer, String> matchResultConsumer) {
if (deep > 0) return;
ACNode point = this;
// List<Result> result = new ArrayList<>();
for (int i = 0; i < str.length(); i++) {
char c = str.charAt(i);
ACNode next = point.next(c);
if (next == null) {
point = this;
continue;
}
;
point = next;
consumerAllWordsFromPoint(point, i, matchResultConsumer);
// if (point.word != null) {
// result.add(new Result((i - point.deep + 1), point.word));
// }
}
return;
}
private void consumerAllWordsFromPoint(ACNode point, int index, BiConsumer<Integer, String> matchResultConsumer) {
if (point == null || point.deep == 0) return;
if (point.word != null) {
// result.add(new Result((index - point.deep + 1), point.word));
matchResultConsumer.accept(index - point.deep + 1, point.word);
}
consumerAllWordsFromPoint(point.fail, index, matchResultConsumer);
}
}
public static void main(String[] args) {
test();
}
public static void test() {
ACNode root = new ACNode();
root.addWord("abab");
root.addWord("abc");
root.addWord("bca");
root.addWord("cc");
root.addWord("cac");
root.addWord("bab");
root.build();
String str = "abababcacc";
root.match(str, (index, word) -> {
System.out.println("index: "+index + "\t word: " + word);
});
}
}