第 36 章

随机化与概率算法

第三十六章:随机化与概率算法

你在面试中被问到:"如何公平地从一副 52 张牌中随机抽出 5 张?"你说"每次从剩余牌中随机选一张"。面试官追问:"如何证明每种 5 张牌的组合被选中的概率完全相等?"你愣住了。

再想另一个场景:一个数据流源源不断地到来,你事先不知道总共有多少条记录,内存只够存 k 条。如何保证在任意时刻停止时,内存中的 k 条记录是对已处理数据的均匀随机抽样?

随机化不是"加点随机数就好了"。它是一种设计范式,可以把确定性算法中难以处理的问题(如对抗性输入、海量数据、分布式协调)变得简单而优雅。本章从 Fisher-Yates 洗牌开始,深入水塘抽样的正确性证明,最后讨论随机化在工程系统中的广泛应用。


Level 1 · 你需要知道的

1.1 Fisher-Yates 洗牌算法

问题:给定一个数组,生成其所有元素的一个均匀随机排列(每种排列等概率出现)。

错误做法(很多人的第一反应):

import random

def bad_shuffle(arr):
    """错误的洗牌!概率不均匀"""
    n = len(arr)
    for i in range(n):
        j = random.randint(0, n - 1)  # 从 [0, n-1] 中随机选
        arr[i], arr[j] = arr[j], arr[i]
    return arr

这个做法为什么错?总共有 n^n 种可能的交换序列(每步 n 种选择,共 n 步),但只有 n! 种排列。由于 n^n 通常不能被 n! 整除(例如 3^3 = 27,3! = 6,27/6 不是整数),所以某些排列被映射到的次数一定多于其他排列。概率不均匀。

正确做法:Fisher-Yates(又称 Knuth Shuffle)

import random

def fisher_yates_shuffle(arr: list) -> list:
    """Fisher-Yates 洗牌算法
    
    保证每种排列等概率出现
    时间复杂度: O(n)
    空间复杂度: O(1)(原地洗牌)
    """
    n = len(arr)
    for i in range(n - 1, 0, -1):
        # 从 [0, i] 中均匀随机选一个位置
        j = random.randint(0, i)
        arr[i], arr[j] = arr[j], arr[i]
    return arr

正确性证明(数学归纳法):

需要证明:算法结束后,n! 种排列中每一种出现的概率都是 1/n!。

基础情况:n = 1 时,只有一种排列,概率为 1 = 1/1!。正确。

归纳步骤:假设对 n-1 个元素的 Fisher-Yates 能产生均匀排列。考虑 n 个元素的情况。

第一步(i = n-1):从 [0, n-1] 中随机选 j,将 arr[j] 放到位置 n-1。每个元素被放到最后位置的概率是 1/n。

之后的步骤:对前 n-1 个位置做 Fisher-Yates。由归纳假设,这 n-1 个位置的排列是均匀的(概率 1/(n-1)!)。

综合:任意一个特定排列 (π₁, π₂, ..., π_n) 出现的概率 = P(π_n 在最后) × P(前 n-1 个位置是 (π₁, ..., π_{n-1})) = 1/n × 1/(n-1)! = 1/n!。

另一种实现(从前往后)

def fisher_yates_forward(arr: list) -> list:
    """Fisher-Yates 从前往后版本(等价)"""
    n = len(arr)
    for i in range(n):
        # 从 [i, n-1] 中均匀随机选一个位置
        j = random.randint(i, n - 1)
        arr[i], arr[j] = arr[j], arr[i]
    return arr

这两种版本在数学上等价。关键不变量是:第 i 步选择的范围必须恰好覆盖还未确定的位置。

1.2 水塘抽样(Reservoir Sampling)

问题:从一个大小未知的数据流中,均匀随机地选取 k 个元素。要求只能遍历数据一次,且内存只能存 k 个元素。

算法(Algorithm R,Vitter 1985)

import random

def reservoir_sampling(stream, k: int) -> list:
    """水塘抽样 - Algorithm R
    
    从未知大小的数据流中均匀抽取 k 个元素
    时间复杂度: O(n),n 是流的总大小
    空间复杂度: O(k)
    """
    reservoir = []
    
    for i, item in enumerate(stream):
        if i < k:
            # 前 k 个元素直接放入水塘
            reservoir.append(item)
        else:
            # 第 i+1 个元素(0-indexed 的第 i 个)
            # 以 k/(i+1) 的概率替换水塘中的某个元素
            j = random.randint(0, i)  # [0, i] 均匀随机
            if j < k:
                reservoir[j] = item
    
    return reservoir

正确性证明:需要证明在处理完 n 个元素后,每个元素在水塘中的概率恰好是 k/n。

用数学归纳法

基础情况:当 n = k 时,所有元素都在水塘中,概率 k/k = 1。正确。

归纳步骤:假设处理完前 n-1 个元素后,每个元素在水塘中的概率是 k/(n-1)。

现在处理第 n 个元素:

k = 1 的特殊情况(从流中选一个)

def reservoir_sampling_single(stream) -> object:
    """水塘抽样 k=1: 从流中均匀随机选一个元素"""
    result = None
    for i, item in enumerate(stream):
        # 以 1/(i+1) 的概率选择当前元素
        if random.randint(0, i) == 0:
            result = item
    return result

直觉:第 1 个元素以概率 1 被选中,第 2 个以概率 1/2 被选中(第 1 个存活概率 1/2),第 3 个以概率 1/3 被选中(前面的存活概率 2/3),...,第 n 个以概率 1/n 被选中。最终每个元素被选中的概率都是 1/n。

1.3 随机化快速排序

标准快排的最坏情况是 O(n²)(当输入是已排序数组且 pivot 选第一个元素时)。随机化 pivot 选择可以避免对抗性输入。

import random

def randomized_quicksort(arr: list, low: int = 0, high: int = None) -> None:
    """随机化快速排序
    
    期望时间复杂度: O(n log n)(对任意输入)
    最坏时间复杂度: O(n²)(概率极小,约 1/n!)
    空间复杂度: O(log n) 期望栈深度
    """
    if high is None:
        high = len(arr) - 1
    
    if low < high:
        pivot_idx = randomized_partition(arr, low, high)
        randomized_quicksort(arr, low, pivot_idx - 1)
        randomized_quicksort(arr, pivot_idx + 1, high)

def randomized_partition(arr: list, low: int, high: int) -> int:
    """随机选择 pivot 并分区"""
    # 随机选择 pivot
    pivot_idx = random.randint(low, high)
    arr[pivot_idx], arr[high] = arr[high], arr[pivot_idx]
    
    pivot = arr[high]
    i = low - 1
    
    for j in range(low, high):
        if arr[j] <= pivot:
            i += 1
            arr[i], arr[j] = arr[j], arr[i]
    
    arr[i + 1], arr[high] = arr[high], arr[i + 1]
    return i + 1

为什么随机化有效?

确定性快排的问题是对手可以构造使其退化的输入。但随机化 pivot 让对手无法预测我们的选择。

期望复杂度分析:设 T(n) 是对 n 个元素排序的期望比较次数。随机 pivot 等概率地将数组分成 (0, n-1), (1, n-2), ..., (n-1, 0) 这 n 种情况:

T(n) = n - 1 + (1/n) · Σᵢ₌₀ⁿ⁻¹ [T(i) + T(n-1-i)]

解这个递推得 T(n) = 2n·H_n - 4n ≈ 2n·ln(n) ≈ 1.39·n·log₂(n)。

即随机化快排的期望比较次数只比理论最优(n·log₂(n))多约 39%,而且这个期望值对任意输入成立——无论输入多么恶意。

与确定性优化的比较

方法 最坏情况 能否被对手攻击
选第一个 O(n²) 是(已排序输入)
选中间 O(n²) 是(可构造)
三数取中 O(n²) 是(更难但可构造)
随机选择 O(n²) 但概率极小 否(与输入无关)
Median of medians O(n log n) 确定性 否(但常数大)

1.4 随机数生成的基本要求

均匀性:每个可能值出现的概率相等。

常见错误:用 rand() % n 产生 [0, n-1] 的随机数

如果 RAND_MAX + 1 不能被 n 整除,那么某些值会比其他值多一个对应的 rand() 输出。例如,RAND_MAX = 32767 时,rand() % 10 产生 0-7 的概率是 3277/32768,而 8-9 的概率是 3276/32768。偏差虽小但在大量抽样时可检测到。

正确做法:拒绝采样

import random

def uniform_random(n: int) -> int:
    """无偏差地产生 [0, n-1] 的均匀随机数
    
    原理: 拒绝超出范围的值
    """
    # Python 的 random.randint 已经处理了这个问题
    # 但如果你只有 rand() 产生 [0, RAND_MAX]:
    RAND_MAX = 32767  # 假设
    limit = RAND_MAX - (RAND_MAX + 1) % n  # 最大的 n 的倍数 - 1
    while True:
        r = random.randint(0, RAND_MAX)  # 模拟 rand()
        if r <= limit:
            return r % n

1.5 常见错误与陷阱

错误 1:洗牌时选择范围不对

# 错误:j 的范围应该是 [0, i] 而不是 [0, n-1]
for i in range(n - 1, 0, -1):
    j = random.randint(0, n - 1)  # 错!应该是 randint(0, i)
    arr[i], arr[j] = arr[j], arr[i]

错误 2:水塘抽样的索引计算

# 错误:第 i 个元素(0-indexed)应以 k/(i+1) 概率进入
# 如果从 1-indexed 开始就是 k/i
for i, item in enumerate(stream, start=1):  # 1-indexed
    if i <= k:
        reservoir.append(item)
    else:
        j = random.randint(1, i)
        if j <= k:
            reservoir[j - 1] = item

错误 3:随机化快排的栈溢出

对于大数组(n > 10⁶),递归深度可能导致栈溢出。解决方案:

  1. 尾递归优化:总是先递归较短的子数组
  2. 当子数组小于某个阈值时切换为插入排序
  3. 使用迭代 + 显式栈
def quicksort_safe(arr, low, high):
    """尾递归优化的快排"""
    while low < high:
        pivot = randomized_partition(arr, low, high)
        # 先递归较短的一侧,尾调用较长的一侧
        if pivot - low < high - pivot:
            quicksort_safe(arr, low, pivot - 1)
            low = pivot + 1  # 尾调用优化
        else:
            quicksort_safe(arr, pivot + 1, high)
            high = pivot - 1

Level 2 · 它是怎么运行的

2.1 蒙特卡洛算法 vs 拉斯维加斯算法

随机化算法分为两大类:

拉斯维加斯(Las Vegas)算法

蒙特卡洛(Monte Carlo)算法

# 拉斯维加斯示例:随机化快速选择(第 k 小)
def randomized_select(arr: list, k: int) -> int:
    """找第 k 小元素
    
    期望时间: O(n)
    最坏时间: O(n²)(但概率极小)
    结果: 永远正确
    """
    if len(arr) == 1:
        return arr[0]
    
    pivot = random.choice(arr)
    left = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    
    if k <= len(left):
        return randomized_select(left, k)
    elif k <= len(left) + len(middle):
        return pivot
    else:
        return randomized_select(right, k - len(left) - len(middle))

# 蒙特卡洛示例:估计 π 的值
def monte_carlo_pi(num_samples: int) -> float:
    """蒙特卡洛估计 π
    
    在单位正方形中随机投点,统计落在四分之一圆内的比例
    π/4 = (圆内点数) / (总点数)
    
    误差: O(1/√n) — 需要 4 倍样本才能减半误差
    """
    inside = 0
    for _ in range(num_samples):
        x = random.random()
        y = random.random()
        if x * x + y * y <= 1:
            inside += 1
    return 4 * inside / num_samples

两类算法的关键区别

特性 Las Vegas Monte Carlo
正确性 保证正确 概率正确
时间 随机(但期望有限) 确定/有上界
能否互转 可(设时间上界,超时重启) 有条件(需验证函数)
典型应用 排序、选择 素性测试、近似计算

重要转换:任何 Las Vegas 算法都可以转为 Monte Carlo——设定时间上限,超时则输出"不知道"。反过来,如果 Monte Carlo 的输出可以高效验证正确性,则可以转为 Las Vegas——反复运行直到验证通过。

2.2 SkipList 的随机层高

SkipList(跳表)是一种基于随机化的数据结构,由 William Pugh 在 1990 年提出,作为平衡搜索树的概率替代品。

核心思想:每个节点有一个随机决定的"层高"。层高为 h 的节点出现在第 1 层到第 h 层的所有链表中。通过高层的"快速通道"跳过大量节点,实现 O(log n) 的期望搜索时间。

import random

class SkipListNode:
    def __init__(self, key, level):
        self.key = key
        self.forward = [None] * (level + 1)  # forward[i] 是第 i 层的下一个节点

class SkipList:
    """跳表实现
    
    期望操作时间: O(log n)
    空间: O(n) 期望
    """
    MAX_LEVEL = 32
    P = 0.5  # 上升概率
    
    def __init__(self):
        self.header = SkipListNode(-float('inf'), self.MAX_LEVEL)
        self.level = 0  # 当前最高层
    
    def random_level(self) -> int:
        """随机决定新节点的层高
        
        每层以概率 P 继续向上。期望层高 = 1/(1-P)
        当 P = 0.5 时,期望层高 = 2
        """
        lvl = 0
        while random.random() < self.P and lvl < self.MAX_LEVEL:
            lvl += 1
        return lvl
    
    def search(self, key) -> bool:
        """搜索 key 是否存在"""
        current = self.header
        for i in range(self.level, -1, -1):
            while current.forward[i] and current.forward[i].key < key:
                current = current.forward[i]
        current = current.forward[0]
        return current is not None and current.key == key
    
    def insert(self, key) -> None:
        """插入 key"""
        update = [None] * (self.MAX_LEVEL + 1)
        current = self.header
        
        for i in range(self.level, -1, -1):
            while current.forward[i] and current.forward[i].key < key:
                current = current.forward[i]
            update[i] = current
        
        new_level = self.random_level()
        if new_level > self.level:
            for i in range(self.level + 1, new_level + 1):
                update[i] = self.header
            self.level = new_level
        
        new_node = SkipListNode(key, new_level)
        for i in range(new_level + 1):
            new_node.forward[i] = update[i].forward[i]
            update[i].forward[i] = new_node

为什么 P = 0.5 是好的选择?

P 值的选择权衡

P 期望层高 空间倍率 搜索比较次数
1/2 2 2n (log₂n)/1 + 1/P = log₂n + 2
1/4 4/3 (4/3)n (log₄n)/1 + 1/P ≈ log₂n/2 + 4
1/e e/(e-1) ≈ 1.58 1.58n 最优(但差异很小)

实践中 P = 0.5 或 P = 0.25 都很常见。Redis 的有序集合(zset)使用 P = 0.25, MAX_LEVEL = 32。

SkipList vs 平衡BST的工程权衡

特性 SkipList 红黑树/AVL
实现复杂度 简单 复杂(旋转)
并发友好 好(锁粒度小) 差(旋转影响多节点)
范围查询 天然支持 需要中序遍历
缓存友好性 差(指针跳跃) 也差(但稍好)
最坏情况 O(n)(极小概率) O(log n) 确定性

2.3 随机化在负载均衡中的应用

问题:有 n 台服务器,m 个请求到达。如何分配请求使得负载尽量均匀?

方法一:完全随机分配

每个请求独立均匀随机选一台服务器。根据"球入盒"模型(Balls into Bins),当 m = n 时,最重负载的服务器期望接收 Θ(log n / log log n) 个请求。

方法二:Power of Two Choices(Mitzenmacher 等人, 2001)

每个请求随机选 2 台服务器,然后分配给当前负载较轻的那台。这个简单改进将最大负载从 Θ(log n / log log n) 降到 Θ(log log n) — 指数级改进!

import random

class PowerOfTwoChoices:
    """Power of Two Choices 负载均衡
    
    最大负载从 O(log n / log log n) 降为 O(log log n)
    """
    def __init__(self, num_servers: int):
        self.loads = [0] * num_servers
        self.n = num_servers
    
    def assign_request(self) -> int:
        """分配一个请求,返回选中的服务器编号"""
        # 随机选两台
        s1 = random.randint(0, self.n - 1)
        s2 = random.randint(0, self.n - 1)
        while s2 == s1:  # 确保选的是不同的两台
            s2 = random.randint(0, self.n - 1)
        
        # 选负载较轻的
        if self.loads[s1] <= self.loads[s2]:
            self.loads[s1] += 1
            return s1
        else:
            self.loads[s2] += 1
            return s2

为什么 2 选 1 就能有如此大的改进?

直觉:完全随机时,一旦某台服务器"领先"(负载高于平均),它继续积累新请求的概率与其他服务器相同。但 Power of Two Choices 中,负载越高的服务器越不容易被选中(因为两个随机选择中至少一个可能更轻),形成了自动的负反馈机制。

数学上,这是指数时序分析(exponential backoff)的一个实例:每增加一个"选择"维度,极端值的尾概率就会指数衰减。

方法三:一致性哈希(Consistent Hashing)

当需要保证同一个 key 总是路由到同一台服务器(如缓存)时:

import hashlib
import bisect

class ConsistentHashing:
    """一致性哈希
    
    支持动态增减服务器,每次变更只影响 O(1/n) 的 key
    使用虚拟节点确保负载均匀
    """
    def __init__(self, num_virtual_nodes: int = 150):
        self.num_virtual = num_virtual_nodes
        self.ring = []  # 排序的 hash 值列表
        self.ring_map = {}  # hash值 -> 真实服务器
    
    def _hash(self, key: str) -> int:
        return int(hashlib.md5(key.encode()).hexdigest(), 16)
    
    def add_server(self, server: str) -> None:
        for i in range(self.num_virtual):
            h = self._hash(f"{server}#{i}")
            bisect.insort(self.ring, h)
            self.ring_map[h] = server
    
    def get_server(self, key: str) -> str:
        if not self.ring:
            return None
        h = self._hash(key)
        idx = bisect.bisect_left(self.ring, h) % len(self.ring)
        return self.ring_map[self.ring[idx]]

2.4 随机化 Skip:从 O(n) 到 O(√n)

一个有趣的技巧:在水塘抽样中,大部分元素会被拒绝。能否跳过那些注定被拒绝的元素?

Vitter 的 Algorithm Z(1985):不逐个检查,而是直接计算"下一个被选中的元素"的位置,从而跳过中间的所有元素。

import random
import math

def reservoir_sampling_optimized(stream_size: int, k: int):
    """Vitter's Algorithm Z 的简化版本
    
    关键优化: 计算下一个替换位置,跳过中间元素
    对于大流(n >> k),这显著减少了随机数生成的次数
    """
    reservoir = list(range(k))  # 假设前 k 个元素已在水塘中
    
    # 简化版本: 使用几何分布跳跃
    W = math.exp(math.log(random.random()) / k)
    next_pos = k + int(math.floor(math.log(random.random()) / math.log(1 - W))) + 1
    
    i = k
    while i < stream_size:
        if i == next_pos:
            # 替换水塘中的一个随机元素
            reservoir[random.randint(0, k - 1)] = i  # 这里 i 代表流中第 i 个元素
            W *= math.exp(math.log(random.random()) / k)
            next_pos = i + int(math.floor(math.log(random.random()) / math.log(1 - W))) + 1
        i += 1
    
    return reservoir

当 n >> k 时,Algorithm Z 只需要 O(k · (1 + log(n/k))) 次随机数生成,而 Algorithm R 需要 O(n) 次。

2.5 随机化算法的去随机化

一个有趣的理论问题:随机化算法能否"去随机化"(derandomization),即转换为确定性算法而不损失效率?

条件期望法(Method of Conditional Expectations)

如果随机化算法的期望结果满足某个性质(如期望解 ≥ OPT/2),那么可以贪心地逐步固定随机位,使得条件期望始终不下降,最终得到确定性地达到期望值的解。

# 示例: MAX-SAT 的随机化 → 去随机化
# 随机化: 每个变量独立等概率为 True/False,期望满足 m/2 个子句(m 个子句)
# 去随机化: 对每个变量,选择使"条件期望满足子句数"更大的值

def derandomized_maxsat(clauses: list, num_vars: int) -> list:
    """去随机化的 MAX-SAT 近似算法
    
    保证满足至少 m/2 个子句(m 是子句总数)
    时间: O(n · m)
    """
    assignment = [None] * (num_vars + 1)
    
    for var in range(1, num_vars + 1):
        # 计算将 var 设为 True 时的条件期望
        exp_true = conditional_expected_satisfied(clauses, assignment, var, True)
        exp_false = conditional_expected_satisfied(clauses, assignment, var, False)
        
        assignment[var] = exp_true >= exp_false
    
    return assignment

def conditional_expected_satisfied(clauses, assignment, var, value):
    """计算条件期望: 已固定变量 + 当前决定 + 剩余随机"""
    assignment[var] = value
    expected = 0
    
    for clause in clauses:
        # 检查子句是否已经被满足
        satisfied = False
        all_determined = True
        undetermined_count = 0
        
        for literal in clause:
            v = abs(literal)
            if assignment[v] is not None:
                if (literal > 0 and assignment[v]) or (literal < 0 and not assignment[v]):
                    satisfied = True
                    break
            else:
                all_determined = False
                undetermined_count += 1
        
        if satisfied:
            expected += 1
        elif not all_determined:
            # 至少有一个未确定的变量,子句被满足的概率 = 1 - (1/2)^(未确定变量数)
            expected += 1 - (0.5 ** undetermined_count)
        # else: 所有变量都确定了但子句没满足,贡献 0
    
    assignment[var] = None  # 恢复
    return expected

2.6 哈希函数中的随机性

通用哈希(Universal Hashing, Carter & Wegman, 1979)

从一个哈希函数族中随机选一个函数,保证对任意两个不同的 key,碰撞概率不超过 1/m(m 是表大小)。

import random

class UniversalHash:
    """通用哈希族
    
    h(x) = ((a*x + b) mod p) mod m
    其中 p 是大素数,a ∈ [1, p-1],b ∈ [0, p-1] 随机选择
    
    碰撞概率: 对任意 x ≠ y,P[h(x) = h(y)] ≤ 1/m
    """
    def __init__(self, m: int, p: int = (1 << 61) - 1):
        self.m = m
        self.p = p  # 梅森素数 2^61 - 1
        self.a = random.randint(1, p - 1)
        self.b = random.randint(0, p - 1)
    
    def hash(self, x: int) -> int:
        return ((self.a * x + self.b) % self.p) % self.m

为什么通用哈希重要?

固定的哈希函数(如 hash(x) = x % m)总有对手可以构造使所有 key 碰撞的输入。但从哈希族中随机选择后,对手无法预知选了哪个函数,因此无法构造对抗性输入。这使得哈希表的最坏情况期望变为 O(1)。


Level 3 · 规范怎么定义的

3.1 Fisher 和 Yates 的原始方法(1938)

Ronald A. Fisher 和 Frank Yates 在 1938 年出版的《Statistical Tables for Biological, Agricultural and Medical Research》一书中首次描述了这个算法。他们的原始方法是为人工执行设计的:

  1. 写下数字 1 到 n
  2. 生成一个 [1, k] 的随机数(k 是剩余数字的个数)
  3. 划掉第 k 个剩余数字,写到结果序列中
  4. 重复直到所有数字被划掉

这个"划掉"版本时间复杂度是 O(n²)(每次需要 O(n) 找到第 k 个未划掉的数)。

Knuth (1969) 的改进:Donald Knuth 在《The Art of Computer Programming》第 2 卷中给出了现在广泛使用的 O(n) 原地版本——通过交换而非"划掉"来实现。他在 Algorithm P(第 3.4.2 节)中描述了这个算法,并证明了其正确性。这就是我们今天所说的 Fisher-Yates-Knuth shuffle。

Durstenfeld (1964) 实际上比 Knuth 更早独立发表了这个原地版本(在 Communications of the ACM 上),但 Knuth 的教科书使其广为人知。

形式化正确性要求

一个洗牌算法是"均匀"的当且仅当对 n! 种排列中的每一种 π,算法输出 π 的概率恰好是 1/n!。

Fisher-Yates 满足这一要求是因为:

3.2 Knuth Shuffle 的实现细节

Knuth 在 TAOCP Vol. 2 (1969) 中的 Algorithm P:

Algorithm P (Shuffling).
Given a sequence of n elements a[0], a[1], ..., a[n-1]:
P1. Set j ← n - 1.
P2. Generate a uniform random number U between 0 and 1.
    Set k ← ⌊(j+1)·U⌋. (Now k is uniform on {0, 1, ..., j}.)
P3. Exchange a[k] ↔ a[j].
P4. Decrease j by 1. If j > 0, go to P2. Otherwise, terminate.

Knuth 指出这个算法的关键性质:

  1. 每种排列恰好对应一个选择序列:(k_{n-1}, k_{n-2}, ..., k_1) 其中 0 ≤ k_j ≤ j。这是"阶乘数系统"(factorial number system)的一种表示。

  2. 与随机数生成器的耦合:算法的正确性依赖于 U 是真正均匀的。如果使用伪随机数生成器(PRNG),其周期必须至少为 n!。对于 52 张牌(52! ≈ 8×10⁶⁷),需要至少 226 位的 PRNG 状态——标准的 Mersenne Twister(32位种子,2^19937 周期)是足够的,但简单的线性同余生成器(32位状态,周期 2^32)远远不够。

3.3 Vitter 1985 年的水塘抽样论文

Jeffrey S. Vitter 在 1985 年发表了经典论文"Random Sampling with a Reservoir"(ACM Transactions on Mathematical Software),系统地分析了水塘抽样的多种算法:

Algorithm R(最简单版本):

Algorithm X

Algorithm Z(最优):

Vitter 的贡献不仅是算法本身,更是严格的概率分析方法。他证明了 Algorithm Z 是最优的(在比较模型下,无法用更少的随机位完成均匀抽样)。

加权水塘抽样的变体

A-Res 算法(Efraimidis & Spirakis, 2006):每个元素 i 的权重为 w_i,被选中的概率正比于权重。

import heapq
import random
import math

def weighted_reservoir_sampling(stream, k: int) -> list:
    """加权水塘抽样 (A-Res algorithm)
    
    每个元素 (item, weight) 按权重比例被抽中
    """
    # 使用最小堆存储 k 个最大的 key
    heap = []  # (key, item)
    
    for item, weight in stream:
        # key = random^(1/weight) — weight 越大,key 越可能大
        key = random.random() ** (1.0 / weight)
        
        if len(heap) < k:
            heapq.heappush(heap, (key, item))
        elif key > heap[0][0]:
            heapq.heapreplace(heap, (key, item))
    
    return [item for _, item in heap]

3.4 随机化算法的计算复杂性分类

BPP(Bounded-Error Probabilistic Polynomial Time)

一个决策问题在 BPP 中,如果存在一个概率多项式时间算法,对任何输入:

(2/3 可以通过重复运行 + 多数投票放大到 1 - 2^(-k))

RP(Randomized Polynomial Time)

ZPP(Zero-Error Probabilistic Polynomial Time)

关系图:P ⊆ ZPP ⊆ RP ⊆ BPP ⊆ ... ⊆ PSPACE

普遍猜想是 P = BPP(所有有效的随机化算法都可以去随机化),但目前尚未证明。Impagliazzo & Wigderson (1997) 证明了:如果存在某些难度假设(如 E = DTIME(2^O(n)) 中有指数级电路复杂度的问题),则 P = BPP。

3.5 伪随机数生成器的理论基础

密码学安全的 PRNG(CSPRNG)要求

  1. 给定前 k 个输出位,无法在多项式时间内预测第 k+1 位(即使拥有无限计算资源除外)
  2. 给定内部状态,无法恢复之前的输出(前向安全性)

常用 PRNG 的特性比较

PRNG 周期 安全性 速度 适用场景
LCG (线性同余) 2^32 不安全 极快 不推荐
Mersenne Twister 2^19937 不安全(可逆推) 科学计算、模拟
xorshift128+ 2^128 不安全 极快 游戏、非安全场景
ChaCha20 2^256 密码学安全 中等 安全场景
/dev/urandom 取决于实现 安全 密钥生成

Python 中的随机数


Level 4 · 边界与陷阱

4.1 面试题:打乱数组 (LeetCode #384)

题目:给你一个整数数组 nums,设计算法来打乱一个没有重复元素的数组。实现 Solution 类:

import random

class Solution:
    """LeetCode 384: Shuffle an Array
    
    使用 Fisher-Yates 洗牌保证均匀性
    """
    def __init__(self, nums: list):
        self.original = nums[:]
        self.current = nums[:]
    
    def reset(self) -> list:
        self.current = self.original[:]
        return self.current
    
    def shuffle(self) -> list:
        n = len(self.current)
        for i in range(n - 1, 0, -1):
            j = random.randint(0, i)
            self.current[i], self.current[j] = self.current[j], self.current[i]
        return self.current

面试跟进问题

  1. "如何验证你的洗牌是均匀的?"

    • 对小数组(如 [1,2,3])运行百万次,统计每种排列出现的频率,应接近 1/n! = 1/6 ≈ 16.67%
    • 卡方检验可以量化偏差
  2. "如果数组很大(10^8 元素),如何只洗牌前 k 个?"

    • 只执行 Fisher-Yates 的前 k 步即可得到 k 个均匀随机元素
    • 这本质上就是"从 n 个中随机选 k 个"的算法
  3. "如果不能修改原数组呢?"

    • 使用额外 O(n) 空间复制后洗牌
    • 或者使用 inside-out 版本(不需要原地修改)
def inside_out_shuffle(original: list) -> list:
    """Inside-out Fisher-Yates(不修改原数组)"""
    n = len(original)
    result = [0] * n
    for i in range(n):
        j = random.randint(0, i)
        if j != i:
            result[i] = result[j]
        result[j] = original[i]
    return result

4.2 面试题:链表随机节点 (LeetCode #382)

题目:给你一个单链表,随机选择链表的一个节点,并返回相应的节点值。每个节点被选中的概率应该相同。

分析:这是水塘抽样 k=1 的直接应用。链表长度未知(或不想先遍历一次),需要一次遍历完成。

import random

class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next

class Solution:
    """LeetCode 382: Linked List Random Node
    
    水塘抽样 k=1
    """
    def __init__(self, head: ListNode):
        self.head = head
    
    def getRandom(self) -> int:
        result = self.head.val
        node = self.head.next
        i = 2  # 当前是第 2 个节点
        
        while node:
            # 以 1/i 的概率选择当前节点
            if random.randint(1, i) == 1:
                result = node.val
            node = node.next
            i += 1
        
        return result

正确性验证(以 3 个节点为例):

设链表为 A → B → C

面试跟进

  1. "如果需要返回 k 个随机节点呢?" — 水塘抽样 k > 1 版本
  2. "如果可以知道链表长度呢?" — 先数长度 n,再生成 [0, n-1] 的随机数,走到那个位置
  3. "调用 getRandom 非常频繁怎么办?" — 第一次遍历时把所有值存入数组,之后 O(1) 随机访问

4.3 面试题:随机数索引 (LeetCode #398)

题目:给定一个可能含有重复元素的整数数组,随机输出给定数字的索引。相同数字的每个索引被返回的概率应相等。

import random

class Solution:
    """LeetCode 398: Random Pick Index
    
    方法一: 水塘抽样(O(n) 时间,O(1) 额外空间)
    方法二: 预处理哈希表(O(1) 查询,O(n) 空间)
    """
    def __init__(self, nums: list):
        self.nums = nums
    
    def pick(self, target: int) -> int:
        """水塘抽样方法"""
        result = -1
        count = 0
        
        for i, num in enumerate(self.nums):
            if num == target:
                count += 1
                # 以 1/count 的概率选择当前索引
                if random.randint(1, count) == 1:
                    result = i
        
        return result


class SolutionOptimized:
    """预处理版本 — 适合多次查询"""
    
    def __init__(self, nums: list):
        from collections import defaultdict
        self.indices = defaultdict(list)
        for i, num in enumerate(nums):
            self.indices[num].append(i)
    
    def pick(self, target: int) -> int:
        idx_list = self.indices[target]
        return random.choice(idx_list)

方法选择

面试中的讨论点

4.4 随机化算法的测试与验证

如何确信你的随机化算法是正确的?这比确定性算法困难得多。

方法一:统计检验

from collections import Counter
import scipy.stats

def verify_shuffle(shuffle_func, arr, num_trials=1000000):
    """验证洗牌的均匀性"""
    n = len(arr)
    expected_count = num_trials / math.factorial(n)
    
    counts = Counter()
    for _ in range(num_trials):
        result = shuffle_func(arr[:])
        counts[tuple(result)] += 1
    
    # 卡方检验
    observed = list(counts.values())
    expected = [expected_count] * math.factorial(n)
    
    # 如果有些排列没出现过,补零
    while len(observed) < math.factorial(n):
        observed.append(0)
    
    chi2, p_value = scipy.stats.chisquare(observed, expected)
    
    print(f"卡方统计量: {chi2:.2f}")
    print(f"p 值: {p_value:.4f}")
    print(f"均匀性: {'通过' if p_value > 0.05 else '不通过'}")

方法二:不变量检查

def verify_reservoir_sampling(sample_func, stream_size, k, num_trials=100000):
    """验证水塘抽样的均匀性"""
    counts = [0] * stream_size
    
    for _ in range(num_trials):
        sample = sample_func(range(stream_size), k)
        for idx in sample:
            counts[idx] += 1
    
    # 每个元素被选中的期望次数 = num_trials * k / stream_size
    expected = num_trials * k / stream_size
    
    # 检查是否所有计数都接近期望值
    max_deviation = max(abs(c - expected) / expected for c in counts)
    print(f"最大相对偏差: {max_deviation:.4f}")
    print(f"均匀性: {'通过' if max_deviation < 0.05 else '不通过'}")

4.5 实际工程中的随机化应用

1. 随机化在数据库中

2. 随机化在分布式系统中

import random
import time

def exponential_backoff_with_jitter(attempt: int, base_delay: float = 1.0, 
                                     max_delay: float = 60.0) -> float:
    """带 jitter 的指数退避
    
    在分布式系统中用于重试逻辑
    """
    # Full Jitter (AWS 推荐)
    delay = min(max_delay, base_delay * (2 ** attempt))
    return random.uniform(0, delay)

3. 随机化在 A/B 测试中

4. 布隆过滤器(Bloom Filter)中的多哈希

4.6 面试应答策略

当面试中遇到随机化问题时:

  1. 识别问题类型

    • "等概率" / "均匀随机" → Fisher-Yates / 水塘抽样
    • "数据流" / "未知大小" → 水塘抽样
    • "概率正比于权重" → 加权水塘 / 轮盘赌选择
    • "避免最坏情况" → 随机化 pivot / 通用哈希
  2. 证明正确性的标准模式

    • 对于洗牌:证明每种排列概率 = 1/n!
    • 对于抽样:证明每个元素被选中概率 = k/n
    • 通常用数学归纳法
  3. 空间-时间权衡

    • 水塘抽样:O(k) 空间,O(n) 时间
    • 预处理:O(n) 空间,O(1) 查询时间
    • 选择取决于查询频率和内存限制
  4. 常见追问

    • "如何验证均匀性?" → 统计检验(卡方检验)
    • "随机数生成器有偏怎么办?" → 拒绝采样
    • "如果数据流特别大怎么优化?" → Vitter's Algorithm Z(跳跃式采样)
本章评分
4.6  / 5  (3 评分)

💬 留言讨论