Chapter 29

Dynamic Programming III: Tree DP and Bitmask DP

Chapter 29: Dynamic Programming III — Tree DP and Bitmask DP

The previous two chapters conquered one-dimensional DP, knapsack problems, sequence DP, and interval DP. This chapter enters the "advanced arsenal" of dynamic programming: when the state space is no longer a linear sequence or a rectangular grid but rather a tree or a subset of a set, we need entirely new ways to represent states.

The core idea of Tree DP is bottom-up aggregation: compute optimal solutions for all subtrees first, then merge them at the parent. The core idea of Bitmask DP is representing sets as binary numbers: an n-bit integer can encode all 2^n subsets, compressing "which elements have been selected" into a single integer. Digit DP and Game Theory DP further generalize state definitions.

These techniques represent "final boss" difficulty in interviews — few candidates can implement them correctly, yet the underlying principles are surprisingly accessible.


Level 1 · What You Need to Know

29.1 Tree DP Concepts

What Is Tree DP?

Tree DP performs dynamic programming on tree structures. Unlike linear DP, the tree's topology dictates the direction of transitions: a node's state depends on the states of all its children.

Why do we need Tree DP? Many real-world problems are naturally tree-shaped: organizational hierarchies (can you invite a person and their direct report to the same party?), file systems, compiler ASTs, network topologies. Linear DP cannot directly handle optimization on these structures.

Core Pattern

def dfs(node):
    # 1. Initialize current node's state
    # 2. Recursively process each child
    for child in node.children:
        dfs(child)
        # 3. Update current node's state using child's result
    # 4. Return current node's final state

This is the structure of a post-order traversal — process subtrees first, then handle the current node.

29.2 House Robber III (LeetCode #337)

Problem Statement

A binary tree where each node contains a non-negative integer representing the amount of money. Rule: you cannot rob two directly connected nodes (parent and child cannot both be robbed). Find the maximum amount you can rob.

State Definition

For each node, define two states:

Transition

rob(node) = node.val + not_rob(left) + not_rob(right)
    # If we rob current node, neither child can be robbed

not_rob(node) = max(rob(left), not_rob(left)) + max(rob(right), not_rob(right))
    # If we don't rob current node, each child independently chooses optimal

Why two states? The key constraint is "parent and child cannot both be robbed." If we only define dp[node] = max robbery from subtree, we cannot tell whether node itself was robbed during the transition — but this information is crucial for the parent. Hence two states are necessary.

Complete Implementation

from typing import Optional, Tuple

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def rob(root: Optional[TreeNode]) -> int:
    """House Robber III: Tree DP"""
    def dfs(node: Optional[TreeNode]) -> Tuple[int, int]:
        """Returns (max if rob this node, max if don't rob this node)"""
        if not node:
            return (0, 0)
        
        left_rob, left_not = dfs(node.left)
        right_rob, right_not = dfs(node.right)
        
        # Rob current: children cannot be robbed
        rob_current = node.val + left_not + right_not
        # Don't rob current: each child picks its own best
        not_rob_current = max(left_rob, left_not) + max(right_rob, right_not)
        
        return (rob_current, not_rob_current)
    
    r, nr = dfs(root)
    return max(r, nr)

Time Complexity: O(n), each node visited once. Space Complexity: O(h), where h is the tree height (recursion stack depth).

Common Mistakes

  1. Forgetting to return two states: Returning only one maximum makes it impossible for the parent to know whether the child was robbed.
  2. Redundant computation: Using memoization with unhashable keys (node objects) — a single post-order traversal is cleaner.
  3. Confusing "don't rob" with "rob nothing in subtree": Not robbing the current node does NOT mean you must rob the children; children can also choose not to be robbed.

29.3 Bitmask DP Concepts

What Is Bitmask DP?

Bitmask DP (State Compression DP) uses the binary representation of an integer to encode subset selection states. For n elements, each with two choices (selected/not selected), there are 2^n total combinations — representable by a single n-bit binary number.

Why do we need Bitmask DP? Consider: given n cities, find the shortest path visiting all cities exactly once (Travelling Salesman Problem, TSP). Brute-force enumeration of all permutations costs O(n!). By compressing "which cities have been visited" into a bitmask, we transform the problem into DP with complexity O(n^2 * 2^n) — entirely feasible for n <= 20.

Bit Manipulation Basics

# Check if bit i is set (element i is in the set)
(mask >> i) & 1

# Set bit i (add element i to the set)
mask | (1 << i)

# Clear bit i (remove element i from the set)
mask & ~(1 << i)

# Enumerate all subsets of mask
sub = mask
while sub > 0:
    # process subset sub
    sub = (sub - 1) & mask

29.4 Travelling Salesman Problem (TSP)

Problem Statement

Given n cities and a distance matrix dist[i][j], find the shortest path starting from city 0, visiting every city exactly once, and returning to city 0.

State Definition

dp[mask][i] = shortest path length having visited the set of cities represented by mask, currently at city i.

Note: mask is an n-bit binary number; bit k being 1 means city k has been visited.

Transition

dp[mask][i] = min(dp[mask ^ (1 << i)][j] + dist[j][i])
    where j != i, j is in mask (bit j of mask is 1)
    mask ^ (1 << i) removes city i from mask

Intuition: to reach state (mask, i), we must have come from some city j in mask (other than i), paying dist[j][i].

Initial Condition

dp[1][0] = 0  # Only city 0 visited, currently at city 0, path length 0

Final Answer

answer = min(dp[(1 << n) - 1][i] + dist[i][0])  for all i != 0

Complete Implementation

def tsp(dist: list[list[int]]) -> int:
    """
    Travelling Salesman Problem: Bitmask DP
    dist[i][j] = distance from city i to city j
    Returns shortest distance from city 0 through all cities back to city 0
    """
    n = len(dist)
    INF = float('inf')
    
    dp = [[INF] * n for _ in range(1 << n)]
    dp[1][0] = 0  # Start: only city 0 visited
    
    for mask in range(1, 1 << n):
        for i in range(n):
            if dp[mask][i] == INF:
                continue
            if not (mask >> i & 1):
                continue  # City i not in mask, invalid state
            
            # From city i, try visiting unvisited city j
            for j in range(n):
                if mask >> j & 1:
                    continue  # City j already visited
                new_mask = mask | (1 << j)
                dp[new_mask][j] = min(dp[new_mask][j], dp[mask][i] + dist[i][j])
    
    # After visiting all cities, return to start
    full_mask = (1 << n) - 1
    ans = INF
    for i in range(1, n):
        if dp[full_mask][i] != INF:
            ans = min(ans, dp[full_mask][i] + dist[i][0])
    
    return ans

Time Complexity: O(n^2 * 2^n). Outer loop enumerates all 2^n masks; inner loops enumerate current city i and next city j.

Space Complexity: O(n * 2^n).

Practical Limits: For n = 20, 2^20 = 1,048,576, multiplied by n^2 = 400, gives ~400 million operations — too slow for Python but feasible in C++ within 1-2 seconds. In interviews, n is typically at most 15.


Level 2 · How It Works Under the Hood

29.5 Digit DP

Motivating Problem: Number of Digit One (LeetCode #233)

Given an integer n, count the total number of times digit 1 appears in all integers from 1 to n. For example, n = 13 gives 6 (the numbers 1, 10, 11, 12, 13 contain six 1s total).

Core Idea of Digit DP

Digit DP views numbers as sequences of digits, making decisions from the most significant digit to the least significant. The key technique is maintaining a tight flag:

This ensures we count only numbers in [1, n].

Why the tight constraint? Consider n = 345. If we've chosen 3 for the hundreds digit, the tens digit cannot exceed 4. But if we chose 2 for hundreds, the tens digit can be anything from 0 to 9. The tight flag distinguishes these cases.

General Framework

from functools import lru_cache

def count_digit_one(n: int) -> int:
    """LeetCode 233: Number of Digit One"""
    if n <= 0:
        return 0
    
    digits = list(str(n))
    length = len(digits)
    
    @lru_cache(maxsize=None)
    def dp(pos: int, count: int, tight: bool, started: bool) -> int:
        """
        pos: current digit position (MSB to LSB)
        count: number of 1s seen so far
        tight: whether we're still constrained by upper bound
        started: whether we've placed a non-zero digit (handles leading zeros)
        """
        if pos == length:
            return count
        
        limit = int(digits[pos]) if tight else 9
        total = 0
        
        for d in range(0, limit + 1):
            new_tight = tight and (d == limit)
            new_started = started or (d != 0)
            new_count = count + (1 if d == 1 else 0)
            total += dp(pos + 1, new_count, new_tight, new_started)
        
        return total
    
    return dp(0, 0, True, False)

Step-by-Step Execution (n = 13)

digits = ['1', '3']

Tens place (pos=0):
  d=0: tight=False, ones place can be 0-9
       d=1 in ones contributes 1 → subtree contributes 1
  d=1: tight=True, ones place can be 0-3
       tens digit itself contributes 1 to each number
       ones d=0: count=1, contributes 1
       ones d=1: count=2, contributes 2
       ones d=2: count=1, contributes 1
       ones d=3: count=1, contributes 1
       subtree contributes = 1+2+1+1 = 5

Total = 1 + 5 = 6 ✓

Generality of Digit DP

This framework solves a wide class of "count numbers in [L, R] satisfying some digit property" problems:

All use the same framework with different state parameters and transition conditions.

29.6 Game Theory DP

Nim Game

Two players alternate taking stones from a pile. Each turn, take 1 to 3 stones. The player who takes the last stone wins. Does the first player have a winning strategy?

Key Result: If n % 4 == 0, the second player wins; otherwise the first player wins.

Why 4? This is a special case of Sprague-Grundy theory. Analyzing small cases:

The pattern: first player loses iff n is a multiple of 4.

def can_win_nim(n: int) -> bool:
    """Nim Game: take 1-3 stones per turn"""
    return n % 4 != 0

Stone Game (LeetCode #877)

Two players alternate picking from either end of a row of stone piles. Total number of stones is odd, so no ties. Does the first player always win?

State Definition

dp[i][j] = maximum score advantage the current player can achieve facing piles[i..j].

Transition

dp[i][j] = max(
    piles[i] - dp[i+1][j],   # take left
    piles[j] - dp[i][j-1]    # take right
)

Intuition: if I take piles[i], my opponent faces piles[i+1..j] and achieves advantage dp[i+1][j]. From my perspective, their advantage is my disadvantage, so I subtract it.

Complete Implementation

def stone_game(piles: list[int]) -> bool:
    """Stone Game: interval game DP"""
    n = len(piles)
    dp = [[0] * n for _ in range(n)]
    
    # Base case: single pile, current player takes it all
    for i in range(n):
        dp[i][i] = piles[i]
    
    # Fill by increasing interval length
    for length in range(2, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            dp[i][j] = max(
                piles[i] - dp[i + 1][j],
                piles[j] - dp[i][j - 1]
            )
    
    return dp[0][n - 1] > 0

Time Complexity: O(n^2). Space Complexity: O(n^2).

Mathematical Shortcut for LeetCode #877

For the specific constraints of LeetCode #877 (even number of piles, odd total), the first player always wins. The first player can choose to take only odd-indexed or only even-indexed piles (by always taking left or always taking right in a coordinated fashion), and since the total is odd, these two groups have different sums. The first player picks the larger group.

The DP solution above is general and works for any variant.

29.7 More Tree DP Examples

All Possible Full Binary Trees (LeetCode #894)

A full binary tree: every node is either a leaf or has exactly two children. Given n nodes, return all possible full binary trees.

Key observation: a full binary tree must have an odd number of nodes (root + left subtree + right subtree, where both subtrees have odd node counts). If n is even, the answer is empty.

from functools import lru_cache
from typing import Optional, List

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def all_possible_fbt(n: int) -> List[Optional[TreeNode]]:
    """All Possible Full Binary Trees"""
    @lru_cache(maxsize=None)
    def build(num_nodes: int) -> List[Optional[TreeNode]]:
        if num_nodes == 1:
            return [TreeNode(0)]
        if num_nodes % 2 == 0:
            return []
        
        result = []
        # Enumerate left subtree size (odd numbers)
        for left_count in range(1, num_nodes - 1, 2):
            right_count = num_nodes - 1 - left_count
            left_trees = build(left_count)
            right_trees = build(right_count)
            for lt in left_trees:
                for rt in right_trees:
                    root = TreeNode(0)
                    root.left = lt
                    root.right = rt
                    result.append(root)
        return result
    
    return build(n)

Level 3 · What the Theory Says

29.8 Complexity Analysis of Bitmask DP

Why O(n * 2^n)?

The state space of bitmask DP consists of two components:

  1. Subset enumeration: 2^n subsets of n elements
  2. An extra dimension per subset: typically "which element are we currently at," giving n choices

Total states: n * 2^n. Transitions typically enumerate which element we came from (another factor of n), yielding O(n^2 * 2^n).

Comparison with Brute Force

Method Time Complexity n=15 n=20
Brute force permutation O(n!) 1.3 * 10^12 2.4 * 10^18
Bitmask DP O(n^2 * 2^n) 7.4 * 10^6 4.2 * 10^8

For n=20, bitmask DP is 10^10 times faster! This is the power of the Held-Karp algorithm (1962).

History of the Held-Karp Algorithm

Michael Held and Richard Karp published "A Dynamic Programming Approach to Sequencing Problems" in the Journal of the Society for Industrial and Applied Mathematics (1962), first combining state compression with dynamic programming to reduce TSP's exact solution complexity from O(n!) to O(n^2 * 2^n). Though still exponential, this remains the best known complexity for exact TSP (believed impossible to improve under reasonable assumptions).

Enumerating Subsets of a Subset

Some bitmask DP problems require enumerating all subsets of a given set:

# Enumerate all non-empty subsets of mask
sub = mask
while sub > 0:
    # process subset sub
    sub = (sub - 1) & mask

What is the complexity? For a fixed mask with k bits set, this enumerates 2^k - 1 non-empty subsets. Summing over all masks:

$$\sum_{k=0}^{n} \binom{n}{k} \cdot 2^k = 3^n$$

This follows from the binomial theorem $(1+2)^n = 3^n$. So "enumerate subsets of all subsets" has total complexity O(3^n), not O(4^n).

29.9 Sprague-Grundy Theorem

Game Theory Foundations

Combinatorial Game Theory studies two-player games with alternating turns, perfect information, and no randomness. In such games, every position is either an N-position (Next player wins) or a P-position (Previous player wins, i.e., current player loses).

Grundy Values

Every game position can be assigned a non-negative integer g (Grundy value or nimber):

$$g(x) = \text{mex}{g(y) : y \text{ is a successor of } x}$$

where mex is the smallest non-negative integer not in the given set.

Statement of the Sprague-Grundy Theorem

Independently discovered by Roland Sprague (1935) and Patrick Grundy (1939):

Theorem: A combinatorial game is equivalent to a Nim heap of size equal to the game's Grundy value. The Grundy value of a sum (combination) of independent games equals the XOR of their individual Grundy values.

Why does this matter? This theorem reduces any finite, acyclic impartial game to Nim. Once you can compute each sub-game's Grundy value, the win/loss outcome of the entire combined game is determined by a single XOR operation.

Multi-pile Nim via SG Analysis

Multi-pile Nim: k piles with n_1, n_2, ..., n_k stones. Players alternate taking any number from any single pile.

Each pile is an independent sub-game. A Nim heap of size n_i has Grundy value n_i.

Total Grundy value = n_1 XOR n_2 XOR ... XOR n_k.

Code Verification

def compute_grundy(n: int, max_take: int) -> int:
    """
    Compute Grundy value for a stone-taking game
    One pile of n stones, at most max_take per turn
    """
    grundy = [0] * (n + 1)
    for i in range(1, n + 1):
        reachable = set()
        for take in range(1, min(i, max_take) + 1):
            reachable.add(grundy[i - take])
        mex = 0
        while mex in reachable:
            mex += 1
        grundy[i] = mex
    return grundy[n]

# Verify: with max_take=3, Grundy value = n % 4
for n in range(20):
    assert compute_grundy(n, 3) == n % 4

SG Theorem in Interviews

When facing game theory problems in interviews:

  1. Tabulate small cases to find P/N position patterns
  2. If the game decomposes into independent sub-games, compute Grundy values and XOR
  3. Many problems don't require full SG analysis — direct pattern recognition or mathematical shortcuts suffice

29.10 Theoretical Foundations of Digit DP

Correctness Proof

Digit DP's core converts "count numbers in [0, n] satisfying a condition" into "digit-by-digit decisions." Correctness relies on this observation:

For a d-digit number n = a_{d-1} a_{d-2} ... a_1 a_0, any number x in [0, n] can be uniquely characterized by:

This is exactly what the tight flag captures: tight = True means "all preceding digits match the upper bound"; once a digit is chosen less than the bound, tight becomes False and all subsequent digits are unconstrained.

Effectiveness of Memoization

In the memoized digit DP, the state (pos, count, tight, started) has:

Total states = O(d^2 * 4) = O(d^2) — extremely small. Digit DP is remarkably efficient.


Level 4 · Edge Cases and Pitfalls

29.11 Interview Deep Dive: House Robber III (#337)

Interview Traps

  1. Stack overflow: If the tree degenerates into a linked list, recursion depth hits O(n). Python's default limit is 1000; use sys.setrecursionlimit.

  2. Return value design: Many first attempts write dfs(node, robbed) indicating whether the current node is robbed, causing state explosion. The correct approach: each node returns a tuple (rob, not_rob).

  3. Follow-up: What if it's not a binary tree? Same idea, but not_rob becomes the sum of max(rob_child, not_rob_child) over all children:

def rob_general_tree(root):
    """House Robber on a general (multi-child) tree"""
    def dfs(node):
        if not node:
            return (0, 0)
        rob_sum = node.val
        not_rob_sum = 0
        for child in node.children:
            child_rob, child_not = dfs(child)
            rob_sum += child_not
            not_rob_sum += max(child_rob, child_not)
        return (rob_sum, not_rob_sum)
    
    r, nr = dfs(root)
    return max(r, nr)

29.12 Interview Deep Dive: Shortest Superstring (#943)

Problem Statement

Given an array of strings words, find the shortest string containing every string as a substring.

Analysis

This is essentially a TSP variant! Treat each string as a "city." The "distance" from string A to string B is the number of new characters needed when appending B after A (i.e., B's length minus its overlap with A's suffix).

Steps:

  1. Preprocessing: compute pairwise overlap lengths
  2. Bitmask DP: like TSP, find the optimal ordering
  3. Path reconstruction: build the final string from the DP result

Complete Implementation

def shortest_superstring(words: list[str]) -> str:
    """Shortest Superstring: Bitmask DP"""
    n = len(words)
    
    # Precompute: overlap[i][j] = max overlap of words[i]'s suffix with words[j]'s prefix
    overlap = [[0] * n for _ in range(n)]
    for i in range(n):
        for j in range(n):
            if i == j:
                continue
            max_k = min(len(words[i]), len(words[j]))
            for k in range(max_k, 0, -1):
                if words[i].endswith(words[j][:k]):
                    overlap[i][j] = k
                    break
    
    # Bitmask DP
    # dp[mask][i] = max total overlap when selected strings are mask, last is words[i]
    dp = [[0] * n for _ in range(1 << n)]
    parent = [[-1] * n for _ in range(1 << n)]
    
    for mask in range(1, 1 << n):
        for i in range(n):
            if not (mask >> i & 1):
                continue
            prev_mask = mask ^ (1 << i)
            if prev_mask == 0:
                continue
            for j in range(n):
                if not (prev_mask >> j & 1):
                    continue
                val = dp[prev_mask][j] + overlap[j][i]
                if val > dp[mask][i]:
                    dp[mask][i] = val
                    parent[mask][i] = j
    
    # Find optimal last string
    full_mask = (1 << n) - 1
    last = max(range(n), key=lambda i: dp[full_mask][i])
    
    # Path reconstruction
    path = []
    mask = full_mask
    cur = last
    while cur != -1:
        path.append(cur)
        prev = parent[mask][cur]
        mask ^= (1 << cur)
        cur = prev
    path.reverse()
    
    # Build result string
    result = words[path[0]]
    for k in range(1, len(path)):
        i, j = path[k - 1], path[k]
        result += words[j][overlap[i][j]:]
    
    return result

Complexity: O(n^2 * 2^n), feasible for n <= 12.

29.13 Interview Deep Dive: All Possible Full Binary Trees (#894)

Interview Key Points

  1. Recursion + Memoization: The key observation is that n must be odd for a solution to exist, n=1 yields a single node. For n>=3, enumerate left subtree sizes: 1, 3, 5, ..., n-2.

  2. Common follow-up questions:

    • "What's the time complexity?" — This is a variant of Catalan numbers; the count of full binary trees is Catalan(n/2), approximately 4^(n/2) / (n/2)^(3/2)
    • "Can you implement it iteratively?" — Yes, build bottom-up from n=1
  3. Reference sharing: In the memoized version, returned subtrees may be referenced by multiple parents. This is fine if we only read the trees, but deep copying is needed if modifications are required.

# Iterative version
def all_possible_fbt_iterative(n: int) -> list:
    """Build all full binary trees bottom-up"""
    if n % 2 == 0:
        return []
    
    dp = [[] for _ in range(n + 1)]
    dp[1] = [TreeNode(0)]
    
    for total in range(3, n + 1, 2):
        for left in range(1, total - 1, 2):
            right = total - 1 - left
            for lt in dp[left]:
                for rt in dp[right]:
                    root = TreeNode(0, lt, rt)
                    dp[total].append(root)
    
    return dp[n]

29.14 Practical Engineering Tips for Bitmask DP

Python Optimization for Bitmask DP

  1. Bit operations over set operations: Python's set is convenient but integer bit operations are an order of magnitude faster for bitmask DP.

  2. Precompute popcount:

popcount = [bin(i).count('1') for i in range(1 << n)]
  1. Layered processing: If dp[mask] only depends on states with smaller popcount, process masks layer by layer.

  2. Integer overflow: Python integers have no upper limit, but in other languages, 2^n overflows 32-bit integers for n > 30. Mention this in interviews.

Common Interview Traps

Trap Description Solution
Mask enumeration order Iterating masks from small to large satisfies dependencies Ensure dp[mask]'s prerequisites are computed
Missing initial state Forgetting dp[1][0]=0 in TSP Carefully define "start with only the origin"
Return to start TSP must add return distance Don't forget + dist[last][0]
Symmetry Some problems treat (A,B) and (B,A) as equivalent Fix the starting point to save half the time

29.15 Comprehensive Comparison: When to Use Which DP

Problem Characteristic DP Type Typical Problems
Optimization on tree structures Tree DP House Robber III, longest path in tree
Order-dependent subset selection Bitmask DP TSP, Shortest Superstring
Counting within numeric ranges Digit DP Number of Digit One, Non-negative integers without consecutive 1s
Two-player alternating games Game Theory DP Stone Game, Nim
Relationship between two sequences Sequence DP LCS, Edit Distance
Interval merging/splitting Interval DP Matrix Chain, Burst Balloons

In interviews, first classify the problem type, then apply the corresponding state definition template — much faster than designing states from scratch.


Chapter Summary

  1. Tree DP relies on post-order traversal: compute subtrees first, then merge at parent. States typically split into "select/don't select current node."

  2. Bitmask DP encodes subset states as bitmasks, reducing O(n!) problems to O(n^2 * 2^n). Applicable when n <= 20 for permutation/selection problems.

  3. Digit DP makes digit-by-digit decisions with a tight constraint ensuring we don't exceed the upper bound. Used for counting numbers with specific digit properties within a range.

  4. Game Theory DP: the Sprague-Grundy theorem reduces any impartial game to Nim; combine independent sub-games via XOR.

  5. Interview strategy: classify first (identify problem characteristics), apply template (state definition + transition), then verify (hand-compute small examples).

Rate this chapter
4.9  / 5  (3 ratings)

💬 Comments