回溯&剪枝
一、回溯算法的核心思想
回溯法是一种 试探性 的搜索方法,主要用于 求解组合、排列、子集、路径等问题 。其本质是通过构建决策树,递归地搜索所有可能的解法,遇到不满足条件的分支就 回退(撤销选择),继续尝试其他路径。
模板框架(通用):
java
void backtrack(List<选择类型> path, List<选择类型> 选择列表) {
if (满足结束条件) {
记录结果(path);
return;
}
for (选择类型 选择 : 选择列表) {
// 做出选择
path.add(选择);
// 根据当前选择构造新的选择列表(或维护索引)
backtrack(path, 新的选择列表);
// 撤销选择
path.remove(path.size() - 1);
}
}
二、剪枝策略
剪枝 是优化回溯的关键手段,用于提前排除不必要的递归分支,减少时间复杂度。常见的剪枝方法有:
- 重复剪枝:避免重复计算,如组合、子集问题中。
- 条件剪枝:基于题意排除不满足的路径,如数独判断是否合法。
- 排序+跳过:常用于有重复元素的排列、组合问题。
- 限制条件:限制搜索空间,如火柴拼正方形问题中限定边长。
三、题目分类与解析
题号 | 题目 | 类型 | 回溯点 | 剪枝点 |
---|---|---|---|---|
22 | 括号生成 | 生成型 | 左右括号数受限,递归构造 | 如果右括号比左括号多,则非法,剪枝 |
36 | 有效的数独 | 判定型 | 遍历每个格子尝试填数 | 每次填入数字前先判断是否冲突 |
79 | 单词搜索 | 路径搜索 | 从每个起点开始DFS | 已访问路径不能重复走,提前剪枝 |
46 | 全排列 | 排列型 | 每层选择一个未使用元素 | 用 visited[] 防止重复选择 |
77 | 组合 | 组合型 | 从当前数开始向后选 | 若剩余元素不足剪枝(如 n - i + 1 < k - path.size()) |
78 | 子集 | 子集型 | 每次可以选或不选当前元素 | 无需剪枝,结果包含所有子集 |
17 | 电话号码字母组合 | 多叉树搜索 | 每个数字对应多个字母进行选择 | 不需剪枝 |
473 | 火柴拼正方形 | 构造型 | 把火柴分到4个边 | 长度和大于边长立即剪枝、排序+逆序剪枝 |
四、一些实战技巧
回溯模板变体:
- 排列问题 :需要一个
visited[]
来标记使用过的数字。 - 组合问题:通常用起始索引控制递归范围,避免重复。
- 子集问题:其实是组合问题的泛化,路径可以为空。
- 构造问题:比如括号生成/拼图问题,剪枝是核心。
五、回溯题目小结口诀
- 排列(46):不重复地选满所有元素。
- 组合(77):选固定个数的组合,按顺序递增。
- 子集(78):所有组合都要,允许为空。
- 括号(22):递归模拟规则,左右平衡。
- 电话字母(17):多叉树构造路径。
- 火柴拼正方形(473):递归分组+剪枝优化。
- 数独(36):约束强,先判断再递归。
- 单词搜索(79):网格中 DFS 搜索路径。【深度遍历】
六、算法题代码
括号生成
java
import java.util.ArrayList;
//leetcode submit region begin(Prohibit modification and deletion)
class Solution {
public List<String> generateParenthesis(int n) {
List<String> res = new ArrayList<>();
backtrack(res,new StringBuilder(),0,0,n);
return res;
}
void backtrack(List<String> res,StringBuilder cur,int open,int close,int max){
if (cur.length() == max * 2) {
res.add(cur.toString());
return;
}
if (open < max) {
cur.append("(");
backtrack(res,cur,open + 1,close,max);
cur.deleteCharAt(cur.length()-1);
}
if (close < open) {
cur.append(")");
backtrack(res,cur,open,close+1,max);
cur.deleteCharAt(cur.length()-1);
}
}
}
//leetcode submit region end(Prohibit modification and deletion)
有效的数独
java
import java.util.HashSet;
//leetcode submit region begin(Prohibit modification and deletion)
class Solution {
public boolean isValidSudoku(char[][] board) {
// 设置三个 set 集合;横向 竖向 以及 3*3 .一次遍历
Set<String> rows = new HashSet<>();
Set<String> clos = new HashSet<>();
Set<String> boxes = new HashSet<>();
for (int i = 0; i < 9; i++) {
for (int j = 0; j < 9; j++) {
char num = board[i][j];
if (num == '.') {
continue;
}
// 判断行
if(!rows.add(i + "-" + num)) return false;
// 判断列
if(!clos.add(j + "-" + num)) return false;
// 判断格子
int boxIndex = (i/3)*3 + j / 3;
if (!boxes.add(boxIndex + "-" + num)) return false;
}
}
return true;
}
}
//leetcode submit region end(Prohibit modification and deletion)
单词搜索
java
//leetcode submit region begin(Prohibit modification and deletion)
class Solution {
public boolean exist(char[][] board, String word) {
// 长宽
int h = board.length;
int w = board[0].length;
// 标记 数组
boolean[][] visited = new boolean[h][w];
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
if (check(board, word, 0, i, j, visited)) {
return true;
}
}
}
return false;
}
// 深度遍历
boolean check(char[][] board,String word,int index,int i,int j,boolean[][] visited) {
if (i < 0 || i >= board.length || j < 0 || j >= board[0].length || board[i][j] != word.charAt(index) || visited[i][j]) {
return false;
}
if (index == word.length()-1) {
return true;
}
// 标记
visited[i][j] = true;
int[][] directions = {{0,1},{0,-1},{1,0},{-1,0}};
for (int[] dir : directions) {
int newi = i + dir[0];
int newj = j + dir[1];
if (check(board,word,index +1,newi,newj,visited)) {
return true;
}
}
// 回溯,取消标记
visited[i][j] = false;
return false;
}
}
//leetcode submit region end(Prohibit modification and deletion)
全排列
java
import java.util.ArrayList;
import java.util.List;
//leetcode submit region begin(Prohibit modification and deletion)
class Solution {
public List<List<Integer>> permute(int[] nums) {
List<List<Integer>> res = new ArrayList<>();
boolean[] used = new boolean[nums.length];
backtrack(nums,new ArrayList<>(),res,used);
return res;
}
void backtrack(int[] nums,List<Integer> temp,List<List<Integer>> res,boolean[] used) {
if (nums.length == temp.size()) {
res.add(new ArrayList<>(temp));
return;
}
for (int i = 0; i < nums.length; i++) {
if (used[i]) {
continue;
}
temp.add(nums[i]);
used[i] = true;
backtrack(nums, temp, res, used);
temp.remove(temp.size()-1);
used[i] = false;
}
}
}
//leetcode submit region end(Prohibit modification and deletion)
全排列②
java
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
//leetcode submit region begin(Prohibit modification and deletion)
class Solution {
public List<List<Integer>> permuteUnique(int[] nums) {
List<List<Integer>> res = new ArrayList<>();
Arrays.sort(nums);
boolean[] used = new boolean[nums.length];
backtrack(nums,new ArrayList<>(),res,used);
return res;
}
void backtrack(int[] nums,List<Integer> temp,List<List<Integer>> res,boolean[] used) {
if (nums.length == temp.size()) {
res.add(new ArrayList<>(temp));
return;
}
for (int i = 0; i < nums.length; i++) {
if (used[i]) {
continue;
}
if (i > 0 && nums[i] == nums[i - 1] && !used[i - 1]) continue; // 去重
temp.add(nums[i]);
used[i] = true;
backtrack(nums, temp, res, used);
temp.remove(temp.size()-1);
used[i] = false;
}
}
}
//leetcode submit region end(Prohibit modification and deletion)
组合
java
import java.util.ArrayList;
import java.util.List;
//leetcode submit region begin(Prohibit modification and deletion)
class Solution {
public List<List<Integer>> combine(int n, int k) {
List<List<Integer>> res = new ArrayList<>();
backtrack(n,k,1,new ArrayList<>(),res);
return res;
}
void backtrack(int n, int k, int start, List<Integer> temp,List<List<Integer>> res) {
if (temp.size() == k) {
res.add(new ArrayList<>(temp));
return;
}
for (int i = start; i <= n ; i++) {
temp.add(i);
backtrack(n,k,i+1,temp,res);
temp.remove(temp.size()-1);
}
}
}
//leetcode submit region end(Prohibit modification and deletion)
子集
java
import java.util.ArrayList;
//leetcode submit region begin(Prohibit modification and deletion)
class Solution {
public List<List<Integer>> subsets(int[] nums) {
List<List<Integer>> res = new ArrayList<>();
backtrack(nums,res,0, new ArrayList<>());
return res;
}
void backtrack(int[] nums,List<List<Integer>> res,int start, List<Integer> temp){
res.add(new ArrayList<>(temp));
for (int i = start; i < nums.length ; i++) {
temp.add(nums[i]);
backtrack(nums,res,i+1,temp);
temp.remove(temp.size() -1);
}
}
}
//leetcode submit region end(Prohibit modification and deletion)
电话号码的字母组合
java
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
//leetcode submit region begin(Prohibit modification and deletion)
class Solution {
public List<String> letterCombinations(String digits) {
// 枚举出 2-9
HashMap<Integer, List<String>> map = new HashMap<>();
map.put(2, Arrays.asList("a", "b", "c"));
map.put(3, Arrays.asList("d", "e", "f"));
map.put(4, Arrays.asList("g", "h", "i"));
map.put(5, Arrays.asList("j", "k", "l"));
map.put(6, Arrays.asList("m", "n", "o"));
map.put(7, Arrays.asList("p", "q", "r", "s"));
map.put(8, Arrays.asList("t", "u", "v"));
map.put(9, Arrays.asList("w", "x", "y", "z"));
if (digits.length() == 0 || digits == null) {
return new ArrayList<>();
}
// 排列组合
List<String> res = new ArrayList<>();
backtrack(res, new StringBuilder(), digits, 0, map);
return res;
}
void backtrack(List<String> res, StringBuilder temp, String digits, int index , HashMap<Integer, List<String>> map){
if (digits.length() == index) {
res.add(temp.toString());
return;
}
int digit = digits.charAt(index) - '0';
List<String> letters = map.get(digit);
for (String letter:letters){
temp.append(letter);
backtrack(res,temp,digits,index + 1,map);
temp.deleteCharAt(temp.length() -1);
}
}
}
//leetcode submit region end(Prohibit modification and deletion)
火柴拼正方形
java
import java.lang.reflect.Array;
import java.util.Arrays;
//leetcode submit region begin(Prohibit modification and deletion)
class Solution {
public boolean makesquare(int[] matchsticks) {
// 求和 sum
int sum = Arrays.stream(matchsticks).sum();
if(sum % 4 != 0) {
return false;
}
// 排序
Arrays.sort(matchsticks);
// 降序
reverse(matchsticks);
int side = sum/4;
int[] sides = new int [4];
return backtrack(matchsticks,sides,side,0);
}
boolean backtrack(int[] matchsticks,int[] sides,int target,int index) {
if (index == matchsticks.length) {
return sides[0] == target && sides[1] == target &&
sides[2] == target &&sides[3] == target;
}
for (int i = 0; i < sides.length; i++) {
// 终止条件
if (sides[i] + matchsticks[index] > target) continue;
// 执行流程
sides[i] = sides[i] + matchsticks[index];
// 递归
if (backtrack(matchsticks,sides,target,index+1)) return true;
// 回溯
sides[i] = sides[i] - matchsticks[index];
// 剪枝
if(sides[i] == 0) {
break;
}
}
return false;
}
void reverse (int[] arr) {
int left = 0;
int right = arr.length -1;
while (left < right) {
int tmp = arr[left];
arr[left] = arr[right];
arr[right] = tmp;
left++;
right--;
}
}
}
//leetcode submit region end(Prohibit modification and deletion)