随机化与概率算法
第三十六章:随机化与概率算法
你在面试中被问到:"如何公平地从一副 52 张牌中随机抽出 5 张?"你说"每次从剩余牌中随机选一张"。面试官追问:"如何证明每种 5 张牌的组合被选中的概率完全相等?"你愣住了。
再想另一个场景:一个数据流源源不断地到来,你事先不知道总共有多少条记录,内存只够存 k 条。如何保证在任意时刻停止时,内存中的 k 条记录是对已处理数据的均匀随机抽样?
随机化不是"加点随机数就好了"。它是一种设计范式,可以把确定性算法中难以处理的问题(如对抗性输入、海量数据、分布式协调)变得简单而优雅。本章从 Fisher-Yates 洗牌开始,深入水塘抽样的正确性证明,最后讨论随机化在工程系统中的广泛应用。
Level 1 · 你需要知道的
1.1 Fisher-Yates 洗牌算法
问题:给定一个数组,生成其所有元素的一个均匀随机排列(每种排列等概率出现)。
错误做法(很多人的第一反应):
import random
def bad_shuffle(arr):
"""错误的洗牌!概率不均匀"""
n = len(arr)
for i in range(n):
j = random.randint(0, n - 1) # 从 [0, n-1] 中随机选
arr[i], arr[j] = arr[j], arr[i]
return arr
这个做法为什么错?总共有 n^n 种可能的交换序列(每步 n 种选择,共 n 步),但只有 n! 种排列。由于 n^n 通常不能被 n! 整除(例如 3^3 = 27,3! = 6,27/6 不是整数),所以某些排列被映射到的次数一定多于其他排列。概率不均匀。
正确做法:Fisher-Yates(又称 Knuth Shuffle)
import random
def fisher_yates_shuffle(arr: list) -> list:
"""Fisher-Yates 洗牌算法
保证每种排列等概率出现
时间复杂度: O(n)
空间复杂度: O(1)(原地洗牌)
"""
n = len(arr)
for i in range(n - 1, 0, -1):
# 从 [0, i] 中均匀随机选一个位置
j = random.randint(0, i)
arr[i], arr[j] = arr[j], arr[i]
return arr
正确性证明(数学归纳法):
需要证明:算法结束后,n! 种排列中每一种出现的概率都是 1/n!。
基础情况:n = 1 时,只有一种排列,概率为 1 = 1/1!。正确。
归纳步骤:假设对 n-1 个元素的 Fisher-Yates 能产生均匀排列。考虑 n 个元素的情况。
第一步(i = n-1):从 [0, n-1] 中随机选 j,将 arr[j] 放到位置 n-1。每个元素被放到最后位置的概率是 1/n。
之后的步骤:对前 n-1 个位置做 Fisher-Yates。由归纳假设,这 n-1 个位置的排列是均匀的(概率 1/(n-1)!)。
综合:任意一个特定排列 (π₁, π₂, ..., π_n) 出现的概率 = P(π_n 在最后) × P(前 n-1 个位置是 (π₁, ..., π_{n-1})) = 1/n × 1/(n-1)! = 1/n!。
另一种实现(从前往后):
def fisher_yates_forward(arr: list) -> list:
"""Fisher-Yates 从前往后版本(等价)"""
n = len(arr)
for i in range(n):
# 从 [i, n-1] 中均匀随机选一个位置
j = random.randint(i, n - 1)
arr[i], arr[j] = arr[j], arr[i]
return arr
这两种版本在数学上等价。关键不变量是:第 i 步选择的范围必须恰好覆盖还未确定的位置。
1.2 水塘抽样(Reservoir Sampling)
问题:从一个大小未知的数据流中,均匀随机地选取 k 个元素。要求只能遍历数据一次,且内存只能存 k 个元素。
算法(Algorithm R,Vitter 1985):
import random
def reservoir_sampling(stream, k: int) -> list:
"""水塘抽样 - Algorithm R
从未知大小的数据流中均匀抽取 k 个元素
时间复杂度: O(n),n 是流的总大小
空间复杂度: O(k)
"""
reservoir = []
for i, item in enumerate(stream):
if i < k:
# 前 k 个元素直接放入水塘
reservoir.append(item)
else:
# 第 i+1 个元素(0-indexed 的第 i 个)
# 以 k/(i+1) 的概率替换水塘中的某个元素
j = random.randint(0, i) # [0, i] 均匀随机
if j < k:
reservoir[j] = item
return reservoir
正确性证明:需要证明在处理完 n 个元素后,每个元素在水塘中的概率恰好是 k/n。
用数学归纳法:
基础情况:当 n = k 时,所有元素都在水塘中,概率 k/k = 1。正确。
归纳步骤:假设处理完前 n-1 个元素后,每个元素在水塘中的概率是 k/(n-1)。
现在处理第 n 个元素:
- 它以概率 k/n 进入水塘。✓
- 对于已在水塘中的元素 x:
- x 在水塘中(由归纳假设概率 k/(n-1))
- x 不被第 n 个元素替换的概率 = 1 - P(第 n 个进入) × P(恰好替换 x) = 1 - (k/n) × (1/k) = 1 - 1/n = (n-1)/n
- 所以 x 最终在水塘中的概率 = k/(n-1) × (n-1)/n = k/n ✓
k = 1 的特殊情况(从流中选一个):
def reservoir_sampling_single(stream) -> object:
"""水塘抽样 k=1: 从流中均匀随机选一个元素"""
result = None
for i, item in enumerate(stream):
# 以 1/(i+1) 的概率选择当前元素
if random.randint(0, i) == 0:
result = item
return result
直觉:第 1 个元素以概率 1 被选中,第 2 个以概率 1/2 被选中(第 1 个存活概率 1/2),第 3 个以概率 1/3 被选中(前面的存活概率 2/3),...,第 n 个以概率 1/n 被选中。最终每个元素被选中的概率都是 1/n。
1.3 随机化快速排序
标准快排的最坏情况是 O(n²)(当输入是已排序数组且 pivot 选第一个元素时)。随机化 pivot 选择可以避免对抗性输入。
import random
def randomized_quicksort(arr: list, low: int = 0, high: int = None) -> None:
"""随机化快速排序
期望时间复杂度: O(n log n)(对任意输入)
最坏时间复杂度: O(n²)(概率极小,约 1/n!)
空间复杂度: O(log n) 期望栈深度
"""
if high is None:
high = len(arr) - 1
if low < high:
pivot_idx = randomized_partition(arr, low, high)
randomized_quicksort(arr, low, pivot_idx - 1)
randomized_quicksort(arr, pivot_idx + 1, high)
def randomized_partition(arr: list, low: int, high: int) -> int:
"""随机选择 pivot 并分区"""
# 随机选择 pivot
pivot_idx = random.randint(low, high)
arr[pivot_idx], arr[high] = arr[high], arr[pivot_idx]
pivot = arr[high]
i = low - 1
for j in range(low, high):
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i]
arr[i + 1], arr[high] = arr[high], arr[i + 1]
return i + 1
为什么随机化有效?
确定性快排的问题是对手可以构造使其退化的输入。但随机化 pivot 让对手无法预测我们的选择。
期望复杂度分析:设 T(n) 是对 n 个元素排序的期望比较次数。随机 pivot 等概率地将数组分成 (0, n-1), (1, n-2), ..., (n-1, 0) 这 n 种情况:
T(n) = n - 1 + (1/n) · Σᵢ₌₀ⁿ⁻¹ [T(i) + T(n-1-i)]
解这个递推得 T(n) = 2n·H_n - 4n ≈ 2n·ln(n) ≈ 1.39·n·log₂(n)。
即随机化快排的期望比较次数只比理论最优(n·log₂(n))多约 39%,而且这个期望值对任意输入成立——无论输入多么恶意。
与确定性优化的比较:
| 方法 | 最坏情况 | 能否被对手攻击 |
|---|---|---|
| 选第一个 | O(n²) | 是(已排序输入) |
| 选中间 | O(n²) | 是(可构造) |
| 三数取中 | O(n²) | 是(更难但可构造) |
| 随机选择 | O(n²) 但概率极小 | 否(与输入无关) |
| Median of medians | O(n log n) 确定性 | 否(但常数大) |
1.4 随机数生成的基本要求
均匀性:每个可能值出现的概率相等。
常见错误:用 rand() % n 产生 [0, n-1] 的随机数
如果 RAND_MAX + 1 不能被 n 整除,那么某些值会比其他值多一个对应的 rand() 输出。例如,RAND_MAX = 32767 时,rand() % 10 产生 0-7 的概率是 3277/32768,而 8-9 的概率是 3276/32768。偏差虽小但在大量抽样时可检测到。
正确做法:拒绝采样
import random
def uniform_random(n: int) -> int:
"""无偏差地产生 [0, n-1] 的均匀随机数
原理: 拒绝超出范围的值
"""
# Python 的 random.randint 已经处理了这个问题
# 但如果你只有 rand() 产生 [0, RAND_MAX]:
RAND_MAX = 32767 # 假设
limit = RAND_MAX - (RAND_MAX + 1) % n # 最大的 n 的倍数 - 1
while True:
r = random.randint(0, RAND_MAX) # 模拟 rand()
if r <= limit:
return r % n
1.5 常见错误与陷阱
错误 1:洗牌时选择范围不对
# 错误:j 的范围应该是 [0, i] 而不是 [0, n-1]
for i in range(n - 1, 0, -1):
j = random.randint(0, n - 1) # 错!应该是 randint(0, i)
arr[i], arr[j] = arr[j], arr[i]
错误 2:水塘抽样的索引计算
# 错误:第 i 个元素(0-indexed)应以 k/(i+1) 概率进入
# 如果从 1-indexed 开始就是 k/i
for i, item in enumerate(stream, start=1): # 1-indexed
if i <= k:
reservoir.append(item)
else:
j = random.randint(1, i)
if j <= k:
reservoir[j - 1] = item
错误 3:随机化快排的栈溢出
对于大数组(n > 10⁶),递归深度可能导致栈溢出。解决方案:
- 尾递归优化:总是先递归较短的子数组
- 当子数组小于某个阈值时切换为插入排序
- 使用迭代 + 显式栈
def quicksort_safe(arr, low, high):
"""尾递归优化的快排"""
while low < high:
pivot = randomized_partition(arr, low, high)
# 先递归较短的一侧,尾调用较长的一侧
if pivot - low < high - pivot:
quicksort_safe(arr, low, pivot - 1)
low = pivot + 1 # 尾调用优化
else:
quicksort_safe(arr, pivot + 1, high)
high = pivot - 1
Level 2 · 它是怎么运行的
2.1 蒙特卡洛算法 vs 拉斯维加斯算法
随机化算法分为两大类:
拉斯维加斯(Las Vegas)算法:
- 结果总是正确的
- 运行时间是随机的(期望有限)
- 例子:随机化快排、随机化 pivot 的快速选择
蒙特卡洛(Monte Carlo)算法:
- 运行时间是确定的(或有上界)
- 结果可能有错误(错误概率可控)
- 例子:Miller-Rabin 素性测试、随机化的 min-cut 算法
# 拉斯维加斯示例:随机化快速选择(第 k 小)
def randomized_select(arr: list, k: int) -> int:
"""找第 k 小元素
期望时间: O(n)
最坏时间: O(n²)(但概率极小)
结果: 永远正确
"""
if len(arr) == 1:
return arr[0]
pivot = random.choice(arr)
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
if k <= len(left):
return randomized_select(left, k)
elif k <= len(left) + len(middle):
return pivot
else:
return randomized_select(right, k - len(left) - len(middle))
# 蒙特卡洛示例:估计 π 的值
def monte_carlo_pi(num_samples: int) -> float:
"""蒙特卡洛估计 π
在单位正方形中随机投点,统计落在四分之一圆内的比例
π/4 = (圆内点数) / (总点数)
误差: O(1/√n) — 需要 4 倍样本才能减半误差
"""
inside = 0
for _ in range(num_samples):
x = random.random()
y = random.random()
if x * x + y * y <= 1:
inside += 1
return 4 * inside / num_samples
两类算法的关键区别:
| 特性 | Las Vegas | Monte Carlo |
|---|---|---|
| 正确性 | 保证正确 | 概率正确 |
| 时间 | 随机(但期望有限) | 确定/有上界 |
| 能否互转 | 可(设时间上界,超时重启) | 有条件(需验证函数) |
| 典型应用 | 排序、选择 | 素性测试、近似计算 |
重要转换:任何 Las Vegas 算法都可以转为 Monte Carlo——设定时间上限,超时则输出"不知道"。反过来,如果 Monte Carlo 的输出可以高效验证正确性,则可以转为 Las Vegas——反复运行直到验证通过。
2.2 SkipList 的随机层高
SkipList(跳表)是一种基于随机化的数据结构,由 William Pugh 在 1990 年提出,作为平衡搜索树的概率替代品。
核心思想:每个节点有一个随机决定的"层高"。层高为 h 的节点出现在第 1 层到第 h 层的所有链表中。通过高层的"快速通道"跳过大量节点,实现 O(log n) 的期望搜索时间。
import random
class SkipListNode:
def __init__(self, key, level):
self.key = key
self.forward = [None] * (level + 1) # forward[i] 是第 i 层的下一个节点
class SkipList:
"""跳表实现
期望操作时间: O(log n)
空间: O(n) 期望
"""
MAX_LEVEL = 32
P = 0.5 # 上升概率
def __init__(self):
self.header = SkipListNode(-float('inf'), self.MAX_LEVEL)
self.level = 0 # 当前最高层
def random_level(self) -> int:
"""随机决定新节点的层高
每层以概率 P 继续向上。期望层高 = 1/(1-P)
当 P = 0.5 时,期望层高 = 2
"""
lvl = 0
while random.random() < self.P and lvl < self.MAX_LEVEL:
lvl += 1
return lvl
def search(self, key) -> bool:
"""搜索 key 是否存在"""
current = self.header
for i in range(self.level, -1, -1):
while current.forward[i] and current.forward[i].key < key:
current = current.forward[i]
current = current.forward[0]
return current is not None and current.key == key
def insert(self, key) -> None:
"""插入 key"""
update = [None] * (self.MAX_LEVEL + 1)
current = self.header
for i in range(self.level, -1, -1):
while current.forward[i] and current.forward[i].key < key:
current = current.forward[i]
update[i] = current
new_level = self.random_level()
if new_level > self.level:
for i in range(self.level + 1, new_level + 1):
update[i] = self.header
self.level = new_level
new_node = SkipListNode(key, new_level)
for i in range(new_level + 1):
new_node.forward[i] = update[i].forward[i]
update[i].forward[i] = new_node
为什么 P = 0.5 是好的选择?
- 层高 h 的概率是 P^(h-1) · (1-P) = (1/2)^h(几何分布)
- 期望层高 = 1/(1-P) = 2(P=0.5时)
- n 个节点的最大层高期望为 log_{1/P}(n) = log₂(n)
- 搜索时间:每层期望前进 1/P = 2 步,共 log₂(n) 层,总计 O(log n) 步
P 值的选择权衡:
| P | 期望层高 | 空间倍率 | 搜索比较次数 |
|---|---|---|---|
| 1/2 | 2 | 2n | (log₂n)/1 + 1/P = log₂n + 2 |
| 1/4 | 4/3 | (4/3)n | (log₄n)/1 + 1/P ≈ log₂n/2 + 4 |
| 1/e | e/(e-1) ≈ 1.58 | 1.58n | 最优(但差异很小) |
实践中 P = 0.5 或 P = 0.25 都很常见。Redis 的有序集合(zset)使用 P = 0.25, MAX_LEVEL = 32。
SkipList vs 平衡BST的工程权衡:
| 特性 | SkipList | 红黑树/AVL |
|---|---|---|
| 实现复杂度 | 简单 | 复杂(旋转) |
| 并发友好 | 好(锁粒度小) | 差(旋转影响多节点) |
| 范围查询 | 天然支持 | 需要中序遍历 |
| 缓存友好性 | 差(指针跳跃) | 也差(但稍好) |
| 最坏情况 | O(n)(极小概率) | O(log n) 确定性 |
2.3 随机化在负载均衡中的应用
问题:有 n 台服务器,m 个请求到达。如何分配请求使得负载尽量均匀?
方法一:完全随机分配
每个请求独立均匀随机选一台服务器。根据"球入盒"模型(Balls into Bins),当 m = n 时,最重负载的服务器期望接收 Θ(log n / log log n) 个请求。
方法二:Power of Two Choices(Mitzenmacher 等人, 2001)
每个请求随机选 2 台服务器,然后分配给当前负载较轻的那台。这个简单改进将最大负载从 Θ(log n / log log n) 降到 Θ(log log n) — 指数级改进!
import random
class PowerOfTwoChoices:
"""Power of Two Choices 负载均衡
最大负载从 O(log n / log log n) 降为 O(log log n)
"""
def __init__(self, num_servers: int):
self.loads = [0] * num_servers
self.n = num_servers
def assign_request(self) -> int:
"""分配一个请求,返回选中的服务器编号"""
# 随机选两台
s1 = random.randint(0, self.n - 1)
s2 = random.randint(0, self.n - 1)
while s2 == s1: # 确保选的是不同的两台
s2 = random.randint(0, self.n - 1)
# 选负载较轻的
if self.loads[s1] <= self.loads[s2]:
self.loads[s1] += 1
return s1
else:
self.loads[s2] += 1
return s2
为什么 2 选 1 就能有如此大的改进?
直觉:完全随机时,一旦某台服务器"领先"(负载高于平均),它继续积累新请求的概率与其他服务器相同。但 Power of Two Choices 中,负载越高的服务器越不容易被选中(因为两个随机选择中至少一个可能更轻),形成了自动的负反馈机制。
数学上,这是指数时序分析(exponential backoff)的一个实例:每增加一个"选择"维度,极端值的尾概率就会指数衰减。
方法三:一致性哈希(Consistent Hashing)
当需要保证同一个 key 总是路由到同一台服务器(如缓存)时:
import hashlib
import bisect
class ConsistentHashing:
"""一致性哈希
支持动态增减服务器,每次变更只影响 O(1/n) 的 key
使用虚拟节点确保负载均匀
"""
def __init__(self, num_virtual_nodes: int = 150):
self.num_virtual = num_virtual_nodes
self.ring = [] # 排序的 hash 值列表
self.ring_map = {} # hash值 -> 真实服务器
def _hash(self, key: str) -> int:
return int(hashlib.md5(key.encode()).hexdigest(), 16)
def add_server(self, server: str) -> None:
for i in range(self.num_virtual):
h = self._hash(f"{server}#{i}")
bisect.insort(self.ring, h)
self.ring_map[h] = server
def get_server(self, key: str) -> str:
if not self.ring:
return None
h = self._hash(key)
idx = bisect.bisect_left(self.ring, h) % len(self.ring)
return self.ring_map[self.ring[idx]]
2.4 随机化 Skip:从 O(n) 到 O(√n)
一个有趣的技巧:在水塘抽样中,大部分元素会被拒绝。能否跳过那些注定被拒绝的元素?
Vitter 的 Algorithm Z(1985):不逐个检查,而是直接计算"下一个被选中的元素"的位置,从而跳过中间的所有元素。
import random
import math
def reservoir_sampling_optimized(stream_size: int, k: int):
"""Vitter's Algorithm Z 的简化版本
关键优化: 计算下一个替换位置,跳过中间元素
对于大流(n >> k),这显著减少了随机数生成的次数
"""
reservoir = list(range(k)) # 假设前 k 个元素已在水塘中
# 简化版本: 使用几何分布跳跃
W = math.exp(math.log(random.random()) / k)
next_pos = k + int(math.floor(math.log(random.random()) / math.log(1 - W))) + 1
i = k
while i < stream_size:
if i == next_pos:
# 替换水塘中的一个随机元素
reservoir[random.randint(0, k - 1)] = i # 这里 i 代表流中第 i 个元素
W *= math.exp(math.log(random.random()) / k)
next_pos = i + int(math.floor(math.log(random.random()) / math.log(1 - W))) + 1
i += 1
return reservoir
当 n >> k 时,Algorithm Z 只需要 O(k · (1 + log(n/k))) 次随机数生成,而 Algorithm R 需要 O(n) 次。
2.5 随机化算法的去随机化
一个有趣的理论问题:随机化算法能否"去随机化"(derandomization),即转换为确定性算法而不损失效率?
条件期望法(Method of Conditional Expectations):
如果随机化算法的期望结果满足某个性质(如期望解 ≥ OPT/2),那么可以贪心地逐步固定随机位,使得条件期望始终不下降,最终得到确定性地达到期望值的解。
# 示例: MAX-SAT 的随机化 → 去随机化
# 随机化: 每个变量独立等概率为 True/False,期望满足 m/2 个子句(m 个子句)
# 去随机化: 对每个变量,选择使"条件期望满足子句数"更大的值
def derandomized_maxsat(clauses: list, num_vars: int) -> list:
"""去随机化的 MAX-SAT 近似算法
保证满足至少 m/2 个子句(m 是子句总数)
时间: O(n · m)
"""
assignment = [None] * (num_vars + 1)
for var in range(1, num_vars + 1):
# 计算将 var 设为 True 时的条件期望
exp_true = conditional_expected_satisfied(clauses, assignment, var, True)
exp_false = conditional_expected_satisfied(clauses, assignment, var, False)
assignment[var] = exp_true >= exp_false
return assignment
def conditional_expected_satisfied(clauses, assignment, var, value):
"""计算条件期望: 已固定变量 + 当前决定 + 剩余随机"""
assignment[var] = value
expected = 0
for clause in clauses:
# 检查子句是否已经被满足
satisfied = False
all_determined = True
undetermined_count = 0
for literal in clause:
v = abs(literal)
if assignment[v] is not None:
if (literal > 0 and assignment[v]) or (literal < 0 and not assignment[v]):
satisfied = True
break
else:
all_determined = False
undetermined_count += 1
if satisfied:
expected += 1
elif not all_determined:
# 至少有一个未确定的变量,子句被满足的概率 = 1 - (1/2)^(未确定变量数)
expected += 1 - (0.5 ** undetermined_count)
# else: 所有变量都确定了但子句没满足,贡献 0
assignment[var] = None # 恢复
return expected
2.6 哈希函数中的随机性
通用哈希(Universal Hashing, Carter & Wegman, 1979):
从一个哈希函数族中随机选一个函数,保证对任意两个不同的 key,碰撞概率不超过 1/m(m 是表大小)。
import random
class UniversalHash:
"""通用哈希族
h(x) = ((a*x + b) mod p) mod m
其中 p 是大素数,a ∈ [1, p-1],b ∈ [0, p-1] 随机选择
碰撞概率: 对任意 x ≠ y,P[h(x) = h(y)] ≤ 1/m
"""
def __init__(self, m: int, p: int = (1 << 61) - 1):
self.m = m
self.p = p # 梅森素数 2^61 - 1
self.a = random.randint(1, p - 1)
self.b = random.randint(0, p - 1)
def hash(self, x: int) -> int:
return ((self.a * x + self.b) % self.p) % self.m
为什么通用哈希重要?
固定的哈希函数(如 hash(x) = x % m)总有对手可以构造使所有 key 碰撞的输入。但从哈希族中随机选择后,对手无法预知选了哪个函数,因此无法构造对抗性输入。这使得哈希表的最坏情况期望变为 O(1)。
Level 3 · 规范怎么定义的
3.1 Fisher 和 Yates 的原始方法(1938)
Ronald A. Fisher 和 Frank Yates 在 1938 年出版的《Statistical Tables for Biological, Agricultural and Medical Research》一书中首次描述了这个算法。他们的原始方法是为人工执行设计的:
- 写下数字 1 到 n
- 生成一个 [1, k] 的随机数(k 是剩余数字的个数)
- 划掉第 k 个剩余数字,写到结果序列中
- 重复直到所有数字被划掉
这个"划掉"版本时间复杂度是 O(n²)(每次需要 O(n) 找到第 k 个未划掉的数)。
Knuth (1969) 的改进:Donald Knuth 在《The Art of Computer Programming》第 2 卷中给出了现在广泛使用的 O(n) 原地版本——通过交换而非"划掉"来实现。他在 Algorithm P(第 3.4.2 节)中描述了这个算法,并证明了其正确性。这就是我们今天所说的 Fisher-Yates-Knuth shuffle。
Durstenfeld (1964) 实际上比 Knuth 更早独立发表了这个原地版本(在 Communications of the ACM 上),但 Knuth 的教科书使其广为人知。
形式化正确性要求:
一个洗牌算法是"均匀"的当且仅当对 n! 种排列中的每一种 π,算法输出 π 的概率恰好是 1/n!。
Fisher-Yates 满足这一要求是因为:
- 算法做了 n-1 次独立选择
- 第 i 次有 (n-i+1) 种选择(从 [0, n-i] 中选)
- 总选择序列数 = n × (n-1) × ... × 2 × 1 = n!
- 每个选择序列等概率(各 1/n!)
- 不同选择序列产生不同排列(双射)
3.2 Knuth Shuffle 的实现细节
Knuth 在 TAOCP Vol. 2 (1969) 中的 Algorithm P:
Algorithm P (Shuffling).
Given a sequence of n elements a[0], a[1], ..., a[n-1]:
P1. Set j ← n - 1.
P2. Generate a uniform random number U between 0 and 1.
Set k ← ⌊(j+1)·U⌋. (Now k is uniform on {0, 1, ..., j}.)
P3. Exchange a[k] ↔ a[j].
P4. Decrease j by 1. If j > 0, go to P2. Otherwise, terminate.
Knuth 指出这个算法的关键性质:
-
每种排列恰好对应一个选择序列:(k_{n-1}, k_{n-2}, ..., k_1) 其中 0 ≤ k_j ≤ j。这是"阶乘数系统"(factorial number system)的一种表示。
-
与随机数生成器的耦合:算法的正确性依赖于 U 是真正均匀的。如果使用伪随机数生成器(PRNG),其周期必须至少为 n!。对于 52 张牌(52! ≈ 8×10⁶⁷),需要至少 226 位的 PRNG 状态——标准的 Mersenne Twister(32位种子,2^19937 周期)是足够的,但简单的线性同余生成器(32位状态,周期 2^32)远远不够。
3.3 Vitter 1985 年的水塘抽样论文
Jeffrey S. Vitter 在 1985 年发表了经典论文"Random Sampling with a Reservoir"(ACM Transactions on Mathematical Software),系统地分析了水塘抽样的多种算法:
Algorithm R(最简单版本):
- 前 k 个元素直接入库
- 第 i 个元素(i > k)以概率 k/i 替换库中随机一个
- 每个元素需要一次随机数生成
- 总计 n-k 次随机数生成
Algorithm X:
- 计算到下一个替换的"跳跃距离"
- 跳跃距离 S 满足 P(S > s) = ∏_{j=0}^{s} (1 - k/(t+1+j)),其中 t 是当前位置
- 需要 O(k(1+log(n/k))) 次随机数生成
- 但计算 S 本身需要 O(S) 时间
Algorithm Z(最优):
- 用拒绝采样的方式高效生成跳跃距离
- 关键洞察:当 n >> k 时,跳跃距离近似服从某个可快速采样的分布
- 期望时间 O(k(1+log(n/k)))
- 期望随机数生成次数也是 O(k(1+log(n/k)))
Vitter 的贡献不仅是算法本身,更是严格的概率分析方法。他证明了 Algorithm Z 是最优的(在比较模型下,无法用更少的随机位完成均匀抽样)。
加权水塘抽样的变体:
A-Res 算法(Efraimidis & Spirakis, 2006):每个元素 i 的权重为 w_i,被选中的概率正比于权重。
import heapq
import random
import math
def weighted_reservoir_sampling(stream, k: int) -> list:
"""加权水塘抽样 (A-Res algorithm)
每个元素 (item, weight) 按权重比例被抽中
"""
# 使用最小堆存储 k 个最大的 key
heap = [] # (key, item)
for item, weight in stream:
# key = random^(1/weight) — weight 越大,key 越可能大
key = random.random() ** (1.0 / weight)
if len(heap) < k:
heapq.heappush(heap, (key, item))
elif key > heap[0][0]:
heapq.heapreplace(heap, (key, item))
return [item for _, item in heap]
3.4 随机化算法的计算复杂性分类
BPP(Bounded-Error Probabilistic Polynomial Time):
一个决策问题在 BPP 中,如果存在一个概率多项式时间算法,对任何输入:
- 若答案为 YES,算法以 ≥ 2/3 概率输出 YES
- 若答案为 NO,算法以 ≥ 2/3 概率输出 NO
(2/3 可以通过重复运行 + 多数投票放大到 1 - 2^(-k))
RP(Randomized Polynomial Time):
- 若答案为 YES,算法以 ≥ 1/2 概率输出 YES
- 若答案为 NO,算法总是输出 NO(无假阳性)
ZPP(Zero-Error Probabilistic Polynomial Time):
- 算法总是给出正确答案
- 期望运行时间是多项式的
- ZPP = RP ∩ coRP
关系图:P ⊆ ZPP ⊆ RP ⊆ BPP ⊆ ... ⊆ PSPACE
普遍猜想是 P = BPP(所有有效的随机化算法都可以去随机化),但目前尚未证明。Impagliazzo & Wigderson (1997) 证明了:如果存在某些难度假设(如 E = DTIME(2^O(n)) 中有指数级电路复杂度的问题),则 P = BPP。
3.5 伪随机数生成器的理论基础
密码学安全的 PRNG(CSPRNG)要求:
- 给定前 k 个输出位,无法在多项式时间内预测第 k+1 位(即使拥有无限计算资源除外)
- 给定内部状态,无法恢复之前的输出(前向安全性)
常用 PRNG 的特性比较:
| PRNG | 周期 | 安全性 | 速度 | 适用场景 |
|---|---|---|---|---|
| LCG (线性同余) | 2^32 | 不安全 | 极快 | 不推荐 |
| Mersenne Twister | 2^19937 | 不安全(可逆推) | 快 | 科学计算、模拟 |
| xorshift128+ | 2^128 | 不安全 | 极快 | 游戏、非安全场景 |
| ChaCha20 | 2^256 | 密码学安全 | 中等 | 安全场景 |
| /dev/urandom | 取决于实现 | 安全 | 慢 | 密钥生成 |
Python 中的随机数:
random模块使用 Mersenne Twister,不适合安全场景secrets模块使用操作系统的 CSPRNG,适合安全场景os.urandom(n)直接获取 n 字节的安全随机数
Level 4 · 边界与陷阱
4.1 面试题:打乱数组 (LeetCode #384)
题目:给你一个整数数组 nums,设计算法来打乱一个没有重复元素的数组。实现 Solution 类:
Solution(int[] nums)使用整数数组 nums 初始化对象int[] reset()重设数组到它的初始状态并返回int[] shuffle()返回数组随机打乱后的结果
import random
class Solution:
"""LeetCode 384: Shuffle an Array
使用 Fisher-Yates 洗牌保证均匀性
"""
def __init__(self, nums: list):
self.original = nums[:]
self.current = nums[:]
def reset(self) -> list:
self.current = self.original[:]
return self.current
def shuffle(self) -> list:
n = len(self.current)
for i in range(n - 1, 0, -1):
j = random.randint(0, i)
self.current[i], self.current[j] = self.current[j], self.current[i]
return self.current
面试跟进问题:
-
"如何验证你的洗牌是均匀的?"
- 对小数组(如 [1,2,3])运行百万次,统计每种排列出现的频率,应接近 1/n! = 1/6 ≈ 16.67%
- 卡方检验可以量化偏差
-
"如果数组很大(10^8 元素),如何只洗牌前 k 个?"
- 只执行 Fisher-Yates 的前 k 步即可得到 k 个均匀随机元素
- 这本质上就是"从 n 个中随机选 k 个"的算法
-
"如果不能修改原数组呢?"
- 使用额外 O(n) 空间复制后洗牌
- 或者使用 inside-out 版本(不需要原地修改)
def inside_out_shuffle(original: list) -> list:
"""Inside-out Fisher-Yates(不修改原数组)"""
n = len(original)
result = [0] * n
for i in range(n):
j = random.randint(0, i)
if j != i:
result[i] = result[j]
result[j] = original[i]
return result
4.2 面试题:链表随机节点 (LeetCode #382)
题目:给你一个单链表,随机选择链表的一个节点,并返回相应的节点值。每个节点被选中的概率应该相同。
分析:这是水塘抽样 k=1 的直接应用。链表长度未知(或不想先遍历一次),需要一次遍历完成。
import random
class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next
class Solution:
"""LeetCode 382: Linked List Random Node
水塘抽样 k=1
"""
def __init__(self, head: ListNode):
self.head = head
def getRandom(self) -> int:
result = self.head.val
node = self.head.next
i = 2 # 当前是第 2 个节点
while node:
# 以 1/i 的概率选择当前节点
if random.randint(1, i) == 1:
result = node.val
node = node.next
i += 1
return result
正确性验证(以 3 个节点为例):
设链表为 A → B → C
- A 被选中:需要第 2 步不选 B(概率 1/2)且第 3 步不选 C(概率 2/3)= 1 × 1/2 × 2/3 = 1/3 ✓
- B 被选中:第 2 步选 B(概率 1/2)且第 3 步不选 C(概率 2/3)= 1/2 × 2/3 = 1/3 ✓
- C 被选中:第 3 步选 C(概率 1/3)= 1/3 ✓
面试跟进:
- "如果需要返回 k 个随机节点呢?" — 水塘抽样 k > 1 版本
- "如果可以知道链表长度呢?" — 先数长度 n,再生成 [0, n-1] 的随机数,走到那个位置
- "调用 getRandom 非常频繁怎么办?" — 第一次遍历时把所有值存入数组,之后 O(1) 随机访问
4.3 面试题:随机数索引 (LeetCode #398)
题目:给定一个可能含有重复元素的整数数组,随机输出给定数字的索引。相同数字的每个索引被返回的概率应相等。
import random
class Solution:
"""LeetCode 398: Random Pick Index
方法一: 水塘抽样(O(n) 时间,O(1) 额外空间)
方法二: 预处理哈希表(O(1) 查询,O(n) 空间)
"""
def __init__(self, nums: list):
self.nums = nums
def pick(self, target: int) -> int:
"""水塘抽样方法"""
result = -1
count = 0
for i, num in enumerate(self.nums):
if num == target:
count += 1
# 以 1/count 的概率选择当前索引
if random.randint(1, count) == 1:
result = i
return result
class SolutionOptimized:
"""预处理版本 — 适合多次查询"""
def __init__(self, nums: list):
from collections import defaultdict
self.indices = defaultdict(list)
for i, num in enumerate(nums):
self.indices[num].append(i)
def pick(self, target: int) -> int:
idx_list = self.indices[target]
return random.choice(idx_list)
方法选择:
- 如果 pick 调用次数少(如只调一次):水塘抽样,O(1) 初始化 + O(n) 每次查询
- 如果 pick 调用次数多:预处理哈希表,O(n) 初始化 + O(1) 每次查询
面试中的讨论点:
- 水塘抽样的优势:不需要额外空间,适合流式数据
- 哈希表的优势:查询 O(1)
- Follow-up:如果数组动态变化(可以 insert/delete)怎么办?
4.4 随机化算法的测试与验证
如何确信你的随机化算法是正确的?这比确定性算法困难得多。
方法一:统计检验
from collections import Counter
import scipy.stats
def verify_shuffle(shuffle_func, arr, num_trials=1000000):
"""验证洗牌的均匀性"""
n = len(arr)
expected_count = num_trials / math.factorial(n)
counts = Counter()
for _ in range(num_trials):
result = shuffle_func(arr[:])
counts[tuple(result)] += 1
# 卡方检验
observed = list(counts.values())
expected = [expected_count] * math.factorial(n)
# 如果有些排列没出现过,补零
while len(observed) < math.factorial(n):
observed.append(0)
chi2, p_value = scipy.stats.chisquare(observed, expected)
print(f"卡方统计量: {chi2:.2f}")
print(f"p 值: {p_value:.4f}")
print(f"均匀性: {'通过' if p_value > 0.05 else '不通过'}")
方法二:不变量检查
def verify_reservoir_sampling(sample_func, stream_size, k, num_trials=100000):
"""验证水塘抽样的均匀性"""
counts = [0] * stream_size
for _ in range(num_trials):
sample = sample_func(range(stream_size), k)
for idx in sample:
counts[idx] += 1
# 每个元素被选中的期望次数 = num_trials * k / stream_size
expected = num_trials * k / stream_size
# 检查是否所有计数都接近期望值
max_deviation = max(abs(c - expected) / expected for c in counts)
print(f"最大相对偏差: {max_deviation:.4f}")
print(f"均匀性: {'通过' if max_deviation < 0.05 else '不通过'}")
4.5 实际工程中的随机化应用
1. 随机化在数据库中:
- MySQL 的
ORDER BY RAND()实际上是全表扫描 + 排序(O(n log n)),非常慢 - 更好的做法:用 ID 范围 + 随机 offset,或预生成随机 ID 列表
2. 随机化在分布式系统中:
- Exponential Backoff:网络重试时随机等待时间,避免惊群效应(thundering herd)
- Jitter:在固定间隔上添加随机偏移,防止定时任务同时触发
import random
import time
def exponential_backoff_with_jitter(attempt: int, base_delay: float = 1.0,
max_delay: float = 60.0) -> float:
"""带 jitter 的指数退避
在分布式系统中用于重试逻辑
"""
# Full Jitter (AWS 推荐)
delay = min(max_delay, base_delay * (2 ** attempt))
return random.uniform(0, delay)
3. 随机化在 A/B 测试中:
- 用户 ID 的哈希值决定分组(确定性分流,同一用户总是同组)
- 哈希函数选择影响分组的均匀性
4. 布隆过滤器(Bloom Filter)中的多哈希:
- k 个独立哈希函数 → 假阳性率 = (1 - e^(-kn/m))^k
- 最优 k = (m/n) · ln 2
4.6 面试应答策略
当面试中遇到随机化问题时:
-
识别问题类型:
- "等概率" / "均匀随机" → Fisher-Yates / 水塘抽样
- "数据流" / "未知大小" → 水塘抽样
- "概率正比于权重" → 加权水塘 / 轮盘赌选择
- "避免最坏情况" → 随机化 pivot / 通用哈希
-
证明正确性的标准模式:
- 对于洗牌:证明每种排列概率 = 1/n!
- 对于抽样:证明每个元素被选中概率 = k/n
- 通常用数学归纳法
-
空间-时间权衡:
- 水塘抽样:O(k) 空间,O(n) 时间
- 预处理:O(n) 空间,O(1) 查询时间
- 选择取决于查询频率和内存限制
-
常见追问:
- "如何验证均匀性?" → 统计检验(卡方检验)
- "随机数生成器有偏怎么办?" → 拒绝采样
- "如果数据流特别大怎么优化?" → Vitter's Algorithm Z(跳跃式采样)