第 34 章

位运算:接近硬件的思考

第三十四章:位运算 — 接近硬件的思考

在你写 if x % 2 == 0 判断奇偶时,CPU 实际执行的是检查 x 的最低位是否为 0。在你写 x * 2 时,CPU 实际执行的是把 x 的所有位左移一位。编程语言给了你"数字"的抽象,但在硬件层面,一切都是比特——0 和 1 的序列。

位运算是直接操作比特的运算。它们比算术运算更快(通常只需 1 个 CPU 周期),但对人类来说不直观。掌握位运算有三个层次的价值:

  1. 写出更高效的代码:用位运算替代除法、取模等昂贵操作
  2. 理解底层系统:操作系统、网络协议、加密算法都大量使用位运算
  3. 解决特定算法问题:有些问题(如"只出现一次的数字")用位运算可以达到 O(1) 空间

Level 1 · 你需要知道的

1.1 六种基本位运算

运算 符号 含义 示例 (4位)
AND & 两位都为 1 才是 1 1100 & 1010 = 1000
OR | 任一位为 1 就是 1 1100 | 1010 = 1110
XOR ^ 两位不同才是 1 1100 ^ 1010 = 0110
NOT ~ 每位取反 ~1100 = 0011
左移 << 所有位左移,右边补 0 1010 << 1 = 10100
右移 >> 所有位右移 1010 >> 1 = 0101
# 基本位运算演示
a = 0b1100  # 12
b = 0b1010  # 10

print(f"a & b  = {bin(a & b)}")   # 0b1000 = 8
print(f"a | b  = {bin(a | b)}")   # 0b1110 = 14
print(f"a ^ b  = {bin(a ^ b)}")   # 0b0110 = 6
print(f"~a     = {bin(~a & 0xF)}")  # 0b0011 = 3 (4位)
print(f"a << 1 = {bin(a << 1)}")  # 0b11000 = 24
print(f"a >> 1 = {bin(a >> 1)}")  # 0b110 = 6

1.2 XOR 的特殊性质

XOR(异或)是位运算中最神奇的操作,因为它同时具有多个有用的性质:

# XOR 的代数性质
# 1. 自反性:a ^ a = 0(任何数异或自己等于 0)
# 2. 零元素:a ^ 0 = a(任何数异或 0 等于自己)
# 3. 交换律:a ^ b = b ^ a
# 4. 结合律:(a ^ b) ^ c = a ^ (b ^ c)

# 这些性质组合起来意味着:
# 一组数字异或,结果与顺序无关,成对的会消除

# 应用 1:不用临时变量交换两个数
def swap_xor(a: int, b: int) -> tuple[int, int]:
    """用 XOR 交换两个数(无需临时变量)"""
    a = a ^ b
    b = a ^ b  # b = (a^b)^b = a
    a = a ^ b  # a = (a^b)^a = b(此时 b 已经是原来的 a)
    return a, b

x, y = 5, 3
x, y = swap_xor(x, y)
print(f"交换后: x={x}, y={y}")  # x=3, y=5

# 应用 2:找出只出现一次的数字
def single_number(nums: list[int]) -> int:
    """
    数组中所有数字都出现两次,只有一个出现一次,找出它
    LeetCode #136
    
    原理:a ^ a = 0, 0 ^ b = b
    所有成对的数字异或后消除,剩下的就是单独的那个
    """
    result = 0
    for num in nums:
        result ^= num
    return result

print(single_number([2, 1, 4, 1, 2]))  # 4

1.3 常用位运算技巧

# 技巧 1:判断奇偶
def is_even(n: int) -> bool:
    """n & 1 == 0 则为偶数"""
    return (n & 1) == 0

# 技巧 2:乘以/除以 2 的幂
def multiply_power_of_2(n: int, k: int) -> int:
    """n * 2^k = n << k"""
    return n << k

def divide_power_of_2(n: int, k: int) -> int:
    """n // 2^k = n >> k(仅对正整数)"""
    return n >> k

# 技巧 3:判断是否是 2 的幂
def is_power_of_two(n: int) -> bool:
    """
    LeetCode #231
    2 的幂的二进制只有一个 1:1, 10, 100, 1000, ...
    n & (n-1) 会清除最低位的 1
    如果清除后变成 0,说明只有一个 1
    """
    return n > 0 and (n & (n - 1)) == 0

# 技巧 4:获取最低位的 1(lowbit)
def lowest_bit(n: int) -> int:
    """
    n & (-n) 或 n & (~n + 1)
    结果是只包含 n 的最低位 1 的数
    例:n = 12 (1100),lowest_bit = 4 (0100)
    
    原理:-n 的补码表示是 ~n + 1
    ~1100 + 1 = 0011 + 1 = 0100
    1100 & 0100 = 0100
    """
    return n & (-n)

# 技巧 5:清除最低位的 1
def clear_lowest_bit(n: int) -> int:
    """
    n & (n-1) 清除 n 的最低位 1
    例:n = 12 (1100),n-1 = 11 (1011)
    1100 & 1011 = 1000
    """
    return n & (n - 1)

# 技巧 6:获取第 k 位
def get_bit(n: int, k: int) -> int:
    """获取 n 的第 k 位(从 0 开始)"""
    return (n >> k) & 1

# 技巧 7:设置第 k 位为 1
def set_bit(n: int, k: int) -> int:
    """将 n 的第 k 位设为 1"""
    return n | (1 << k)

# 技巧 8:清除第 k 位
def clear_bit(n: int, k: int) -> int:
    """将 n 的第 k 位设为 0"""
    return n & ~(1 << k)

# 技巧 9:翻转第 k 位
def toggle_bit(n: int, k: int) -> int:
    """翻转 n 的第 k 位"""
    return n ^ (1 << k)


# 演示
print(f"12 是偶数: {is_even(12)}")          # True
print(f"5 * 8 = {multiply_power_of_2(5, 3)}")  # 40
print(f"16 是 2 的幂: {is_power_of_two(16)}")   # True
print(f"lowest_bit(12) = {lowest_bit(12)}")     # 4
print(f"clear_lowest_bit(12) = {clear_lowest_bit(12)}")  # 8

1.4 位运算操作速查表

操作 代码 说明
判断奇偶 n & 1 0=偶, 1=奇
乘以 2^k n << k
除以 2^k n >> k 仅正整数
是否 2 的幂 n & (n-1) == 0 n>0 前提
最低位的 1 n & (-n) lowbit
清除最低 1 n & (n-1)
取第 k 位 (n >> k) & 1
设第 k 位为 1 n | (1 << k)
清第 k 位 n & ~(1 << k)
翻转第 k 位 n ^ (1 << k)
取低 k 位 n & ((1 << k) - 1) 掩码
所有位取反 ~n 注意符号

1.5 常见错误

错误 1:Python 整数无限精度的陷阱

# Python 的整数是任意精度的,~ 运算不会像 C 那样产生固定位数结果
n = 5  # 101
print(~n)  # -6,不是你期望的"翻转为 010"

# 如果需要 32 位翻转:
def flip_32bit(n: int) -> int:
    return n ^ 0xFFFFFFFF

print(flip_32bit(5))  # 4294967290 (无符号) 或处理为有符号

错误 2:移位运算的溢出

# 在 C/Java 中,左移可能溢出
# Python 不会溢出,但在竞赛中提交 C++ 代码时要注意
# int32: 1 << 31 是负数!应该用 1LL << 31 或 unsigned

# Python 中安全的做法
def safe_left_shift(n: int, k: int, bits: int = 32) -> int:
    """模拟固定位数的左移"""
    mask = (1 << bits) - 1
    return (n << k) & mask

错误 3:运算符优先级

# 位运算优先级低于比较运算!这是常见 bug 来源
n = 5
# 错误:if n & 1 == 0  实际上是 if n & (1 == 0) 即 if n & 0
# 正确:if (n & 1) == 0
print(n & 1 == 0)    # False(因为 1==0 是 False=0,5&0=0,但Python特殊处理)
print((n & 1) == 0)  # False(这才是"5是否为偶数"的正确判断)

Level 2 · 它是怎么运行的

2.1 Brian Kernighan 算法 — 统计 1 的个数

问题:给定一个整数 n,统计它的二进制表示中有多少个 1(这个数叫做 popcount汉明权重)。

朴素方法:逐位检查,O(位数) = O(32) 或 O(64)。

Brian Kernighan 方法:每次清除最低位的 1,直到变为 0。循环次数 = 1 的个数。

def count_bits_kernighan(n: int) -> int:
    """
    Brian Kernighan 算法:统计二进制中 1 的个数
    LeetCode #191: Number of 1 Bits
    
    核心:n & (n-1) 清除最低位的 1
    循环次数 = 1 的个数
    
    时间:O(k),k = 1 的个数(最坏 O(log n) = O(位数))
    
    来源:Brian Kernighan 在 "The C Programming Language" (K&R, 1978) 中介绍
    实际上这个技巧更早由 Peter Wegner 在 1960 年发表
    """
    count = 0
    while n:
        n &= (n - 1)  # 清除最低位的 1
        count += 1
    return count


# 为什么 n & (n-1) 能清除最低位的 1?
# 假设 n = ...1000(最后的 1 后面都是 0)
# 则 n-1 = ...0111(最后的 1 变 0,后面的 0 变 1)
# n & (n-1) = ...0000(最后的 1 和它右边的位都被清零)

# 演示
for n in [0, 1, 7, 12, 128, 255]:
    print(f"n={n:3d} ({bin(n):>10s}): {count_bits_kernighan(n)} 个 1")

查表法(更快的 O(1) 实现):

# 预计算每个字节(0-255)中 1 的个数
POPCOUNT_TABLE = [0] * 256
for i in range(1, 256):
    POPCOUNT_TABLE[i] = POPCOUNT_TABLE[i >> 1] + (i & 1)

def count_bits_table(n: int) -> int:
    """
    查表法:O(1) 时间,但需要 256 字节的表
    把 32 位整数拆成 4 个字节,分别查表后相加
    """
    count = 0
    while n:
        count += POPCOUNT_TABLE[n & 0xFF]
        n >>= 8
    return count


# 位并行法(CPU 实际使用的方法)
def count_bits_parallel(n: int) -> int:
    """
    位并行法:在 32 位整数上 O(1) 操作
    
    思路:分治
    1. 每 2 位一组,统计每组中 1 的个数
    2. 每 4 位一组,合并相邻的 2 位组
    3. 每 8 位一组,合并相邻的 4 位组
    4. ...直到合并为一个数
    
    这就是 CPU 的 POPCNT 指令的实现原理
    """
    # 确保在 32 位范围内
    n = n & 0xFFFFFFFF
    
    # Step 1: 每 2 位一组统计
    n = (n & 0x55555555) + ((n >> 1) & 0x55555555)
    # Step 2: 每 4 位一组统计
    n = (n & 0x33333333) + ((n >> 2) & 0x33333333)
    # Step 3: 每 8 位一组统计
    n = (n & 0x0F0F0F0F) + ((n >> 4) & 0x0F0F0F0F)
    # Step 4: 每 16 位一组统计
    n = (n & 0x00FF00FF) + ((n >> 8) & 0x00FF00FF)
    # Step 5: 最终合并
    n = (n & 0x0000FFFF) + ((n >> 16) & 0x0000FFFF)
    
    return n


# 验证一致性
for n in [0, 1, 7, 12, 128, 255, 0xDEADBEEF]:
    k = count_bits_kernighan(n)
    t = count_bits_table(n)
    p = count_bits_parallel(n)
    assert k == t == p, f"不一致: n={n}"
    print(f"n={n:#010x}: popcount={k}")

2.2 位图(Bitmap)

位图用一个比特表示一个布尔值,比 Python 的 list[bool](每个元素 28 字节)节省 200 多倍空间。

class Bitmap:
    """
    位图:用 1 bit 存储一个布尔值
    
    应用场景:
    - 布隆过滤器(Bloom Filter)的底层
    - 操作系统页面管理(空闲/已用)
    - Redis 的 BITFIELD 命令
    - 大规模去重(40亿个整数去重只需 512MB)
    """
    
    def __init__(self, size: int):
        """创建能存储 size 个位的位图"""
        self.size = size
        # 用字节数组存储,每字节 8 位
        self.data = bytearray((size + 7) // 8)
    
    def set(self, pos: int):
        """将第 pos 位设为 1"""
        if 0 <= pos < self.size:
            self.data[pos >> 3] |= (1 << (pos & 7))
    
    def clear(self, pos: int):
        """将第 pos 位设为 0"""
        if 0 <= pos < self.size:
            self.data[pos >> 3] &= ~(1 << (pos & 7))
    
    def get(self, pos: int) -> bool:
        """获取第 pos 位的值"""
        if 0 <= pos < self.size:
            return bool(self.data[pos >> 3] & (1 << (pos & 7)))
        return False
    
    def count_ones(self) -> int:
        """统计所有 1 的个数"""
        count = 0
        for byte in self.data:
            count += POPCOUNT_TABLE[byte]
        return count
    
    def memory_bytes(self) -> int:
        """实际内存使用"""
        return len(self.data)


# 演示:40亿个整数去重
# 如果用 set 存储:每个整数 28 字节,40亿 * 28 = 112 GB
# 如果用 Bitmap:40亿 / 8 = 500 MB

bm = Bitmap(100)
bm.set(0)
bm.set(42)
bm.set(99)
print(f"位图大小: {bm.memory_bytes()} 字节, 存储 {bm.size} 个位")
print(f"位 0: {bm.get(0)}, 位 42: {bm.get(42)}, 位 50: {bm.get(50)}")
print(f"总共 {bm.count_ones()} 个 1")

2.3 位运算枚举子集

问题:给定集合 {0, 1, ..., n-1},枚举它的所有子集。

每个子集可以用一个 n 位整数表示:第 i 位为 1 表示元素 i 在子集中。

def enumerate_all_subsets(n: int):
    """
    枚举 n 个元素的所有子集
    一个 n 位整数的每一位代表一个元素是否存在
    总共 2^n 个子集
    """
    subsets = []
    for mask in range(1 << n):
        subset = []
        for i in range(n):
            if mask & (1 << i):
                subset.append(i)
        subsets.append(subset)
    return subsets


def enumerate_submasks(mask: int):
    """
    枚举一个掩码的所有子掩码(子集的子集)
    
    例如 mask = 0b1101,它的子掩码有:
    1101, 1100, 1001, 1000, 0101, 0100, 0001, 0000
    
    技巧:sub = (sub - 1) & mask 可以从 mask 开始遍历所有子掩码
    
    时间复杂度:O(2^popcount(mask))
    """
    submasks = []
    sub = mask
    while sub > 0:
        submasks.append(sub)
        sub = (sub - 1) & mask
    submasks.append(0)  # 空集
    return submasks


# 演示
print("集合 {0,1,2} 的所有子集:")
for subset in enumerate_all_subsets(3):
    print(f"  {subset}")

print("\n掩码 0b1101 的所有子掩码:")
for sub in enumerate_submasks(0b1101):
    print(f"  {bin(sub)}")

2.4 状态压缩动态规划

状态压缩(bitmask DP)是用一个整数的各位来表示一组布尔状态,从而把"集合"作为 DP 的状态维度。

经典问题:旅行商问题(TSP)

def tsp_bitmask(dist: list[list[int]]) -> int:
    """
    旅行商问题:访问所有城市恰好一次的最短路径
    
    状态:dp[mask][i] = 已访问集合为 mask,当前在城市 i 的最短距离
    转移:dp[mask | (1<<j)][j] = min(dp[mask][i] + dist[i][j])
           其中 j 不在 mask 中
    
    时间复杂度:O(2^n * n^2)
    空间复杂度:O(2^n * n)
    
    适用范围:n <= 20(2^20 = 100万,再大内存就不够了)
    """
    n = len(dist)
    INF = float('inf')
    
    # dp[mask][i] = 已访问 mask 中的城市,最后在 i 的最短距离
    dp = [[INF] * n for _ in range(1 << n)]
    dp[1][0] = 0  # 从城市 0 出发
    
    for mask in range(1 << n):
        for i in range(n):
            if dp[mask][i] == INF:
                continue
            if not (mask & (1 << i)):
                continue  # 当前位置 i 必须在已访问集合中
            
            # 尝试访问下一个未访问的城市
            for j in range(n):
                if mask & (1 << j):
                    continue  # j 已访问
                new_mask = mask | (1 << j)
                new_dist = dp[mask][i] + dist[i][j]
                if new_dist < dp[new_mask][j]:
                    dp[new_mask][j] = new_dist
    
    # 从所有城市都访问过的状态,回到起点
    full_mask = (1 << n) - 1
    result = INF
    for i in range(n):
        if dp[full_mask][i] + dist[i][0] < result:
            result = dp[full_mask][i] + dist[i][0]
    
    return result


# 演示
dist = [
    [0, 10, 15, 20],
    [10, 0, 35, 25],
    [15, 35, 0, 30],
    [20, 25, 30, 0]
]
print(f"TSP 最短距离: {tsp_bitmask(dist)}")  # 80

经典问题:最大独立集 / 最大团

def max_independent_set(adj: list[list[int]]) -> int:
    """
    最大独立集:找最多的顶点使得任意两个都不相邻
    
    adj[i] = 与顶点 i 相邻的顶点列表
    用位掩码表示"哪些顶点被选中"
    
    O(2^n * n)
    """
    n = len(adj)
    
    # 预处理:adj_mask[i] = 与 i 相邻的所有顶点的掩码
    adj_mask = [0] * n
    for i in range(n):
        for j in adj[i]:
            adj_mask[i] |= (1 << j)
    
    max_size = 0
    
    for mask in range(1 << n):
        # 检查 mask 是否是独立集
        is_independent = True
        for i in range(n):
            if not (mask & (1 << i)):
                continue
            # 如果 i 被选中,检查它的邻居是否也被选中
            if mask & adj_mask[i]:
                is_independent = False
                break
        
        if is_independent:
            size = bin(mask).count('1')
            max_size = max(max_size, size)
    
    return max_size

2.5 位运算在权限系统中的应用

Linux 文件权限、数据库权限、用户角色等都广泛使用位运算:

# Linux 文件权限模型
READ = 0b100    # 4
WRITE = 0b010   # 2
EXECUTE = 0b001 # 1

def check_permission(user_perm: int, required: int) -> bool:
    """检查用户是否有所需权限"""
    return (user_perm & required) == required

def grant_permission(current: int, new_perm: int) -> int:
    """授予新权限"""
    return current | new_perm

def revoke_permission(current: int, perm: int) -> int:
    """撤销权限"""
    return current & ~perm

# 演示
user = READ | WRITE  # rwx = 110 = 读写权限
print(f"有读权限: {check_permission(user, READ)}")       # True
print(f"有执行权限: {check_permission(user, EXECUTE)}")  # False

user = grant_permission(user, EXECUTE)  # 加上执行权限
print(f"授权后: {oct(user)}")  # 0o7 (rwx)

user = revoke_permission(user, WRITE)   # 撤销写权限
print(f"撤销后: {oct(user)}")  # 0o5 (r-x)

Level 3 · 规范怎么定义的

3.1 位运算在 CPU 指令级的实现

在现代 CPU 中,位运算是最基本的操作,直接由硬件电路实现:

AND 门:两个输入都为 1 时输出 1(用两个晶体管串联实现) OR 门:任一输入为 1 时输出 1(用两个晶体管并联实现) XOR 门:两输入不同时输出 1(需要 4 个 NAND 门,约 8 个晶体管) NOT 门:取反(1 个晶体管)

一个 64 位 AND 操作 = 64 个 AND 门并行工作,1 个时钟周期完成。

整数加法器的实现也基于位运算:

def add_without_plus(a: int, b: int) -> int:
    """
    不用 + 号实现加法(面试常考)
    
    原理:
    - a ^ b = 不考虑进位的加法结果
    - (a & b) << 1 = 进位
    - 重复直到进位为 0
    
    这就是硬件加法器的工作原理!
    """
    # Python 整数无限精度,需要模拟 32 位
    MASK = 0xFFFFFFFF
    MAX_INT = 0x7FFFFFFF
    
    a, b = a & MASK, b & MASK
    
    while b:
        carry = ((a & b) << 1) & MASK
        a = (a ^ b) & MASK
        b = carry
    
    # 处理负数(补码)
    return a if a <= MAX_INT else ~(a ^ MASK)


# 演示
print(add_without_plus(5, 3))    # 8
print(add_without_plus(-1, 1))   # 0
print(add_without_plus(-5, -3))  # -8

3.2 补码表示(Two's Complement)

现代计算机几乎都使用补码表示负数。理解补码是理解位运算的基础。

补码的定义(对于 n 位整数):

补码的性质:

def to_twos_complement(n: int, bits: int = 32) -> str:
    """将整数转为 n 位补码的二进制表示"""
    if n >= 0:
        return format(n, f'0{bits}b')
    else:
        # 负数的补码 = 2^bits + n
        return format((1 << bits) + n, f'0{bits}b')

def from_twos_complement(binary: str) -> int:
    """将补码二进制转为整数"""
    bits = len(binary)
    n = int(binary, 2)
    if binary[0] == '1':  # 负数
        n -= (1 << bits)
    return n

# 演示
for n in [5, -5, 0, -1, 127, -128]:
    tc = to_twos_complement(n, 8)
    back = from_twos_complement(tc)
    print(f"{n:4d} -> {tc} -> {back}")

为什么 n & (-n) 能得到最低位的 1?

n    = ...XY10...0  (最低位的 1 后面都是 0)
~n   = ...X'Y'01...1
~n+1 = ...X'Y'10...0  (即 -n 的补码)
n & (-n) = 000010...0  (只有最低位的 1 被保留)

因为 n 和 -n 在最低位的 1 处相同,在更高位处互补。

3.3 SIMD — 单指令多数据

SIMD(Single Instruction, Multiple Data) 是现代 CPU 提供的向量指令,可以对多个数据同时执行相同的位运算。

# SIMD 的概念演示(Python 无法直接使用 SIMD,这里展示原理)

# 没有 SIMD:逐元素处理
def add_arrays_scalar(a: list[int], b: list[int]) -> list[int]:
    """标量加法:每次处理 1 个元素"""
    return [x + y for x, y in zip(a, b)]

# 有 SIMD:4 个元素同时处理(SSE 128 位寄存器,4 个 32 位整数)
# 在 C/Rust 中可以这样写:
# __m128i va = _mm_load_si128(a);
# __m128i vb = _mm_load_si128(b);
# __m128i vc = _mm_add_epi32(va, vb);  // 一条指令加 4 个整数

# 在 Python 中可以通过 NumPy 间接使用 SIMD
import numpy as np

def add_arrays_simd(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    """NumPy 底层使用 SIMD 加速"""
    return a + b

SIMD 的实际架构演进:

指令集 位宽 可并行32位操作数 年代
MMX 64 bit 2 1997
SSE 128 bit 4 1999
AVX 256 bit 8 2011
AVX-512 512 bit 16 2017
ARM NEON 128 bit 4 2004

SIMD 在字符串处理中的应用:

现代的 strlen(), memcmp(), memchr() 实现都使用 SIMD:

def simd_style_memchr(data: bytes, target: int) -> int:
    """
    模拟 SIMD 风格的字节搜索
    实际的 libc memchr 使用 SSE/AVX 一次比较 16-64 字节
    """
    # 模拟 16 字节宽度的 SIMD 操作
    LANE_WIDTH = 16
    n = len(data)
    
    # 按 16 字节块处理
    i = 0
    while i + LANE_WIDTH <= n:
        # "SIMD compare":一次比较 16 字节
        # 实际硬件中这是一条指令
        block = data[i:i+LANE_WIDTH]
        for j in range(LANE_WIDTH):
            if block[j] == target:
                return i + j
        i += LANE_WIDTH
    
    # 处理尾部
    while i < n:
        if data[i] == target:
            return i
        i += 1
    
    return -1

3.4 位运算的数学理论

位运算构成了一个布尔代数(Boolean Algebra),它的公理化定义由 George Boole 在 1854 年提出,后来由 Claude Shannon 在 1937 年的硕士论文中证明可以用电路实现。

布尔代数的基本定律:

定律 AND 形式 OR 形式
交换律 a & b = b & a a | b = b | a
结合律 (a & b) & c = a & (b & c) (a | b) | c = a | (b | c)
分配律 a & (b | c) = (a&b) | (a&c) a | (b & c) = (a|b) & (a|c)
吸收律 a & (a | b) = a a | (a & b) = a
德摩根律 ~(a & b) = ~a | ~b ~(a | b) = ~a & ~b
互补律 a & ~a = 0 a | ~a = 全1

XOR 的额外性质:

这些数学性质解释了为什么 XOR 在密码学中如此重要:它是唯一一个信息论上"完美"的按位运算——知道输出和一个输入,可以唯一确定另一个输入。

3.5 汉明码(Hamming Code)— 位运算的经典应用

Richard Hamming 在 1950 年发明的纠错码,是位运算在通信理论中最优美的应用之一:

def hamming_encode(data: int, data_bits: int = 4) -> int:
    """
    Hamming(7,4) 编码:4 位数据 -> 7 位编码
    
    位置编号:1, 2, 3, 4, 5, 6, 7
    校验位:位置 1, 2, 4(2 的幂次位置)
    数据位:位置 3, 5, 6, 7
    
    校验规则:
    p1 覆盖位置 1,3,5,7(二进制位置的第 0 位为 1 的位置)
    p2 覆盖位置 2,3,6,7(二进制位置的第 1 位为 1 的位置)
    p4 覆盖位置 4,5,6,7(二进制位置的第 2 位为 1 的位置)
    
    Hamming, "Error Detecting and Error Correcting Codes", 
    Bell System Technical Journal, 1950
    """
    # 提取 4 个数据位
    d1 = (data >> 3) & 1  # 位置 3
    d2 = (data >> 2) & 1  # 位置 5
    d3 = (data >> 1) & 1  # 位置 6
    d4 = data & 1         # 位置 7
    
    # 计算校验位
    p1 = d1 ^ d2 ^ d4    # 覆盖位置 3,5,7
    p2 = d1 ^ d3 ^ d4    # 覆盖位置 3,6,7
    p4 = d2 ^ d3 ^ d4    # 覆盖位置 5,6,7
    
    # 组装 7 位编码:p1 p2 d1 p4 d2 d3 d4
    encoded = (p1 << 6) | (p2 << 5) | (d1 << 4) | (p4 << 3) | \
              (d2 << 2) | (d3 << 1) | d4
    
    return encoded


def hamming_decode(received: int) -> tuple[int, int]:
    """
    Hamming(7,4) 解码:检测并纠正 1 位错误
    
    返回 (数据, 错误位置),错误位置 0 表示无错误
    """
    # 计算校验子(syndrome)
    # 如果 syndrome = 0,无错误
    # 如果 syndrome != 0,它指出错误位置
    
    bits = [(received >> i) & 1 for i in range(6, -1, -1)]
    # bits[0] = 位置 1, bits[1] = 位置 2, ..., bits[6] = 位置 7
    
    # 校验子的每一位
    s1 = bits[0] ^ bits[2] ^ bits[4] ^ bits[6]  # 位置 1,3,5,7
    s2 = bits[1] ^ bits[2] ^ bits[5] ^ bits[6]  # 位置 2,3,6,7
    s4 = bits[3] ^ bits[4] ^ bits[5] ^ bits[6]  # 位置 4,5,6,7
    
    syndrome = (s4 << 2) | (s2 << 1) | s1
    
    # 纠正错误
    if syndrome != 0:
        error_pos = syndrome - 1  # 转为 0-indexed
        bits[error_pos] ^= 1  # 翻转错误位
    
    # 提取数据位(位置 3, 5, 6, 7 -> index 2, 4, 5, 6)
    data = (bits[2] << 3) | (bits[4] << 2) | (bits[5] << 1) | bits[6]
    
    return data, syndrome


# 演示
original_data = 0b1011  # 数据 = 11
encoded = hamming_encode(original_data)
print(f"原始数据: {bin(original_data)}")
print(f"编码后: {bin(encoded)}")

# 模拟 1 位错误
corrupted = encoded ^ (1 << 3)  # 翻转第 4 位(位置 4)
print(f"传输错误: {bin(corrupted)}")

decoded, error_pos = hamming_decode(corrupted)
print(f"解码结果: {bin(decoded)}, 错误位置: {error_pos}")
print(f"纠错成功: {decoded == original_data}")

Level 4 · 边界与陷阱

4.1 面试题:只出现一次的数字 I(LeetCode #136)

def single_number_136(nums: list[int]) -> int:
    """
    所有数字都出现两次,只有一个出现一次
    XOR 所有数字,成对的消除,剩下单独的
    
    时间 O(n), 空间 O(1)
    """
    result = 0
    for num in nums:
        result ^= num
    return result

print(single_number_136([4, 1, 2, 1, 2]))  # 4

4.2 面试题:只出现一次的数字 II(LeetCode #137)

def single_number_137(nums: list[int]) -> int:
    """
    所有数字都出现三次,只有一个出现一次
    
    思路:对每一位,统计所有数字在该位上 1 的个数
    如果能被 3 整除,说明单独的数字在该位是 0
    否则在该位是 1
    
    时间 O(32n) = O(n), 空间 O(1)
    """
    result = 0
    for i in range(32):
        bit_sum = 0
        for num in nums:
            # 处理 Python 负数
            if num < 0:
                num = num & 0xFFFFFFFF
            bit_sum += (num >> i) & 1
        
        if bit_sum % 3 != 0:
            result |= (1 << i)
    
    # 处理负数结果
    if result >= (1 << 31):
        result -= (1 << 32)
    
    return result


# 方法 2:有限状态机(更巧妙)
def single_number_137_v2(nums: list[int]) -> int:
    """
    用两个变量 ones, twos 模拟"模 3 计数器"
    
    ones: 记录出现 1 次的位
    twos: 记录出现 2 次的位
    出现 3 次时清零
    
    状态转移(对于每一位):
    count = 0: ones=0, twos=0
    count = 1: ones=1, twos=0
    count = 2: ones=0, twos=1
    count = 3: ones=0, twos=0 (归零)
    """
    ones, twos = 0, 0
    for num in nums:
        ones = (ones ^ num) & ~twos
        twos = (twos ^ num) & ~ones
    return ones


# 测试
print(single_number_137([2, 2, 3, 2]))    # 3
print(single_number_137_v2([0, 1, 0, 1, 0, 1, 99]))  # 99

4.3 面试题:只出现一次的数字 III(LeetCode #260)

def single_number_260(nums: list[int]) -> list[int]:
    """
    所有数字都出现两次,有两个出现一次(设为 a 和 b)
    
    思路:
    1. XOR 所有数字得到 a ^ b(成对的消除)
    2. a ^ b 中为 1 的位说明 a 和 b 在该位不同
    3. 用这个位把所有数字分成两组
    4. 每组分别 XOR 得到 a 和 b
    
    时间 O(n), 空间 O(1)
    """
    # Step 1: 得到 a ^ b
    xor_all = 0
    for num in nums:
        xor_all ^= num
    
    # Step 2: 找 a ^ b 中任意一个为 1 的位
    # 使用 lowbit 技巧
    diff_bit = xor_all & (-xor_all)
    
    # Step 3: 按该位分组并分别 XOR
    a, b = 0, 0
    for num in nums:
        if num & diff_bit:
            a ^= num
        else:
            b ^= num
    
    return [a, b]


# 测试
result = single_number_260([1, 2, 1, 3, 2, 5])
print(f"两个只出现一次的数: {sorted(result)}")  # [3, 5]

4.4 面试题:汉明距离(LeetCode #461)

def hamming_distance(x: int, y: int) -> int:
    """
    两个整数的汉明距离 = 对应位不同的位数
    
    方法:XOR 后统计 1 的个数
    x ^ y 中为 1 的位就是两者不同的位
    """
    xor = x ^ y
    count = 0
    while xor:
        xor &= (xor - 1)  # Brian Kernighan
        count += 1
    return count


def total_hamming_distance(nums: list[int]) -> int:
    """
    LeetCode #477: 所有整数对的汉明距离之和
    
    技巧:逐位统计
    对于每一位,统计该位为 1 的数字个数 c
    该位对总距离的贡献 = c * (n - c)
    (c 个 1 与 (n-c) 个 0 配对)
    
    时间 O(32n) = O(n)
    """
    n = len(nums)
    total = 0
    
    for i in range(32):
        ones = sum(1 for num in nums if (num >> i) & 1)
        total += ones * (n - ones)
    
    return total


# 测试
print(f"hamming(1, 4) = {hamming_distance(1, 4)}")  # 2 (001 vs 100)
print(f"total hamming [4,14,2] = {total_hamming_distance([4, 14, 2])}")  # 6

4.5 面试题:2 的幂(LeetCode #231)与位 1 的个数(LeetCode #191)

def is_power_of_two(n: int) -> bool:
    """
    LeetCode #231
    n 是 2 的幂当且仅当 n > 0 且二进制只有一个 1
    """
    return n > 0 and (n & (n - 1)) == 0


def number_of_1_bits(n: int) -> int:
    """
    LeetCode #191
    统计无符号整数的二进制中 1 的个数
    """
    count = 0
    while n:
        n &= (n - 1)
        count += 1
    return count


# 进阶:对 0 到 n 的所有数字统计 1 的个数
def count_bits_all(n: int) -> list[int]:
    """
    LeetCode #338: Counting Bits
    返回 [popcount(0), popcount(1), ..., popcount(n)]
    
    DP 关系:popcount(x) = popcount(x >> 1) + (x & 1)
    即:x 的 1 个数 = x 右移一位的 1 个数 + x 的最低位
    
    时间 O(n), 空间 O(n)
    """
    result = [0] * (n + 1)
    for i in range(1, n + 1):
        result[i] = result[i >> 1] + (i & 1)
    return result


# 替代 DP:popcount(x) = popcount(x & (x-1)) + 1
def count_bits_all_v2(n: int) -> list[int]:
    """利用清除最低位 1 的关系"""
    result = [0] * (n + 1)
    for i in range(1, n + 1):
        result[i] = result[i & (i - 1)] + 1
    return result


# 测试
print(f"16 是 2 的幂: {is_power_of_two(16)}")  # True
print(f"18 是 2 的幂: {is_power_of_two(18)}")  # False
print(f"popcount(11) = {number_of_1_bits(11)}")  # 3
print(f"count_bits(5) = {count_bits_all(5)}")    # [0, 1, 1, 2, 1, 2]

4.6 面试题:更多位运算技巧

def reverse_bits(n: int) -> int:
    """
    LeetCode #190: 反转 32 位无符号整数的位
    
    分治法:先交换相邻的 1 位,再交换相邻的 2 位,
    再交换相邻的 4 位,8 位,16 位
    """
    n = ((n & 0x55555555) << 1) | ((n >> 1) & 0x55555555)   # 交换 1 位
    n = ((n & 0x33333333) << 2) | ((n >> 2) & 0x33333333)   # 交换 2 位
    n = ((n & 0x0F0F0F0F) << 4) | ((n >> 4) & 0x0F0F0F0F)  # 交换 4 位
    n = ((n & 0x00FF00FF) << 8) | ((n >> 8) & 0x00FF00FF)   # 交换 8 位
    n = ((n & 0x0000FFFF) << 16) | ((n >> 16) & 0x0000FFFF) # 交换 16 位
    return n


def missing_number(nums: list[int]) -> int:
    """
    LeetCode #268: 0到n中缺少的数字
    
    方法:XOR 所有下标和所有数字
    下标 0,1,...,n 和 nums 中出现的数字 XOR 后,
    成对的消除,剩下缺失的那个
    """
    n = len(nums)
    result = n  # 从 n 开始(因为下标只到 n-1)
    for i in range(n):
        result ^= i ^ nums[i]
    return result


def single_non_duplicate(nums: list[int]) -> int:
    """
    LeetCode #540: 有序数组中的单一元素
    所有元素出现两次,有一个只出现一次
    
    利用有序性质的二分查找(比直接 XOR 更高效):
    在单一元素之前,pairs 的起始是偶数下标
    在单一元素之后,pairs 的起始是奇数下标
    """
    lo, hi = 0, len(nums) - 1
    while lo < hi:
        mid = (lo + hi) // 2
        # 保证 mid 是偶数
        if mid % 2 == 1:
            mid -= 1
        # 如果 nums[mid] == nums[mid+1],单一元素在右边
        if nums[mid] == nums[mid + 1]:
            lo = mid + 2
        else:
            hi = mid
    return nums[lo]


def find_complement(n: int) -> int:
    """
    LeetCode #476: 数字的补数
    翻转 n 的所有有效位(不包括前导零)
    
    例:5 = 101,补数 = 010 = 2
    """
    # 找到与 n 位数相同的全 1 掩码
    mask = 1
    while mask <= n:
        mask <<= 1
    mask -= 1  # 例:n=5(101),mask = 111 = 7
    
    return n ^ mask


# 测试
print(f"reverse_bits(43261596) = {reverse_bits(43261596)}")
print(f"missing_number([3,0,1]) = {missing_number([3, 0, 1])}")  # 2
print(f"single_non_duplicate([1,1,2,3,3]) = {single_non_duplicate([1,1,2,3,3,4,4,8,8])}")  # 2
print(f"complement(5) = {find_complement(5)}")  # 2

4.7 位运算的实际工程应用

1. 网络协议中的位运算

def parse_ipv4(ip_int: int) -> str:
    """将 32 位整数解析为 IPv4 地址"""
    return f"{(ip_int >> 24) & 0xFF}.{(ip_int >> 16) & 0xFF}.{(ip_int >> 8) & 0xFF}.{ip_int & 0xFF}"

def ip_to_int(ip: str) -> int:
    """将 IPv4 地址转为 32 位整数"""
    parts = [int(p) for p in ip.split('.')]
    return (parts[0] << 24) | (parts[1] << 16) | (parts[2] << 8) | parts[3]

def subnet_match(ip: str, subnet: str) -> bool:
    """
    判断 IP 是否在子网中
    子网格式:192.168.1.0/24
    
    方法:把 IP 和子网地址的前 prefix_len 位做比较
    """
    network, prefix_len = subnet.split('/')
    prefix_len = int(prefix_len)
    
    ip_int = ip_to_int(ip)
    network_int = ip_to_int(network)
    
    # 掩码:前 prefix_len 位为 1,其余为 0
    mask = (0xFFFFFFFF << (32 - prefix_len)) & 0xFFFFFFFF
    
    return (ip_int & mask) == (network_int & mask)


# 演示
print(parse_ipv4(0xC0A80101))  # 192.168.1.1
print(f"192.168.1.100 in 192.168.1.0/24: {subnet_match('192.168.1.100', '192.168.1.0/24')}")  # True
print(f"192.168.2.1 in 192.168.1.0/24: {subnet_match('192.168.2.1', '192.168.1.0/24')}")      # False

2. 哈希函数中的位运算

def murmur_hash_mix(h: int) -> int:
    """
    MurmurHash3 的最终混合步骤
    通过位运算打散哈希值,减少冲突
    """
    h ^= h >> 16
    h = (h * 0x85ebca6b) & 0xFFFFFFFF
    h ^= h >> 13
    h = (h * 0xc2b2ae35) & 0xFFFFFFFF
    h ^= h >> 16
    return h


def fibonacci_hash(key: int, bits: int = 10) -> int:
    """
    Fibonacci 哈希(用于 HashMap 的槽位定位)
    
    利用黄金比例的位运算近似:2^32 / phi ≈ 2654435769
    乘以这个魔数后取高位,分布非常均匀
    
    这就是 Java HashMap 和 Linux 内核使用的哈希方法
    """
    GOLDEN_RATIO = 2654435769  # 2^32 * (sqrt(5) - 1) / 2
    return ((key * GOLDEN_RATIO) & 0xFFFFFFFF) >> (32 - bits)

3. 游戏开发中的位运算

# 碰撞层(Collision Layers)使用位掩码
LAYER_PLAYER = 1 << 0     # 0001
LAYER_ENEMY = 1 << 1      # 0010
LAYER_BULLET = 1 << 2     # 0100
LAYER_WALL = 1 << 3       # 1000

# 定义每个对象可以碰撞的层
player_collision_mask = LAYER_ENEMY | LAYER_WALL      # 1010
bullet_collision_mask = LAYER_ENEMY | LAYER_WALL      # 1010
enemy_collision_mask = LAYER_PLAYER | LAYER_BULLET    # 0101

def can_collide(obj_layer: int, target_mask: int) -> bool:
    """检查两个对象是否可以碰撞"""
    return bool(obj_layer & target_mask)

print(f"子弹能打到敌人: {can_collide(LAYER_BULLET, enemy_collision_mask)}")  # True
print(f"子弹能打到玩家: {can_collide(LAYER_BULLET, player_collision_mask)}")  # False

4.8 本章总结

知识点 关键信息
基本运算 AND, OR, XOR, NOT, SHIFT(6种)
XOR 核心性质 a^a=0, a^0=a, 交换律, 结合律
lowbit n & (-n),补码原理
清除最低 1 n & (n-1),Brian Kernighan 算法
popcount 位并行法 O(1) / 查表法 O(1) / Kernighan O(k)
状态压缩 用 n 位整数表示 n 个元素的子集
TSP 状压 DP O(2^n * n^2),适用 n≤20
SIMD 128/256/512 位并行运算

位运算是算法与硬件的交汇点。在面试中,它考验你对二进制表示的理解;在工程中,它帮你写出高性能代码;在系统设计中,它是网络协议、加密算法、操作系统的基础语言。

记住一个核心原则:计算机不认识"数字",它只认识比特。 当你用位运算思考时,你就在用计算机的母语编程。

本章评分
4.5  / 5  (3 评分)

💬 留言讨论