使用 DFS 轻松求解数独难题(C++ 的一个简单实现)

起因

都说懒惰是第一生产力,最近在玩数独游戏的时候,总会遇到拆解数独比较复杂的情况,就想着自己写个代码解题,解放双手。所以很快就写了一个简单的代码求解经典数独。拿来跑了几个最高难度的数独发现确实很爽!虽说是比较暴力的 DFS,但是由于数独中约束较多的性质,实际上要找出唯一解并不复杂,即使是最高难度的数独也可以在 0.04s 内解完,可以说是非常的方便。

思路

经典数独游戏由 9*9 的方格组成,每个方格可填 1~9 的数字,一般都有三种约束:同行,同列,同宫不可出现相同的数字。只要暴力时利用这些约束,就可以快速剪枝。

考虑最简单的情况:我们对于任何一个空位,可以尝试去填 1~9 的数字,并且检查三种约束是否满足。若满足,就继续填下一个空位。

处理约束

实际上,并不需要每个格子都去把 1~9 全部尝试。因为填的数字越多,约束就越强,我们就越容易发现之前填数时的错误。所以我们可以预先处理三种约束影响的格子范围:

cpp 复制代码
void initializeRelation() {
	memset(digitsUsed, 0, sizeof digitsUsed);
	// sub-grids
	for (int i = 0; i < 3; i++) {
		for (int j = 0; j < 3; j++) {
			int num = i * 3 + j;
			for (int k = 0; k < 3; k++) {
				for (int l = 0; l < 3; l++) {
					int idx = calcIdx(i * 3 + k, j * 3 + l);
					group[2][idx] = num;
					r[num].push_back(idx);
				}
			}
		}
	}
	// rows
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			int idx = calcIdx(i, j);
			group[0][idx] = i + N;
			r[i + N].push_back(idx);
		}
	}
	// columns
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			int idx = calcIdx(j, i);
			group[1][idx] = i + N * 2;
			r[i + N * 2].push_back(idx);
		}
	}
}

预先处理完约束后,下次要找一个格子到底应该对应哪些约束时,就可以直接找到对应的 idx 序号了。

状态压缩

一个格子可以填 1~9 共九种数字,那么到底哪些是可以填的呢?就如同我们实际解数独时一样,我们可以在格子上标记一下有哪些数字是符合约束的。一个简单的方法是把这个状态压缩成二进制数,每个可用数字代表一个二进制位的 1,若不可用,则该位为 0。那么一个格子上的可用数字可用一个 9 位二进制数表示,范围是0~2^9,也即一个格子至多只有 512 种状态。

接下来 gcc 有一些方便的内建函数可以帮到我们,它们都是以 __builtin 开头:

  • __builtin_popcount(unsigned int x) 返回无符号整型x的二进制中 1 的个数
  • __builtin_ctz(unsigned int x)返回无符号整型x的二进制的末尾有多少个 0

上述函数也可使用 std::bitset<N>::count 等实现,作用类似。

现在计算某个格子还有多少可用数字就可以这样:

cpp 复制代码
inline int calcUsable(int idx) {
	return 9 - __builtin_popcount(digitsUsed[idx]);
}

DFS

当我们枚举数字时,其实就是从当前状态中找到下一个可用数字,并根据约束关系删除与其相关的格子中的可用数字。

那么搜索时如何快速判断当前填的数字否可行呢?一个简单的思路是每次找到可用数字最少的格子,这样的格子可以确定更多的约束,搜索空间也更少,一旦失败了,我们可以迅速回滚。

那么把所有的空格子按照他们的[可用数字个数,可用数字状态]作为一个数对,我们就可以利用std::set构造出一个暴力 DFS 方案:

cpp 复制代码
bool dfs() {
	if (grid.empty()) {
		return true;
	}
	pair<int, int> p = *grid.begin();
	grid.erase(p);
	int idx = p.second;
	int digitBit = MASK & ~digitsUsed[idx];

	for (int nextDigitBit = digitBit; nextDigitBit; nextDigitBit ^= lowbit(nextDigitBit)) {
		int digit = lowbit0Count(nextDigitBit);
		int currentDigitBit = 1 << digit;
		g[idx] = digit + 1;
		vector<int> last;
		for (int j = 0; j < 3; j++) {
			for (auto & x: r[group[j][idx]]) {
				auto it = grid.find(make_pair(calcUsable(x), x));
				if (it != grid.end() && (digitsUsed[x] | currentDigitBit) != digitsUsed[x]) {
					grid.erase(it);
					digitsUsed[x] = digitsUsed[x] | currentDigitBit;
					grid.insert(make_pair(calcUsable(x), x));
					last.push_back(x);
				}
			}
		}
		if (dfs()) {
			return true;
		}
		for (auto &x: last) {
			grid.erase(make_pair(calcUsable(x), x));
			digitsUsed[x] = digitsUsed[x] & ~currentDigitBit;
			grid.insert(make_pair(calcUsable(x), x));
		}
	}
	grid.insert(p);
	return false;
}

结语

由于只考虑经典数独,代码还是非常简洁而且高效的。而对于各种各样的变形数独,也可以考虑根据这种简化约束的方式去暴力求解。如果想要模仿人类解法,对强弱链等逻辑进行推演而非简单暴力的话,还需要更多的工作。

当然,数独如果由机器暴力计算就会缺失很多乐趣,但去寻找现有问题的一种代码实现也同样是另一种乐趣。我觉得能在数学游戏中找到自己喜欢的部分,并发掘出其中的趣味,其本身也是一种快乐的事情。

附录

最终代码如下,输入重定向于sudoku.in,输入格式中星号*代表空位,可在代码最后注释中看到样例。

输出格式为先输出整体的解,再输出只包含原数独中空位的解。

cpp 复制代码
#include <bits/stdc++.h>

using namespace std;

const int N = 9;
const int R_NUM = 27;
const int GRID_NUM = 81;
const int MASK = (1 << N) - 1;

char str[10][100];
int s[9][9];
int g[GRID_NUM];
int group[3][GRID_NUM]; // groups
vector<int> r[R_NUM]; // relations
set<pair<int, int>> grid;
int digitsUsed[GRID_NUM];

/**
group 0:
000000000
111111111
222222222
333333333
444444444
555555555
666666666
777777777
888888888

group 1:
012345678
012345678
012345678
012345678
012345678
012345678
012345678
012345678
012345678

group 2:
000111222
000111222
000111222
333444555
333444555
333444555
666777888
666777888
666777888
**/

inline int calcX(int idx) {
	return group[0][idx];
}

inline int calcIdx(int x, int y) {
	return x * N + y;
}

inline int lowbit(int x) {
	return x & (-x);
}

inline int lowbit0Count(int x) {
	return __builtin_ctz(x);
}

inline int calcUsable(int idx) {
	return 9 - __builtin_popcount(digitsUsed[idx]);
}

void initializeRelation() {
	memset(digitsUsed, 0, sizeof digitsUsed);
	// sub-grids
	for (int i = 0; i < 3; i++) {
		for (int j = 0; j < 3; j++) {
			int num = i * 3 + j;
			for (int k = 0; k < 3; k++) {
				for (int l = 0; l < 3; l++) {
					int idx = calcIdx(i * 3 + k, j * 3 + l);
					group[2][idx] = num;
					r[num].push_back(idx);
				}
			}
		}
	}
	// rows
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			int idx = calcIdx(i, j);
			group[0][idx] = i + N;
			r[i + N].push_back(idx);
		}
	}
	// columns
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			int idx = calcIdx(j, i);
			group[1][idx] = i + N * 2;
			r[i + N * 2].push_back(idx);
		}
	}
}

void fail() {
	printf("IMPOSSIBLE\n");
	exit(0);
}

void printResult() {
	printf("Result:\n");
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			printf("%d", g[calcIdx(i, j)]);
		}
		printf("\n");
	}
}

void printFillableResult() {
	printf("\nFillable Result:\n");
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			printf("%c", (s[i][j] == 0) ? g[calcIdx(i, j)] + '0' : '*');
		}
		printf("\n");
	}
}

bool dfs() {
	if (grid.empty()) {
		return true;
	}
	pair<int, int> p = *grid.begin();
	grid.erase(p);
	int idx = p.second;
	int digitBit = MASK & ~digitsUsed[idx];

	for (int nextDigitBit = digitBit; nextDigitBit; nextDigitBit ^= lowbit(nextDigitBit)) {
		int digit = lowbit0Count(nextDigitBit);
		int currentDigitBit = 1 << digit;
		g[idx] = digit + 1;
		vector<int> last;
		for (int j = 0; j < 3; j++) {
			for (auto & x: r[group[j][idx]]) {
				auto it = grid.find(make_pair(calcUsable(x), x));
				if (it != grid.end() && (digitsUsed[x] | currentDigitBit) != digitsUsed[x]) {
					grid.erase(it);
					digitsUsed[x] = digitsUsed[x] | currentDigitBit;
					grid.insert(make_pair(calcUsable(x), x));
					last.push_back(x);
				}
			}
		}
		if (dfs()) {
			return true;
		}
		for (auto &x: last) {
			grid.erase(make_pair(calcUsable(x), x));
			digitsUsed[x] = digitsUsed[x] & ~currentDigitBit;
			grid.insert(make_pair(calcUsable(x), x));
		}
	}
	grid.insert(p);
	return false;
}


int main() {
	freopen("sudoku.in", "r", stdin);
	initializeRelation();

	// Enter a sudoku puzzle: (9 lines with 9 characters on each line, use * for blank)
	for (int i = 0; i < N; i++) {
		scanf("%s", str[i]);
	}

	for (int i = 0; i < N; i++) {
		if (strlen(str[i]) != N) {
			exit(0);
		}
		for (int j = 0; j < N; j++) {
			int idx = calcIdx(i, j);
			if (str[i][j] == '*') {
				g[idx] = s[i][j] = 0;
				digitsUsed[idx] = 0;
			} else if (str[i][j] >= '1' && str[i][j] <= '9') {
				g[idx] = s[i][j] = str[i][j] - '0';
			} else {
				exit(0);
			}
		}
	}

	for (int idx = 0; idx < GRID_NUM; idx++) {
		if (g[idx] == 0) {
			for (int j = 0; j < 3; j++) {
				for (auto & cur: r[group[j][idx]]) {
					if (g[cur] != 0) {
						digitsUsed[idx] |= 1 << (g[cur] - 1);
					}
				}
			}
			// pair is (digitsLeftCount, idx)
			grid.insert(make_pair(calcUsable(idx), idx));
		}
	}

	if (dfs()) {
		printResult();
		printFillableResult();
	} else {
		printResult();
		fail();
	}

	return 0;
}


/*
<Sample Input>

*23456789
456789123
789123456
312645978
645978312
978312645
231564897
564897231
897231564

**95*8*7*
23769***4
5**32**1*
8*1935**7
49*8*2*51
**3**6*2*
*1*2*4**6
6*8*****2
*7*1*38**


*****2***
2*4****7*
****5**49
**6**85**
******83*
57**4****
*3*7****6
*65*3**9*
7***9*1**

*/
相关推荐
大二转专业43 分钟前
408算法题leetcode--第24天
考研·算法·leetcode
zaim144 分钟前
计算机的错误计算(一百一十四)
java·c++·python·rust·go·c·多项式
学习使我变快乐1 小时前
C++:const成员
开发语言·c++
凭栏落花侧1 小时前
决策树:简单易懂的预测模型
人工智能·算法·决策树·机器学习·信息可视化·数据挖掘·数据分析
Starry_hello world2 小时前
二叉树实现
数据结构·笔记·有问必答
hong_zc2 小时前
算法【Java】—— 二叉树的深搜
java·算法
吱吱鼠叔3 小时前
MATLAB计算与建模常见函数:5.曲线拟合
算法·机器学习·matlab
嵌入式AI的盲4 小时前
数组指针和指针数组
数据结构·算法
一律清风4 小时前
QT-文件创建时间修改器
c++·qt
风清扬_jd4 小时前
Chromium 如何定义一个chrome.settingsPrivate接口给前端调用c++
前端·c++·chrome