第 14 章

线段树与树状数组

第十四章:线段树与树状数组

当你需要对一个数组执行大量的区间查询单点/区间修改操作时,朴素方法要么查询 O(n),要么修改 O(n),无论如何都有一端是线性的。线段树(Segment Tree)和树状数组(Binary Indexed Tree / Fenwick Tree)是两种专为这类问题设计的数据结构,它们将查询和修改的时间复杂度都降到 O(log n)。

这两种数据结构在竞赛、面试和工程实践中的出现频率极高。理解它们不仅仅是学会一种"高级数据结构",更是深入理解"用空间换时间"和"分治思想在数据结构中的应用"的绝佳案例。


Level 1 · 你需要知道的

14.1 从问题出发:为什么需要线段树

考虑这样一个场景:你有一个长度为 n 的数组 nums,需要反复执行两种操作:

  1. 区间查询:求 nums[l..r] 的和(或最大值、最小值等)
  2. 单点修改:将 nums[i] 修改为某个新值
方案 查询复杂度 修改复杂度 问题
原始数组 O(n) O(1) 查询太慢
前缀和数组 O(1) O(n) 修改要重建整个前缀和
线段树 O(log n) O(log n) 两端都是对数级

核心洞察:线段树之所以能做到双 O(log n),是因为它将区间 [0, n-1] 递归地二分,形成一棵完全二叉树。每个节点存储对应区间的聚合信息(如区间和)。当你修改某个元素时,只需要更新从叶子到根的路径上的节点,共 O(log n) 个;当你查询某个区间时,最多需要访问 O(log n) 个节点就能覆盖整个查询区间。

14.2 线段树的结构

线段树是一棵二叉树,其结构如下:

对于一个长度为 n 的数组,线段树有以下性质:

为什么分配 4n 空间? 这是一个常见的困惑点。线段树是一棵接近完全二叉树的结构,如果 n 恰好是 2 的幂,则需要 2n - 1 个节点。但如果 n 不是 2 的幂,最后一层会有空洞,此时需要的空间上界是 4n。在实践中直接分配 4n 是安全且简单的做法。

14.3 用数组实现线段树

线段树最常见的实现方式是用数组模拟完全二叉树(类似堆的存储方式):

class SegmentTree:
    """基础线段树 — 支持单点修改、区间求和查询"""
    
    def __init__(self, nums):
        self.n = len(nums)
        self.tree = [0] * (4 * self.n)  # 分配 4n 空间
        if self.n > 0:
            self._build(nums, 1, 0, self.n - 1)
    
    def _build(self, nums, node, start, end):
        """递归建树:自底向上聚合"""
        if start == end:
            # 叶子节点:直接存储原始值
            self.tree[node] = nums[start]
            return
        mid = (start + end) // 2
        self._build(nums, 2 * node, start, mid)       # 构建左子树
        self._build(nums, 2 * node + 1, mid + 1, end) # 构建右子树
        # 内部节点:左右子树聚合
        self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
    
    def update(self, idx, val):
        """单点修改:将 nums[idx] 改为 val"""
        self._update(1, 0, self.n - 1, idx, val)
    
    def _update(self, node, start, end, idx, val):
        if start == end:
            # 到达叶子节点,直接修改
            self.tree[node] = val
            return
        mid = (start + end) // 2
        if idx <= mid:
            self._update(2 * node, start, mid, idx, val)
        else:
            self._update(2 * node + 1, mid + 1, end, idx, val)
        # 回溯时更新父节点
        self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
    
    def query(self, l, r):
        """区间查询:求 nums[l..r] 的和"""
        return self._query(1, 0, self.n - 1, l, r)
    
    def _query(self, node, start, end, l, r):
        # 当前节点区间完全在查询区间内 — 直接返回
        if l <= start and end <= r:
            return self.tree[node]
        # 当前节点区间与查询区间无交集
        if end < l or start > r:
            return 0
        # 部分重叠 — 递归查询左右子树
        mid = (start + end) // 2
        left_sum = self._query(2 * node, start, mid, l, r)
        right_sum = self._query(2 * node + 1, mid + 1, end, l, r)
        return left_sum + right_sum

理解递归过程

  1. _build:从叶子向上构建。每个叶子存原始值,每个父节点等于左右子节点之和。时间复杂度 O(n)。
  2. _update:从根向下找到对应叶子修改,然后回溯更新路径上所有祖先。时间复杂度 O(log n)。
  3. _query:从根向下,如果当前节点的区间完全包含在查询区间内则直接返回,否则继续递归。每层最多访问常数个节点,时间复杂度 O(log n)。

为什么 query 是 O(log n)? 关键观察是:对于任意查询区间 [l, r],在线段树的每一层中,最多只有两个节点会被"部分覆盖"(左边界和右边界各一个),其他被访问的节点都是"完全覆盖"而直接返回。因此总访问节点数不超过 4⌈log₂n⌉。

14.4 使用示例

# 创建线段树
nums = [1, 3, 5, 7, 9, 11]
st = SegmentTree(nums)

# 查询区间 [1, 4] 的和:3 + 5 + 7 + 9 = 24
print(st.query(1, 4))  # 输出 24

# 将 nums[2] 修改为 10
st.update(2, 10)

# 再次查询 [1, 4]:3 + 10 + 7 + 9 = 29
print(st.query(1, 4))  # 输出 29

14.5 树状数组(BIT / Fenwick Tree)

树状数组是另一种处理前缀和查询与单点修改的数据结构,它比线段树更简洁、常数更小,但功能相对有限(基础版本只支持前缀查询,不直接支持任意区间最大值/最小值)。

核心思想:树状数组利用二进制表示的最低位(lowbit)将原始数组划分成不同长度的"负责区间",每个位置 i 负责存储从 i - lowbit(i) + 1i 这一段的聚合值。

lowbit 操作lowbit(x) = x & (-x),它提取 x 的二进制表示中最低位的 1。

x = 12 = 1100₂
-x = ...10100₂ (补码)
x & (-x) = 00100₂ = 4
所以 lowbit(12) = 4,表示位置 12 负责管理 4 个元素的区间 [9, 12]

为什么 lowbit 能工作? 这不是巧合,而是精心设计。Peter Fenwick 在 1994 年的论文中指出,如果我们按照 lowbit 将数组分层管理,那么:

class BinaryIndexedTree:
    """树状数组 — 支持单点修改、前缀和查询"""
    
    def __init__(self, n):
        self.n = n
        self.tree = [0] * (n + 1)  # 1-indexed
    
    @staticmethod
    def lowbit(x):
        return x & (-x)
    
    def update(self, i, delta):
        """将位置 i 增加 delta(1-indexed)"""
        while i <= self.n:
            self.tree[i] += delta
            i += self.lowbit(i)
    
    def prefix_sum(self, i):
        """查询前缀和 [1..i](1-indexed)"""
        s = 0
        while i > 0:
            s += self.tree[i]
            i -= self.lowbit(i)
        return s
    
    def range_sum(self, l, r):
        """查询区间和 [l..r](1-indexed)"""
        return self.prefix_sum(r) - self.prefix_sum(l - 1)

从原始数组构建树状数组

def build_from_array(nums):
    """O(n) 建树"""
    n = len(nums)
    bit = BinaryIndexedTree(n)
    # 方法 1:逐个 update,O(n log n)
    for i, val in enumerate(nums):
        bit.update(i + 1, val)
    return bit

def build_from_array_linear(nums):
    """O(n) 建树 — 利用父子关系"""
    n = len(nums)
    bit = BinaryIndexedTree(n)
    for i in range(1, n + 1):
        bit.tree[i] += nums[i - 1]
        parent = i + (i & (-i))
        if parent <= n:
            bit.tree[parent] += bit.tree[i]
    return bit

14.6 线段树 vs 树状数组:何时用哪个

特性 线段树 树状数组
空间 4n n+1
代码量 多(递归) 少(循环)
常数因子 较大 较小(2-5 倍快)
区间修改 支持(懒传播) 支持(差分技巧)
区间最值 支持 不直接支持
动态开点 支持 不支持
持久化 支持 困难

经验法则

14.7 常见错误与调试

错误 1:数组越界

# 错误:分配 2n 空间
self.tree = [0] * (2 * self.n)  # 当 n 不是 2 的幂时会越界!

# 正确:分配 4n 空间
self.tree = [0] * (4 * self.n)

错误 2:树状数组用 0-indexed

# 错误:i 从 0 开始,lowbit(0) = 0,死循环!
def update(self, i, delta):
    while i <= self.n:
        self.tree[i] += delta
        i += self.lowbit(i)  # 如果 i=0,lowbit(0)=0,永远不会增加

# 正确:树状数组必须 1-indexed

错误 3:query 中忘记处理无交集的情况

# 错误:没有判断无交集
def _query(self, node, start, end, l, r):
    if l <= start and end <= r:
        return self.tree[node]
    mid = (start + end) // 2
    # 如果不判断,即使 l > end 也会继续递归
    return self._query(2*node, start, mid, l, r) + \
           self._query(2*node+1, mid+1, end, l, r)

# 正确:加上无交集的判断
def _query(self, node, start, end, l, r):
    if l <= start and end <= r:
        return self.tree[node]
    if end < l or start > r:
        return 0  # 无交集时返回单位元
    mid = (start + end) // 2
    return self._query(2*node, start, mid, l, r) + \
           self._query(2*node+1, mid+1, end, l, r)

错误 4:update 后忘记回溯更新父节点

# 错误
def _update(self, node, start, end, idx, val):
    if start == end:
        self.tree[node] = val
        return
    mid = (start + end) // 2
    if idx <= mid:
        self._update(2*node, start, mid, idx, val)
    else:
        self._update(2*node+1, mid+1, end, idx, val)
    # 忘记这一行就完蛋了!
    # self.tree[node] = self.tree[2*node] + self.tree[2*node+1]

Level 2 · 它是怎么运行的

14.8 懒传播(Lazy Propagation)

基础线段树只支持单点修改。但如果需要区间修改(将 nums[l..r] 的所有元素都加上某个值),朴素做法是逐个修改 r - l + 1 个点,退化为 O(n log n)。

懒传播(Lazy Propagation)的核心思想是延迟更新:当一个区间修改完全覆盖某个节点的区间时,不立即下推到子节点,而是在该节点打上一个"懒标记"(lazy tag),表示"这个节点的子树还有未传播的修改"。只有当后续操作需要访问子节点时,才将懒标记下推(pushdown)。

为什么这是正确的? 因为如果后续操作不访问子节点,那么子节点的精确值就不需要。当需要时再计算,这就是"按需计算"(lazy evaluation)的思想——在函数式编程中也有类似概念。

class SegmentTreeLazy:
    """带懒传播的线段树 — 支持区间修改、区间查询"""
    
    def __init__(self, nums):
        self.n = len(nums)
        self.tree = [0] * (4 * self.n)
        self.lazy = [0] * (4 * self.n)  # 懒标记数组
        if self.n > 0:
            self._build(nums, 1, 0, self.n - 1)
    
    def _build(self, nums, node, start, end):
        if start == end:
            self.tree[node] = nums[start]
            return
        mid = (start + end) // 2
        self._build(nums, 2 * node, start, mid)
        self._build(nums, 2 * node + 1, mid + 1, end)
        self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
    
    def _pushdown(self, node, start, end):
        """将懒标记下推到子节点"""
        if self.lazy[node] != 0:
            mid = (start + end) // 2
            left_len = mid - start + 1
            right_len = end - mid
            
            # 更新左子节点的值和懒标记
            self.tree[2 * node] += self.lazy[node] * left_len
            self.lazy[2 * node] += self.lazy[node]
            
            # 更新右子节点的值和懒标记
            self.tree[2 * node + 1] += self.lazy[node] * right_len
            self.lazy[2 * node + 1] += self.lazy[node]
            
            # 清除当前节点的懒标记
            self.lazy[node] = 0
    
    def range_update(self, l, r, val):
        """区间修改:将 nums[l..r] 都加上 val"""
        self._range_update(1, 0, self.n - 1, l, r, val)
    
    def _range_update(self, node, start, end, l, r, val):
        # 完全覆盖 — 打懒标记,不下推
        if l <= start and end <= r:
            self.tree[node] += val * (end - start + 1)
            self.lazy[node] += val
            return
        # 无交集
        if end < l or start > r:
            return
        # 部分覆盖 — 先下推已有的懒标记,再递归
        self._pushdown(node, start, end)
        mid = (start + end) // 2
        self._range_update(2 * node, start, mid, l, r, val)
        self._range_update(2 * node + 1, mid + 1, end, l, r, val)
        self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
    
    def range_query(self, l, r):
        """区间查询:求 nums[l..r] 的和"""
        return self._range_query(1, 0, self.n - 1, l, r)
    
    def _range_query(self, node, start, end, l, r):
        if l <= start and end <= r:
            return self.tree[node]
        if end < l or start > r:
            return 0
        # 关键:查询子节点前必须先下推
        self._pushdown(node, start, end)
        mid = (start + end) // 2
        return self._range_query(2 * node, start, mid, l, r) + \
               self._range_query(2 * node + 1, mid + 1, end, l, r)

懒传播的时间复杂度分析

区间修改和区间查询的复杂度仍为 O(log n)。证明与单点修改类似——每层最多访问常数个"部分覆盖"的节点,而下推操作本身是 O(1) 的。

懒传播的正确性条件

懒标记必须满足可合并性:多次累积的懒标记能正确合并。例如"区间加"的懒标记可以直接累加;"区间赋值"的懒标记需要用新值覆盖旧值。如果同时有"区间加"和"区间乘"两种操作,懒标记的设计就更复杂了(需要维护 (乘数, 加数) 对,且下推时要注意顺序)。

14.9 懒传播的应用示例

# 场景:频繁的区间加与区间求和
nums = [1, 3, 5, 7, 9, 11]
st = SegmentTreeLazy(nums)

# 将 [1, 4] 所有元素加 3:[1, 6, 8, 10, 12, 11]
st.range_update(1, 4, 3)

# 查询 [2, 5] 的和:8 + 10 + 12 + 11 = 41
print(st.range_query(2, 5))  # 输出 41

# 将 [0, 2] 所有元素加 -1:[0, 5, 7, 10, 12, 11]
st.range_update(0, 2, -1)

# 查询 [0, 5] 的和:0 + 5 + 7 + 10 + 12 + 11 = 45
print(st.range_query(0, 5))  # 输出 45

14.10 线段树的动态开点

标准线段树用数组存储,需要预先分配 4n 空间。但如果值域很大(比如 [0, 10^9]),开不了 4×10^9 的数组。动态开点的思想是:不预先分配所有节点,而是在需要时才创建。

class DynamicSegTree:
    """动态开点线段树 — 值域可以很大"""
    
    def __init__(self):
        self.tree = {}   # node_id -> value
        self.lazy = {}   # node_id -> lazy_value
        self.left_child = {}   # node_id -> left_child_id
        self.right_child = {}  # node_id -> right_child_id
        self.cnt = 0     # 节点计数器
        self.root = self._new_node()
    
    def _new_node(self):
        self.cnt += 1
        self.tree[self.cnt] = 0
        self.lazy[self.cnt] = 0
        return self.cnt
    
    def _pushdown(self, node, left_len, right_len):
        if self.lazy.get(node, 0) == 0:
            return
        # 确保子节点存在
        if node not in self.left_child:
            self.left_child[node] = self._new_node()
        if node not in self.right_child:
            self.right_child[node] = self._new_node()
        
        lc, rc = self.left_child[node], self.right_child[node]
        self.tree[lc] += self.lazy[node] * left_len
        self.lazy[lc] = self.lazy.get(lc, 0) + self.lazy[node]
        self.tree[rc] += self.lazy[node] * right_len
        self.lazy[rc] = self.lazy.get(rc, 0) + self.lazy[node]
        self.lazy[node] = 0
    
    def update(self, node, start, end, l, r, val):
        """区间 [l,r] 加 val"""
        if l <= start and end <= r:
            self.tree[node] += val * (end - start + 1)
            self.lazy[node] = self.lazy.get(node, 0) + val
            return
        mid = (start + end) // 2
        self._pushdown(node, mid - start + 1, end - mid)
        if node not in self.left_child:
            self.left_child[node] = self._new_node()
        if node not in self.right_child:
            self.right_child[node] = self._new_node()
        if l <= mid:
            self.update(self.left_child[node], start, mid, l, r, val)
        if r > mid:
            self.update(self.right_child[node], mid + 1, end, l, r, val)
        self.tree[node] = self.tree[self.left_child[node]] + \
                          self.tree[self.right_child[node]]
    
    def query(self, node, start, end, l, r):
        """查询区间 [l,r] 的和"""
        if l <= start and end <= r:
            return self.tree.get(node, 0)
        if end < l or start > r:
            return 0
        mid = (start + end) // 2
        self._pushdown(node, mid - start + 1, end - mid)
        res = 0
        if l <= mid and node in self.left_child:
            res += self.query(self.left_child[node], start, mid, l, r)
        if r > mid and node in self.right_child:
            res += self.query(self.right_child[node], mid + 1, end, l, r)
        return res

动态开点的空间复杂度:每次修改操作最多创建 O(log n) 个新节点(从根到叶子的路径)。如果有 q 次操作,总空间为 O(q log n)。

典型应用场景

14.11 树状数组求逆序对

逆序对问题:给定数组 nums,求满足 i < jnums[i] > nums[j] 的 (i, j) 对数。

这是树状数组的经典应用之一。思路是:从右向左遍历数组,对于每个元素 nums[i],查询"当前已经出现过的、比 nums[i] 小的元素有多少个",这就是以 nums[i] 结尾的逆序对数。

def count_inversions(nums):
    """用树状数组统计逆序对数"""
    if not nums:
        return 0
    
    # 离散化:将值映射到 [1, n] 的范围
    sorted_unique = sorted(set(nums))
    rank = {v: i + 1 for i, v in enumerate(sorted_unique)}
    
    n = len(sorted_unique)
    bit = BinaryIndexedTree(n)
    inversions = 0
    
    # 从右向左遍历
    for i in range(len(nums) - 1, -1, -1):
        r = rank[nums[i]]
        # 查询比 nums[i] 小的元素个数(已经在 BIT 中的)
        inversions += bit.prefix_sum(r - 1)
        # 将 nums[i] 加入 BIT
        bit.update(r, 1)
    
    return inversions

# 示例
print(count_inversions([5, 2, 6, 1]))  # 输出 5
# 逆序对:(5,2), (5,1), (2,1), (6,1) — 等等让我数一下
# (0,1): 5>2 ✓
# (0,3): 5>1 ✓
# (1,3): 2>1 ✓
# (2,3): 6>1 ✓
# 共 4 对... 不对,再看 (5,2,6,1)
# i=0,j=1: 5>2 ✓
# i=0,j=3: 5>1 ✓
# i=1,j=3: 2>1 ✓
# i=2,j=3: 6>1 ✓
# 结果是 4,不是 5。上面的代码是正确的,我的手算有误。

让我纠正上面的示例:

print(count_inversions([2, 4, 1, 3, 5]))  # 输出 3
# 逆序对:(2,1), (4,1), (4,3)

时间复杂度:O(n log n)(离散化 O(n log n) + 遍历中每次 BIT 操作 O(log n))。

14.12 二维树状数组

二维树状数组用于处理矩阵上的"单点修改 + 矩形区域求和"问题。

class BIT2D:
    """二维树状数组"""
    
    def __init__(self, rows, cols):
        self.rows = rows
        self.cols = cols
        self.tree = [[0] * (cols + 1) for _ in range(rows + 1)]
    
    def update(self, x, y, delta):
        """将位置 (x, y) 增加 delta(1-indexed)"""
        i = x
        while i <= self.rows:
            j = y
            while j <= self.cols:
                self.tree[i][j] += delta
                j += j & (-j)
            i += i & (-i)
    
    def prefix_sum(self, x, y):
        """查询 (1,1) 到 (x,y) 的矩形和"""
        s = 0
        i = x
        while i > 0:
            j = y
            while j > 0:
                s += self.tree[i][j]
                j -= j & (-j)
            i -= i & (-i)
        return s
    
    def range_sum(self, x1, y1, x2, y2):
        """查询 (x1,y1) 到 (x2,y2) 的矩形和(容斥原理)"""
        return (self.prefix_sum(x2, y2)
                - self.prefix_sum(x1 - 1, y2)
                - self.prefix_sum(x2, y1 - 1)
                + self.prefix_sum(x1 - 1, y1 - 1))

二维树状数组的复杂度:单次修改/查询为 O(log m × log n),空间为 O(m × n)。

与二维前缀和的对比

14.13 线段树与树状数组的内部执行模型

理解这两种数据结构的关键在于理解它们如何"分解"区间。

线段树的区间分解

以 n=8 为例,查询 [2, 7] 时,线段树如何将这个查询分解:

               [0,7]
           /          \
       [0,3]          [4,7]
      /     \        /     \
   [0,1]  [2,3]  [4,5]  [6,7]
   / \    / \    / \    / \
 [0][1] [2][3] [4][5] [6][7]

查询 [2,7]:
- [0,7] 部分覆盖,递归
  - [0,3] 部分覆盖,递归
    - [0,1] 无交集,返回 0
    - [2,3] 完全覆盖,返回 tree[[2,3]]  ✓
  - [4,7] 完全覆盖,返回 tree[[4,7]]  ✓

结果 = tree[[2,3]] + tree[[4,7]]
只需要 2 个节点的值就覆盖了 [2,7]!

树状数组的前缀分解

查询 prefix_sum(7) 时(1-indexed):

i = 7 = 111₂, tree[7] 管理 [7,7](lowbit=1)
i -= 1 → i = 6 = 110₂, tree[6] 管理 [5,6](lowbit=2)
i -= 2 → i = 4 = 100₂, tree[4] 管理 [1,4](lowbit=4)
i -= 4 → i = 0,结束

prefix_sum(7) = tree[7] + tree[6] + tree[4]
             = [7,7] + [5,6] + [1,4]
             = [1,7]  ✓

这个分解的数学本质是:将 i 的二进制表示逐位消去。i 有多少个 1,就需要累加多少个 tree 值。最多 ⌊log₂n⌋ + 1 个。

14.14 更多变体:区间修改 + 单点查询(差分 BIT)

如果需要"区间修改"但只需"单点查询",可以对差分数组建树状数组:

class DiffBIT:
    """差分树状数组:区间修改 + 单点查询"""
    
    def __init__(self, n):
        self.bit = BinaryIndexedTree(n)
    
    def range_add(self, l, r, val):
        """将 [l, r] 所有元素加 val(1-indexed)"""
        self.bit.update(l, val)
        if r + 1 <= self.bit.n:
            self.bit.update(r + 1, -val)
    
    def point_query(self, i):
        """查询位置 i 的当前值"""
        return self.bit.prefix_sum(i)

原理:对差分数组 d 建 BIT。range_add(l, r, val) 相当于 d[l] += val, d[r+1] -= val。而 point_query(i) = prefix_sum(d, i) = 原数组位置 i 的累积增量。

如果需要"区间修改 + 区间查询",可以用两个树状数组配合:

class RangeAddRangeSum:
    """用两个 BIT 实现区间加 + 区间求和"""
    
    def __init__(self, nums):
        self.n = len(nums)
        self.bit1 = BinaryIndexedTree(self.n)  # 存储 d[i]
        self.bit2 = BinaryIndexedTree(self.n)  # 存储 i * d[i]
        # 原始前缀和
        self.prefix = [0] * (self.n + 1)
        for i in range(self.n):
            self.prefix[i + 1] = self.prefix[i] + nums[i]
    
    def range_add(self, l, r, val):
        """区间 [l, r] 加 val(1-indexed)"""
        self.bit1.update(l, val)
        self.bit1.update(r + 1, -val)
        self.bit2.update(l, val * l)
        self.bit2.update(r + 1, -val * (r + 1))
    
    def prefix_sum(self, i):
        """查询前缀和 [1..i]"""
        return self.prefix[i] + \
               (i + 1) * self.bit1.prefix_sum(i) - \
               self.bit2.prefix_sum(i)
    
    def range_sum(self, l, r):
        """查询区间和 [l, r]"""
        return self.prefix_sum(r) - self.prefix_sum(l - 1)

推导过程:设原始数组为 a,差分数组为 d(d[i] = a[i] - a[i-1])。

前缀和 a[1] + a[2] + ... + a[i] 可以展开为:

Σ(k=1 to i) a[k] = Σ(k=1 to i) Σ(j=1 to k) d[j]
                  = Σ(j=1 to i) d[j] * (i - j + 1)
                  = (i+1) * Σ(j=1 to i) d[j] - Σ(j=1 to i) j * d[j]

所以只需要维护两个前缀和:Σd[j]Σj*d[j],各用一个 BIT 即可。


Level 3 · 规范怎么定义的

14.15 Peter Fenwick 的原始论文(1994)

树状数组的正式提出来自 Peter Fenwick 于 1994 年发表的论文"A New Data Structure for Cumulative Frequency Tables"(Software: Practice and Experience, Vol. 24(3), pp. 327-336, March 1994)。

论文的背景:Fenwick 的原始动机是算术编码(Arithmetic Coding)中的累积频率表(Cumulative Frequency Table)。在数据压缩中,需要频繁执行两种操作:

  1. 更新某个符号的频率(单点修改)
  2. 查询所有频率不超过某个阈值的符号的累积频率(前缀和查询)

传统的做法是维护一个平坦的累积频率表,修改为 O(n);或者使用平衡树,但实现复杂。Fenwick 提出了一种巧妙的方法,利用整数的二进制表示,在 O(log n) 时间内完成两种操作,且实现极为简洁。

关键设计决策

Fenwick 在论文中解释了为什么选择"去掉最低位的 1"作为核心操作:

"The key observation is that any positive integer can be uniquely represented as a sum of distinct powers of 2. The tree structure exploits this by assigning responsibility for cumulative frequencies in a manner that mirrors binary representation."

也就是说,任何正整数 i 都可以唯一分解为 2 的幂之和。树状数组的结构正是利用这一点——位置 i 负责管理 lowbit(i) 个连续元素的聚合值,而从位置 i 出发通过反复减去 lowbit(i),恰好能把 [1, i] 分解为 O(log i) 个不重叠的段。

Fenwick 论文的核心定理

定理:对于任意正整数 i,从 i 开始反复执行 i -= lowbit(i) 操作,得到的序列 i₁, i₂, ..., iₖ 满足:

  1. 序列严格递减,终止于 0
  2. 区间 [iⱼ - lowbit(iⱼ) + 1, iⱼ](j = 1, ..., k)互不重叠
  3. 这些区间的并集恰好是 [1, i]
  4. k ≤ ⌊log₂i⌋ + 1

证明思路:每次减去 lowbit(i) 就是消去 i 的二进制表示中最低位的 1。如果 i 有 k 个 1,则恰好 k 步后到达 0。每步对应的区间长度恰好是被消去的那个 2 的幂。由于 i 的二进制分解是唯一的,这些区间互不重叠且覆盖 [1, i]

14.16 线段树的历史

线段树(Segment Tree)的历史比树状数组更早,但其发展路径并非一篇论文定义一切,而是由多个计算几何和算法领域的研究者逐步完善。

1977 年 — Jon Bentley:Bentley 在"Solutions to Klee's Rectangle Problems"(Carnegie-Mellon University, 1977)中描述了一种基于区间分割的树结构,用于解决矩形面积联合(Klee's Rectangle Problem)。这是线段树思想的早期来源之一。

1979 年 — Bentley 和 Wood:"An Optimal Worst-Case Algorithm for Reporting Intersections of Rectangles"(IEEE Transactions on Computers, 1980)进一步完善了区间树的理论基础。

1980 年代 — McCreight, Willard 等人:线段树逐渐演化为一种通用的区间管理结构,在计算几何中被广泛使用(线段求交、窗口查询等)。

重要澄清:学术文献中的"Segment Tree"通常指一种特定的区间存储结构(每个区间被分解存储在 O(log n) 个节点中),与竞赛编程中的"线段树"(实质上是区间树 / Statistic Tree)有微妙区别。竞赛中的线段树更接近于一种"分治树"或"递归区间聚合树"。但在当代算法社区中,"Segment Tree"已经约定俗成地指代竞赛中的这种结构。

14.17 持久化线段树(主席树)

持久化线段树(Persistent Segment Tree)能够保存线段树在每次修改前后的所有历史版本。它由黄嘉泰(网名"主席")在竞赛中推广使用,因此在中文社区被称为"主席树"。

核心思想:每次修改时,不修改原有节点,而是创建新的节点。由于每次修改只影响从根到某个叶子的一条路径(O(log n) 个节点),所以只需创建 O(log n) 个新节点,其余节点与旧版本共享。

class PersistentSegTree:
    """持久化线段树 — 每次修改创建新版本"""
    
    def __init__(self, max_nodes=2000000):
        self.lc = [0] * max_nodes   # 左子节点 id
        self.rc = [0] * max_nodes   # 右子节点 id
        self.val = [0] * max_nodes  # 节点值
        self.cnt = 0                # 当前使用的节点数
        self.roots = []             # 每个版本的根节点 id
    
    def _new_node(self):
        self.cnt += 1
        return self.cnt
    
    def build(self, start, end):
        """建立初始版本"""
        node = self._new_node()
        if start == end:
            return node
        mid = (start + end) // 2
        self.lc[node] = self.build(start, mid)
        self.rc[node] = self.build(mid + 1, end)
        return node
    
    def update(self, prev, start, end, pos, delta):
        """基于版本 prev 创建新版本,在 pos 处加 delta"""
        node = self._new_node()
        self.lc[node] = self.lc[prev]
        self.rc[node] = self.rc[prev]
        self.val[node] = self.val[prev] + delta
        if start == end:
            return node
        mid = (start + end) // 2
        if pos <= mid:
            self.lc[node] = self.update(self.lc[prev], start, mid, pos, delta)
        else:
            self.rc[node] = self.update(self.rc[prev], mid + 1, end, pos, delta)
        return node
    
    def query(self, left_root, right_root, start, end, k):
        """查询两个版本之间的第 k 小值"""
        if start == end:
            return start
        mid = (start + end) // 2
        left_count = self.val[self.lc[right_root]] - self.val[self.lc[left_root]]
        if k <= left_count:
            return self.query(self.lc[left_root], self.lc[right_root],
                            start, mid, k)
        else:
            return self.query(self.rc[left_root], self.rc[right_root],
                            mid + 1, end, k - left_count)

经典应用:区间第 k 小值

给定数组,查询任意区间 [l, r] 中第 k 小的元素。

思路:

  1. 对值域建线段树,每个叶子表示一个值
  2. 按顺序插入元素,第 i 个版本表示"前 i 个元素"的线段树
  3. 查询 [l, r] 的第 k 小 = 用版本 r 减去版本 l-1,得到区间 [l, r] 的频率分布,然后二分即可

空间复杂度:初始版本 O(n) 节点,每次修改新增 O(log n) 节点。n 次修改后总空间 O(n log n)。

14.18 线段树在竞赛中的地位

线段树是算法竞赛(ACM-ICPC, Codeforces, AtCoder, NOI/IOI)中最核心的数据结构之一。根据 Codeforces 的题目标签统计,带有"segment tree"或"data structures"标签的题目占比超过 20%,尤其在 Div. 1 级别(rating 1900+)的题目中更为集中。

为什么线段树如此重要? 因为它具有极强的通用性可扩展性

  1. 通用性:只要满足"区间可合并"条件的信息,都可以用线段树维护。区间和、区间最值、区间 GCD、区间矩阵乘积、区间哈希值等等。

  2. 可扩展性

    • 加上懒传播 → 支持区间修改
    • 加上持久化 → 支持历史版本查询
    • 加上动态开点 → 支持大值域
    • 线段树合并 → 解决树上问题
    • 李超线段树 → 维护凸包/直线集合
    • 线段树二分 → 不需要额外的二分搜索

Codeforces 评级与线段树的关系

评级区间 线段树出现频率 典型难度
800-1200 极少 前缀和就够
1200-1600 偶尔 基础线段树/BIT
1600-2000 经常 懒传播/BIT 变体
2000-2400 非常频繁 持久化/线段树合并
2400+ 几乎必考 各种线段树变体组合

Level 4 · 边界与陷阱

14.19 面试题:区域和检索 — 数组可修改(LeetCode #307)

题目:实现 NumArray 类:

分析:这是线段树/树状数组最基础的应用场景。两种都可以,但树状数组代码更短。

class NumArray:
    """解法 1:树状数组"""
    
    def __init__(self, nums):
        self.n = len(nums)
        self.nums = nums[:]
        self.bit = [0] * (self.n + 1)
        for i in range(self.n):
            self._add(i + 1, nums[i])
    
    def _add(self, i, delta):
        while i <= self.n:
            self.bit[i] += delta
            i += i & (-i)
    
    def _prefix(self, i):
        s = 0
        while i > 0:
            s += self.bit[i]
            i -= i & (-i)
        return s
    
    def update(self, index, val):
        delta = val - self.nums[index]
        self.nums[index] = val
        self._add(index + 1, delta)
    
    def sumRange(self, left, right):
        return self._prefix(right + 1) - self._prefix(left)
class NumArray2:
    """解法 2:线段树"""
    
    def __init__(self, nums):
        self.n = len(nums)
        self.tree = [0] * (4 * self.n)
        if self.n > 0:
            self._build(nums, 1, 0, self.n - 1)
    
    def _build(self, nums, node, s, e):
        if s == e:
            self.tree[node] = nums[s]
            return
        mid = (s + e) // 2
        self._build(nums, 2*node, s, mid)
        self._build(nums, 2*node+1, mid+1, e)
        self.tree[node] = self.tree[2*node] + self.tree[2*node+1]
    
    def update(self, index, val):
        self._update(1, 0, self.n-1, index, val)
    
    def _update(self, node, s, e, idx, val):
        if s == e:
            self.tree[node] = val
            return
        mid = (s + e) // 2
        if idx <= mid:
            self._update(2*node, s, mid, idx, val)
        else:
            self._update(2*node+1, mid+1, e, idx, val)
        self.tree[node] = self.tree[2*node] + self.tree[2*node+1]
    
    def sumRange(self, left, right):
        return self._query(1, 0, self.n-1, left, right)
    
    def _query(self, node, s, e, l, r):
        if l <= s and e <= r:
            return self.tree[node]
        if e < l or s > r:
            return 0
        mid = (s + e) // 2
        return self._query(2*node, s, mid, l, r) + \
               self._query(2*node+1, mid+1, e, l, r)

面试要点

14.20 面试题:计算右侧小于当前元素的个数(LeetCode #315)

题目:给定数组 nums,返回一个数组 counts,其中 counts[i]nums[i] 右侧严格小于 nums[i] 的元素个数。

分析:这本质上是"逆序对"的变体——对于每个位置 i,统计它与右侧形成的逆序对数。

def countSmaller(nums):
    """从右向左遍历,用 BIT 统计"""
    if not nums:
        return []
    
    # 离散化
    sorted_unique = sorted(set(nums))
    rank = {v: i + 1 for i, v in enumerate(sorted_unique)}
    
    n = len(sorted_unique)
    bit = [0] * (n + 1)
    
    def update(i):
        while i <= n:
            bit[i] += 1
            i += i & (-i)
    
    def query(i):
        s = 0
        while i > 0:
            s += bit[i]
            i -= i & (-i)
        return s
    
    result = []
    for i in range(len(nums) - 1, -1, -1):
        r = rank[nums[i]]
        # 查询已经在 BIT 中的、比 nums[i] 小的元素个数
        result.append(query(r - 1))
        update(r)
    
    return result[::-1]

# 示例
print(countSmaller([5, 2, 6, 1]))  # [2, 1, 1, 0]

另一种解法:用线段树

def countSmaller_segtree(nums):
    """用线段树统计每个元素右侧比它小的元素数"""
    if not nums:
        return []
    
    # 离散化
    sorted_unique = sorted(set(nums))
    rank = {v: i for i, v in enumerate(sorted_unique)}
    m = len(sorted_unique)
    
    # 线段树维护每个值出现的次数
    tree = [0] * (4 * m)
    
    def update(node, s, e, idx):
        if s == e:
            tree[node] += 1
            return
        mid = (s + e) // 2
        if idx <= mid:
            update(2*node, s, mid, idx)
        else:
            update(2*node+1, mid+1, e, idx)
        tree[node] = tree[2*node] + tree[2*node+1]
    
    def query(node, s, e, l, r):
        if l > r:
            return 0
        if l <= s and e <= r:
            return tree[node]
        if e < l or s > r:
            return 0
        mid = (s + e) // 2
        return query(2*node, s, mid, l, r) + query(2*node+1, mid+1, e, l, r)
    
    result = []
    for i in range(len(nums) - 1, -1, -1):
        r = rank[nums[i]]
        # 查询 [0, r-1] 范围内的元素个数
        result.append(query(1, 0, m-1, 0, r-1))
        update(1, 0, m-1, r)
    
    return result[::-1]

面试拓展

14.21 线段树的常见实现错误总结

在竞赛和面试中,线段树的实现有很多微妙的错误来源。以下是按严重程度排序的 Top 10 错误:

1. 空间分配不足

# 错误:n=5 时, 完全二叉树需要的空间超过 2*5
tree = [0] * (2 * n)  # 不够!

# 正确
tree = [0] * (4 * n)

为什么?当 n=5 时,高度 h = ⌈log₂5⌉ = 3,满二叉树节点数 = 2^4 - 1 = 15,所以需要 16(1-indexed)。2n = 10,明显不够。

2. 懒传播中 pushdown 遗漏

# 错误:在 update 中没有先 pushdown
def _range_update(self, node, start, end, l, r, val):
    if l <= start and end <= r:
        self.tree[node] += val * (end - start + 1)
        self.lazy[node] += val
        return
    # 忘记 pushdown!子节点可能还有旧的 lazy 没处理
    mid = (start + end) // 2
    self._range_update(2*node, start, mid, l, r, val)
    ...

3. pushdown 中区间长度计算错误

# 错误
left_len = mid - start  # 少了 1!
right_len = end - mid - 1  # 少了 1!

# 正确
left_len = mid - start + 1
right_len = end - mid

4. 叶子节点判断条件错误

# 错误:用 node >= n 作为叶子判断
if node >= self.n:  # 这对于数组实现的线段树是错的

# 正确:用 start == end 判断
if start == end:
    ...

5. 区间查询的边界返回值错误

# 求区间最小值时
if end < l or start > r:
    return 0  # 错误!应该返回无穷大

# 正确
if end < l or start > r:
    return float('inf')  # 最小值查询的单位元

不同操作的单位元:求和 → 0,最小值 → +∞,最大值 → -∞,GCD → 0,乘积 → 1。

6. update 时"修改为"和"增加"混淆

# "修改为 val" 的版本
def _update(self, node, start, end, idx, val):
    if start == end:
        self.tree[node] = val  # 直接赋值
        return
    ...

# "增加 delta" 的版本
def _update(self, node, start, end, idx, delta):
    if start == end:
        self.tree[node] += delta  # 累加
        return
    ...

面试中一定要先明确题目要求的是"改为"还是"加上"。

7. 树状数组的 0-indexing 死循环

# 如果传入 i=0,lowbit(0)=0,update 和 query 都会死循环
# 树状数组必须从 1 开始!

8. 离散化后忘记用 rank 替代原值

# 错误
bit.update(nums[i], 1)  # nums[i] 可能是 10^9,数组开不了这么大!

# 正确
bit.update(rank[nums[i]], 1)  # rank 在 [1, n] 范围内

9. 线段树合并时忘记处理空节点

# 错误
def merge(a, b, s, e):
    new_node = ...
    new_node.left = merge(a.left, b.left, ...)
    new_node.right = merge(a.right, b.right, ...)
    # 如果 a 或 b 是 None 会出错

# 正确
def merge(a, b, s, e):
    if a is None: return b
    if b is None: return a
    ...

10. 多种懒传播操作的优先级问题

当同时存在"区间赋值"和"区间加"两种操作时,pushdown 的顺序至关重要:

# 如果先赋值后加:子节点 = 赋值的值 + 后来加的值
# 如果先加后赋值:子节点 = 赋值的值(加的被覆盖了)
# 必须记录操作的时间顺序,或者将操作统一为 (乘, 加) 对

14.22 什么时候用前缀和,什么时候用线段树

这是面试中最常被问到的"选择"类问题。

场景 最佳选择 原因
数组不变,多次区间求和 前缀和 O(1) 查询,O(n) 预处理
数组不变,多次区间最值 Sparse Table O(1) 查询,O(n log n) 预处理
单点修改 + 区间求和 树状数组 O(log n) 两端,代码短
单点修改 + 区间最值 线段树 树状数组不直接支持
区间修改 + 区间求和 线段树(懒传播) 树状数组也行但更复杂
区间修改 + 区间最值 线段树(懒传播) 唯一方案
大值域 + 稀疏操作 动态开点线段树 不需预分配全部空间
需要历史版本 持久化线段树 树状数组无法持久化
二维区域修改+查询 二维线段树/BIT 按需选择

面试回答模板

"首先我会分析操作类型:如果数组是静态的(不修改),前缀和就够了,O(1) 查询;如果需要修改,就看修改是单点还是区间。单点修改 + 区间和用树状数组最简洁;区间修改或需要维护最值就用线段树加懒传播。具体到这道题..."

14.23 高级技巧:线段树上二分

线段树的一个强大特性是可以在树上直接进行二分搜索,避免额外的 O(log n) 开销。

例如:"找到最左边的位置 i,使得 prefix_sum(0, i) >= target"。

朴素方法:在外部二分 + 线段树查询 → O(log² n)

线段树上二分:O(log n)

def find_first(self, target):
    """找到最小的 i 使得 prefix[0..i] >= target"""
    return self._find_first(1, 0, self.n - 1, target)

def _find_first(self, node, start, end, target):
    if start == end:
        return start if self.tree[node] >= target else -1
    mid = (start + end) // 2
    # 如果左子树的和 >= target,答案在左子树
    if self.tree[2 * node] >= target:
        return self._find_first(2 * node, start, mid, target)
    else:
        # 否则答案在右子树,但目标要减去左子树的贡献
        return self._find_first(2 * node + 1, mid + 1, end,
                               target - self.tree[2 * node])

为什么这比 "外部二分 + 查询" 快? 外部二分需要 O(log n) 次查询,每次查询 O(log n),总共 O(log² n)。而线段树上二分从根到叶子只走一条路径,O(log n)。这在时间紧张的竞赛题中可能是 AC 和 TLE 的区别。

14.24 实战经验总结

竞赛中的线段树技巧清单

  1. 确定单位元:开始写之前先确定你的 query 在无交集时返回什么值(和→0, 最小值→INF, 最大值→-INF, GCD→0, 异或→0)

  2. 确定合并方式:内部节点的值如何从子节点计算。对于复杂信息(如最大子段和),合并函数可能需要维护多个字段(前缀最大和、后缀最大和、区间和、最大子段和)

  3. 懒标记的设计

    • 标记的初始值("无标记"状态)是什么
    • 两个标记如何合并
    • 标记如何应用到节点值
    • 这三个问题想清楚,代码就不会出错
  4. 测试建议

    • 用暴力 O(n) 解法对拍
    • 特别测试 n=1, n=2 的边界
    • 测试操作区间 = 全区间的情况
    • 测试操作区间 = 单点的情况
# 对拍模板
import random

def stress_test(n=100, q=1000, max_val=100):
    """暴力对拍验证线段树正确性"""
    nums = [random.randint(0, max_val) for _ in range(n)]
    st = SegmentTree(nums[:])
    
    for _ in range(q):
        op = random.randint(0, 1)
        if op == 0:  # update
            idx = random.randint(0, n-1)
            val = random.randint(0, max_val)
            nums[idx] = val
            st.update(idx, val)
        else:  # query
            l = random.randint(0, n-1)
            r = random.randint(l, n-1)
            expected = sum(nums[l:r+1])
            got = st.query(l, r)
            assert expected == got, f"Mismatch at [{l},{r}]: expected {expected}, got {got}"
    
    print("All tests passed!")

stress_test()

面试中的线段树思维

面试官问线段树相关问题时,通常不期望你写出完美的线段树代码(太长了),而是考察你的思维方式

  1. 能否识别出这是一个"区间查询 + 修改"的问题
  2. 能否选择合适的数据结构(前缀和 / BIT / 线段树)
  3. 能否分析时间空间复杂度
  4. 对于 BIT 能否写出完整代码
  5. 对于线段树能否描述清楚建树、查询、修改的过程

不需要在白板上写出完整的带懒传播的线段树——那更适合竞赛选手在 IDE 中完成。但基础线段树和树状数组的代码你应该能在 10 分钟内写出来。

14.25 延伸阅读与练习

必练题目

题目 难度 考点
LeetCode #303 区域和检索(不可变) Easy 前缀和(对比基线)
LeetCode #307 区域和检索(可修改) Medium BIT / 线段树基础
LeetCode #315 计算右侧小于当前元素的个数 Hard BIT + 离散化
LeetCode #493 翻转对 Hard BIT / 归并排序
LeetCode #327 区间和的个数 Hard BIT + 离散化
LeetCode #218 天际线问题 Hard 线段树 / 扫描线
LeetCode #699 掉落的方块 Hard 线段树 / 坐标压缩

进阶题目(Codeforces):

题目 考点
CF #558E 线段树 + 计数排序
CF #914D 线段树维护 GCD
CF #877E 线段树 + 异或操作
CF #242E 线段树 + 位运算

推荐学习资源

  1. Peter Fenwick, "A New Data Structure for Cumulative Frequency Tables", Software: Practice and Experience, 1994
  2. Thomas H. Cormen et al., "Introduction to Algorithms" (CLRS), Chapter 14 (Augmenting Data Structures)
  3. Competitive Programmer's Handbook (Antti Laaksonen), Chapter 9: Range Queries
  4. CP-Algorithms (cp-algorithms.com) — Segment Tree 专题(有大量变体和代码模板)
本章评分
4.6  / 5  (23 评分)

💬 留言讨论