第 13 章

字典树与AC自动机

第十三章:字典树与AC自动机

在文本处理领域,有一类问题反复出现:给定一组关键词,如何在一段文本中高效地查找它们?搜索引擎的自动补全、IDE 的代码提示、网络防火墙的敏感词过滤、DNA 序列的模式匹配——这些场景背后都站着同一个数据结构家族:Trie(字典树)

当我们进一步将 Trie 与 KMP 算法的失配思想结合,就得到了 Aho-Corasick 自动机(AC 自动机)——一个能在 O(n + m + z) 时间内同时匹配多个模式串的算法,其中 n 是文本长度,m 是所有模式串的总长度,z 是匹配次数。这个算法在 1975 年由 Alfred Aho 和 Margaret Corasick 在贝尔实验室提出,至今仍是多模式匹配的工业标准。

本章将从 Trie 的最基本操作开始,逐步深入到 AC 自动机的完整实现,最后用真实的面试题和工程案例将理论落地。


Level 1 · 你需要知道的

13.1 Trie 的基本概念

Trie(发音为 "try",源自 retrieval)是一种多叉树结构,每条从根到叶子的路径表示一个字符串。与哈希表不同,Trie 天然支持前缀查询——这是它最核心的优势。

想象一个存储了 ["apple", "app", "apt", "bat", "bad"] 的 Trie:

        root
       /    \
      a      b
     /        \
    p          a
   / \        / \
  p   t      t   d
  |
  l
  |
  e

每个节点代表一个字符(严格来说是一条边上的字符),从根到任意标记节点的路径构成一个完整的单词。

为什么不直接用哈希表? 哈希表查找单个 key 是 O(L)(L 为字符串长度),但它无法回答"所有以 'ap' 开头的单词有哪些?"这类前缀问题。哈希表需要遍历所有 key 逐一检查前缀,复杂度为 O(N·L);而 Trie 只需走到 'a' → 'p' 节点,然后遍历其子树,复杂度为 O(P + K),P 是前缀长度,K 是匹配的结果数量。

13.2 Trie 的基本实现

class TrieNode:
    def __init__(self):
        self.children = {}  # char -> TrieNode
        self.is_end = False  # 标记是否为完整单词的结尾


class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word: str) -> None:
        """插入一个单词,时间复杂度 O(L),L = len(word)"""
        node = self.root
        for char in word:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.is_end = True

    def search(self, word: str) -> bool:
        """精确查找,时间复杂度 O(L)"""
        node = self._find_node(word)
        return node is not None and node.is_end

    def starts_with(self, prefix: str) -> bool:
        """前缀查询,时间复杂度 O(P),P = len(prefix)"""
        return self._find_node(prefix) is not None

    def _find_node(self, prefix: str):
        """找到前缀对应的节点"""
        node = self.root
        for char in prefix:
            if char not in node.children:
                return None
            node = node.children[char]
        return node

使用示例

trie = Trie()
trie.insert("apple")
trie.insert("app")
trie.insert("application")

print(trie.search("app"))        # True
print(trie.search("ap"))         # False(不是完整单词)
print(trie.starts_with("ap"))    # True
print(trie.starts_with("b"))     # False

复杂度分析

操作 时间复杂度 空间复杂度
插入 O(L) O(L)(最坏情况,无公共前缀)
查找 O(L) O(1)
前缀查询 O(P) O(1)
构建整棵树 O(N·L_avg) O(N·L_avg)(最坏)

其中 L 是单词长度,P 是前缀长度,N 是单词数量,L_avg 是平均单词长度。

13.3 自动补全功能实现

自动补全是 Trie 最经典的应用场景。当用户输入一个前缀时,我们需要返回所有以该前缀开头的单词:

class AutoCompleteTrie(Trie):
    def autocomplete(self, prefix: str, limit: int = 10) -> list:
        """返回所有以 prefix 开头的单词,最多返回 limit 个"""
        node = self._find_node(prefix)
        if node is None:
            return []

        results = []
        self._dfs_collect(node, prefix, results, limit)
        return results

    def _dfs_collect(self, node, current_word, results, limit):
        """DFS 收集所有完整单词"""
        if len(results) >= limit:
            return
        if node.is_end:
            results.append(current_word)
        for char in sorted(node.children.keys()):  # 按字典序
            self._dfs_collect(
                node.children[char],
                current_word + char,
                results,
                limit
            )

带权重的自动补全(按热度排序):

class WeightedTrieNode:
    def __init__(self):
        self.children = {}
        self.is_end = False
        self.weight = 0  # 搜索热度/频率


class WeightedAutoComplete:
    def __init__(self):
        self.root = WeightedTrieNode()

    def insert(self, word: str, weight: int) -> None:
        node = self.root
        for char in word:
            if char not in node.children:
                node.children[char] = WeightedTrieNode()
            node = node.children[char]
        node.is_end = True
        node.weight = weight

    def top_k_suggestions(self, prefix: str, k: int = 5) -> list:
        """返回权重最高的 k 个补全结果"""
        import heapq
        node = self._find_node(prefix)
        if node is None:
            return []

        # 用最小堆维护 top-k
        heap = []  # (weight, word)
        self._collect_all(node, prefix, heap, k)
        # 按权重降序返回
        return [word for _, word in sorted(heap, reverse=True)]

    def _collect_all(self, node, current, heap, k):
        import heapq
        if node.is_end:
            if len(heap) < k:
                heapq.heappush(heap, (node.weight, current))
            elif node.weight > heap[0][0]:
                heapq.heapreplace(heap, (node.weight, current))
        for char, child in node.children.items():
            self._collect_all(child, current + char, heap, k)

    def _find_node(self, prefix):
        node = self.root
        for char in prefix:
            if char not in node.children:
                return None
            node = node.children[char]
        return node

13.4 压缩 Trie(Patricia Tree / Radix Tree)

标准 Trie 的一个问题是:当路径上只有一个子节点时,会浪费大量空间。例如存储 "romane", "romanus", "romulus",从 'r' 到 'rom' 的路径每个节点都只有一个子节点。

Patricia Tree(Practical Algorithm to Retrieve Information Coded in Alphanumeric,由 Donald R. Morrison 1968 年提出)将这些单链路径压缩为一条边:

标准 Trie:           压缩 Trie (Patricia Tree):
    r                     rom
    |                    /   \
    o                  an     ulus
    |                 / \
    m               e    us
   / \
  a   u
  |   |
  n   l
 / \  |
e   u u
    |  s
    s
class PatriciaNode:
    def __init__(self, label=""):
        self.label = label       # 边上的字符串(可以是多个字符)
        self.children = {}       # 第一个字符 -> PatriciaNode
        self.is_end = False


class PatriciaTrie:
    def __init__(self):
        self.root = PatriciaNode()

    def insert(self, word: str) -> None:
        node = self.root
        i = 0
        while i < len(word):
            char = word[i]
            if char not in node.children:
                # 直接插入剩余部分作为新边
                new_node = PatriciaNode(word[i:])
                new_node.is_end = True
                node.children[char] = new_node
                return
            
            child = node.children[char]
            label = child.label
            # 找到 word[i:] 与 label 的最长公共前缀
            j = 0
            while j < len(label) and i + j < len(word) and label[j] == word[i + j]:
                j += 1

            if j == len(label):
                # label 完全匹配,继续向下
                i += j
                node = child
            else:
                # 需要分裂节点
                # 创建分裂点
                split_node = PatriciaNode(label[:j])
                node.children[char] = split_node

                # 原来的 child 变成 split_node 的子节点
                child.label = label[j:]
                split_node.children[label[j]] = child

                # 如果 word 还有剩余,创建新分支
                if i + j < len(word):
                    new_node = PatriciaNode(word[i + j:])
                    new_node.is_end = True
                    split_node.children[word[i + j]] = new_node
                else:
                    split_node.is_end = True
                return

        node.is_end = True

    def search(self, word: str) -> bool:
        node = self.root
        i = 0
        while i < len(word):
            char = word[i]
            if char not in node.children:
                return False
            child = node.children[char]
            label = child.label
            # 检查 label 是否完全匹配
            if not word[i:i+len(label)] == label:
                return False
            i += len(label)
            node = child
        return node.is_end

Patricia Tree 的优势

13.5 常见错误与陷阱

错误 1:混淆 search 和 starts_with

# 错误:忘记检查 is_end 标记
def search(self, word):
    node = self._find_node(word)
    return node is not None  # BUG! "app" 在 "apple" 中会返回 True

# 正确:
def search(self, word):
    node = self._find_node(word)
    return node is not None and node.is_end

错误 2:删除操作不回收空节点

# 正确的删除实现
def delete(self, word: str) -> bool:
    """删除单词,返回是否成功删除"""
    def _delete(node, word, depth):
        if depth == len(word):
            if not node.is_end:
                return False  # 单词不存在
            node.is_end = False
            return len(node.children) == 0  # 返回是否可以删除该节点

        char = word[depth]
        if char not in node.children:
            return False

        should_delete = _delete(node.children[char], word, depth + 1)
        if should_delete:
            del node.children[char]
            # 如果当前节点也没有其他子节点且不是单词结尾,可以继续删除
            return not node.is_end and len(node.children) == 0
        return False

    _delete(self.root, word, 0)

错误 3:自动补全时未设置上限导致 OOM

当 Trie 中存储百万级单词时,一个很短的前缀(如 "a")可能匹配几十万个结果。必须设置 limit 参数。


Level 2 · 它是怎么运行的

13.6 Trie 的空间优化策略

13.6.1 子节点存储方式对比

Trie 节点的 children 字段有多种实现方式,选择直接影响时间和空间性能:

方式一:固定大小数组

class TrieNodeArray:
    def __init__(self):
        # 26 个字母,直接用索引访问
        self.children = [None] * 26
        self.is_end = False

    def get_child(self, char):
        return self.children[ord(char) - ord('a')]

    def set_child(self, char, node):
        self.children[ord(char) - ord('a')] = node

方式二:哈希表(字典)

class TrieNodeHash:
    def __init__(self):
        self.children = {}  # 动态分配
        self.is_end = False

方式三:有序数组(适合静态 Trie)

class TrieNodeSorted:
    def __init__(self):
        self.keys = []    # 有序字符列表
        self.values = []  # 对应的子节点列表
        self.is_end = False

    def get_child(self, char):
        # 二分查找
        import bisect
        idx = bisect.bisect_left(self.keys, char)
        if idx < len(self.keys) and self.keys[idx] == char:
            return self.values[idx]
        return None

量化对比(存储 10 万个英文单词,平均长度 8):

实现方式 内存占用 插入时间 查找时间
固定数组(26) ~100 MB 最快 最快
哈希表 ~40 MB 中等 中等
有序数组 ~25 MB 最慢(需要移动元素) 略慢(二分)
压缩 Trie ~15 MB 中等

13.6.2 双数组 Trie(Double-Array Trie)

在中文分词、日文输入法等需要极致性能的场景中,双数组 Trie(Aoe Jun-ichi, 1989)是工业级选择。它用两个整数数组 base[]check[] 来模拟整棵 Trie:

class DoubleArrayTrie:
    """
    双数组 Trie 的核心思想:
    - base[s] + c = t:从状态 s 经过字符 c 转移到状态 t
    - check[t] = s:验证状态 t 的父节点确实是 s
    """
    def __init__(self, size=1000000):
        self.base = [0] * size
        self.check = [-1] * size  # -1 表示未使用
        self.base[0] = 1  # 根节点

    def transition(self, state, char_code):
        """从 state 经过 char_code 转移"""
        t = self.base[state] + char_code
        if t < len(self.check) and self.check[t] == state:
            return t
        return -1  # 转移失败

    def search(self, word):
        state = 0
        for ch in word:
            code = ord(ch) - ord('a') + 1  # 1-based
            state = self.transition(state, code)
            if state == -1:
                return False
        # 检查是否为终止状态(用特殊标记)
        end_state = self.transition(state, 0)  # 0 作为终止符
        return end_state != -1

双数组 Trie 的优势在于:

缺点是构建复杂、动态插入困难,适合静态词典(如分词词典)。

13.7 Trie 在 IP 路由表中的应用

网络路由器的核心任务之一是最长前缀匹配(Longest Prefix Match, LPM):给定一个目标 IP 地址,在路由表中找到最长的匹配前缀。这正是 Trie 的天然应用场景。

class IPRoutingTrie:
    """
    IP 路由表的 Trie 实现
    每个 bit 作为一个字符(0 或 1),构建二叉 Trie
    """
    class Node:
        def __init__(self):
            self.children = [None, None]  # 0 和 1
            self.next_hop = None  # 路由下一跳

    def __init__(self):
        self.root = self.Node()

    def insert_route(self, prefix: str, prefix_len: int, next_hop: str):
        """
        插入路由条目
        prefix: IP 地址的二进制表示
        prefix_len: 前缀长度(CIDR 记法中的 /24 等)
        next_hop: 下一跳地址
        """
        node = self.root
        for i in range(prefix_len):
            bit = int(prefix[i])
            if node.children[bit] is None:
                node.children[bit] = self.Node()
            node = node.children[bit]
        node.next_hop = next_hop

    def longest_prefix_match(self, ip_binary: str) -> str:
        """
        最长前缀匹配
        返回最具体的路由的下一跳
        """
        node = self.root
        last_match = None
        for bit_char in ip_binary:
            bit = int(bit_char)
            if node.children[bit] is None:
                break
            node = node.children[bit]
            if node.next_hop is not None:
                last_match = node.next_hop
        return last_match

    @staticmethod
    def ip_to_binary(ip: str) -> str:
        """将 IP 地址转换为 32 位二进制字符串"""
        parts = ip.split('.')
        binary = ''
        for part in parts:
            binary += format(int(part), '08b')
        return binary


# 使用示例
router = IPRoutingTrie()

# 插入路由表项
# 192.168.0.0/16 -> Gateway A
router.insert_route(
    IPRoutingTrie.ip_to_binary("192.168.0.0"), 16, "Gateway A"
)
# 192.168.1.0/24 -> Gateway B
router.insert_route(
    IPRoutingTrie.ip_to_binary("192.168.1.0"), 24, "Gateway B"
)
# 0.0.0.0/0 -> Default Gateway
router.insert_route(
    IPRoutingTrie.ip_to_binary("0.0.0.0"), 0, "Default"
)

# 查询
ip = IPRoutingTrie.ip_to_binary("192.168.1.100")
print(router.longest_prefix_match(ip))  # "Gateway B"(匹配 /24,更长)

ip2 = IPRoutingTrie.ip_to_binary("192.168.2.50")
print(router.longest_prefix_match(ip2))  # "Gateway A"(只匹配 /16)

为什么路由器用 Trie 而不是哈希表? 因为路由查找本质上是前缀匹配问题。一个数据包的目的 IP 可能匹配多条路由规则(如 /8、/16、/24),路由器需要找到最长的那个。哈希表只能做精确匹配,无法高效处理前缀关系。

现代路由器的优化

13.8 AC 自动机:多模式匹配的终极武器

13.8.1 问题定义

单模式匹配:在文本 T 中查找一个模式串 P(KMP 算法解决)。 多模式匹配:在文本 T 中同时查找多个模式串 P₁, P₂, ..., Pₖ。

朴素方法:对每个模式串分别运行 KMP,总复杂度 O(n·k + m)。当模式串数量 k 很大时(如敏感词库有上万个词),这不可接受。

AC 自动机的核心洞察是:将所有模式串构建成一棵 Trie,然后在 Trie 上添加失配指针(failure links),使得扫描文本时不需要回退——这与 KMP 的思想一脉相承,只不过从一维(单模式串)推广到了树形结构(多模式串)。

13.8.2 AC 自动机的三个组成部分

  1. Goto 函数:即 Trie 的转移,从当前状态经过一个字符到达下一个状态
  2. Failure 函数:当 Goto 失败时,跳转到最长的真后缀对应的状态(类比 KMP 的 next 数组)
  3. Output 函数:记录每个状态对应的所有匹配模式串

13.8.3 完整实现

from collections import deque


class AhoCorasick:
    class State:
        def __init__(self):
            self.goto = {}        # char -> state_id
            self.failure = 0      # 失配指针,指向状态编号
            self.output = []      # 该状态匹配的所有模式串

    def __init__(self):
        self.states = [self.State()]  # states[0] 是根节点

    def _new_state(self):
        self.states.append(self.State())
        return len(self.states) - 1

    def build(self, patterns: list) -> None:
        """构建 AC 自动机(三步)"""
        # 第一步:构建 Trie(Goto 函数)
        for pattern in patterns:
            state = 0  # 从根开始
            for char in pattern:
                if char not in self.states[state].goto:
                    self.states[state].goto[char] = self._new_state()
                state = self.states[state].goto[char]
            self.states[state].output.append(pattern)

        # 第二步:BFS 构建 Failure 函数
        queue = deque()
        # 根的直接子节点的 failure 都指向根
        for char, next_state in self.states[0].goto.items():
            self.states[next_state].failure = 0
            queue.append(next_state)

        while queue:
            current = queue.popleft()
            for char, next_state in self.states[current].goto.items():
                queue.append(next_state)
                # 找 failure 指针:沿着父节点的 failure 链查找
                failure = self.states[current].failure
                while failure != 0 and char not in self.states[failure].goto:
                    failure = self.states[failure].failure
                self.states[next_state].failure = (
                    self.states[failure].goto[char]
                    if char in self.states[failure].goto
                    else 0
                )
                # 合并 output:当前状态的输出 = 自身输出 + failure 指向状态的输出
                self.states[next_state].output += (
                    self.states[self.states[next_state].failure].output
                )

            queue.append(next_state) if False else None  # placeholder removed

    def search(self, text: str) -> list:
        """
        在文本中搜索所有匹配
        返回 [(position, pattern), ...]
        """
        results = []
        state = 0
        for i, char in enumerate(text):
            # 沿着 failure 链找到能转移的状态
            while state != 0 and char not in self.states[state].goto:
                state = self.states[state].failure
            state = self.states[state].goto.get(char, 0)
            # 收集所有匹配
            for pattern in self.states[state].output:
                results.append((i - len(pattern) + 1, pattern))
        return results


# 使用示例
ac = AhoCorasick()
patterns = ["he", "she", "his", "hers"]
ac.build(patterns)

text = "ahishers"
matches = ac.search(text)
for pos, pattern in matches:
    print(f"位置 {pos}: '{pattern}'")
# 输出:
# 位置 1: 'his'
# 位置 3: 'she'
# 位置 3: 'he'(she 包含 he)
# 位置 4: 'hers'

等等,上面的代码有一个 bug——在 BFS 中多了一行无意义的代码。让我给出修正版本:

from collections import deque


class AhoCorasick:
    """AC 自动机 - 多模式串匹配"""

    class State:
        def __init__(self):
            self.goto = {}        # 转移函数: char -> state_id
            self.failure = 0      # 失配指针
            self.output = []      # 输出函数: 该状态匹配的模式串列表

    def __init__(self):
        self.states = [self.State()]  # 初始只有根节点(id=0)

    def _new_state(self):
        self.states.append(self.State())
        return len(self.states) - 1

    def build(self, patterns: list) -> None:
        """
        三步构建 AC 自动机:
        1. 构建 Trie
        2. BFS 建立 failure 指针
        3. 合并 output
        """
        # Step 1: 构建 Goto(Trie)
        for pattern in patterns:
            state = 0
            for char in pattern:
                if char not in self.states[state].goto:
                    self.states[state].goto[char] = self._new_state()
                state = self.states[state].goto[char]
            self.states[state].output.append(pattern)

        # Step 2 & 3: BFS 构建 Failure 指针 + 合并 Output
        queue = deque()

        # 深度为 1 的节点:failure 指向根
        for char, s in self.states[0].goto.items():
            self.states[s].failure = 0
            queue.append(s)

        while queue:
            u = queue.popleft()
            for char, v in self.states[u].goto.items():
                queue.append(v)

                # 计算 v 的 failure 指针
                f = self.states[u].failure
                while f != 0 and char not in self.states[f].goto:
                    f = self.states[f].failure

                self.states[v].failure = (
                    self.states[f].goto[char] if char in self.states[f].goto else 0
                )

                # 避免 failure 指向自己
                if self.states[v].failure == v:
                    self.states[v].failure = 0

                # 合并 output(字典后缀链上的所有模式)
                self.states[v].output = (
                    self.states[v].output +
                    self.states[self.states[v].failure].output
                )

    def search(self, text: str) -> list:
        """
        扫描文本,返回所有匹配 [(start_pos, pattern), ...]
        时间复杂度: O(n + z),n = len(text),z = 匹配总数
        """
        results = []
        state = 0

        for i, char in enumerate(text):
            while state != 0 and char not in self.states[state].goto:
                state = self.states[state].failure
            state = self.states[state].goto.get(char, 0)

            for pattern in self.states[state].output:
                results.append((i - len(pattern) + 1, pattern))

        return results

13.8.4 AC 自动机的执行过程可视化

以模式串集合 {"he", "she", "his", "hers"} 为例,构建的自动机状态图(省略部分):

状态转移 (Goto):
  0 --h--> 1 --e--> 2 [output: "he"]
                     \--r--> 8 --s--> 9 [output: "hers"]
           \--i--> 6 --s--> 7 [output: "his"]
  0 --s--> 3 --h--> 4 --e--> 5 [output: "she", "he"]

Failure 指针:
  状态 5 (she) --failure--> 状态 2 (he)
  因为 "she" 的最长真后缀 "he" 恰好也是一个 Trie 中的前缀路径

当扫描文本 "ahishers" 时:

  1. 'a':根节点无 'a' 转移,停留在根(状态 0)
  2. 'h':转移到状态 1
  3. 'i':转移到状态 6
  4. 's':转移到状态 7,输出 "his"
  5. 'h':状态 7 无 'h' 转移,沿 failure 链回到状态 0,再转移到状态 1... 然后到状态 4(因为 "sh" 路径存在)
  6. 'e':转移到状态 5,输出 "she" 和 "he"
  7. 'r':从状态 5 的 failure 状态 2 出发,转移到状态 8
  8. 's':转移到状态 9,输出 "hers"

13.9 敏感词过滤实战

class SensitiveWordFilter:
    """
    基于 AC 自动机的敏感词过滤器
    支持:
    - 动态添加/删除敏感词
    - 全半角字符统一
    - 大小写不敏感
    - 过滤结果替换为 ***
    """
    def __init__(self):
        self.ac = None
        self.patterns = set()
        self._dirty = True  # 标记是否需要重建

    def add_word(self, word: str) -> None:
        """添加敏感词"""
        normalized = self._normalize(word)
        if normalized:
            self.patterns.add(normalized)
            self._dirty = True

    def remove_word(self, word: str) -> None:
        """删除敏感词"""
        normalized = self._normalize(word)
        self.patterns.discard(normalized)
        self._dirty = True

    def _normalize(self, text: str) -> str:
        """规范化文本:全角转半角、转小写"""
        result = []
        for ch in text:
            code = ord(ch)
            # 全角字符转半角
            if 0xFF01 <= code <= 0xFF5E:
                ch = chr(code - 0xFEE0)
            elif code == 0x3000:
                ch = ' '
            result.append(ch.lower())
        return ''.join(result)

    def _ensure_built(self):
        """确保自动机已构建"""
        if self._dirty:
            self.ac = AhoCorasick()
            self.ac.build(list(self.patterns))
            self._dirty = False

    def contains(self, text: str) -> bool:
        """检测文本是否包含敏感词"""
        self._ensure_built()
        normalized = self._normalize(text)
        return len(self.ac.search(normalized)) > 0

    def filter(self, text: str, replacement: str = "***") -> str:
        """将敏感词替换为指定字符"""
        self._ensure_built()
        normalized = self._normalize(text)
        matches = self.ac.search(normalized)

        if not matches:
            return text

        # 标记需要替换的位置
        mask = [False] * len(text)
        for start, pattern in matches:
            for i in range(start, start + len(pattern)):
                mask[i] = True

        # 构建结果
        result = []
        i = 0
        while i < len(text):
            if mask[i]:
                result.append(replacement)
                while i < len(text) and mask[i]:
                    i += 1
            else:
                result.append(text[i])
                i += 1
        return ''.join(result)

    def find_all(self, text: str) -> list:
        """找出所有敏感词及位置"""
        self._ensure_built()
        normalized = self._normalize(text)
        return self.ac.search(normalized)


# 使用示例
filter = SensitiveWordFilter()
filter.add_word("赌博")
filter.add_word("色情")
filter.add_word("暴力")
filter.add_word("毒品")

text = "这个网站包含赌博和色情内容,请远离"
print(filter.contains(text))  # True
print(filter.filter(text))    # "这个网站包含***和***内容,请远离"
print(filter.find_all(text))  # [(6, '赌博'), (9, '色情')]

工程优化建议

  1. 缓存友好的状态存储:将所有状态存入连续数组而非对象链表
  2. 增量构建:对于频繁变化的敏感词库,可以维护多个小型 AC 自动机分批查询
  3. 跳过无关字符:中文敏感词之间可能插入特殊字符(如 "赌◆博"),预处理时去除干扰字符
  4. 多级过滤:先用布隆过滤器快速排除无敏感词的文本,再用 AC 自动机精确匹配

Level 3 · 规范怎么定义的

13.10 Edward Fredkin 与 Trie 的起源(1960)

"Trie" 这个名字的来历,本身就是计算机科学史上一个有趣的命名之争。

1959 年,René de la Briandais 在 Western Joint Computer Conference 上发表了一篇论文 "File Searching Using Variable Length Keys",描述了一种利用字符串的逐字符比较来组织字典的树形结构。但他没有给这种结构起一个专门的名字。

1960 年,Edward Fredkin 在他的论文 "Trie Memory" (Communications of the ACM, Vol. 3, No. 9) 中正式提出了 "Trie" 这个术语。Fredkin 解释说,这个词取自 "retrieval" 的中间部分,暗示其核心用途是信息检索

命名争议:Fredkin 本人将 "Trie" 读作 "tree"(与 "树" 同音),认为这强调了它本质上是一种树结构。但后来大多数学术界和工业界将其读作 "try",以避免与通用的 "tree" 混淆。Donald Knuth 在《The Art of Computer Programming》Vol. 3 (Sorting and Searching, 1973) 中对此进行了详细讨论,并倾向于 "try" 的发音。

Fredkin 的原始动机:他在 Bolt, Beranek and Newman(BBN,后来参与了 ARPANET 的建设)工作时,需要高效存储和检索大量字符串。哈希表虽然单次查找快,但无法支持按前缀范围查询——而这在信息检索系统中是基本需求。Trie 结构的提出直接解决了这个问题。

Fredkin 的另一贡献:他同时提出了 Trie 可以推广到任意基数(radix)的思想。当每个字符有 r 种可能值时,每个节点最多有 r 个子节点。对于二进制字符串,r=2,这就是后来广泛应用于 IP 路由的二叉 Trie。

13.11 Aho-Corasick 自动机的诞生(1975)

Alfred V. Aho 和 Margaret J. Corasick 于 1975 年在 Communications of the ACM (Vol. 18, No. 6) 上发表了里程碑式的论文 "Efficient String Matching: An Aid to Bibliographic Search"。

研究背景:当时贝尔实验室正在开发文本处理工具。Unix 系统的 fgrep 命令需要在文件中同时搜索多个固定字符串。如果对每个模式串独立搜索,当模式串数量达到数百个时,性能不可接受。Aho 和 Corasick 需要一种能"一次扫描文本,同时匹配所有模式"的算法。

核心创新

  1. 将所有模式串构建为一棵 Trie(他们称之为 "goto function")
  2. 在 Trie 上增加 "failure function"——当当前字符无法在 Trie 中继续匹配时,跳转到最长的正确后缀对应的状态。这与 KMP 的思想相同,但推广到了多模式情况
  3. 增加 "output function"——通过 failure 链的传递,确保经过一个状态时能报告所有在该位置结束的模式

复杂度证明(论文中的关键定理):

设文本长度为 n,所有模式串的总长度为 m,匹配次数为 z:

为什么搜索是 O(n + z)? 关键观察是:文本中每个字符最多被处理一次(沿 failure 链回退的总次数在整个搜索过程中是 O(n) 的,这与 KMP 的分析完全类似)。

13.12 AC 自动机与 KMP 的关系

KMP 算法(Knuth-Morris-Pratt, 1977)和 AC 自动机有深刻的内在联系。从形式上看:

KMP 是 AC 自动机在单模式串情况下的特例。

概念 KMP AC 自动机
结构基础 单个模式串(线性) 多模式串的 Trie(树形)
失配处理 next[j]:模式串前 j 个字符的最长真前后缀 failure[s]:状态 s 对应字符串的最长真后缀(在 Trie 中)
预处理 O(m) O(m·σ) 或 O(m)
搜索 O(n) O(n + z)
提出时间 Knuth, Morris, Pratt 1977(实际发现于 1970) Aho, Corasick 1975

有趣的历史细节:AC 算法的论文(1975)比 KMP 论文(1977)发表得更早。实际上 KMP 算法在 1970 年就已被发现,但论文经过了长时间的修改才正式发表。Aho 在设计 AC 自动机时明确借鉴了 KMP 的失配跳转思想(他是 Knuth 的同事,了解 KMP 的未发表工作)。

从 KMP 到 AC 的推广过程

# KMP 的 next 数组:对单个模式串
def build_kmp_next(pattern):
    """
    next[i] = pattern[0:i] 的最长真前后缀长度
    等价于:如果在位置 i 失配,应该跳到 next[i] 继续
    """
    n = len(pattern)
    next_arr = [0] * n
    j = 0
    for i in range(1, n):
        while j > 0 and pattern[i] != pattern[j]:
            j = next_arr[j - 1]
        if pattern[i] == pattern[j]:
            j += 1
        next_arr[i] = j
    return next_arr

# AC 的 failure 指针:推广到多模式(树形结构)
# 对比:KMP 的 j = next[j-1] 对应 AC 的 failure = states[failure].failure
# KMP 是一条链上的回退,AC 是树上的回退(但逻辑完全相同)

本质统一:KMP 和 AC 自动机都是**确定有限自动机(DFA)**的特例。KMP 构建的是一个线性 DFA(每个状态对应模式串的一个前缀),AC 自动机构建的是一个树形 DFA(每个状态对应 Trie 中的一个前缀)。两者的 failure 函数本质上都是在实现 DFA 的状态压缩——避免显式存储所有可能的转移。

13.13 后缀树简介(Weiner, 1973)

后缀树(Suffix Tree) 是 Trie 家族中另一个极其重要的成员。Peter Weiner 在 1973 年的 FOCS(IEEE Symposium on Foundations of Computer Science)会议上发表了论文 "Linear Pattern Matching Algorithms",首次提出了线性时间构建后缀树的算法。Knuth 称之为 "1973 年的年度算法"。

什么是后缀树? 对于字符串 S = "banana$"($ 是终止符),其所有后缀为:

banana$
anana$
nana$
ana$
na$
a$
$

将这些后缀全部插入一棵压缩 Trie(Patricia Tree),就得到了后缀树。

后缀树的威力

问题 复杂度
判断 P 是否是 S 的子串 O(
计算 P 在 S 中出现几次 O(
找 S 的最长重复子串 O(
找两个字符串的最长公共子串 O(
找 S 的第 k 小后缀(后缀数组) O(

构建算法的演变

与 Trie 的关系:后缀树本质上是对一个字符串的所有后缀构建的压缩 Trie。它将 Trie 的应用从"多个短字符串的集合"扩展到了"一个长字符串的子串问题"。

为什么后缀树在实践中被后缀数组取代? Manber 和 Myers (1993) 提出的后缀数组(Suffix Array) 只需要一个整数数组来存储后缀的排序顺序,配合 LCP 数组能解决后缀树的大部分问题,且空间开销远小于后缀树(后缀树每个节点需要多个指针)。在基因组学等需要处理超长字符串的领域,后缀数组是事实标准。

13.14 Trie 的理论复杂度分析

平均情况分析(假设字符串在字符集 Σ 上均匀随机生成):

Philippe Flajolet 和 Robert Sedgewick 在《Analytic Combinatorics》(2009) 中对 Trie 进行了精确的平均情况分析:

最坏情况:当所有字符串共享很长的公共前缀时(如 "aaa...a1", "aaa...a2"),Trie 退化为接近线性链,高度接近最长字符串的长度 L_max。这就是压缩 Trie 存在的意义。


Level 4 · 边界与陷阱

13.15 LeetCode #208:实现 Trie(前缀树)

题目:实现 Trie 类,支持 insert、search、startsWith 三个操作。

这是最基础的 Trie 面试题,直接考察数据结构的实现能力:

class Trie:
    def __init__(self):
        self.children = {}
        self.is_end = False

    def insert(self, word: str) -> None:
        node = self
        for ch in word:
            if ch not in node.children:
                node.children[ch] = Trie()
            node = node.children[ch]
        node.is_end = True

    def search(self, word: str) -> bool:
        node = self._search_prefix(word)
        return node is not None and node.is_end

    def startsWith(self, prefix: str) -> bool:
        return self._search_prefix(prefix) is not None

    def _search_prefix(self, prefix: str):
        node = self
        for ch in prefix:
            if ch not in node.children:
                return None
            node = node.children[ch]
        return node

面试要点

13.16 LeetCode #211:添加与搜索单词

题目:设计数据结构 WordDictionary,支持 addWord 和 search,search 中 '.' 可以匹配任意一个字符。

class WordDictionary:
    def __init__(self):
        self.children = {}
        self.is_end = False

    def addWord(self, word: str) -> None:
        node = self
        for ch in word:
            if ch not in node.children:
                node.children[ch] = WordDictionary()
            node = node.children[ch]
        node.is_end = True

    def search(self, word: str) -> bool:
        return self._search(word, 0)

    def _search(self, word: str, index: int) -> bool:
        if index == len(word):
            return self.is_end

        ch = word[index]
        if ch == '.':
            # 通配符:尝试所有子节点
            for child in self.children.values():
                if child._search(word, index + 1):
                    return True
            return False
        else:
            if ch not in self.children:
                return False
            return self.children[ch]._search(word, index + 1)

复杂度分析

面试追问

13.17 LeetCode #212:单词搜索 II

题目:给定一个 m×n 的字符网格和一个单词列表,找出所有同时出现在网格和列表中的单词。网格中的单词必须由相邻单元格(上下左右)的字母构成,同一单元格不能重复使用。

这是 Trie + DFS 回溯的经典组合题:

class Solution:
    def findWords(self, board: list, words: list) -> list:
        # 第一步:将所有单词构建成 Trie
        root = {}
        for word in words:
            node = root
            for ch in word:
                node = node.setdefault(ch, {})
            node['#'] = word  # '#' 标记单词结尾,存储完整单词

        m, n = len(board), len(board[0])
        result = []

        def dfs(i, j, parent):
            ch = board[i][j]
            node = parent.get(ch)
            if node is None:
                return

            # 检查是否匹配到一个完整单词
            if '#' in node:
                result.append(node['#'])
                del node['#']  # 避免重复添加

            # 标记已访问
            board[i][j] = '@'

            # 四个方向 DFS
            for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                ni, nj = i + di, j + dj
                if 0 <= ni < m and 0 <= nj < n and board[ni][nj] != '@':
                    dfs(ni, nj, node)

            # 恢复
            board[i][j] = ch

            # 优化:如果当前节点已无子节点,剪掉它(剪枝)
            if not node:
                del parent[ch]

        # 从每个位置开始搜索
        for i in range(m):
            for j in range(n):
                dfs(i, j, root)

        return result

关键优化

  1. 用 Trie 替代逐词搜索:朴素方法对每个单词做一次 DFS,复杂度 O(k·m·n·4^L)。用 Trie 后,所有共享前缀的单词只需搜索一次
  2. 剪枝:找到一个单词后从 Trie 中删除,减少后续搜索空间
  3. 节点清理:当一个 Trie 节点的所有子孙都被匹配完毕后删除该节点,后续搜索不会再进入这条路径

复杂度

13.18 Trie vs 哈希表:如何选择

这是面试中常见的设计讨论题。答案取决于具体需求:

维度 Trie 哈希表
精确查找 O(L) O(L)(哈希计算也是 O(L))
前缀查询 O(P + K),天然支持 不支持(需遍历所有 key)
排序遍历 天然有序(DFS = 字典序) 无序
最长前缀匹配 O(L),天然支持 需要多次查找
内存占用 通常更大(每个字符一个节点) 通常更小
缓存性能 差(指针追踪,随机访问) 好(连续内存)
动态插入/删除 O(L) 均摊 O(L)
哈希冲突 不存在 存在,最坏 O(n)
热点问题 不存在 热 key 重哈希

选 Trie 的场景

选哈希表的场景

13.19 面试中的经典变体题

变体 1:MapSum — 键值映射

class MapSum:
    """
    LeetCode #677: 实现 insert(key, val) 和 sum(prefix)
    sum 返回所有以 prefix 为前缀的 key 的 val 之和
    """
    def __init__(self):
        self.root = {}
        self.map = {}  # 记录已插入的 key -> val

    def insert(self, key: str, val: int) -> None:
        delta = val - self.map.get(key, 0)
        self.map[key] = val
        node = self.root
        for ch in key:
            node = node.setdefault(ch, {'_sum': 0})
            node['_sum'] += delta

    def sum(self, prefix: str) -> int:
        node = self.root
        for ch in prefix:
            if ch not in node:
                return 0
            node = node[ch]
        return node.get('_sum', 0)

变体 2:回文对

class Solution:
    """
    LeetCode #336: 给定一组唯一的单词,找出所有 (i, j) 对使得
    words[i] + words[j] 是回文串
    
    思路:将每个单词的逆序插入 Trie,搜索时检查剩余部分是否为回文
    """
    def palindromePairs(self, words: list) -> list:
        # 构建逆序 Trie
        root = {}
        for idx, word in enumerate(words):
            node = root
            reversed_word = word[::-1]
            for i, ch in enumerate(reversed_word):
                # 如果 reversed_word[i:] 是回文,记录这个索引
                if self._is_palindrome(reversed_word, i, len(reversed_word) - 1):
                    node.setdefault('_palindrome_indices', []).append(idx)
                node = node.setdefault(ch, {})
            node['_idx'] = idx
            node.setdefault('_palindrome_indices', []).append(idx)

        result = []
        for idx, word in enumerate(words):
            node = root
            for i, ch in enumerate(word):
                # Case 1: word 比 Trie 中某个逆序词长
                if '_idx' in node and node['_idx'] != idx:
                    if self._is_palindrome(word, i, len(word) - 1):
                        result.append([idx, node['_idx']])
                if ch not in node:
                    break
                node = node[ch]
            else:
                # Case 2: word 完全匹配或更短
                if '_idx' in node and node['_idx'] != idx:
                    result.append([idx, node['_idx']])
                # Case 3: Trie 中有更长的逆序词,且剩余是回文
                for j in node.get('_palindrome_indices', []):
                    if j != idx and j != node.get('_idx', -1):
                        result.append([idx, j])

        return result

    def _is_palindrome(self, s, left, right):
        while left < right:
            if s[left] != s[right]:
                return False
            left += 1
            right -= 1
        return True

13.20 工程实践中的注意事项

1. 内存管理

在存储大量字符串时,Trie 的内存开销可能是哈希表的 3-10 倍。以 Python 为例,每个 TrieNode 对象的开销约 100-200 bytes(对象头 + dict 开销),而哈希表中一个字符串 key 只需要字符串本身的内存。

优化策略

class TrieNodeSlots:
    """使用 __slots__ 减少内存开销"""
    __slots__ = ['children', 'is_end']

    def __init__(self):
        self.children = {}
        self.is_end = False

2. 并发安全

生产环境中 Trie 通常是多线程访问的(如搜索引擎的自动补全服务)。选择方案:

3. 持久化

大型 Trie 需要持久化到磁盘。常见方案:

4. 性能基准测试

import time
import random
import string


def benchmark_trie_vs_hashset(n_words=100000, word_len=10):
    """Trie vs HashSet 性能对比"""
    # 生成随机单词
    words = [
        ''.join(random.choices(string.ascii_lowercase, k=word_len))
        for _ in range(n_words)
    ]

    # Trie 构建
    trie = Trie()
    start = time.time()
    for w in words:
        trie.insert(w)
    trie_build = time.time() - start

    # HashSet 构建
    start = time.time()
    hash_set = set(words)
    hash_build = time.time() - start

    # Trie 查找
    start = time.time()
    for w in words[:10000]:
        trie.search(w)
    trie_search = time.time() - start

    # HashSet 查找
    start = time.time()
    for w in words[:10000]:
        w in hash_set
    hash_search = time.time() - start

    # Trie 前缀查询(HashSet 无法高效完成)
    start = time.time()
    for w in words[:10000]:
        trie.starts_with(w[:3])
    trie_prefix = time.time() - start

    print(f"构建 {n_words} 个单词:")
    print(f"  Trie:    {trie_build:.3f}s")
    print(f"  HashSet: {hash_build:.3f}s")
    print(f"查找 10000 次:")
    print(f"  Trie:    {trie_search:.4f}s")
    print(f"  HashSet: {hash_search:.4f}s")
    print(f"前缀查询 10000 次:")
    print(f"  Trie:    {trie_prefix:.4f}s")
    print(f"  HashSet: N/A (不支持高效前缀查询)")

13.21 总结与选择指南

需要前缀匹配? ──── 是 ──→ Trie / Patricia Tree
       │
       否
       │
需要多模式匹配? ── 是 ──→ AC 自动机
       │
       否
       │
需要子串匹配? ─── 是 ──→ 后缀树 / 后缀数组
       │
       否
       │
只需精确查找? ─── 是 ──→ 哈希表

本章关键公式

记住:数据结构的选择不是"哪个更好"的问题,而是"哪个更适合你的具体需求"的问题。Trie 在前缀相关的问题上无可替代,但如果你只需要判断一个字符串是否存在于集合中,哈希表的简单和高效是更明智的选择。

本章评分
4.7  / 5  (26 评分)

💬 留言讨论