动态规划进阶:区间DP深度解析

1. 区间DP概述

区间DP是动态规划的一种特殊形式,主要用于解决区间性质的问题。这类问题的特征是问题的解可以通过小区间的解组合得到大区间的解,具有典型的"分治+记忆化"特征。

2. 区间DP基本概念

2.1 区间DP特点

  • 状态定义 :通常定义为 dp[i][j],表示区间 [i, j] 的最优解
  • 遍历顺序:按区间长度从小到大遍历
  • 状态转移:通过划分区间中间点进行转移

2.2 通用模板

python 复制代码
def interval_dp_template(nums):
    n = len(nums)
    # 初始化DP数组
    dp = [[0] * n for _ in range(n)]
    
    # 初始化基本情况(长度为1的区间)
    for i in range(n):
        dp[i][i] = base_value  # 根据问题确定
    
    # 按区间长度从小到大遍历
    for length in range(2, n + 1):           # 区间长度
        for i in range(n - length + 1):      # 区间起点
            j = i + length - 1               # 区间终点
            
            # 初始化dp[i][j],根据问题设定
            dp[i][j] = init_value
            
            # 遍历分割点
            for k in range(i, j):            # 分割点位置
                # 根据问题确定状态转移方程
                dp[i][j] = max/min(dp[i][j], 
                                   dp[i][k] + dp[k+1][j] + cost)
    
    return dp[0][n-1]  # 整个区间的解

3. 经典区间DP问题

3.1 戳气球 (LeetCode 312)

问题描述:戳破气球获得硬币,求能获得的最大硬币数。

状态定义

dp[i][j]:戳破区间 (i, j) 内所有气球能获得的最大硬币数

注意:这里区间是开区间 (i, j),不包括i和j

状态转移方程
复制代码
dp[i][j] = max(
    dp[i][k] + dp[k][j] + nums[i] * nums[k] * nums[j]
    for k in range(i+1, j)
)
Python实现
python 复制代码
def maxCoins(nums):
    """
    戳气球问题
    关键:将问题转化为添加气球,而不是戳破
    """
    # 在首尾添加虚拟气球,值为1
    nums = [1] + nums + [1]
    n = len(nums)
    
    # dp[i][j]表示戳破区间(i,j)内所有气球的最大硬币数
    dp = [[0] * n for _ in range(n)]
    
    # 从下往上,从左往右遍历(按区间长度)
    for length in range(2, n):  # 区间长度至少为3(包含一个气球)
        for i in range(n - length):  # 区间起点
            j = i + length  # 区间终点
            
            # 遍历最后一个被戳破的气球位置
            for k in range(i + 1, j):
                # 戳破k位置气球的收益
                profit = dp[i][k] + dp[k][j] + nums[i] * nums[k] * nums[j]
                dp[i][j] = max(dp[i][j], profit)
    
    return dp[0][n-1]

#### 记忆化搜索(自顶向下)实现
def maxCoins_memo(nums):
    nums = [1] + nums + [1]
    n = len(nums)
    memo = [[-1] * n for _ in range(n)]
    
    def dfs(left, right):
        # 开区间 (left, right) 内没有气球
        if left + 1 == right:
            return 0
        
        if memo[left][right] != -1:
            return memo[left][right]
        
        max_coins = 0
        # 遍历最后一个被戳破的气球
        for k in range(left + 1, right):
            coins = (dfs(left, k) + dfs(k, right) + 
                    nums[left] * nums[k] * nums[right])
            max_coins = max(max_coins, coins)
        
        memo[left][right] = max_coins
        return max_coins
    
    return dfs(0, n - 1)
Java实现
java 复制代码
public class BurstBalloons {
    public int maxCoins(int[] nums) {
        int n = nums.length;
        // 添加虚拟气球
        int[] newNums = new int[n + 2];
        newNums[0] = 1;
        newNums[n + 1] = 1;
        for (int i = 0; i < n; i++) {
            newNums[i + 1] = nums[i];
        }
        
        n += 2;
        int[][] dp = new int[n][n];
        
        // 按区间长度遍历
        for (int len = 2; len < n; len++) {
            for (int i = 0; i < n - len; i++) {
                int j = i + len;
                // 遍历最后一个戳破的气球
                for (int k = i + 1; k < j; k++) {
                    dp[i][j] = Math.max(dp[i][j], 
                        dp[i][k] + dp[k][j] + newNums[i] * newNums[k] * newNums[j]);
                }
            }
        }
        
        return dp[0][n-1];
    }
    
    // 记忆化搜索版本
    public int maxCoinsMemo(int[] nums) {
        int n = nums.length;
        int[] newNums = new int[n + 2];
        newNums[0] = 1;
        newNums[n + 1] = 1;
        for (int i = 0; i < n; i++) {
            newNums[i + 1] = nums[i];
        }
        
        int[][] memo = new int[n + 2][n + 2];
        for (int[] row : memo) Arrays.fill(row, -1);
        
        return dfs(newNums, 0, n + 1, memo);
    }
    
    private int dfs(int[] nums, int left, int right, int[][] memo) {
        if (left + 1 == right) return 0;
        if (memo[left][right] != -1) return memo[left][right];
        
        int maxCoins = 0;
        for (int k = left + 1; k < right; k++) {
            int coins = dfs(nums, left, k, memo) + 
                       dfs(nums, k, right, memo) + 
                       nums[left] * nums[k] * nums[right];
            maxCoins = Math.max(maxCoins, coins);
        }
        
        memo[left][right] = maxCoins;
        return maxCoins;
    }
}

3.2 多边形三角剖分的最低得分 (LeetCode 1039)

问题描述:将凸多边形三角剖分,使得所有三角形得分和最小。

状态定义

dp[i][j]:顶点i到顶点j构成的多边形的最低得分

状态转移方程
复制代码
dp[i][j] = min(
    dp[i][k] + dp[k][j] + values[i] * values[k] * values[j]
    for k in range(i+1, j)
)
Python实现
python 复制代码
def minScoreTriangulation(values):
    """
    多边形三角剖分
    """
    n = len(values)
    dp = [[0] * n for _ in range(n)]
    
    # 按区间长度遍历(至少3个顶点才能构成三角形)
    for length in range(2, n):  # length=2表示有3个顶点
        for i in range(n - length):
            j = i + length
            dp[i][j] = float('inf')
            
            # 遍历分割点,以(i,j)为边,k为第三个顶点
            for k in range(i + 1, j):
                score = dp[i][k] + dp[k][j] + values[i] * values[k] * values[j]
                dp[i][j] = min(dp[i][j], score)
    
    return dp[0][n-1]

#### 递归+记忆化版本
def minScoreTriangulation_memo(values):
    n = len(values)
    memo = [[-1] * n for _ in range(n)]
    
    def dfs(i, j):
        # 只有两个顶点,无法构成三角形
        if j - i < 2:
            return 0
        
        if memo[i][j] != -1:
            return memo[i][j]
        
        min_score = float('inf')
        # 以(i,j)为边,k为第三个顶点
        for k in range(i + 1, j):
            score = dfs(i, k) + dfs(k, j) + values[i] * values[k] * values[j]
            min_score = min(min_score, score)
        
        memo[i][j] = min_score
        return min_score
    
    return dfs(0, n - 1)
Java实现
java 复制代码
public class MinScoreTriangulation {
    public int minScoreTriangulation(int[] values) {
        int n = values.length;
        int[][] dp = new int[n][n];
        
        for (int len = 2; len < n; len++) {
            for (int i = 0; i < n - len; i++) {
                int j = i + len;
                dp[i][j] = Integer.MAX_VALUE;
                
                for (int k = i + 1; k < j; k++) {
                    dp[i][j] = Math.min(dp[i][j], 
                        dp[i][k] + dp[k][j] + values[i] * values[k] * values[j]);
                }
            }
        }
        
        return dp[0][n-1];
    }
}

3.3 奇怪的打印机 (LeetCode 664)

问题描述:打印机每次可以打印连续相同字符,求打印目标字符串的最少次数。

状态定义

dp[i][j]:打印子串 s[i:j+1] 所需的最少次数

状态转移方程
复制代码
if s[i] == s[j]:
    dp[i][j] = dp[i][j-1]  # 首尾相同,可以一起打印
else:
    dp[i][j] = min(dp[i][k] + dp[k+1][j] for k in range(i, j))
Python实现
python 复制代码
def strangePrinter(s):
    """
    奇怪的打印机
    """
    if not s:
        return 0
    
    n = len(s)
    dp = [[0] * n for _ in range(n)]
    
    # 初始化:单个字符需要打印1次
    for i in range(n):
        dp[i][i] = 1
    
    # 按区间长度遍历
    for length in range(2, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            
            # 默认情况:先打印s[i],再打印剩余部分
            dp[i][j] = dp[i+1][j] + 1
            
            # 如果s[i] == s[k],可以一起打印
            for k in range(i + 1, j + 1):
                if s[i] == s[k]:
                    left = dp[i][k-1] if k > i else 0
                    right = dp[k+1][j] if k < j else 0
                    dp[i][j] = min(dp[i][j], left + right)
    
    return dp[0][n-1]

#### 优化版本(预处理相同字符)
def strangePrinter_optimized(s):
    if not s:
        return 0
    
    n = len(s)
    # 预处理:合并连续相同字符
    s_compressed = []
    for ch in s:
        if not s_compressed or ch != s_compressed[-1]:
            s_compressed.append(ch)
    
    m = len(s_compressed)
    dp = [[0] * m for _ in range(m)]
    
    for i in range(m):
        dp[i][i] = 1
    
    for length in range(2, m + 1):
        for i in range(m - length + 1):
            j = i + length - 1
            dp[i][j] = dp[i+1][j] + 1
            
            for k in range(i + 1, j + 1):
                if s_compressed[i] == s_compressed[k]:
                    left = dp[i][k-1] if k > i else 0
                    right = dp[k+1][j] if k < j else 0
                    dp[i][j] = min(dp[i][j], left + right)
    
    return dp[0][m-1]
Java实现
java 复制代码
public class StrangePrinter {
    public int strangePrinter(String s) {
        if (s == null || s.length() == 0) return 0;
        
        int n = s.length();
        int[][] dp = new int[n][n];
        
        // 初始化
        for (int i = 0; i < n; i++) {
            dp[i][i] = 1;
        }
        
        // 按区间长度遍历
        for (int len = 2; len <= n; len++) {
            for (int i = 0; i <= n - len; i++) {
                int j = i + len - 1;
                dp[i][j] = dp[i+1][j] + 1;
                
                // 寻找可以一起打印的相同字符
                for (int k = i + 1; k <= j; k++) {
                    if (s.charAt(i) == s.charAt(k)) {
                        int left = (k > i) ? dp[i][k-1] : 0;
                        int right = (k < j) ? dp[k+1][j] : 0;
                        dp[i][j] = Math.min(dp[i][j], left + right);
                    }
                }
            }
        }
        
        return dp[0][n-1];
    }
}

4. 区间DP的典型应用场景

4.1 回文相关问题

最长回文子序列

(在序列DP部分已详细介绍,这里展示区间DP解法)

python 复制代码
def longestPalindromeSubseq_interval(s):
    n = len(s)
    dp = [[0] * n for _ in range(n)]
    
    # 初始化:单个字符是回文,长度为1
    for i in range(n):
        dp[i][i] = 1
    
    # 按区间长度遍历
    for length in range(2, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            
            if s[i] == s[j]:
                dp[i][j] = dp[i+1][j-1] + 2
            else:
                dp[i][j] = max(dp[i+1][j], dp[i][j-1])
    
    return dp[0][n-1]

4.2 石子合并问题

问题描述:N堆石子排成一排,每次合并相邻两堆,代价为两堆石子数之和,求最小合并代价。

python 复制代码
def stoneMerge(stones):
    """
    石子合并问题(最小代价)
    """
    n = len(stones)
    
    # 前缀和,方便计算区间和
    prefix_sum = [0] * (n + 1)
    for i in range(n):
        prefix_sum[i+1] = prefix_sum[i] + stones[i]
    
    dp = [[0] * n for _ in range(n)]
    
    # 按区间长度遍历
    for length in range(2, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            dp[i][j] = float('inf')
            
            # 遍历分割点
            for k in range(i, j):
                cost = dp[i][k] + dp[k+1][j] + prefix_sum[j+1] - prefix_sum[i]
                dp[i][j] = min(dp[i][j], cost)
    
    return dp[0][n-1]

#### 环形石子合并(扩展)
def stoneMergeCircular(stones):
    """
    环形石子合并
    技巧:复制数组,将环形转化为线性
    """
    n = len(stones)
    # 复制数组,形成2n长度的数组
    extended_stones = stones + stones
    
    # 前缀和
    prefix_sum = [0] * (2 * n + 1)
    for i in range(2 * n):
        prefix_sum[i+1] = prefix_sum[i] + extended_stones[i]
    
    dp = [[0] * (2 * n) for _ in range(2 * n)]
    
    # 按区间长度遍历
    for length in range(2, n + 1):
        for i in range(2 * n - length + 1):
            j = i + length - 1
            dp[i][j] = float('inf')
            
            for k in range(i, j):
                cost = dp[i][k] + dp[k+1][j] + prefix_sum[j+1] - prefix_sum[i]
                dp[i][j] = min(dp[i][j], cost)
    
    # 取所有长度为n的区间的最小值
    result = float('inf')
    for i in range(n):
        result = min(result, dp[i][i+n-1])
    
    return result
Java实现
java 复制代码
public class StoneMerge {
    // 线性石子合并
    public int stoneMerge(int[] stones) {
        int n = stones.length;
        int[] prefix = new int[n + 1];
        for (int i = 0; i < n; i++) {
            prefix[i + 1] = prefix[i] + stones[i];
        }
        
        int[][] dp = new int[n][n];
        
        for (int len = 2; len <= n; len++) {
            for (int i = 0; i <= n - len; i++) {
                int j = i + len - 1;
                dp[i][j] = Integer.MAX_VALUE;
                
                for (int k = i; k < j; k++) {
                    int cost = dp[i][k] + dp[k+1][j] + prefix[j+1] - prefix[i];
                    dp[i][j] = Math.min(dp[i][j], cost);
                }
            }
        }
        
        return dp[0][n-1];
    }
}

5. 区间DP优化技巧

5.1 四边形不等式优化

对于某些区间DP问题,如果代价函数满足四边形不等式,可以使用四边形不等式优化,将时间复杂度从O(n³)降低到O(n²)。

四边形不等式条件

设w(i, j)为区间[i, j]的代价,如果满足:

复制代码
w(i, j) + w(i', j') ≤ w(i, j') + w(i', j) 对于所有 i ≤ i' ≤ j ≤ j'

则可以使用优化。

优化后的石子合并
python 复制代码
def stoneMerge_optimized(stones):
    """
    四边形不等式优化版本
    """
    n = len(stones)
    prefix_sum = [0] * (n + 1)
    for i in range(n):
        prefix_sum[i+1] = prefix_sum[i] + stones[i]
    
    dp = [[0] * n for _ in range(n)]
    # s[i][j]记录最优分割点
    s = [[0] * n for _ in range(n)]
    
    # 初始化
    for i in range(n):
        s[i][i] = i
        dp[i][i] = 0
    
    # 按区间长度遍历
    for length in range(2, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            dp[i][j] = float('inf')
            
            # 优化:只在s[i][j-1]到s[i+1][j]之间搜索
            start = s[i][j-1] if i <= j-1 else i
            end = s[i+1][j] if i+1 <= j else j
            
            for k in range(start, end + 1):
                if k < j:  # 确保k不是最后一个元素
                    cost = dp[i][k] + dp[k+1][j] + prefix_sum[j+1] - prefix_sum[i]
                    if cost < dp[i][j]:
                        dp[i][j] = cost
                        s[i][j] = k
    
    return dp[0][n-1]

5.2 断环成链技巧

对于环形区间DP问题,常用的技巧是将数组复制一倍,转化为线性问题。

python 复制代码
def circular_interval_dp(nums):
    """
    环形区间DP通用解法
    """
    n = len(nums)
    # 复制数组
    extended_nums = nums + nums
    
    # 处理长度为2n的线性数组
    dp = [[0] * (2 * n) for _ in range(2 * n)]
    
    for length in range(2, n + 1):
        for i in range(2 * n - length + 1):
            j = i + length - 1
            # ... 状态转移 ...
    
    # 取所有长度为n的区间的结果
    result = max/min(dp[i][i+n-1] for i in range(n))
    return result

6. 区间DP解题模板总结

6.1 通用解题框架

python 复制代码
def solve_interval_dp(nums):
    n = len(nums)
    
    # 1. 初始化DP表
    dp = [[0] * n for _ in range(n)]
    
    # 2. 初始化基本情况(长度为1的区间)
    for i in range(n):
        dp[i][i] = base_case_value(nums[i])
    
    # 3. 按区间长度从小到大遍历
    for length in range(2, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            
            # 4. 初始化当前区间
            dp[i][j] = init_value
            
            # 5. 遍历分割点
            for k in range(i, j):
                # 6. 计算子问题组合的代价
                left = dp[i][k]
                right = dp[k+1][j] if k+1 <= j else 0
                cost = calculate_cost(nums, i, k, j)
                
                # 7. 更新最优解
                dp[i][j] = update_value(dp[i][j], left + right + cost)
    
    # 8. 返回整个区间的解
    return dp[0][n-1]

6.2 不同问题的初始化

问题类型 区间长度1初始化 区间长度2初始化 备注
戳气球 dp[i][i] = 0 特殊处理(开区间) 添加虚拟气球
三角剖分 dp[i][i] = 0 dp[i][i+1] = 0 至少3个点才能剖分
石子合并 dp[i][i] = 0 dp[i][i+1] = stones[i]+stones[i+1] 需要前缀和
奇怪打印机 dp[i][i] = 1 根据字符是否相同 字符打印问题

6.3 状态转移方程对比

问题 状态转移方程 特点
戳气球 dp[i][j] = max(dp[i][k] + dp[k][j] + nums[i]*nums[k]*nums[j]) 开区间,k在(i,j)内
三角剖分 dp[i][j] = min(dp[i][k] + dp[k][j] + v[i]*v[k]*v[j]) 类似戳气球
石子合并 dp[i][j] = min(dp[i][k] + dp[k+1][j] + sum[i:j+1]) 需要区间和
回文子序列 dp[i][j] = dp[i+1][j-1]+2 if s[i]==s[j] else max(dp[i+1][j], dp[i][j-1]) 字符比较

6.4 遍历顺序的重要性

python 复制代码
# 正确的遍历顺序(按区间长度)
for length in range(2, n+1):
    for i in range(n-length+1):
        j = i + length - 1
        # 此时所有更小区间的解都已经计算好了

# 错误的遍历顺序
for i in range(n):
    for j in range(i, n):
        # 可能依赖未计算的子区间

7. 常见错误与调试技巧

7.1 常见错误

  1. 区间定义混淆

    • 开区间 vs 闭区间
    • 索引从0开始还是1开始
  2. 边界条件错误

    • 区间长度为1或2的特殊处理
    • 分割点范围错误
  3. 初始化遗漏

    • 忘记初始化对角线
    • 区间长度2的情况需要特殊处理
  4. 遍历顺序错误

    • 没有按区间长度从小到大
    • 分割点范围不对

7.2 调试技巧

  1. 打印DP表
python 复制代码
def print_dp_table(dp):
    n = len(dp)
    for i in range(n):
        for j in range(n):
            print(f"{dp[i][j]:3d}", end=" ")
        print()
  1. 小规模测试
python 复制代码
# 测试用例
test_cases = [
    ([1], 0),  # 边界情况
    ([1, 2], 2),  # 最小非平凡情况
    ([3, 1, 5, 8], 167),  # 标准测试
]
  1. 逐步验证
    • 手动计算小规模案例
    • 验证初始化是否正确
    • 检查状态转移是否覆盖所有情况

7.3 性能优化建议

  1. 空间优化
python 复制代码
# 如果只依赖相邻行,可以使用滚动数组
def interval_dp_space_optimized(nums):
    n = len(nums)
    dp_curr = [0] * n
    dp_prev = [0] * n
    
    for length in range(2, n+1):
        for i in range(n-length+1):
            j = i + length - 1
            # 计算dp_curr[i]
        dp_prev, dp_curr = dp_curr, dp_prev
  1. 时间优化
    • 四边形不等式优化
    • 预处理前缀和避免重复计算
    • 记忆化搜索减少重复状态计算

8. 进阶练习题目

8.1 推荐练习顺序

  1. 基础题目

    • 最长回文子序列(复习)
    • 石子合并(线性)
  2. 中等难度

    • 戳气球
    • 多边形三角剖分
  3. 进阶题目

    • 奇怪的打印机
    • 环形石子合并
  4. 挑战题目

    • 合并果子(优先队列解法)
    • 最优二叉搜索树

8.2 变种问题

  1. 最大得分问题

    • 将最小代价改为最大得分
    • 状态转移从min改为max
  2. 带权区间DP

    • 每个区间有额外权重
    • 代价函数更复杂
  3. 高维区间DP

    • 二维区间DP(矩阵链乘法)
    • 树形区间DP

8.3 综合应用

  1. 结合其他算法

    • 区间DP + 贪心
    • 区间DP + 二分查找
    • 区间DP + 状态压缩
  2. 实际问题建模

    • 任务调度问题
    • 资源分配问题
    • 字符串编辑问题

9. 面试准备建议

9.1 必备知识点

  1. 理解区间DP的基本思想
  2. 掌握经典问题的状态定义和转移
  3. 熟悉区间DP的通用模板
  4. 了解常见的优化技巧

9.2 解题思路

  1. 识别问题:判断是否属于区间DP
  2. 定义状态:明确dp[i][j]的含义
  3. 确定转移:如何从小区间得到大区间
  4. 确定顺序:按区间长度从小到大
  5. 边界处理:处理好最小区间的情况

9.3 沟通技巧

  1. 清晰解释状态定义
  2. 说明遍历顺序的原因
  3. 分析时间空间复杂度
  4. 讨论可能的优化方案

10. 总结

区间DP是动态规划中非常重要的一类问题,其核心思想是"分治+记忆化"。通过将大区间分解为小区间,利用小区间的最优解组合得到大区间的最优解。

关键要点

  1. 状态定义dp[i][j] 表示区间 [i, j] 的最优解
  2. 遍历顺序:按区间长度从小到大
  3. 状态转移:遍历分割点,组合子区间解
  4. 初始化:处理好最小区间的情况
  5. 优化:四边形不等式、断环成链等技巧

掌握区间DP不仅有助于解决特定的算法问题,更能培养将复杂问题分解为子问题的思维能力,这是解决许多工程问题的关键技能。建议通过大量练习,深入理解区间DP的思想和应用。

相关推荐
QiZhang | UESTC2 小时前
【算法题学习方法调整】回溯核心逻辑调整:从记代码到套逻辑调整
算法·学习方法
救救孩子把2 小时前
59-机器学习与大模型开发数学教程-5-6 Adam、RMSProp、AdaGrad 等自适应优化算法
人工智能·算法·机器学习
Σίσυφος19002 小时前
PCL 中常用的滤波对比
算法
进击的小头2 小时前
连续系统离散化方法(嵌入式信号处理实战指南)
c语言·算法·信号处理
永远都不秃头的程序员(互关)2 小时前
【决策树深度探索(五)】智慧之眼:信息增益,如何找到最佳决策问题?
算法·决策树·机器学习
智者知已应修善业2 小时前
【输出方形点阵】2024-11-1
c语言·c++·经验分享·笔记·算法
近津薪荼2 小时前
优选算法——双指针专题2(模拟)
c++·学习·算法
乌萨奇也要立志学C++2 小时前
【洛谷】DFS 新手必学的4 道DFS经典题 手把手教你剪枝与回溯
算法·深度优先
sali-tec2 小时前
C# 基于OpenCv的视觉工作流-章15-多边形逼近
图像处理·人工智能·opencv·算法·计算机视觉