第 25 章

回溯:暴力搜索的优雅写法

第二十五章:回溯 — 暴力搜索的优雅写法

你有一个组合爆炸的搜索空间。可能是所有排列、所有子集、所有合法的棋盘布局、所有可能的密码组合。朴素的做法是枚举每一种可能,但当搜索空间是指数级甚至阶乘级时,盲目枚举在时间和空间上都不可行。

回溯(Backtracking)的核心洞察在于:不必枚举完所有可能再判断哪些合法——你可以在构造解的过程中,一旦发现当前路径不可能产生合法解,就立即回头(backtrack)尝试其他选择。这就像走迷宫时遇到死胡同会折返,而不是把整个迷宫走完再判断哪条路通向出口。

回溯算法的形式之美在于它的模板极其统一。无论是排列、组合、子集、N 皇后、数独还是正则匹配,底层都是同一个框架:选择 → 递归 → 撤销选择。一旦你理解了这个框架,所有回溯问题都变成了"在这个问题中,选择是什么、约束是什么、目标是什么"的填空题。

本章将从最基本的模板出发,逐步覆盖排列、组合、子集三大类经典问题,然后深入到 N 皇后、数独这类约束满足问题,再上升到算法理论层面讨论回溯与 NP 完全问题的关系。


Level 1 · 你需要知道的

1.1 回溯的核心思想

回溯算法本质上是对决策树的深度优先搜索。在搜索过程中,你维护一条从根到当前节点的路径(称为"当前选择序列"),每到一个节点就面临若干分支选择。当你发现某个分支不满足约束条件时,就"回溯"到上一个节点,尝试下一个分支。

三步曲

  1. 做选择(Choose):从候选集中选出一个元素,加入当前解
  2. 递归探索(Explore):在当前选择的基础上,继续搜索下一层
  3. 撤销选择(Unchoose):把刚才的选择从当前解中移除,恢复状态

为什么必须"撤销选择"?因为递归返回后,你需要尝试同一层的其他分支。如果不撤销,当前解中就会残留上一次的选择,导致后续分支在错误的状态上构建。

来看一个具体例子。假设我们要生成集合 {1, 2, 3} 的所有排列。决策树如下:

                    []
           /        |        \
         [1]       [2]       [3]
        /   \     /   \     /   \
     [1,2] [1,3] [2,1] [2,3] [3,1] [3,2]
      |      |     |      |     |      |
  [1,2,3][1,3,2][2,1,3][2,3,1][3,1,2][3,2,1]

每一层代表一个"位置"的选择。第一层选第一个数,第二层选第二个数,以此类推。在任何时刻,已经被选过的数不能再选(这是排列的约束)。

1.2 通用回溯模板

def backtrack(candidates, path, result, **constraints):
    """通用回溯模板
    
    Args:
        candidates: 候选集合(可供选择的元素)
        path: 当前已做的选择序列
        result: 存储所有合法解
        constraints: 问题特定的约束条件
    """
    # 终止条件:当前路径构成一个完整解
    if is_solution(path):
        result.append(path[:])  # 注意:必须复制 path
        return
    
    for choice in candidates:
        # 剪枝:跳过不满足约束的选择
        if not is_valid(choice, path, constraints):
            continue
        
        # 1. 做选择
        path.append(choice)
        
        # 2. 递归探索
        backtrack(next_candidates(choice), path, result, **constraints)
        
        # 3. 撤销选择
        path.pop()

关键点

1.3 排列问题

1.3.1 全排列(LeetCode #46)

问题:给定一个不含重复数字的数组 nums,返回其所有可能的全排列。

分析

from typing import List

def permute(nums: List[int]) -> List[List[int]]:
    """全排列
    
    时间复杂度: O(n! × n)  生成 n! 个排列,每个复制需要 O(n)
    空间复杂度: O(n)  递归栈深度 + used 数组
    """
    result = []
    used = [False] * len(nums)
    
    def backtrack(path: List[int]):
        # 终止条件
        if len(path) == len(nums):
            result.append(path[:])
            return
        
        for i in range(len(nums)):
            # 剪枝:跳过已使用的元素
            if used[i]:
                continue
            
            # 做选择
            path.append(nums[i])
            used[i] = True
            
            # 递归
            backtrack(path)
            
            # 撤销选择
            path.pop()
            used[i] = False
    
    backtrack([])
    return result

# 测试
print(permute([1, 2, 3]))
# [[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]

为什么用 used 数组而不是 if nums[i] not in path

因为 in 操作对列表是 O(n) 的,而 used[i] 是 O(1)。当 n 较大时,这个优化很重要。此外,使用索引而非值来判断"是否已用"可以正确处理数组中有相同值的情况(虽然本题保证无重复)。

1.3.2 全排列 II(LeetCode #47,含重复元素)

问题:给定一个可包含重复数字的序列 nums,返回所有不重复的全排列。

分析:相比 #46,难点在于去重。例如 nums = [1, 1, 2],如果不去重会得到两个 [1, 1, 2](第一个 1 在前 vs 第二个 1 在前)。

去重策略:先排序,然后在同一层递归中,如果当前元素和前一个元素相同,且前一个元素没有被使用(说明前一个元素是在同一层被撤销的),就跳过当前元素。

def permute_unique(nums: List[int]) -> List[List[int]]:
    """含重复元素的全排列
    
    核心去重逻辑:排序后,同一层中相同元素只选第一个
    
    时间复杂度: O(n! × n)  最坏情况(无重复时)
    空间复杂度: O(n)
    """
    result = []
    nums.sort()  # 排序是去重的前提
    used = [False] * len(nums)
    
    def backtrack(path: List[int]):
        if len(path) == len(nums):
            result.append(path[:])
            return
        
        for i in range(len(nums)):
            # 基本剪枝:跳过已使用的
            if used[i]:
                continue
            
            # 去重剪枝:同一层中,相同值只选第一个未用的
            # 条件解读:nums[i] == nums[i-1] 说明值相同
            # not used[i-1] 说明 nums[i-1] 在当前层已被撤销(不在路径中)
            # 这意味着我们正在同一层尝试一个相同的值 -> 跳过
            if i > 0 and nums[i] == nums[i - 1] and not used[i - 1]:
                continue
            
            path.append(nums[i])
            used[i] = True
            backtrack(path)
            path.pop()
            used[i] = False
    
    backtrack([])
    return result

# 测试
print(permute_unique([1, 1, 2]))
# [[1,1,2],[1,2,1],[2,1,1]]

深入理解去重条件 not used[i-1]

这个条件容易让人困惑。换一种理解方式:我们规定对于值相同的元素,必须按照它们在排序后数组中的相对顺序来使用。即:如果 nums[i-1] 和 nums[i] 相同,那么 nums[i] 只有在 nums[i-1] 已经被选中的情况下才能被选中。这样就保证了同一组相同值的元素不会产生重复排列。

1.4 组合问题

1.4.1 组合(LeetCode #77)

问题:给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。

分析

def combine(n: int, k: int) -> List[List[int]]:
    """组合
    
    通过 start 参数确保每次只选后面的数,避免重复
    
    时间复杂度: O(C(n,k) × k)  生成 C(n,k) 个组合,每个复制需要 O(k)
    空间复杂度: O(k)  递归栈深度
    """
    result = []
    
    def backtrack(start: int, path: List[int]):
        # 终止条件:组合长度达到 k
        if len(path) == k:
            result.append(path[:])
            return
        
        # 剪枝:如果剩余元素不够凑齐 k 个,提前终止
        # 还需要 k - len(path) 个元素,但从 start 到 n 只有 n - start + 1 个
        if n - start + 1 < k - len(path):
            return
        
        for i in range(start, n + 1):
            path.append(i)
            backtrack(i + 1, path)  # 注意是 i+1,不是 start+1
            path.pop()
    
    backtrack(1, [])
    return result

# 测试
print(combine(4, 2))
# [[1,2],[1,3],[1,4],[2,3],[2,4],[3,4]]

剪枝优化解释if n - start + 1 < k - len(path): return 这行看似简单,实际效果惊人。以 n=20, k=10 为例,不加这行需要访问约 100 万个节点,加了这行后只需约 18 万个节点——减少了 80% 以上的无效搜索。

1.4.2 组合总和(LeetCode #39)

问题:给定一个无重复元素的正整数数组 candidates 和一个目标数 target,找出 candidates 中所有可以使数字之和为 target 的组合。每个数字可以被无限次使用。

分析

def combination_sum(candidates: List[int], target: int) -> List[List[int]]:
    """组合总和(元素可重复使用)
    
    时间复杂度: O(n^(target/min)) 最坏情况  
    空间复杂度: O(target/min)  递归深度
    """
    result = []
    candidates.sort()  # 排序便于剪枝
    
    def backtrack(start: int, path: List[int], remaining: int):
        # 找到合法解
        if remaining == 0:
            result.append(path[:])
            return
        
        for i in range(start, len(candidates)):
            # 剪枝:如果当前元素已经大于剩余值,后面的更大,全部跳过
            if candidates[i] > remaining:
                break
            
            path.append(candidates[i])
            # 注意:传入 i 而非 i+1,因为同一个数可以重复使用
            backtrack(i, path, remaining - candidates[i])
            path.pop()
    
    backtrack(0, [], target)
    return result

# 测试
print(combination_sum([2, 3, 6, 7], 7))
# [[2,2,3],[7]]

排序 + break 的剪枝威力:如果 candidates = [2,3,6,7], target = 7,当我们尝试到 6 时,remaining = 7-6 = 1,此时再尝试 6,6 > 1,break。如果不排序不剪枝,则需要递归下去直到 remaining < 0 才返回,浪费大量时间。

1.5 子集问题

1.5.1 子集(LeetCode #78)

问题:给定一个整数数组 nums,数组中的元素互不相同,返回该数组所有可能的子集。

分析:子集问题与组合问题的核心区别——不需要终止条件来收集结果,每个中间状态本身就是一个子集

def subsets(nums: List[int]) -> List[List[int]]:
    """子集
    
    策略:每到一个节点就收集结果(包括空集)
    
    时间复杂度: O(2^n × n)  共 2^n 个子集,每个复制需要 O(n)
    空间复杂度: O(n)  递归深度
    """
    result = []
    
    def backtrack(start: int, path: List[int]):
        # 每个节点都是一个合法子集,直接收集
        result.append(path[:])
        
        for i in range(start, len(nums)):
            path.append(nums[i])
            backtrack(i + 1, path)
            path.pop()
    
    backtrack(0, [])
    return result

# 测试
print(subsets([1, 2, 3]))
# [[], [1], [1,2], [1,2,3], [1,3], [2], [2,3], [3]]

另一种理解方式——二进制枚举:对于 n 个元素的数组,每个元素要么选要么不选,共 2^n 种可能。可以用 n 位二进制数的每一位表示对应元素是否被选中:

def subsets_bitmask(nums: List[int]) -> List[List[int]]:
    """子集(位掩码法)"""
    n = len(nums)
    result = []
    for mask in range(1 << n):  # 0 到 2^n - 1
        subset = []
        for i in range(n):
            if mask & (1 << i):
                subset.append(nums[i])
        result.append(subset)
    return result

两种方法产生相同结果,但回溯法更容易添加剪枝条件,在搜索空间大时更高效。

1.5.2 子集 II(LeetCode #90,含重复元素)

问题:给定一个可能包含重复元素的整数数组 nums,返回所有可能的不重复子集。

def subsets_with_dup(nums: List[int]) -> List[List[int]]:
    """含重复元素的子集
    
    去重策略:排序 + 同一层跳过相同元素
    """
    result = []
    nums.sort()
    
    def backtrack(start: int, path: List[int]):
        result.append(path[:])
        
        for i in range(start, len(nums)):
            # 去重:同一层中跳过重复元素
            if i > start and nums[i] == nums[i - 1]:
                continue
            
            path.append(nums[i])
            backtrack(i + 1, path)
            path.pop()
    
    backtrack(0, [])
    return result

# 测试
print(subsets_with_dup([1, 2, 2]))
# [[], [1], [1,2], [1,2,2], [2], [2,2]]

1.6 排列 vs 组合 vs 子集的区别

维度 排列 组合 子集
是否关注顺序 是([1,2] ≠ [2,1]) 否([1,2] = [2,1])
结果长度 固定为 n 固定为 k 0 到 n 不等
如何避免重复 used 数组 start 起始位置 start 起始位置
何时收集结果 path 长度 == n path 长度 == k 每个节点都收集
搜索空间大小 n! C(n,k) 2^n
下一层候选集 所有未用元素 start 之后的元素 start 之后的元素

核心区别的直觉

理解了这三者的统一性和差异性,你就掌握了回溯问题的 80%。


Level 2 · 深入理解

2.1 N 皇后问题

N 皇后问题是回溯算法最经典的应用之一:在 N×N 的棋盘上放置 N 个皇后,使得任何两个皇后都不能在同一行、同一列、或同一对角线上。

2.1.1 基本实现

def solve_n_queens(n: int) -> List[List[str]]:
    """N 皇后问题
    
    策略:逐行放置皇后,每行必须且只能放一个
    约束检查:列冲突、主对角线冲突、副对角线冲突
    
    时间复杂度: O(n!)  实际由于剪枝远小于此
    空间复杂度: O(n)
    """
    result = []
    # 用集合记录已占用的列和对角线
    cols = set()         # 已占用的列
    diag1 = set()        # 已占用的主对角线(行-列 为常数)
    diag2 = set()        # 已占用的副对角线(行+列 为常数)
    
    board = [['.'] * n for _ in range(n)]
    
    def backtrack(row: int):
        # 终止条件:所有行都放置了皇后
        if row == n:
            # 将棋盘转为字符串形式
            result.append([''.join(r) for r in board])
            return
        
        for col in range(n):
            # 剪枝:检查列和对角线是否冲突
            if col in cols or (row - col) in diag1 or (row + col) in diag2:
                continue
            
            # 做选择:放置皇后
            board[row][col] = 'Q'
            cols.add(col)
            diag1.add(row - col)
            diag2.add(row + col)
            
            # 递归到下一行
            backtrack(row + 1)
            
            # 撤销选择
            board[row][col] = '.'
            cols.remove(col)
            diag1.remove(row - col)
            diag2.remove(row + col)
    
    backtrack(0)
    return result

# 测试
solutions = solve_n_queens(4)
for sol in solutions:
    for row in sol:
        print(row)
    print()
# .Q..
# ...Q
# Q...
# ..Q.
#
# ..Q.
# Q...
# ...Q
# .Q..

为什么用 row - colrow + col 表示对角线?

在棋盘上,同一条主对角线(左上到右下)上的所有格子满足 row - col 相等。例如 (0,0), (1,1), (2,2) 都在同一条主对角线上,它们的 row - col 都是 0。同理,同一条副对角线(右上到左下)上的所有格子满足 row + col 相等。

2.1.2 位运算优化

对于 N 皇后问题,当 N 较大时(比如 N=15),集合操作的常数因子成为瓶颈。位运算可以把三个集合的操作压缩为整数的位操作,极大提升速度。

def total_n_queens(n: int) -> int:
    """N 皇后问题 - 位运算优化版本(只计数,不记录方案)
    
    核心思想:用 n 位整数的每一位表示该列/对角线是否被占用
    位运算将三次集合查找 + 插入 + 删除压缩为几次位操作
    
    在 N=15 时,位运算版本比集合版本快约 10 倍
    """
    count = 0
    all_ones = (1 << n) - 1  # n 位全 1,表示所有列
    
    def backtrack(cols: int, diag1: int, diag2: int):
        nonlocal count
        
        if cols == all_ones:
            # 所有列都被占用,说明放满了 n 个皇后
            count += 1
            return
        
        # available: 当前行可以放置皇后的位置
        # ~(cols | diag1 | diag2) 取反得到未被攻击的位置
        # & all_ones 确保只看 n 位
        available = all_ones & ~(cols | diag1 | diag2)
        
        while available:
            # 取最低位的 1(lowbit 技巧)
            position = available & (-available)
            available -= position  # 或 available &= available - 1
            
            # 递归
            # cols | position: 标记这一列被占用
            # (diag1 | position) << 1: 主对角线向下一行传播时左移一位
            # (diag2 | position) >> 1: 副对角线向下一行传播时右移一位
            backtrack(
                cols | position,
                (diag1 | position) << 1,
                (diag2 | position) >> 1
            )
    
    backtrack(0, 0, 0)
    return count

# 测试
print(total_n_queens(8))   # 92
print(total_n_queens(12))  # 14200

位运算版本的精妙之处

  1. available & (-available) 提取最低位的 1——这是补码表示的性质:-x = ~x + 1
  2. 对角线的传播用移位实现:主对角线每下一行等于左移一位,副对角线右移一位。这完美模拟了对角线从上往下的覆盖范围扩展
  3. 不需要显式的"撤销选择"——因为状态完全通过函数参数传递,递归返回后自然恢复

2.2 数独求解器

数独是另一个经典的约束满足问题(Constraint Satisfaction Problem, CSP)。

def solve_sudoku(board: List[List[str]]) -> None:
    """数独求解器(原地修改 board)
    
    策略:逐格填入 1-9,通过行/列/宫约束剪枝
    优化:优先填约束最多(候选数最少)的格子(MRV 启发式)
    
    时间复杂度: 最坏 O(9^81),实际剪枝后远小于此
    """
    rows = [set() for _ in range(9)]
    cols = [set() for _ in range(9)]
    boxes = [set() for _ in range(9)]
    empty = []
    
    # 初始化约束集合
    for i in range(9):
        for j in range(9):
            if board[i][j] != '.':
                num = board[i][j]
                rows[i].add(num)
                cols[j].add(num)
                boxes[(i // 3) * 3 + j // 3].add(num)
            else:
                empty.append((i, j))
    
    def backtrack(idx: int) -> bool:
        # 所有空格都填完了
        if idx == len(empty):
            return True
        
        i, j = empty[idx]
        box_id = (i // 3) * 3 + j // 3
        
        for num in '123456789':
            # 约束检查
            if num in rows[i] or num in cols[j] or num in boxes[box_id]:
                continue
            
            # 做选择
            board[i][j] = num
            rows[i].add(num)
            cols[j].add(num)
            boxes[box_id].add(num)
            
            # 递归
            if backtrack(idx + 1):
                return True  # 找到解就返回
            
            # 撤销选择
            board[i][j] = '.'
            rows[i].remove(num)
            cols[j].remove(num)
            boxes[box_id].remove(num)
        
        return False  # 当前格子无合法数字,回溯
    
    backtrack(0)

数独与 N 皇后的对比

方面 N 皇后 数独
变量 每行皇后的列位置 每个空格的数字
值域 0 到 N-1 1 到 9
约束 列、两条对角线 行、列、3×3 宫
解的数量 可能有多个 合法数独只有唯一解
搜索策略 按行递增搜索 按空格顺序(或 MRV)

2.3 剪枝策略

剪枝是回溯算法的灵魂。好的剪枝可以让算法在合理时间内解决实际规模的问题。

2.3.1 可行性剪枝

在选择时就检查约束是否被违反,而不是到终止条件才检查。这是最基本也是最有效的剪枝。

# 不好的写法:到终止条件才检查
def backtrack_bad(path):
    if is_complete(path):
        if is_valid(path):  # 到最后才检查
            result.append(path[:])
        return
    for choice in candidates:
        path.append(choice)
        backtrack_bad(path)
        path.pop()

# 好的写法:每一步都检查
def backtrack_good(path):
    if is_complete(path):
        result.append(path[:])
        return
    for choice in candidates:
        if not is_valid_choice(choice, path):  # 提前检查
            continue
        path.append(choice)
        backtrack_good(path)
        path.pop()

2.3.2 排序后剪枝

对于"求和等于 target"类问题,排序后可以在当前元素大于剩余所需时直接 break,而非 continue。breakcontinue 更强:它不仅跳过当前元素,还跳过了后面所有更大的元素。

def combination_sum2(candidates: List[int], target: int) -> List[List[int]]:
    """组合总和 II(每个元素只能用一次,含重复元素)"""
    result = []
    candidates.sort()
    
    def backtrack(start: int, path: List[int], remaining: int):
        if remaining == 0:
            result.append(path[:])
            return
        
        for i in range(start, len(candidates)):
            # 排序后剪枝:当前值超过剩余,后面更大,全部不可能
            if candidates[i] > remaining:
                break  # 注意是 break 不是 continue!
            
            # 去重:同一层跳过重复
            if i > start and candidates[i] == candidates[i - 1]:
                continue
            
            path.append(candidates[i])
            backtrack(i + 1, path, remaining - candidates[i])
            path.pop()
    
    backtrack(0, [], target)
    return result

2.3.3 最优性剪枝

如果问题要求最优解(最大/最小),可以维护当前已知最优值。当发现某条分支不可能超越当前最优时,提前终止。

def min_cost_backtrack(graph, n, current_cost, best_cost, path, visited):
    """旅行商问题(TSP)的回溯解法 + 最优性剪枝"""
    if len(path) == n:
        # 加上回到起点的代价
        total = current_cost + graph[path[-1]][path[0]]
        return min(best_cost, total)
    
    for next_city in range(n):
        if visited[next_city]:
            continue
        
        new_cost = current_cost + graph[path[-1]][next_city]
        
        # 最优性剪枝:当前代价已超过已知最优,不可能更好
        if new_cost >= best_cost:
            continue
        
        visited[next_city] = True
        path.append(next_city)
        best_cost = min_cost_backtrack(graph, n, new_cost, best_cost, path, visited)
        path.pop()
        visited[next_city] = False
    
    return best_cost

2.4 回溯与 DFS 的关系

回溯和 DFS(深度优先搜索)经常被混为一谈。它们确实密切相关,但有微妙区别:

概念 深度优先搜索(DFS) 回溯(Backtracking)
本质 图/树的遍历策略 问题求解方法论
关注点 访问顺序 决策 + 撤销
状态管理 通常不修改(或用 visited 标记) 显式做选择和撤销选择
使用场景 遍历所有节点 搜索满足约束的解
搜索空间 显式图/树 隐式搜索树(由决策构造)

类比:DFS 是"走遍地图的方式",回溯是"在迷宫中找出口的方法"。走迷宫用的是 DFS 的方向策略,但加上了"此路不通就折返"的回溯逻辑。

从实现角度看,回溯就是 DFS + 状态恢复。如果你在 DFS 遍历中加上了"做选择→递归→撤销选择"的模式,那就是回溯了。

2.5 时间复杂度分析

回溯算法的时间复杂度通常是指数级或阶乘级,这不是实现的问题,而是问题本身的搜索空间决定的。

常见问题的搜索空间

问题 搜索空间 分析
全排列 O(n!) 第1层n种选择,第2层n-1种,...
子集 O(2^n) 每个元素选或不选
组合 C(n,k) O(C(n,k)) 组合数
N皇后 O(n!) 第1行n列,第2行最多n-1列,...
数独 O(9^m) m为空格数,每个空格最多9种选择

为什么剪枝不改变渐近复杂度?

从理论角度,最坏情况下所有分支都需要搜索,剪枝并不能改变渐近上界。但在实际应用中,剪枝可以将平均搜索规模减少几个数量级。这就是为什么理论上 O(n!) 的 N 皇后问题在 N=15 时仍然可以在几秒内求解。

分析技巧——节点计数法

回溯的时间 = 搜索树的节点数 × 每个节点的工作量。对于组合问题,搜索树节点数就是答案的数量(因为每个叶子产生一个解,内部节点是通往解的路径)。对于排列问题,搜索树有 n! 个叶子和约 e × n! 个内部节点(指数生成函数给出)。


Level 3 · 理论与历史

3.1 回溯算法的历史

回溯算法的形式化可以追溯到 1950 年代。以下是几个关键里程碑:

1850 年代:N 皇后问题的手工解法

八皇后问题最早由国际象棋棋手 Max Bezzel 在 1848 年提出。数学家 Franz Nauck 在 1850 年给出了所有 92 个解。当时没有计算机,求解完全依赖手工推导和数学对称性。

1960 年:Golomb & Baumert 的系统化

Solomon Golomb 和 Leonard Baumert 在 1965 年发表了论文 "Backtrack Programming"(Journal of the ACM, 12(4):516-524),首次系统性地描述了回溯作为一种通用的算法设计范式。他们指出回溯适用于所有可以表示为"扩展偏序解"的问题。

1970 年代:约束满足问题(CSP)框架

随着人工智能研究的兴起,回溯被纳入约束满足问题的统一框架。CSP 包含三个要素:变量集合、每个变量的值域、变量间的约束。回溯是求解 CSP 的最基本方法——逐个为变量赋值,赋值时检查约束,违反则回溯。

1970-80 年代:改进技术

这些技术将朴素回溯的实际性能提升了几个数量级,使得大规模 CSP(如日程安排、排课)变得可解。

3.2 回溯与分支限界

分支限界(Branch and Bound, B&B)是回溯的"优化版"变体,专门用于求解最优化问题。

核心区别

方面 回溯 分支限界
目标 找到所有解 / 任一解 找到最优解
搜索顺序 通常 DFS BFS、最佳优先搜索
剪枝依据 可行性约束 最优性界(bound)
典型应用 CSP、枚举 整数规划、组合优化

B&B 的关键思想——界(Bound):

对每个搜索树的节点,计算一个"乐观估计"(上界或下界):如果从当前节点出发的最优解都不如已知最优解,就剪掉整棵子树。

def branch_and_bound_knapsack(weights, values, capacity):
    """0-1 背包的分支限界解法
    
    使用贪心法(按价值密度排序取分数物品)计算上界
    """
    n = len(weights)
    # 按价值密度降序排列
    items = sorted(range(n), key=lambda i: values[i] / weights[i], reverse=True)
    best_value = 0
    
    def upper_bound(idx, current_weight, current_value):
        """贪心上界:假设可以取分数物品"""
        bound = current_value
        w = current_weight
        for i in range(idx, n):
            item = items[i]
            if w + weights[item] <= capacity:
                w += weights[item]
                bound += values[item]
            else:
                # 取分数部分
                bound += values[item] * (capacity - w) / weights[item]
                break
        return bound
    
    def solve(idx, current_weight, current_value):
        nonlocal best_value
        
        if idx == n:
            best_value = max(best_value, current_value)
            return
        
        item = items[idx]
        
        # 分支1:选当前物品
        if current_weight + weights[item] <= capacity:
            solve(idx + 1, current_weight + weights[item],
                  current_value + values[item])
        
        # 分支2:不选当前物品(先计算上界看是否值得搜索)
        if upper_bound(idx + 1, current_weight, current_value) > best_value:
            solve(idx + 1, current_weight, current_value)
    
    solve(0, 0, 0)
    return best_value

B&B 在工业中的应用:几乎所有商用的整数线性规划(ILP)求解器——CPLEX、Gurobi、SCIP——都以 B&B 为核心框架。它们结合了线性松弛(LP relaxation)来计算上界、切割平面(cutting planes)来收紧松弛、以及复杂的搜索策略来选择分支变量。

3.3 回溯在编译器中的应用——正则表达式引擎

你可能没想到,回溯最广泛的日常应用之一是正则表达式匹配。大多数编程语言(Python、JavaScript、Java、Ruby、PHP)的正则引擎都是基于回溯的 NFA(非确定性有限自动机)模拟器。

正则回溯的工作原理

当正则引擎遇到分支(如 a|b)或量词(如 a*)时,它会选择一个方向尝试匹配。如果失败,就回溯到选择点,尝试另一个方向。

def regex_match(pattern: str, text: str) -> bool:
    """简化的正则匹配引擎(支持 . 和 *)
    
    这本质上是回溯:尝试一种匹配方式,失败则回退尝试其他方式
    """
    def match(p_idx: int, t_idx: int) -> bool:
        # 模式用完了,看文本是否也用完
        if p_idx == len(pattern):
            return t_idx == len(text)
        
        # 当前字符是否匹配
        first_match = (t_idx < len(text) and 
                       (pattern[p_idx] == text[t_idx] or pattern[p_idx] == '.'))
        
        # 处理 * 量词
        if p_idx + 1 < len(pattern) and pattern[p_idx + 1] == '*':
            # 两个选择:
            # 1. * 匹配 0 次:跳过 pattern 中的 x*
            # 2. * 匹配 1+ 次:消耗 text 的一个字符,pattern 不动
            return (match(p_idx + 2, t_idx) or  # 匹配 0 次
                    (first_match and match(p_idx, t_idx + 1)))  # 匹配 1+ 次
        
        # 普通匹配
        if first_match:
            return match(p_idx + 1, t_idx + 1)
        
        return False
    
    return match(0, 0)

灾难性回溯(Catastrophic Backtracking)

回溯式正则引擎有一个严重缺陷:对于某些恶意构造的模式和输入,回溯次数可能呈指数增长,导致程序挂起。这被称为 ReDoS(Regular Expression Denial of Service)攻击。

经典例子:模式 (a+)+$ 匹配字符串 aaaaaaaaaaaaaaaaab。引擎会尝试 2^n 种分组方式后才能确定不匹配。

import re
import time

# 危险!以下代码可能导致程序挂起
# pattern = r"(a+)+$"
# text = "a" * 25 + "b"
# re.match(pattern, text)  # 可能需要几十秒甚至更长

# 安全替代方案:使用非回溯的正则引擎(如 Go 的 regexp 或 RE2)
# 或者改写正则:r"a+$" 语义等价但不会灾难性回溯

为什么不用 DFA? Thompson NFA(1968)或 DFA 可以在 O(n) 时间内匹配任何正则,不存在灾难性回溯。但 DFA 不支持反向引用(backreference)——像 (.)\1 这样的模式引用了先前捕获的内容,这在理论上超出了正则语言的表达能力。由于反向引用在实践中广泛使用,大多数语言选择了回溯引擎作为权衡。Go 是少数坚持使用 RE2(DFA 实现)而放弃反向引用的语言之一。

3.4 NP 完全问题与回溯

为什么很多 NP 问题只能用回溯/暴力?

P vs NP 是计算机科学最重要的开放问题之一。NP 完全(NP-Complete)问题是 NP 中"最难"的一类——如果任何一个 NP 完全问题有多项式时间算法,那所有 NP 问题都有。目前没有人找到任何 NP 完全问题的多项式算法(也没人证明不存在)。

回溯与 NP 完全的关系

NP 的定义是"可以在多项式时间内验证解的正确性"的问题类。这意味着:

回溯恰好符合这个模式:它通过搜索来寻找解,用约束检查来验证和剪枝。对于 NP 完全问题(如旅行商、图着色、布尔可满足性),回溯(配合启发式和剪枝)是最主要的精确求解方法。

典型的 NP 完全问题与回溯解法

问题 决策版本 回溯策略
SAT 布尔公式是否可满足 DPLL 算法(带单元传播和纯文字消除的回溯)
图着色 图是否可用 k 色着色 逐个顶点赋色,约束检查相邻颜色
哈密尔顿路径 是否存在经过所有顶点恰一次的路径 逐步扩展路径,检查连通性和唯一访问
子集和 是否存在子集和为 target 每个元素选/不选,排序后剪枝

一个关键洞察:虽然 NP 完全问题没有已知的多项式精确算法,但对于许多实际规模的实例,配合好的剪枝和启发式的回溯算法足够高效。SAT 求解器(如 MiniSat、CaDiCaL)可以处理数百万变量的工业实例——尽管最坏情况仍是指数级的。这告诉我们:平均情况和最坏情况可能有天壤之别。


Level 4 · 实战与面试

4.1 去重的正确写法

去重是回溯问题中最容易出错的部分。让我们系统地总结两种去重方法。

4.1.1 方法一:排序 + 跳过相同元素

适用于组合/子集问题(同一层跳过):

# 在 for 循环中:
if i > start and nums[i] == nums[i - 1]:
    continue

适用于排列问题(同一层跳过 + used 判断):

# 在 for 循环中:
if i > 0 and nums[i] == nums[i - 1] and not used[i - 1]:
    continue

为什么排列的去重条件多了 not used[i-1]

因为排列的循环每次从 0 开始(不像组合从 start 开始),所以 i > start 的技巧不适用。我们需要另一种方式来区分"同一层的重复"和"不同层的合法选择":

4.1.2 方法二:使用集合记录同一层已选值

当数组无法排序,或排序代价过高时:

def permute_unique_set(nums: List[int]) -> List[List[int]]:
    """用集合去重的排列(不需要排序)"""
    result = []
    used = [False] * len(nums)
    
    def backtrack(path):
        if len(path) == len(nums):
            result.append(path[:])
            return
        
        seen = set()  # 记录当前层已尝试过的值
        for i in range(len(nums)):
            if used[i]:
                continue
            if nums[i] in seen:  # 同一层用过相同值,跳过
                continue
            seen.add(nums[i])
            
            path.append(nums[i])
            used[i] = True
            backtrack(path)
            path.pop()
            used[i] = False
    
    backtrack([])
    return result

两种方法的对比

方面 排序+跳过 集合去重
前置条件 需要排序 O(n log n) 不需要排序
额外空间 O(1)(不算递归栈) O(n) 每层一个集合
适用场景 数组可排序 数组不宜排序(如需保持原始顺序)
常数因子 更小 哈希表开销

4.2 面试高频题详解

4.2.1 电话号码的字母组合(LeetCode #17)

问题:给定一个仅包含数字 2-9 的字符串,返回所有它能表示的字母组合。

def letter_combinations(digits: str) -> List[str]:
    """电话号码的字母组合
    
    时间复杂度: O(4^n × n)  最多4个字母(如7和9),n为digits长度
    空间复杂度: O(n)  递归深度
    """
    if not digits:
        return []
    
    phone_map = {
        '2': 'abc', '3': 'def', '4': 'ghi', '5': 'jkl',
        '6': 'mno', '7': 'pqrs', '8': 'tuv', '9': 'wxyz'
    }
    
    result = []
    
    def backtrack(idx: int, path: List[str]):
        if idx == len(digits):
            result.append(''.join(path))
            return
        
        for char in phone_map[digits[idx]]:
            path.append(char)
            backtrack(idx + 1, path)
            path.pop()
    
    backtrack(0, [])
    return result

# 测试
print(letter_combinations("23"))
# ["ad","ae","af","bd","be","bf","cd","ce","cf"]

面试考点:这题看似简单,实际考察你是否理解——这不是排列/组合/子集中的任何一种,而是"多层选择"问题。每层的候选集不同(由 digits 的每一位决定),且没有去重需求。

4.2.2 分割回文串(LeetCode #131)

问题:给定一个字符串 s,将 s 分割成一些子串,使每个子串都是回文串。返回所有可能的分割方案。

def partition(s: str) -> List[List[str]]:
    """分割回文串
    
    策略:枚举第一刀的位置 -> 如果前缀是回文 -> 递归处理剩余部分
    优化:预计算回文表避免重复判断
    
    时间复杂度: O(n × 2^n)  最坏情况所有子串都是回文(如 "aaa")
    空间复杂度: O(n^2)  回文预计算表
    """
    n = len(s)
    
    # 预计算:dp[i][j] 表示 s[i:j+1] 是否为回文
    # 使用动态规划 O(n^2) 预处理,避免每次 O(n) 判断
    is_palindrome = [[False] * n for _ in range(n)]
    for i in range(n - 1, -1, -1):
        for j in range(i, n):
            if s[i] == s[j] and (j - i <= 2 or is_palindrome[i + 1][j - 1]):
                is_palindrome[i][j] = True
    
    result = []
    
    def backtrack(start: int, path: List[str]):
        if start == n:
            result.append(path[:])
            return
        
        for end in range(start, n):
            # 剪枝:只有前缀是回文才继续
            if not is_palindrome[start][end]:
                continue
            
            path.append(s[start:end + 1])
            backtrack(end + 1, path)
            path.pop()
    
    backtrack(0, [])
    return result

# 测试
print(partition("aab"))
# [["a","a","b"],["aa","b"]]

关键优化:预计算回文表是 O(n²) 的,但它把每次判断回文的成本从 O(n) 降到 O(1)。在最坏情况下(如 "aaaa...a"),这个优化将总时间从 O(n² × 2^n) 降到 O(n × 2^n)。

4.2.3 单词搜索(LeetCode #79)

问题:给定一个 m×n 二维字符网格 board 和一个字符串 word,判断 word 是否存在于网格中(相邻字符水平或垂直连接,每个字符只能使用一次)。

def exist(board: List[List[str]], word: str) -> bool:
    """单词搜索
    
    策略:从每个可能的起点出发,DFS + 回溯搜索四个方向
    优化:原地修改 board 做 visited 标记(避免额外空间)
    
    时间复杂度: O(m × n × 3^L)  L为word长度,每步最多3个方向(排除来源)
    空间复杂度: O(L)  递归深度
    """
    m, n = len(board), len(board[0])
    
    def backtrack(i: int, j: int, k: int) -> bool:
        """在位置 (i,j) 尝试匹配 word[k:]"""
        # 终止:所有字符都匹配了
        if k == len(word):
            return True
        
        # 边界检查 + 字符匹配
        if (i < 0 or i >= m or j < 0 or j >= n or 
            board[i][j] != word[k]):
            return False
        
        # 做选择:标记已访问
        temp = board[i][j]
        board[i][j] = '#'  # 原地修改,O(1) 空间
        
        # 探索四个方向
        found = (backtrack(i + 1, j, k + 1) or
                 backtrack(i - 1, j, k + 1) or
                 backtrack(i, j + 1, k + 1) or
                 backtrack(i, j - 1, k + 1))
        
        # 撤销选择:恢复原始字符
        board[i][j] = temp
        
        return found
    
    # 从每个位置作为起点尝试
    for i in range(m):
        for j in range(n):
            if backtrack(i, j, 0):
                return True
    return False

# 测试
board = [
    ['A','B','C','E'],
    ['S','F','C','S'],
    ['A','D','E','E']
]
print(exist(board, "ABCCED"))  # True
print(exist(board, "SEE"))     # True
print(exist(board, "ABCB"))    # False

面试追问

Q: 为什么用原地修改而不是 visited 矩阵? A: visited 矩阵需要 O(m×n) 额外空间。原地修改将字符替换为特殊标记,回溯时恢复,空间复杂度从 O(m×n) 降到 O(L)(仅递归栈)。

Q: 能否进一步优化? A: 可以在搜索前做字符频率检查——如果 word 中某字符的出现次数超过 board 中的次数,直接返回 False。还可以比较 word 首尾字符在 board 中的出现频率,从出现少的那端开始搜索。

4.2.4 括号生成(LeetCode #22)

问题:生成所有由 n 对括号组成的合法组合。

def generate_parenthesis(n: int) -> List[str]:
    """括号生成
    
    约束:任何前缀中左括号数 >= 右括号数
    选择:每步可以加左括号(如果还有余额)或右括号(如果左>右)
    
    时间复杂度: O(4^n / sqrt(n))  第n个卡特兰数
    空间复杂度: O(n)
    """
    result = []
    
    def backtrack(path: List[str], left: int, right: int):
        """
        left: 已使用的左括号数
        right: 已使用的右括号数
        """
        # 终止条件
        if len(path) == 2 * n:
            result.append(''.join(path))
            return
        
        # 选择1:加左括号(如果还有余额)
        if left < n:
            path.append('(')
            backtrack(path, left + 1, right)
            path.pop()
        
        # 选择2:加右括号(如果不会导致不合法)
        if right < left:
            path.append(')')
            backtrack(path, left, right + 1)
            path.pop()
    
    backtrack([], 0, 0)
    return result

# 测试
print(generate_parenthesis(3))
# ["((()))","(()())","(())()","()(())","()()()"]

数学背景——卡特兰数

n 对括号的合法组合数等于第 n 个卡特兰数 C_n = C(2n, n) / (n+1)。前几项为:1, 1, 2, 5, 14, 42, 132, 429, ...

卡特兰数还出现在:n 个节点的不同二叉搜索树数量、从 (0,0) 到 (n,n) 不穿越对角线的路径数、n+1 个矩阵相乘的不同加括号方式数。

4.3 时间复杂度估算技巧

面试中估算回溯的时间复杂度常让人头疼。以下是实用的估算框架:

方法一:搜索树计数

  1. 确定搜索树的深度 d(通常是解的长度)
  2. 确定每层的分支因子 b(通常是候选集大小)
  3. 总节点数约为 b^d(上界)

例如:

方法二:输出大小下界

如果问题要求输出所有解,那么时间复杂度至少是输出大小:

方法三:带剪枝的实际估计

对于有强剪枝的问题,理论上界和实际运行时间可能差数量级。此时通过实验计数搜索树节点更现实。

import functools

def count_nodes_example():
    """计数搜索树节点来估算实际复杂度"""
    count = 0
    
    def backtrack_with_counting(path, ...):
        nonlocal count
        count += 1
        # ... 正常逻辑
    
    backtrack_with_counting([], ...)
    print(f"搜索了 {count} 个节点")

4.4 回溯 vs 动态规划:如何抉择

面试中最常见的问题之一:这道题该用回溯还是动态规划?

核心区别——重叠子问题

判断方法

  1. 画出搜索树
  2. 看是否有相同的子树出现多次
  3. 如果有 → 记忆化/DP;如果没有 → 回溯

灰色地带——回溯 + 记忆化

有些问题兼有两者特征。例如"单词拆分 II"(LeetCode #140)需要找出所有合法分割,逻辑上是回溯(枚举所有分割),但子问题有重叠(相同后缀可以复用结果),可以用记忆化加速。

def word_break_ii(s: str, word_dict: List[str]) -> List[str]:
    """单词拆分 II:回溯 + 记忆化的典型案例"""
    word_set = set(word_dict)
    memo = {}
    
    def backtrack(start: int) -> List[str]:
        if start in memo:
            return memo[start]
        
        if start == len(s):
            return ['']
        
        results = []
        for end in range(start + 1, len(s) + 1):
            word = s[start:end]
            if word in word_set:
                sub_sentences = backtrack(end)
                for sub in sub_sentences:
                    if sub:
                        results.append(word + ' ' + sub)
                    else:
                        results.append(word)
        
        memo[start] = results
        return results
    
    return backtrack(0)

决策树总结

问题需要求所有方案?
  ├── 是 → 有重叠子问题?
  │     ├── 是 → 回溯 + 记忆化
  │     └── 否 → 纯回溯
  └── 否(只求计数/最优值)→ 有最优子结构?
        ├── 是 → 动态规划
        └── 否 → 回溯/贪心/其他

4.5 综合实战:解题模板速查

面对一道新的回溯题,按以下步骤分析:

Step 1:识别问题类型

Step 2:确定三要素

Step 3:考虑剪枝

Step 4:确认复杂度

# 终极模板
def solve(problem_input):
    result = []
    
    # 预处理(排序、建表等)
    candidates = preprocess(problem_input)
    
    def backtrack(state, path):
        # 终止条件
        if is_goal(state):
            result.append(format_solution(path))
            return
        
        for i, choice in enumerate(get_choices(state, candidates)):
            # 剪枝
            if should_prune(choice, state):
                continue  # 或 break(如果排序后)
            
            # 去重(如需)
            if is_duplicate(i, choice, state):
                continue
            
            # 做选择
            make_choice(state, choice, path)
            
            # 递归
            backtrack(next_state(state, choice), path)
            
            # 撤销选择
            undo_choice(state, choice, path)
    
    backtrack(initial_state(), [])
    return result

本章总结

回溯是计算机科学中最基本的算法范式之一。它的核心优雅在于:用统一的框架(选择→递归→撤销)应对千变万化的搜索问题。

关键要点:

  1. 排列/组合/子集是三个基础模型,理解它们的区别(used 数组 vs start 位置)是基础
  2. 剪枝是回溯的灵魂——可行性剪枝、排序后剪枝、最优性剪枝层层递进
  3. 去重是易错点——排序+跳过是最可靠的模式
  4. 回溯本质是DFS,但加上了状态管理(选择与撤销)
  5. NP完全问题没有已知的多项式精确算法,回溯+剪枝是主要的精确求解手段
  6. 回溯 vs DP的区别在于有无重叠子问题——有就用记忆化,没有就纯搜索

当你面对一道新题时,不要急于写代码。先画出搜索树的前几层,确定"选择是什么""约束是什么""什么时候收集结果",然后套模板即可。回溯题不比 DP 难——它们的模板更统一,变体更有规律。难的只是识别出"这是一道回溯题"以及"如何高效剪枝"。

本章评分
4.7  / 5  (5 评分)

💬 留言讨论