使用 DFS 轻松求解数独难题(C++ 的一个简单实现)

起因

都说懒惰是第一生产力,最近在玩数独游戏的时候,总会遇到拆解数独比较复杂的情况,就想着自己写个代码解题,解放双手。所以很快就写了一个简单的代码求解经典数独。拿来跑了几个最高难度的数独发现确实很爽!虽说是比较暴力的 DFS,但是由于数独中约束较多的性质,实际上要找出唯一解并不复杂,即使是最高难度的数独也可以在 0.04s 内解完,可以说是非常的方便。

思路

经典数独游戏由 9*9 的方格组成,每个方格可填 1~9 的数字,一般都有三种约束:同行,同列,同宫不可出现相同的数字。只要暴力时利用这些约束,就可以快速剪枝。

考虑最简单的情况:我们对于任何一个空位,可以尝试去填 1~9 的数字,并且检查三种约束是否满足。若满足,就继续填下一个空位。

处理约束

实际上,并不需要每个格子都去把 1~9 全部尝试。因为填的数字越多,约束就越强,我们就越容易发现之前填数时的错误。所以我们可以预先处理三种约束影响的格子范围:

cpp 复制代码
void initializeRelation() {
	memset(digitsUsed, 0, sizeof digitsUsed);
	// sub-grids
	for (int i = 0; i < 3; i++) {
		for (int j = 0; j < 3; j++) {
			int num = i * 3 + j;
			for (int k = 0; k < 3; k++) {
				for (int l = 0; l < 3; l++) {
					int idx = calcIdx(i * 3 + k, j * 3 + l);
					group[2][idx] = num;
					r[num].push_back(idx);
				}
			}
		}
	}
	// rows
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			int idx = calcIdx(i, j);
			group[0][idx] = i + N;
			r[i + N].push_back(idx);
		}
	}
	// columns
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			int idx = calcIdx(j, i);
			group[1][idx] = i + N * 2;
			r[i + N * 2].push_back(idx);
		}
	}
}

预先处理完约束后,下次要找一个格子到底应该对应哪些约束时,就可以直接找到对应的 idx 序号了。

状态压缩

一个格子可以填 1~9 共九种数字,那么到底哪些是可以填的呢?就如同我们实际解数独时一样,我们可以在格子上标记一下有哪些数字是符合约束的。一个简单的方法是把这个状态压缩成二进制数,每个可用数字代表一个二进制位的 1,若不可用,则该位为 0。那么一个格子上的可用数字可用一个 9 位二进制数表示,范围是0~2^9,也即一个格子至多只有 512 种状态。

接下来 gcc 有一些方便的内建函数可以帮到我们,它们都是以 __builtin 开头:

  • __builtin_popcount(unsigned int x) 返回无符号整型x的二进制中 1 的个数
  • __builtin_ctz(unsigned int x)返回无符号整型x的二进制的末尾有多少个 0

上述函数也可使用 std::bitset<N>::count 等实现,作用类似。

现在计算某个格子还有多少可用数字就可以这样:

cpp 复制代码
inline int calcUsable(int idx) {
	return 9 - __builtin_popcount(digitsUsed[idx]);
}

DFS

当我们枚举数字时,其实就是从当前状态中找到下一个可用数字,并根据约束关系删除与其相关的格子中的可用数字。

那么搜索时如何快速判断当前填的数字否可行呢?一个简单的思路是每次找到可用数字最少的格子,这样的格子可以确定更多的约束,搜索空间也更少,一旦失败了,我们可以迅速回滚。

那么把所有的空格子按照他们的[可用数字个数,可用数字状态]作为一个数对,我们就可以利用std::set构造出一个暴力 DFS 方案:

cpp 复制代码
bool dfs() {
	if (grid.empty()) {
		return true;
	}
	pair<int, int> p = *grid.begin();
	grid.erase(p);
	int idx = p.second;
	int digitBit = MASK & ~digitsUsed[idx];

	for (int nextDigitBit = digitBit; nextDigitBit; nextDigitBit ^= lowbit(nextDigitBit)) {
		int digit = lowbit0Count(nextDigitBit);
		int currentDigitBit = 1 << digit;
		g[idx] = digit + 1;
		vector<int> last;
		for (int j = 0; j < 3; j++) {
			for (auto & x: r[group[j][idx]]) {
				auto it = grid.find(make_pair(calcUsable(x), x));
				if (it != grid.end() && (digitsUsed[x] | currentDigitBit) != digitsUsed[x]) {
					grid.erase(it);
					digitsUsed[x] = digitsUsed[x] | currentDigitBit;
					grid.insert(make_pair(calcUsable(x), x));
					last.push_back(x);
				}
			}
		}
		if (dfs()) {
			return true;
		}
		for (auto &x: last) {
			grid.erase(make_pair(calcUsable(x), x));
			digitsUsed[x] = digitsUsed[x] & ~currentDigitBit;
			grid.insert(make_pair(calcUsable(x), x));
		}
	}
	grid.insert(p);
	return false;
}

结语

由于只考虑经典数独,代码还是非常简洁而且高效的。而对于各种各样的变形数独,也可以考虑根据这种简化约束的方式去暴力求解。如果想要模仿人类解法,对强弱链等逻辑进行推演而非简单暴力的话,还需要更多的工作。

当然,数独如果由机器暴力计算就会缺失很多乐趣,但去寻找现有问题的一种代码实现也同样是另一种乐趣。我觉得能在数学游戏中找到自己喜欢的部分,并发掘出其中的趣味,其本身也是一种快乐的事情。

附录

最终代码如下,输入重定向于sudoku.in,输入格式中星号*代表空位,可在代码最后注释中看到样例。

输出格式为先输出整体的解,再输出只包含原数独中空位的解。

cpp 复制代码
#include <bits/stdc++.h>

using namespace std;

const int N = 9;
const int R_NUM = 27;
const int GRID_NUM = 81;
const int MASK = (1 << N) - 1;

char str[10][100];
int s[9][9];
int g[GRID_NUM];
int group[3][GRID_NUM]; // groups
vector<int> r[R_NUM]; // relations
set<pair<int, int>> grid;
int digitsUsed[GRID_NUM];

/**
group 0:
000000000
111111111
222222222
333333333
444444444
555555555
666666666
777777777
888888888

group 1:
012345678
012345678
012345678
012345678
012345678
012345678
012345678
012345678
012345678

group 2:
000111222
000111222
000111222
333444555
333444555
333444555
666777888
666777888
666777888
**/

inline int calcX(int idx) {
	return group[0][idx];
}

inline int calcIdx(int x, int y) {
	return x * N + y;
}

inline int lowbit(int x) {
	return x & (-x);
}

inline int lowbit0Count(int x) {
	return __builtin_ctz(x);
}

inline int calcUsable(int idx) {
	return 9 - __builtin_popcount(digitsUsed[idx]);
}

void initializeRelation() {
	memset(digitsUsed, 0, sizeof digitsUsed);
	// sub-grids
	for (int i = 0; i < 3; i++) {
		for (int j = 0; j < 3; j++) {
			int num = i * 3 + j;
			for (int k = 0; k < 3; k++) {
				for (int l = 0; l < 3; l++) {
					int idx = calcIdx(i * 3 + k, j * 3 + l);
					group[2][idx] = num;
					r[num].push_back(idx);
				}
			}
		}
	}
	// rows
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			int idx = calcIdx(i, j);
			group[0][idx] = i + N;
			r[i + N].push_back(idx);
		}
	}
	// columns
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			int idx = calcIdx(j, i);
			group[1][idx] = i + N * 2;
			r[i + N * 2].push_back(idx);
		}
	}
}

void fail() {
	printf("IMPOSSIBLE\n");
	exit(0);
}

void printResult() {
	printf("Result:\n");
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			printf("%d", g[calcIdx(i, j)]);
		}
		printf("\n");
	}
}

void printFillableResult() {
	printf("\nFillable Result:\n");
	for (int i = 0; i < N; i++) {
		for (int j = 0; j < N; j++) {
			printf("%c", (s[i][j] == 0) ? g[calcIdx(i, j)] + '0' : '*');
		}
		printf("\n");
	}
}

bool dfs() {
	if (grid.empty()) {
		return true;
	}
	pair<int, int> p = *grid.begin();
	grid.erase(p);
	int idx = p.second;
	int digitBit = MASK & ~digitsUsed[idx];

	for (int nextDigitBit = digitBit; nextDigitBit; nextDigitBit ^= lowbit(nextDigitBit)) {
		int digit = lowbit0Count(nextDigitBit);
		int currentDigitBit = 1 << digit;
		g[idx] = digit + 1;
		vector<int> last;
		for (int j = 0; j < 3; j++) {
			for (auto & x: r[group[j][idx]]) {
				auto it = grid.find(make_pair(calcUsable(x), x));
				if (it != grid.end() && (digitsUsed[x] | currentDigitBit) != digitsUsed[x]) {
					grid.erase(it);
					digitsUsed[x] = digitsUsed[x] | currentDigitBit;
					grid.insert(make_pair(calcUsable(x), x));
					last.push_back(x);
				}
			}
		}
		if (dfs()) {
			return true;
		}
		for (auto &x: last) {
			grid.erase(make_pair(calcUsable(x), x));
			digitsUsed[x] = digitsUsed[x] & ~currentDigitBit;
			grid.insert(make_pair(calcUsable(x), x));
		}
	}
	grid.insert(p);
	return false;
}


int main() {
	freopen("sudoku.in", "r", stdin);
	initializeRelation();

	// Enter a sudoku puzzle: (9 lines with 9 characters on each line, use * for blank)
	for (int i = 0; i < N; i++) {
		scanf("%s", str[i]);
	}

	for (int i = 0; i < N; i++) {
		if (strlen(str[i]) != N) {
			exit(0);
		}
		for (int j = 0; j < N; j++) {
			int idx = calcIdx(i, j);
			if (str[i][j] == '*') {
				g[idx] = s[i][j] = 0;
				digitsUsed[idx] = 0;
			} else if (str[i][j] >= '1' && str[i][j] <= '9') {
				g[idx] = s[i][j] = str[i][j] - '0';
			} else {
				exit(0);
			}
		}
	}

	for (int idx = 0; idx < GRID_NUM; idx++) {
		if (g[idx] == 0) {
			for (int j = 0; j < 3; j++) {
				for (auto & cur: r[group[j][idx]]) {
					if (g[cur] != 0) {
						digitsUsed[idx] |= 1 << (g[cur] - 1);
					}
				}
			}
			// pair is (digitsLeftCount, idx)
			grid.insert(make_pair(calcUsable(idx), idx));
		}
	}

	if (dfs()) {
		printResult();
		printFillableResult();
	} else {
		printResult();
		fail();
	}

	return 0;
}


/*
<Sample Input>

*23456789
456789123
789123456
312645978
645978312
978312645
231564897
564897231
897231564

**95*8*7*
23769***4
5**32**1*
8*1935**7
49*8*2*51
**3**6*2*
*1*2*4**6
6*8*****2
*7*1*38**


*****2***
2*4****7*
****5**49
**6**85**
******83*
57**4****
*3*7****6
*65*3**9*
7***9*1**

*/
相关推荐
LNTON羚通35 分钟前
摄像机视频分析软件下载LiteAIServer视频智能分析平台玩手机打电话检测算法技术的实现
算法·目标检测·音视频·监控·视频监控
Red Red1 小时前
网安基础知识|IDS入侵检测系统|IPS入侵防御系统|堡垒机|VPN|EDR|CC防御|云安全-VDC/VPC|安全服务
网络·笔记·学习·安全·web安全
贰十六2 小时前
笔记:Centos Nginx Jdk Mysql OpenOffce KkFile Minio安装部署
笔记·nginx·centos
知兀2 小时前
Java的方法、基本和引用数据类型
java·笔记·黑马程序员
哭泣的眼泪4082 小时前
解析粗糙度仪在工业制造及材料科学和建筑工程领域的重要性
python·算法·django·virtualenv·pygame
Ysjt | 深2 小时前
C++多线程编程入门教程(优质版)
java·开发语言·jvm·c++
ephemerals__2 小时前
【c++丨STL】list模拟实现(附源码)
开发语言·c++·list
Microsoft Word3 小时前
c++基础语法
开发语言·c++·算法
天才在此3 小时前
汽车加油行驶问题-动态规划算法(已在洛谷AC)
算法·动态规划