AC自动机算法原理及应用

ac自动机的背景

在字符串搜索中,某些场景下——如在一篇长文档中搜索一系列的子串(常见的敏感词过滤),会出现需要对文档进行多次扫描的问题,时间复杂度达到O(mn)。为了提高性能,ac自动机应运而生。

原理

ac自动机的核心思想是减少相同前缀的重复搜索。为了达到这一目的,需要以下条件:

  • 需要一种结构,将具有相同前缀的pattern绑定在一起——Trie树
  • 需要一直规则,在匹配失败的时候能够快速跳转或者回溯——失败指针

Trie树

Trie树又称字典树,它利用树结构的特点,将具有公共前缀的数据聚集到同一个节点之下,从而在搜索的时候省略掉公共前缀的搜索时间。
Trie树需要满足以下几个特征:

  • 根节点无数据,其余节点含单元数据
  • 根节点到叶节点的路径为数据集中的一条数据元素
  • 每个节点的子节点互斥

失败指针

失败指针(Fail)也可以成为跳转指针(Turnto),一般在树结构填充完毕后进行统计,当然也可以在插入的时候动态更新,他可以在匹配失败时快速定位到其他具有相同前缀的数据,或者回溯到上一层的前缀。一般来说,所有的失败指针连接的串都是平行的(除root),而他们的最终归宿又都在root。

实现

/**
 * ac.h
 * Author : erdao
 * Date : 2016/8/5
 */

#ifndef _AC_H__
#define _AC_H__

#include <stdint.h>
#include <unordered_map>
#include <vector>

//< 转换规则,用作map的key
class TransformPolicy {
public:
    virtual uint64_t transform(void *) = 0;
};


// Trie树结构的定义
template <typename valueType, typename tagType, class Trans>
struct TrieNode {
    //unordered_map快速定位子节点
    //using ChildMapType = std::unordered_map<uint64_t, TrieNode *>;

    valueType     value;
    tagType     tag;
    bool        flag;

    std::unordered_map<uint64_t, TrieNode *>    children;
    TrieNode        *fail;

    TrieNode()
        : value{}
        , tag{}
        , flag{false}
        , fail{ nullptr }
    {}
};

template <typename valueType, typename tagType, class Trans>
class AC_Automation {
    using ACNode = TrieNode<valueType, tagType, Trans>;

public:
    AC_Automation()
        : pRoot(nullptr) {
        pRoot = new ACNode{};
    }
    ~AC_Automation() {
        this->removeNode(pRoot);
        pRoot = nullptr;
    }

    void addData(valueType* pValArray, uint32_t uValNum, tagType vTag);
    void buildFail();

    //返回的格式包括位置和tagType
    std::vector<std::pair<uint32_t, tagType>> Search(valueType *pValArray, uint32_t uValNum);

private:
    void removeNode(ACNode *pNode) {
        if (pNode) {
            for (auto &child : pNode->children) {
                removeNode(child.second);
            }
            pNode->children.clear();
            delete pNode;
        }
    }

private:
    ACNode *pRoot;
};

#endif//_AC_H__
/**
 * ac.cpp
 * Author : erdao
 * Date : 2016/8/5
 */

#include <stdint.h>
#include <queue>
#include <vector>

#include "ac.h"

template <typename valueType, typename tagType, class Trans>
void AC_Automation<valueType, tagType, Trans>::addData(valueType* pValArray, uint32_t uValNum, tagType vTag)  {
    using ACNode = TrieNode<valueType, tagType, Trans>;

    Trans t{};
    ACNode *p = pRoot;
    for (uint32_t i = 0; i < uValNum; ++i) {
        uint64_t key = t.transform(&pValArray[i]);
        if (p->children.find(key) != p->children.end()) {
            p = p->children[key];
        }
        else {
            p->children[key] = new ACNode{};
            p = p->children[key];
            p->value = pValArray[i];
        }
    }
    p->tag = vTag;
    p->flag = true;
}

template <typename valueType, typename tagType, class Trans>
void AC_Automation<valueType, tagType, Trans>::buildFail() {
    using ACNode = TrieNode<valueType, tagType, Trans>;

    //实际上,由于所有的fail指针都相交于pRoot,且相互间平行,不存在交叉情况,可以采用自上而下的遍历方式进行创建
    Trans t{};
    std::queue<ACNode *> qNode;
    qNode.push(pRoot);
    /*pRoot默认指向nullptr,可以作为fail跳转终止条件;另外pRoot本身也是终止条件,按照喜好处理*/

    while (!qNode.empty()) {
        ACNode *p = qNode.front();
        //为p的所有child创建fail
        for (const auto &child : p->children) {
            if (p == pRoot) {
                //深度为1的节点失败都跳转到root
                child.second->fail = pRoot;
            }
            //否则的话,寻找父节点失败指针串中与child值相同的节点
            else{
                uint64_t key = t.transform(&child.second->value);
                ACNode *pFail = p->fail;
                while (pFail) {
                    if (pFail->children.find(key) != pFail->children.end()) {
                        child.second->fail = pFail->children[key];
                        break;
                    }
                    pFail = pFail->fail;
                }
                if (!pFail)
                    child.second->fail = pRoot;
            }
            //将本身处理好的节点放入队列,准备处理子节点
            qNode.push(child.second);
        }
        //将子节点处理好的节点弹出队列
        qNode.pop();
    }    

}

template <typename valueType, typename tagType, class Trans>
std::vector<std::pair<uint32_t, tagType>> AC_Automation<valueType, tagType, Trans>::Search(valueType *pValArray, uint32_t uValNum) {
    using ACNode = TrieNode<valueType, tagType, Trans>;

    //搜索的入口从root开始,不停的迭代Trie树中的fail指针,匹配上时进行记录
    std::vector<std::pair<uint32_t, tagType>> vRes;
    ACNode *p = pRoot;
    Trans t{};
    for (uint32_t i = 0; i < uValNum; ++i) {
        uint64_t key = t.transform(&pValArray[i]);
        //寻找到当前p下fail串中能匹配上的p
        while (p->children.find(key) == p->children.end() && p != pRoot) {
            p = p->fail;
        }
        //如果遍历完没找到的话,那么终止条件是p=pRoot(所有fail最终交会)
        if (p->children.find(key) == p->children.end())
            continue;
        p = p->children[key];

        //如果找到了,检查这一层的fail中是否具有终结标记
        auto pFail = p;
        while (pFail) {
            if (pFail->flag)
                vRes.push_back(std::move (std::make_pair(i, pFail->tag)));
            pFail = pFail->fail;
        }
    }

    return std::move(vRes);
}

在进行自动机的使用时,我们需要给定valueType,tagType以及valueType转换成key的TransformPolicy来进行构造。
以下是一个测试:

//    test.cpp
int main() {
    //转换规则
    class TransChar : public TransformPolicy {
    public:
        uint64_t transform(void *a) { return (uint64_t)*(char *)a; }
    };

    char test_words[][10] {
        "her", "she", "shy", "here", "hi", "he"
    };

    AC_Automation<char, int, TransChar> ac{};
    for (uint32_t i = 0; i < sizeof(test_words) / sizeof(test_words[0]); i++) {
        ac.addData(test_words[i], strlen(test_words[i]), i);
    }
    ac.buildFail();

    const char *query = "Oh, she is there so shy, let's go say hi.";
    auto ret = ac.Search((char *)query, strlen(query));

    std::cout << query << std::endl;
    for (auto &r : ret) {
        uint32_t pos = r.first + 1 - strlen(test_words[r.second]);
        for (uint32_t k = 0; k < pos; ++k)
            std::cout << "-";
        std::cout << test_words[r.second] << std::endl;
    }

    return 0;
}

案例的运行结果为:

Oh, she is there so shy, let's go say hi.
----she
-----he
------------he
------------her
------------here
--------------------shy
--------------------------------------hi
right