线段树与树状数组
第十四章:线段树与树状数组
当你需要对一个数组执行大量的区间查询和单点/区间修改操作时,朴素方法要么查询 O(n),要么修改 O(n),无论如何都有一端是线性的。线段树(Segment Tree)和树状数组(Binary Indexed Tree / Fenwick Tree)是两种专为这类问题设计的数据结构,它们将查询和修改的时间复杂度都降到 O(log n)。
这两种数据结构在竞赛、面试和工程实践中的出现频率极高。理解它们不仅仅是学会一种"高级数据结构",更是深入理解"用空间换时间"和"分治思想在数据结构中的应用"的绝佳案例。
Level 1 · 你需要知道的
14.1 从问题出发:为什么需要线段树
考虑这样一个场景:你有一个长度为 n 的数组 nums,需要反复执行两种操作:
- 区间查询:求
nums[l..r]的和(或最大值、最小值等) - 单点修改:将
nums[i]修改为某个新值
| 方案 | 查询复杂度 | 修改复杂度 | 问题 |
|---|---|---|---|
| 原始数组 | O(n) | O(1) | 查询太慢 |
| 前缀和数组 | O(1) | O(n) | 修改要重建整个前缀和 |
| 线段树 | O(log n) | O(log n) | 两端都是对数级 |
核心洞察:线段树之所以能做到双 O(log n),是因为它将区间 [0, n-1] 递归地二分,形成一棵完全二叉树。每个节点存储对应区间的聚合信息(如区间和)。当你修改某个元素时,只需要更新从叶子到根的路径上的节点,共 O(log n) 个;当你查询某个区间时,最多需要访问 O(log n) 个节点就能覆盖整个查询区间。
14.2 线段树的结构
线段树是一棵二叉树,其结构如下:
- 根节点表示整个数组区间
[0, n-1] - 每个内部节点表示一个区间
[l, r],它的左子节点表示[l, mid],右子节点表示[mid+1, r](其中mid = (l + r) // 2) - 叶子节点表示单个元素
[i, i] - 每个节点存储其对应区间的聚合值(如区间和、区间最大值等)
对于一个长度为 n 的数组,线段树有以下性质:
- 叶子节点数 = n
- 总节点数 ≤ 4n(实际上是 2 * 2^⌈log₂n⌉ - 1,但通常分配 4n 空间足够)
- 树的高度 = ⌈log₂n⌉
为什么分配 4n 空间? 这是一个常见的困惑点。线段树是一棵接近完全二叉树的结构,如果 n 恰好是 2 的幂,则需要 2n - 1 个节点。但如果 n 不是 2 的幂,最后一层会有空洞,此时需要的空间上界是 4n。在实践中直接分配 4n 是安全且简单的做法。
14.3 用数组实现线段树
线段树最常见的实现方式是用数组模拟完全二叉树(类似堆的存储方式):
- 根节点索引为 1
- 节点 i 的左子节点索引为
2*i - 节点 i 的右子节点索引为
2*i + 1 - 节点 i 的父节点索引为
i // 2
class SegmentTree:
"""基础线段树 — 支持单点修改、区间求和查询"""
def __init__(self, nums):
self.n = len(nums)
self.tree = [0] * (4 * self.n) # 分配 4n 空间
if self.n > 0:
self._build(nums, 1, 0, self.n - 1)
def _build(self, nums, node, start, end):
"""递归建树:自底向上聚合"""
if start == end:
# 叶子节点:直接存储原始值
self.tree[node] = nums[start]
return
mid = (start + end) // 2
self._build(nums, 2 * node, start, mid) # 构建左子树
self._build(nums, 2 * node + 1, mid + 1, end) # 构建右子树
# 内部节点:左右子树聚合
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def update(self, idx, val):
"""单点修改:将 nums[idx] 改为 val"""
self._update(1, 0, self.n - 1, idx, val)
def _update(self, node, start, end, idx, val):
if start == end:
# 到达叶子节点,直接修改
self.tree[node] = val
return
mid = (start + end) // 2
if idx <= mid:
self._update(2 * node, start, mid, idx, val)
else:
self._update(2 * node + 1, mid + 1, end, idx, val)
# 回溯时更新父节点
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def query(self, l, r):
"""区间查询:求 nums[l..r] 的和"""
return self._query(1, 0, self.n - 1, l, r)
def _query(self, node, start, end, l, r):
# 当前节点区间完全在查询区间内 — 直接返回
if l <= start and end <= r:
return self.tree[node]
# 当前节点区间与查询区间无交集
if end < l or start > r:
return 0
# 部分重叠 — 递归查询左右子树
mid = (start + end) // 2
left_sum = self._query(2 * node, start, mid, l, r)
right_sum = self._query(2 * node + 1, mid + 1, end, l, r)
return left_sum + right_sum
理解递归过程:
_build:从叶子向上构建。每个叶子存原始值,每个父节点等于左右子节点之和。时间复杂度 O(n)。_update:从根向下找到对应叶子修改,然后回溯更新路径上所有祖先。时间复杂度 O(log n)。_query:从根向下,如果当前节点的区间完全包含在查询区间内则直接返回,否则继续递归。每层最多访问常数个节点,时间复杂度 O(log n)。
为什么 query 是 O(log n)? 关键观察是:对于任意查询区间 [l, r],在线段树的每一层中,最多只有两个节点会被"部分覆盖"(左边界和右边界各一个),其他被访问的节点都是"完全覆盖"而直接返回。因此总访问节点数不超过 4⌈log₂n⌉。
14.4 使用示例
# 创建线段树
nums = [1, 3, 5, 7, 9, 11]
st = SegmentTree(nums)
# 查询区间 [1, 4] 的和:3 + 5 + 7 + 9 = 24
print(st.query(1, 4)) # 输出 24
# 将 nums[2] 修改为 10
st.update(2, 10)
# 再次查询 [1, 4]:3 + 10 + 7 + 9 = 29
print(st.query(1, 4)) # 输出 29
14.5 树状数组(BIT / Fenwick Tree)
树状数组是另一种处理前缀和查询与单点修改的数据结构,它比线段树更简洁、常数更小,但功能相对有限(基础版本只支持前缀查询,不直接支持任意区间最大值/最小值)。
核心思想:树状数组利用二进制表示的最低位(lowbit)将原始数组划分成不同长度的"负责区间",每个位置 i 负责存储从 i - lowbit(i) + 1 到 i 这一段的聚合值。
lowbit 操作:lowbit(x) = x & (-x),它提取 x 的二进制表示中最低位的 1。
x = 12 = 1100₂
-x = ...10100₂ (补码)
x & (-x) = 00100₂ = 4
所以 lowbit(12) = 4,表示位置 12 负责管理 4 个元素的区间 [9, 12]
为什么 lowbit 能工作? 这不是巧合,而是精心设计。Peter Fenwick 在 1994 年的论文中指出,如果我们按照 lowbit 将数组分层管理,那么:
- 更新位置 i 时,需要更新的所有位置恰好是
i, i + lowbit(i), i + 2*lowbit(更新后的i), ...,每次 lowbit 都会增大,所以最多 O(log n) 步 - 查询前缀和
[1..i]时,需要累加的位置恰好是i, i - lowbit(i), i - 2*lowbit(更新后的i), ...,每次向下消去最低位的 1,最多 O(log n) 步
class BinaryIndexedTree:
"""树状数组 — 支持单点修改、前缀和查询"""
def __init__(self, n):
self.n = n
self.tree = [0] * (n + 1) # 1-indexed
@staticmethod
def lowbit(x):
return x & (-x)
def update(self, i, delta):
"""将位置 i 增加 delta(1-indexed)"""
while i <= self.n:
self.tree[i] += delta
i += self.lowbit(i)
def prefix_sum(self, i):
"""查询前缀和 [1..i](1-indexed)"""
s = 0
while i > 0:
s += self.tree[i]
i -= self.lowbit(i)
return s
def range_sum(self, l, r):
"""查询区间和 [l..r](1-indexed)"""
return self.prefix_sum(r) - self.prefix_sum(l - 1)
从原始数组构建树状数组:
def build_from_array(nums):
"""O(n) 建树"""
n = len(nums)
bit = BinaryIndexedTree(n)
# 方法 1:逐个 update,O(n log n)
for i, val in enumerate(nums):
bit.update(i + 1, val)
return bit
def build_from_array_linear(nums):
"""O(n) 建树 — 利用父子关系"""
n = len(nums)
bit = BinaryIndexedTree(n)
for i in range(1, n + 1):
bit.tree[i] += nums[i - 1]
parent = i + (i & (-i))
if parent <= n:
bit.tree[parent] += bit.tree[i]
return bit
14.6 线段树 vs 树状数组:何时用哪个
| 特性 | 线段树 | 树状数组 |
|---|---|---|
| 空间 | 4n | n+1 |
| 代码量 | 多(递归) | 少(循环) |
| 常数因子 | 较大 | 较小(2-5 倍快) |
| 区间修改 | 支持(懒传播) | 支持(差分技巧) |
| 区间最值 | 支持 | 不直接支持 |
| 动态开点 | 支持 | 不支持 |
| 持久化 | 支持 | 困难 |
经验法则:
- 如果只需要前缀和 + 单点修改 → 用树状数组
- 如果需要区间最值、区间修改 + 区间查询 → 用线段树
- 如果追求代码简短且不差功能 → 用树状数组
- 如果题目需要持久化或动态开点 → 只能用线段树
14.7 常见错误与调试
错误 1:数组越界
# 错误:分配 2n 空间
self.tree = [0] * (2 * self.n) # 当 n 不是 2 的幂时会越界!
# 正确:分配 4n 空间
self.tree = [0] * (4 * self.n)
错误 2:树状数组用 0-indexed
# 错误:i 从 0 开始,lowbit(0) = 0,死循环!
def update(self, i, delta):
while i <= self.n:
self.tree[i] += delta
i += self.lowbit(i) # 如果 i=0,lowbit(0)=0,永远不会增加
# 正确:树状数组必须 1-indexed
错误 3:query 中忘记处理无交集的情况
# 错误:没有判断无交集
def _query(self, node, start, end, l, r):
if l <= start and end <= r:
return self.tree[node]
mid = (start + end) // 2
# 如果不判断,即使 l > end 也会继续递归
return self._query(2*node, start, mid, l, r) + \
self._query(2*node+1, mid+1, end, l, r)
# 正确:加上无交集的判断
def _query(self, node, start, end, l, r):
if l <= start and end <= r:
return self.tree[node]
if end < l or start > r:
return 0 # 无交集时返回单位元
mid = (start + end) // 2
return self._query(2*node, start, mid, l, r) + \
self._query(2*node+1, mid+1, end, l, r)
错误 4:update 后忘记回溯更新父节点
# 错误
def _update(self, node, start, end, idx, val):
if start == end:
self.tree[node] = val
return
mid = (start + end) // 2
if idx <= mid:
self._update(2*node, start, mid, idx, val)
else:
self._update(2*node+1, mid+1, end, idx, val)
# 忘记这一行就完蛋了!
# self.tree[node] = self.tree[2*node] + self.tree[2*node+1]
Level 2 · 它是怎么运行的
14.8 懒传播(Lazy Propagation)
基础线段树只支持单点修改。但如果需要区间修改(将 nums[l..r] 的所有元素都加上某个值),朴素做法是逐个修改 r - l + 1 个点,退化为 O(n log n)。
懒传播(Lazy Propagation)的核心思想是延迟更新:当一个区间修改完全覆盖某个节点的区间时,不立即下推到子节点,而是在该节点打上一个"懒标记"(lazy tag),表示"这个节点的子树还有未传播的修改"。只有当后续操作需要访问子节点时,才将懒标记下推(pushdown)。
为什么这是正确的? 因为如果后续操作不访问子节点,那么子节点的精确值就不需要。当需要时再计算,这就是"按需计算"(lazy evaluation)的思想——在函数式编程中也有类似概念。
class SegmentTreeLazy:
"""带懒传播的线段树 — 支持区间修改、区间查询"""
def __init__(self, nums):
self.n = len(nums)
self.tree = [0] * (4 * self.n)
self.lazy = [0] * (4 * self.n) # 懒标记数组
if self.n > 0:
self._build(nums, 1, 0, self.n - 1)
def _build(self, nums, node, start, end):
if start == end:
self.tree[node] = nums[start]
return
mid = (start + end) // 2
self._build(nums, 2 * node, start, mid)
self._build(nums, 2 * node + 1, mid + 1, end)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def _pushdown(self, node, start, end):
"""将懒标记下推到子节点"""
if self.lazy[node] != 0:
mid = (start + end) // 2
left_len = mid - start + 1
right_len = end - mid
# 更新左子节点的值和懒标记
self.tree[2 * node] += self.lazy[node] * left_len
self.lazy[2 * node] += self.lazy[node]
# 更新右子节点的值和懒标记
self.tree[2 * node + 1] += self.lazy[node] * right_len
self.lazy[2 * node + 1] += self.lazy[node]
# 清除当前节点的懒标记
self.lazy[node] = 0
def range_update(self, l, r, val):
"""区间修改:将 nums[l..r] 都加上 val"""
self._range_update(1, 0, self.n - 1, l, r, val)
def _range_update(self, node, start, end, l, r, val):
# 完全覆盖 — 打懒标记,不下推
if l <= start and end <= r:
self.tree[node] += val * (end - start + 1)
self.lazy[node] += val
return
# 无交集
if end < l or start > r:
return
# 部分覆盖 — 先下推已有的懒标记,再递归
self._pushdown(node, start, end)
mid = (start + end) // 2
self._range_update(2 * node, start, mid, l, r, val)
self._range_update(2 * node + 1, mid + 1, end, l, r, val)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def range_query(self, l, r):
"""区间查询:求 nums[l..r] 的和"""
return self._range_query(1, 0, self.n - 1, l, r)
def _range_query(self, node, start, end, l, r):
if l <= start and end <= r:
return self.tree[node]
if end < l or start > r:
return 0
# 关键:查询子节点前必须先下推
self._pushdown(node, start, end)
mid = (start + end) // 2
return self._range_query(2 * node, start, mid, l, r) + \
self._range_query(2 * node + 1, mid + 1, end, l, r)
懒传播的时间复杂度分析:
区间修改和区间查询的复杂度仍为 O(log n)。证明与单点修改类似——每层最多访问常数个"部分覆盖"的节点,而下推操作本身是 O(1) 的。
懒传播的正确性条件:
懒标记必须满足可合并性:多次累积的懒标记能正确合并。例如"区间加"的懒标记可以直接累加;"区间赋值"的懒标记需要用新值覆盖旧值。如果同时有"区间加"和"区间乘"两种操作,懒标记的设计就更复杂了(需要维护 (乘数, 加数) 对,且下推时要注意顺序)。
14.9 懒传播的应用示例
# 场景:频繁的区间加与区间求和
nums = [1, 3, 5, 7, 9, 11]
st = SegmentTreeLazy(nums)
# 将 [1, 4] 所有元素加 3:[1, 6, 8, 10, 12, 11]
st.range_update(1, 4, 3)
# 查询 [2, 5] 的和:8 + 10 + 12 + 11 = 41
print(st.range_query(2, 5)) # 输出 41
# 将 [0, 2] 所有元素加 -1:[0, 5, 7, 10, 12, 11]
st.range_update(0, 2, -1)
# 查询 [0, 5] 的和:0 + 5 + 7 + 10 + 12 + 11 = 45
print(st.range_query(0, 5)) # 输出 45
14.10 线段树的动态开点
标准线段树用数组存储,需要预先分配 4n 空间。但如果值域很大(比如 [0, 10^9]),开不了 4×10^9 的数组。动态开点的思想是:不预先分配所有节点,而是在需要时才创建。
class DynamicSegTree:
"""动态开点线段树 — 值域可以很大"""
def __init__(self):
self.tree = {} # node_id -> value
self.lazy = {} # node_id -> lazy_value
self.left_child = {} # node_id -> left_child_id
self.right_child = {} # node_id -> right_child_id
self.cnt = 0 # 节点计数器
self.root = self._new_node()
def _new_node(self):
self.cnt += 1
self.tree[self.cnt] = 0
self.lazy[self.cnt] = 0
return self.cnt
def _pushdown(self, node, left_len, right_len):
if self.lazy.get(node, 0) == 0:
return
# 确保子节点存在
if node not in self.left_child:
self.left_child[node] = self._new_node()
if node not in self.right_child:
self.right_child[node] = self._new_node()
lc, rc = self.left_child[node], self.right_child[node]
self.tree[lc] += self.lazy[node] * left_len
self.lazy[lc] = self.lazy.get(lc, 0) + self.lazy[node]
self.tree[rc] += self.lazy[node] * right_len
self.lazy[rc] = self.lazy.get(rc, 0) + self.lazy[node]
self.lazy[node] = 0
def update(self, node, start, end, l, r, val):
"""区间 [l,r] 加 val"""
if l <= start and end <= r:
self.tree[node] += val * (end - start + 1)
self.lazy[node] = self.lazy.get(node, 0) + val
return
mid = (start + end) // 2
self._pushdown(node, mid - start + 1, end - mid)
if node not in self.left_child:
self.left_child[node] = self._new_node()
if node not in self.right_child:
self.right_child[node] = self._new_node()
if l <= mid:
self.update(self.left_child[node], start, mid, l, r, val)
if r > mid:
self.update(self.right_child[node], mid + 1, end, l, r, val)
self.tree[node] = self.tree[self.left_child[node]] + \
self.tree[self.right_child[node]]
def query(self, node, start, end, l, r):
"""查询区间 [l,r] 的和"""
if l <= start and end <= r:
return self.tree.get(node, 0)
if end < l or start > r:
return 0
mid = (start + end) // 2
self._pushdown(node, mid - start + 1, end - mid)
res = 0
if l <= mid and node in self.left_child:
res += self.query(self.left_child[node], start, mid, l, r)
if r > mid and node in self.right_child:
res += self.query(self.right_child[node], mid + 1, end, l, r)
return res
动态开点的空间复杂度:每次修改操作最多创建 O(log n) 个新节点(从根到叶子的路径)。如果有 q 次操作,总空间为 O(q log n)。
典型应用场景:
- 值域很大但操作次数有限(如坐标范围
[0, 10^9],但只有 10^5 次操作) - 需要持久化的线段树(每次修改创建新节点而不修改旧节点)
14.11 树状数组求逆序对
逆序对问题:给定数组 nums,求满足 i < j 且 nums[i] > nums[j] 的 (i, j) 对数。
这是树状数组的经典应用之一。思路是:从右向左遍历数组,对于每个元素 nums[i],查询"当前已经出现过的、比 nums[i] 小的元素有多少个",这就是以 nums[i] 结尾的逆序对数。
def count_inversions(nums):
"""用树状数组统计逆序对数"""
if not nums:
return 0
# 离散化:将值映射到 [1, n] 的范围
sorted_unique = sorted(set(nums))
rank = {v: i + 1 for i, v in enumerate(sorted_unique)}
n = len(sorted_unique)
bit = BinaryIndexedTree(n)
inversions = 0
# 从右向左遍历
for i in range(len(nums) - 1, -1, -1):
r = rank[nums[i]]
# 查询比 nums[i] 小的元素个数(已经在 BIT 中的)
inversions += bit.prefix_sum(r - 1)
# 将 nums[i] 加入 BIT
bit.update(r, 1)
return inversions
# 示例
print(count_inversions([5, 2, 6, 1])) # 输出 5
# 逆序对:(5,2), (5,1), (2,1), (6,1) — 等等让我数一下
# (0,1): 5>2 ✓
# (0,3): 5>1 ✓
# (1,3): 2>1 ✓
# (2,3): 6>1 ✓
# 共 4 对... 不对,再看 (5,2,6,1)
# i=0,j=1: 5>2 ✓
# i=0,j=3: 5>1 ✓
# i=1,j=3: 2>1 ✓
# i=2,j=3: 6>1 ✓
# 结果是 4,不是 5。上面的代码是正确的,我的手算有误。
让我纠正上面的示例:
print(count_inversions([2, 4, 1, 3, 5])) # 输出 3
# 逆序对:(2,1), (4,1), (4,3)
时间复杂度:O(n log n)(离散化 O(n log n) + 遍历中每次 BIT 操作 O(log n))。
14.12 二维树状数组
二维树状数组用于处理矩阵上的"单点修改 + 矩形区域求和"问题。
class BIT2D:
"""二维树状数组"""
def __init__(self, rows, cols):
self.rows = rows
self.cols = cols
self.tree = [[0] * (cols + 1) for _ in range(rows + 1)]
def update(self, x, y, delta):
"""将位置 (x, y) 增加 delta(1-indexed)"""
i = x
while i <= self.rows:
j = y
while j <= self.cols:
self.tree[i][j] += delta
j += j & (-j)
i += i & (-i)
def prefix_sum(self, x, y):
"""查询 (1,1) 到 (x,y) 的矩形和"""
s = 0
i = x
while i > 0:
j = y
while j > 0:
s += self.tree[i][j]
j -= j & (-j)
i -= i & (-i)
return s
def range_sum(self, x1, y1, x2, y2):
"""查询 (x1,y1) 到 (x2,y2) 的矩形和(容斥原理)"""
return (self.prefix_sum(x2, y2)
- self.prefix_sum(x1 - 1, y2)
- self.prefix_sum(x2, y1 - 1)
+ self.prefix_sum(x1 - 1, y1 - 1))
二维树状数组的复杂度:单次修改/查询为 O(log m × log n),空间为 O(m × n)。
与二维前缀和的对比:
- 如果矩阵不会修改,用二维前缀和即可(O(1) 查询)
- 如果矩阵需要修改,用二维树状数组(O(log² n) 查询和修改)
14.13 线段树与树状数组的内部执行模型
理解这两种数据结构的关键在于理解它们如何"分解"区间。
线段树的区间分解:
以 n=8 为例,查询 [2, 7] 时,线段树如何将这个查询分解:
[0,7]
/ \
[0,3] [4,7]
/ \ / \
[0,1] [2,3] [4,5] [6,7]
/ \ / \ / \ / \
[0][1] [2][3] [4][5] [6][7]
查询 [2,7]:
- [0,7] 部分覆盖,递归
- [0,3] 部分覆盖,递归
- [0,1] 无交集,返回 0
- [2,3] 完全覆盖,返回 tree[[2,3]] ✓
- [4,7] 完全覆盖,返回 tree[[4,7]] ✓
结果 = tree[[2,3]] + tree[[4,7]]
只需要 2 个节点的值就覆盖了 [2,7]!
树状数组的前缀分解:
查询 prefix_sum(7) 时(1-indexed):
i = 7 = 111₂, tree[7] 管理 [7,7](lowbit=1)
i -= 1 → i = 6 = 110₂, tree[6] 管理 [5,6](lowbit=2)
i -= 2 → i = 4 = 100₂, tree[4] 管理 [1,4](lowbit=4)
i -= 4 → i = 0,结束
prefix_sum(7) = tree[7] + tree[6] + tree[4]
= [7,7] + [5,6] + [1,4]
= [1,7] ✓
这个分解的数学本质是:将 i 的二进制表示逐位消去。i 有多少个 1,就需要累加多少个 tree 值。最多 ⌊log₂n⌋ + 1 个。
14.14 更多变体:区间修改 + 单点查询(差分 BIT)
如果需要"区间修改"但只需"单点查询",可以对差分数组建树状数组:
class DiffBIT:
"""差分树状数组:区间修改 + 单点查询"""
def __init__(self, n):
self.bit = BinaryIndexedTree(n)
def range_add(self, l, r, val):
"""将 [l, r] 所有元素加 val(1-indexed)"""
self.bit.update(l, val)
if r + 1 <= self.bit.n:
self.bit.update(r + 1, -val)
def point_query(self, i):
"""查询位置 i 的当前值"""
return self.bit.prefix_sum(i)
原理:对差分数组 d 建 BIT。range_add(l, r, val) 相当于 d[l] += val, d[r+1] -= val。而 point_query(i) = prefix_sum(d, i) = 原数组位置 i 的累积增量。
如果需要"区间修改 + 区间查询",可以用两个树状数组配合:
class RangeAddRangeSum:
"""用两个 BIT 实现区间加 + 区间求和"""
def __init__(self, nums):
self.n = len(nums)
self.bit1 = BinaryIndexedTree(self.n) # 存储 d[i]
self.bit2 = BinaryIndexedTree(self.n) # 存储 i * d[i]
# 原始前缀和
self.prefix = [0] * (self.n + 1)
for i in range(self.n):
self.prefix[i + 1] = self.prefix[i] + nums[i]
def range_add(self, l, r, val):
"""区间 [l, r] 加 val(1-indexed)"""
self.bit1.update(l, val)
self.bit1.update(r + 1, -val)
self.bit2.update(l, val * l)
self.bit2.update(r + 1, -val * (r + 1))
def prefix_sum(self, i):
"""查询前缀和 [1..i]"""
return self.prefix[i] + \
(i + 1) * self.bit1.prefix_sum(i) - \
self.bit2.prefix_sum(i)
def range_sum(self, l, r):
"""查询区间和 [l, r]"""
return self.prefix_sum(r) - self.prefix_sum(l - 1)
推导过程:设原始数组为 a,差分数组为 d(d[i] = a[i] - a[i-1])。
前缀和 a[1] + a[2] + ... + a[i] 可以展开为:
Σ(k=1 to i) a[k] = Σ(k=1 to i) Σ(j=1 to k) d[j]
= Σ(j=1 to i) d[j] * (i - j + 1)
= (i+1) * Σ(j=1 to i) d[j] - Σ(j=1 to i) j * d[j]
所以只需要维护两个前缀和:Σd[j] 和 Σj*d[j],各用一个 BIT 即可。
Level 3 · 规范怎么定义的
14.15 Peter Fenwick 的原始论文(1994)
树状数组的正式提出来自 Peter Fenwick 于 1994 年发表的论文"A New Data Structure for Cumulative Frequency Tables"(Software: Practice and Experience, Vol. 24(3), pp. 327-336, March 1994)。
论文的背景:Fenwick 的原始动机是算术编码(Arithmetic Coding)中的累积频率表(Cumulative Frequency Table)。在数据压缩中,需要频繁执行两种操作:
- 更新某个符号的频率(单点修改)
- 查询所有频率不超过某个阈值的符号的累积频率(前缀和查询)
传统的做法是维护一个平坦的累积频率表,修改为 O(n);或者使用平衡树,但实现复杂。Fenwick 提出了一种巧妙的方法,利用整数的二进制表示,在 O(log n) 时间内完成两种操作,且实现极为简洁。
关键设计决策:
Fenwick 在论文中解释了为什么选择"去掉最低位的 1"作为核心操作:
"The key observation is that any positive integer can be uniquely represented as a sum of distinct powers of 2. The tree structure exploits this by assigning responsibility for cumulative frequencies in a manner that mirrors binary representation."
也就是说,任何正整数 i 都可以唯一分解为 2 的幂之和。树状数组的结构正是利用这一点——位置 i 负责管理 lowbit(i) 个连续元素的聚合值,而从位置 i 出发通过反复减去 lowbit(i),恰好能把 [1, i] 分解为 O(log i) 个不重叠的段。
Fenwick 论文的核心定理:
定理:对于任意正整数 i,从 i 开始反复执行 i -= lowbit(i) 操作,得到的序列 i₁, i₂, ..., iₖ 满足:
- 序列严格递减,终止于 0
- 区间
[iⱼ - lowbit(iⱼ) + 1, iⱼ](j = 1, ..., k)互不重叠 - 这些区间的并集恰好是
[1, i] - k ≤ ⌊log₂i⌋ + 1
证明思路:每次减去 lowbit(i) 就是消去 i 的二进制表示中最低位的 1。如果 i 有 k 个 1,则恰好 k 步后到达 0。每步对应的区间长度恰好是被消去的那个 2 的幂。由于 i 的二进制分解是唯一的,这些区间互不重叠且覆盖 [1, i]。
14.16 线段树的历史
线段树(Segment Tree)的历史比树状数组更早,但其发展路径并非一篇论文定义一切,而是由多个计算几何和算法领域的研究者逐步完善。
1977 年 — Jon Bentley:Bentley 在"Solutions to Klee's Rectangle Problems"(Carnegie-Mellon University, 1977)中描述了一种基于区间分割的树结构,用于解决矩形面积联合(Klee's Rectangle Problem)。这是线段树思想的早期来源之一。
1979 年 — Bentley 和 Wood:"An Optimal Worst-Case Algorithm for Reporting Intersections of Rectangles"(IEEE Transactions on Computers, 1980)进一步完善了区间树的理论基础。
1980 年代 — McCreight, Willard 等人:线段树逐渐演化为一种通用的区间管理结构,在计算几何中被广泛使用(线段求交、窗口查询等)。
重要澄清:学术文献中的"Segment Tree"通常指一种特定的区间存储结构(每个区间被分解存储在 O(log n) 个节点中),与竞赛编程中的"线段树"(实质上是区间树 / Statistic Tree)有微妙区别。竞赛中的线段树更接近于一种"分治树"或"递归区间聚合树"。但在当代算法社区中,"Segment Tree"已经约定俗成地指代竞赛中的这种结构。
14.17 持久化线段树(主席树)
持久化线段树(Persistent Segment Tree)能够保存线段树在每次修改前后的所有历史版本。它由黄嘉泰(网名"主席")在竞赛中推广使用,因此在中文社区被称为"主席树"。
核心思想:每次修改时,不修改原有节点,而是创建新的节点。由于每次修改只影响从根到某个叶子的一条路径(O(log n) 个节点),所以只需创建 O(log n) 个新节点,其余节点与旧版本共享。
class PersistentSegTree:
"""持久化线段树 — 每次修改创建新版本"""
def __init__(self, max_nodes=2000000):
self.lc = [0] * max_nodes # 左子节点 id
self.rc = [0] * max_nodes # 右子节点 id
self.val = [0] * max_nodes # 节点值
self.cnt = 0 # 当前使用的节点数
self.roots = [] # 每个版本的根节点 id
def _new_node(self):
self.cnt += 1
return self.cnt
def build(self, start, end):
"""建立初始版本"""
node = self._new_node()
if start == end:
return node
mid = (start + end) // 2
self.lc[node] = self.build(start, mid)
self.rc[node] = self.build(mid + 1, end)
return node
def update(self, prev, start, end, pos, delta):
"""基于版本 prev 创建新版本,在 pos 处加 delta"""
node = self._new_node()
self.lc[node] = self.lc[prev]
self.rc[node] = self.rc[prev]
self.val[node] = self.val[prev] + delta
if start == end:
return node
mid = (start + end) // 2
if pos <= mid:
self.lc[node] = self.update(self.lc[prev], start, mid, pos, delta)
else:
self.rc[node] = self.update(self.rc[prev], mid + 1, end, pos, delta)
return node
def query(self, left_root, right_root, start, end, k):
"""查询两个版本之间的第 k 小值"""
if start == end:
return start
mid = (start + end) // 2
left_count = self.val[self.lc[right_root]] - self.val[self.lc[left_root]]
if k <= left_count:
return self.query(self.lc[left_root], self.lc[right_root],
start, mid, k)
else:
return self.query(self.rc[left_root], self.rc[right_root],
mid + 1, end, k - left_count)
经典应用:区间第 k 小值
给定数组,查询任意区间 [l, r] 中第 k 小的元素。
思路:
- 对值域建线段树,每个叶子表示一个值
- 按顺序插入元素,第 i 个版本表示"前 i 个元素"的线段树
- 查询
[l, r]的第 k 小 = 用版本 r 减去版本 l-1,得到区间[l, r]的频率分布,然后二分即可
空间复杂度:初始版本 O(n) 节点,每次修改新增 O(log n) 节点。n 次修改后总空间 O(n log n)。
14.18 线段树在竞赛中的地位
线段树是算法竞赛(ACM-ICPC, Codeforces, AtCoder, NOI/IOI)中最核心的数据结构之一。根据 Codeforces 的题目标签统计,带有"segment tree"或"data structures"标签的题目占比超过 20%,尤其在 Div. 1 级别(rating 1900+)的题目中更为集中。
为什么线段树如此重要? 因为它具有极强的通用性和可扩展性:
-
通用性:只要满足"区间可合并"条件的信息,都可以用线段树维护。区间和、区间最值、区间 GCD、区间矩阵乘积、区间哈希值等等。
-
可扩展性:
- 加上懒传播 → 支持区间修改
- 加上持久化 → 支持历史版本查询
- 加上动态开点 → 支持大值域
- 线段树合并 → 解决树上问题
- 李超线段树 → 维护凸包/直线集合
- 线段树二分 → 不需要额外的二分搜索
Codeforces 评级与线段树的关系:
| 评级区间 | 线段树出现频率 | 典型难度 |
|---|---|---|
| 800-1200 | 极少 | 前缀和就够 |
| 1200-1600 | 偶尔 | 基础线段树/BIT |
| 1600-2000 | 经常 | 懒传播/BIT 变体 |
| 2000-2400 | 非常频繁 | 持久化/线段树合并 |
| 2400+ | 几乎必考 | 各种线段树变体组合 |
Level 4 · 边界与陷阱
14.19 面试题:区域和检索 — 数组可修改(LeetCode #307)
题目:实现 NumArray 类:
NumArray(nums)— 用数组nums初始化update(index, val)— 将nums[index]修改为valsumRange(left, right)— 返回nums[left..right]的和
分析:这是线段树/树状数组最基础的应用场景。两种都可以,但树状数组代码更短。
class NumArray:
"""解法 1:树状数组"""
def __init__(self, nums):
self.n = len(nums)
self.nums = nums[:]
self.bit = [0] * (self.n + 1)
for i in range(self.n):
self._add(i + 1, nums[i])
def _add(self, i, delta):
while i <= self.n:
self.bit[i] += delta
i += i & (-i)
def _prefix(self, i):
s = 0
while i > 0:
s += self.bit[i]
i -= i & (-i)
return s
def update(self, index, val):
delta = val - self.nums[index]
self.nums[index] = val
self._add(index + 1, delta)
def sumRange(self, left, right):
return self._prefix(right + 1) - self._prefix(left)
class NumArray2:
"""解法 2:线段树"""
def __init__(self, nums):
self.n = len(nums)
self.tree = [0] * (4 * self.n)
if self.n > 0:
self._build(nums, 1, 0, self.n - 1)
def _build(self, nums, node, s, e):
if s == e:
self.tree[node] = nums[s]
return
mid = (s + e) // 2
self._build(nums, 2*node, s, mid)
self._build(nums, 2*node+1, mid+1, e)
self.tree[node] = self.tree[2*node] + self.tree[2*node+1]
def update(self, index, val):
self._update(1, 0, self.n-1, index, val)
def _update(self, node, s, e, idx, val):
if s == e:
self.tree[node] = val
return
mid = (s + e) // 2
if idx <= mid:
self._update(2*node, s, mid, idx, val)
else:
self._update(2*node+1, mid+1, e, idx, val)
self.tree[node] = self.tree[2*node] + self.tree[2*node+1]
def sumRange(self, left, right):
return self._query(1, 0, self.n-1, left, right)
def _query(self, node, s, e, l, r):
if l <= s and e <= r:
return self.tree[node]
if e < l or s > r:
return 0
mid = (s + e) // 2
return self._query(2*node, s, mid, l, r) + \
self._query(2*node+1, mid+1, e, l, r)
面试要点:
- 两种解法都要能写出来
- 解释为什么不能用前缀和(修改代价 O(n))
- BIT 解法注意
update中要计算差值delta = val - nums[index] - 时间复杂度:构建 O(n log n),每次操作 O(log n)
14.20 面试题:计算右侧小于当前元素的个数(LeetCode #315)
题目:给定数组 nums,返回一个数组 counts,其中 counts[i] 是 nums[i] 右侧严格小于 nums[i] 的元素个数。
分析:这本质上是"逆序对"的变体——对于每个位置 i,统计它与右侧形成的逆序对数。
def countSmaller(nums):
"""从右向左遍历,用 BIT 统计"""
if not nums:
return []
# 离散化
sorted_unique = sorted(set(nums))
rank = {v: i + 1 for i, v in enumerate(sorted_unique)}
n = len(sorted_unique)
bit = [0] * (n + 1)
def update(i):
while i <= n:
bit[i] += 1
i += i & (-i)
def query(i):
s = 0
while i > 0:
s += bit[i]
i -= i & (-i)
return s
result = []
for i in range(len(nums) - 1, -1, -1):
r = rank[nums[i]]
# 查询已经在 BIT 中的、比 nums[i] 小的元素个数
result.append(query(r - 1))
update(r)
return result[::-1]
# 示例
print(countSmaller([5, 2, 6, 1])) # [2, 1, 1, 0]
另一种解法:用线段树
def countSmaller_segtree(nums):
"""用线段树统计每个元素右侧比它小的元素数"""
if not nums:
return []
# 离散化
sorted_unique = sorted(set(nums))
rank = {v: i for i, v in enumerate(sorted_unique)}
m = len(sorted_unique)
# 线段树维护每个值出现的次数
tree = [0] * (4 * m)
def update(node, s, e, idx):
if s == e:
tree[node] += 1
return
mid = (s + e) // 2
if idx <= mid:
update(2*node, s, mid, idx)
else:
update(2*node+1, mid+1, e, idx)
tree[node] = tree[2*node] + tree[2*node+1]
def query(node, s, e, l, r):
if l > r:
return 0
if l <= s and e <= r:
return tree[node]
if e < l or s > r:
return 0
mid = (s + e) // 2
return query(2*node, s, mid, l, r) + query(2*node+1, mid+1, e, l, r)
result = []
for i in range(len(nums) - 1, -1, -1):
r = rank[nums[i]]
# 查询 [0, r-1] 范围内的元素个数
result.append(query(1, 0, m-1, 0, r-1))
update(1, 0, m-1, r)
return result[::-1]
面试拓展:
- 还可以用归并排序解这题(在 merge 过程中统计逆序)
- BIT 解法代码最短,面试中推荐
- 离散化是关键步骤,面试官可能会追问为什么需要离散化(因为 BIT 的下标必须是正整数且不能太大)
14.21 线段树的常见实现错误总结
在竞赛和面试中,线段树的实现有很多微妙的错误来源。以下是按严重程度排序的 Top 10 错误:
1. 空间分配不足
# 错误:n=5 时, 完全二叉树需要的空间超过 2*5
tree = [0] * (2 * n) # 不够!
# 正确
tree = [0] * (4 * n)
为什么?当 n=5 时,高度 h = ⌈log₂5⌉ = 3,满二叉树节点数 = 2^4 - 1 = 15,所以需要 16(1-indexed)。2n = 10,明显不够。
2. 懒传播中 pushdown 遗漏
# 错误:在 update 中没有先 pushdown
def _range_update(self, node, start, end, l, r, val):
if l <= start and end <= r:
self.tree[node] += val * (end - start + 1)
self.lazy[node] += val
return
# 忘记 pushdown!子节点可能还有旧的 lazy 没处理
mid = (start + end) // 2
self._range_update(2*node, start, mid, l, r, val)
...
3. pushdown 中区间长度计算错误
# 错误
left_len = mid - start # 少了 1!
right_len = end - mid - 1 # 少了 1!
# 正确
left_len = mid - start + 1
right_len = end - mid
4. 叶子节点判断条件错误
# 错误:用 node >= n 作为叶子判断
if node >= self.n: # 这对于数组实现的线段树是错的
# 正确:用 start == end 判断
if start == end:
...
5. 区间查询的边界返回值错误
# 求区间最小值时
if end < l or start > r:
return 0 # 错误!应该返回无穷大
# 正确
if end < l or start > r:
return float('inf') # 最小值查询的单位元
不同操作的单位元:求和 → 0,最小值 → +∞,最大值 → -∞,GCD → 0,乘积 → 1。
6. update 时"修改为"和"增加"混淆
# "修改为 val" 的版本
def _update(self, node, start, end, idx, val):
if start == end:
self.tree[node] = val # 直接赋值
return
...
# "增加 delta" 的版本
def _update(self, node, start, end, idx, delta):
if start == end:
self.tree[node] += delta # 累加
return
...
面试中一定要先明确题目要求的是"改为"还是"加上"。
7. 树状数组的 0-indexing 死循环
# 如果传入 i=0,lowbit(0)=0,update 和 query 都会死循环
# 树状数组必须从 1 开始!
8. 离散化后忘记用 rank 替代原值
# 错误
bit.update(nums[i], 1) # nums[i] 可能是 10^9,数组开不了这么大!
# 正确
bit.update(rank[nums[i]], 1) # rank 在 [1, n] 范围内
9. 线段树合并时忘记处理空节点
# 错误
def merge(a, b, s, e):
new_node = ...
new_node.left = merge(a.left, b.left, ...)
new_node.right = merge(a.right, b.right, ...)
# 如果 a 或 b 是 None 会出错
# 正确
def merge(a, b, s, e):
if a is None: return b
if b is None: return a
...
10. 多种懒传播操作的优先级问题
当同时存在"区间赋值"和"区间加"两种操作时,pushdown 的顺序至关重要:
# 如果先赋值后加:子节点 = 赋值的值 + 后来加的值
# 如果先加后赋值:子节点 = 赋值的值(加的被覆盖了)
# 必须记录操作的时间顺序,或者将操作统一为 (乘, 加) 对
14.22 什么时候用前缀和,什么时候用线段树
这是面试中最常被问到的"选择"类问题。
| 场景 | 最佳选择 | 原因 |
|---|---|---|
| 数组不变,多次区间求和 | 前缀和 | O(1) 查询,O(n) 预处理 |
| 数组不变,多次区间最值 | Sparse Table | O(1) 查询,O(n log n) 预处理 |
| 单点修改 + 区间求和 | 树状数组 | O(log n) 两端,代码短 |
| 单点修改 + 区间最值 | 线段树 | 树状数组不直接支持 |
| 区间修改 + 区间求和 | 线段树(懒传播) | 树状数组也行但更复杂 |
| 区间修改 + 区间最值 | 线段树(懒传播) | 唯一方案 |
| 大值域 + 稀疏操作 | 动态开点线段树 | 不需预分配全部空间 |
| 需要历史版本 | 持久化线段树 | 树状数组无法持久化 |
| 二维区域修改+查询 | 二维线段树/BIT | 按需选择 |
面试回答模板:
"首先我会分析操作类型:如果数组是静态的(不修改),前缀和就够了,O(1) 查询;如果需要修改,就看修改是单点还是区间。单点修改 + 区间和用树状数组最简洁;区间修改或需要维护最值就用线段树加懒传播。具体到这道题..."
14.23 高级技巧:线段树上二分
线段树的一个强大特性是可以在树上直接进行二分搜索,避免额外的 O(log n) 开销。
例如:"找到最左边的位置 i,使得 prefix_sum(0, i) >= target"。
朴素方法:在外部二分 + 线段树查询 → O(log² n)
线段树上二分:O(log n)
def find_first(self, target):
"""找到最小的 i 使得 prefix[0..i] >= target"""
return self._find_first(1, 0, self.n - 1, target)
def _find_first(self, node, start, end, target):
if start == end:
return start if self.tree[node] >= target else -1
mid = (start + end) // 2
# 如果左子树的和 >= target,答案在左子树
if self.tree[2 * node] >= target:
return self._find_first(2 * node, start, mid, target)
else:
# 否则答案在右子树,但目标要减去左子树的贡献
return self._find_first(2 * node + 1, mid + 1, end,
target - self.tree[2 * node])
为什么这比 "外部二分 + 查询" 快? 外部二分需要 O(log n) 次查询,每次查询 O(log n),总共 O(log² n)。而线段树上二分从根到叶子只走一条路径,O(log n)。这在时间紧张的竞赛题中可能是 AC 和 TLE 的区别。
14.24 实战经验总结
竞赛中的线段树技巧清单:
-
确定单位元:开始写之前先确定你的
query在无交集时返回什么值(和→0, 最小值→INF, 最大值→-INF, GCD→0, 异或→0) -
确定合并方式:内部节点的值如何从子节点计算。对于复杂信息(如最大子段和),合并函数可能需要维护多个字段(前缀最大和、后缀最大和、区间和、最大子段和)
-
懒标记的设计:
- 标记的初始值("无标记"状态)是什么
- 两个标记如何合并
- 标记如何应用到节点值
- 这三个问题想清楚,代码就不会出错
-
测试建议:
- 用暴力 O(n) 解法对拍
- 特别测试 n=1, n=2 的边界
- 测试操作区间 = 全区间的情况
- 测试操作区间 = 单点的情况
# 对拍模板
import random
def stress_test(n=100, q=1000, max_val=100):
"""暴力对拍验证线段树正确性"""
nums = [random.randint(0, max_val) for _ in range(n)]
st = SegmentTree(nums[:])
for _ in range(q):
op = random.randint(0, 1)
if op == 0: # update
idx = random.randint(0, n-1)
val = random.randint(0, max_val)
nums[idx] = val
st.update(idx, val)
else: # query
l = random.randint(0, n-1)
r = random.randint(l, n-1)
expected = sum(nums[l:r+1])
got = st.query(l, r)
assert expected == got, f"Mismatch at [{l},{r}]: expected {expected}, got {got}"
print("All tests passed!")
stress_test()
面试中的线段树思维:
面试官问线段树相关问题时,通常不期望你写出完美的线段树代码(太长了),而是考察你的思维方式:
- 能否识别出这是一个"区间查询 + 修改"的问题
- 能否选择合适的数据结构(前缀和 / BIT / 线段树)
- 能否分析时间空间复杂度
- 对于 BIT 能否写出完整代码
- 对于线段树能否描述清楚建树、查询、修改的过程
不需要在白板上写出完整的带懒传播的线段树——那更适合竞赛选手在 IDE 中完成。但基础线段树和树状数组的代码你应该能在 10 分钟内写出来。
14.25 延伸阅读与练习
必练题目:
| 题目 | 难度 | 考点 |
|---|---|---|
| LeetCode #303 区域和检索(不可变) | Easy | 前缀和(对比基线) |
| LeetCode #307 区域和检索(可修改) | Medium | BIT / 线段树基础 |
| LeetCode #315 计算右侧小于当前元素的个数 | Hard | BIT + 离散化 |
| LeetCode #493 翻转对 | Hard | BIT / 归并排序 |
| LeetCode #327 区间和的个数 | Hard | BIT + 离散化 |
| LeetCode #218 天际线问题 | Hard | 线段树 / 扫描线 |
| LeetCode #699 掉落的方块 | Hard | 线段树 / 坐标压缩 |
进阶题目(Codeforces):
| 题目 | 考点 |
|---|---|
| CF #558E | 线段树 + 计数排序 |
| CF #914D | 线段树维护 GCD |
| CF #877E | 线段树 + 异或操作 |
| CF #242E | 线段树 + 位运算 |
推荐学习资源:
- Peter Fenwick, "A New Data Structure for Cumulative Frequency Tables", Software: Practice and Experience, 1994
- Thomas H. Cormen et al., "Introduction to Algorithms" (CLRS), Chapter 14 (Augmenting Data Structures)
- Competitive Programmer's Handbook (Antti Laaksonen), Chapter 9: Range Queries
- CP-Algorithms (cp-algorithms.com) — Segment Tree 专题(有大量变体和代码模板)