题目:

给定一个长度为N的字符串a,一个大小为M的字符串集合s(最长20)

需获取a中所有在s中存在的子串下标

如:a=中国人民银行今日发表新闻 s=[人民银行, 新闻]

返回:【2,"人民银行"】, 【0, "新闻"】

 

解:

我们需要使用AC自动机算法来解决该问题,学过编译原理的应该都能看出来AC自动机是一种DFA,它的特殊在于,在字典树的基础上,结合了kmp算法的思路,在匹配失效的时候,高效跳转到一个新的状态继续匹配,不需要走回头路重复匹配。

具体思路很多博文都说了,细节我就不重复说了,大概思路就是,使用bfs方式在子串字典树上构建fail树,核心公式:p.next(c).fail=p.fail.next(c)

在构建的过程中,最开始所有节点默认的fial都指向root节点,然后bfs开始遍历,对于每个节点,需要遍历所有的符号,对于可达的符号,则将(当前节点+该可达符号)指向的子节点加入bfsList,并将该子节点的fail指向当前节点的 fail.next(符号);对于当前节点不可达的符号,则直接将(当前节点+不可达符号)指向当前节点的fail.next(符号)

代码:

package learning;

import java.io.*;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiConsumer;

/**
 * @AUTHOR Linhui
 * @DATE 2023-11-17
 */
public class AC {

    public static class ACNode {

        private int id;
        private final int deep;
        private final ACNode father;
        private final char current;
        private ACNode fail;
        private Map<Character, ACNode> childMap;

        private String word;

        private ACNode(ACNode father, String word) {
            this.father = father;
            this.deep = father.deep + 1;
            this.current = word.charAt(deep - 1);
        }

        public ACNode() {
            this.father = null;
            this.deep = 0;
            this.current = 0;
        }

        public void addWord(String word) {
            if (deep > 0) return;
            ACNode p = this;
            for (int i = 0; i < word.length(); i++) {
                char c = word.charAt(i);
                ACNode finalP = p;
                if (p.childMap == null) p.childMap = new HashMap<>();
                ACNode q = p.childMap.computeIfAbsent(c, k -> new ACNode(finalP, word));
                p = q;
            }
            p.word = word;
        }

        @Override
        public String toString() {
            return String.valueOf(id);
        }

        private void build() {
            if (deep > 0) return;
            Deque<ACNode> bfsList = new ArrayDeque<>();
            if (childMap != null) {
                for (ACNode acNode : childMap.values()) {
                    acNode.fail = this;
                    bfsList.addLast(acNode);
                }
            }
            int i = 0;
            while (!bfsList.isEmpty()) {
                ACNode point = bfsList.removeFirst();
                ACNode fail = point.fail;
                point.id = ++i;
                if (point.childMap != null) {
                    for (ACNode acNode : point.childMap.values()) {
                        bfsList.addLast(acNode);
//                    p.next(c).fail=p.fail.next(c)
                        while (true) {
                            /*说明找到了根结点还没有找到*/
                            if (fail == null) {
                                acNode.fail = this;
                                break;
                            }

                            /*说明有公共前缀*/
                            ACNode next = fail.next(acNode.current);
                            if (next != null) {
                                acNode.fail = next;
                                break;
                            } else {/*继续向上寻找*/
                                fail = fail.fail;
                            }
                        }

//                    acNode.fail = point.fail == null ? this.next(acNode.current) : point.fail.next(acNode.current);

                    }
                }
            }

        }

        /**
         * 1)当前指针curr指向AC自动机的根节点:curr=root。
         * 2)从文本串中读取(下)一个字符。
         * 3)从当前节点的所有孩子节点中寻找与该字符匹配的节点:
         * <p>
         * 如果成功:判断当前节点以及当前节点fail指向的节点是否表示字符串结束,则将匹配的字符串(从根节点到结束节点)保存。curr指向孩子节点,继续执行步骤2。
         * 如果失败执行步骤4
         * 4)若fail == null,则说明没有任何子串为输入字符串的前缀,这时设置curr = root,执行步骤2.
         * 若fail != null,则将curr指向 fail节点,执行步骤3(fail.next(c))。
         */
        public ACNode next(char c) {
            ACNode next = childMap == null ? null : childMap.get(c);
            if (next != null) return next;
            if (fail != null) return fail.next(c);
            return null;
        }

        public static class Result {
            private int index;
            private String word;

            public Result(int index, String word) {
                this.index = index;
                this.word = word;
            }
        }

        public void match(String str, BiConsumer<Integer, String> matchResultConsumer) {
            if (deep > 0) return;
            ACNode point = this;
//            List<Result> result = new ArrayList<>();
            for (int i = 0; i < str.length(); i++) {
                char c = str.charAt(i);
                ACNode next = point.next(c);
                if (next == null) {
                    point = this;
                    continue;
                }
                ;
                point = next;
                consumerAllWordsFromPoint(point, i, matchResultConsumer);
//                if (point.word != null) {
//                    result.add(new Result((i - point.deep + 1), point.word));
//                }
            }
            return;
        }

        private void consumerAllWordsFromPoint(ACNode point, int index, BiConsumer<Integer, String> matchResultConsumer) {
            if (point == null || point.deep == 0) return;
            if (point.word != null) {
//                result.add(new Result((index - point.deep + 1), point.word));
                matchResultConsumer.accept(index - point.deep + 1, point.word);
            }
            consumerAllWordsFromPoint(point.fail, index, matchResultConsumer);
        }


    }

    public static void main(String[] args) {
        test();
    }


    public static void test() {
        ACNode root = new ACNode();
        root.addWord("abab");
        root.addWord("abc");
        root.addWord("bca");
        root.addWord("cc");
        root.addWord("cac");
        root.addWord("bab");
        root.build();
        String str = "abababcacc";
        root.match(str, (index, word) -> {
            System.out.println("index: "+index + "\t word: " + word);
        });
    }
}