- 题目
- 题解(15)
- 讨论(4)
- 排行
中等 通过率:36.69% 时间限制:1秒 空间限制:256M
知识点dfs

校招时部分企业笔试将禁止编程题跳出页面,为提前适应,练习时请使用在线自测,而非本地IDE。
描述
一个 N×MN×M 的由非负整数构成的数字矩阵,你需要在其中取出若干个数字,使得取出的任意两个数字不相邻(若一个数字在另外一个数字相邻 88 个格子中的一个即认为这两个数字相邻),求取出数字和最大是多少。
输入描述:
第一行有一个正整数 TT(1≦T≦201≦T≦20),表示了有 TT 组数据。
对于每一组数据,第一行有两个正整数 N,MN,M(1≦N,M≦61≦N,M≦6),表示了数字矩阵为 NN 行 MM 列。
接下来 NN 行,每行 MM 个非负整数,描述了这个数字矩阵,满足 1≦ai,j≦1051≦ai,j≦105。
输出描述:
输出共 TT 行,每行一个非负整数,输出所求得的答案。
示例1
输入:
1
3 3
1 1 1
1 1 1
1 1 1
复制输出:
4
cpp
#include <bits/stdc++.h>
using namespace std;
using int64 = long long;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T;
if (!(cin >> T)) return 0;
while (T--) {
int n, m;
cin >> n >> m;
vector<vector<int64>> a(n, vector<int64>(m));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) cin >> a[i][j];
}
// 枚举合法行掩码
vector<int> masks;
int full = 1 << m;
for (int s = 0; s < full; ++s) {
if (s & (s << 1)) continue; // 行内相邻
masks.push_back(s);
}
int S = (int)masks.size();
// 预处理兼容列表
vector<vector<int>> compat(S);
for (int i = 0; i < S; ++i) {
int p = masks[i];
for (int j = 0; j < S; ++j) {
int s = masks[j];
if ((s & p) == 0) {
int lp = (p << 1) & (full - 1);
int rp = (p >> 1);
if ((s & lp) == 0 && (s & rp) == 0) compat[i].push_back(j);
}
}
}
// 预处理每行在掩码 s 下的行和
vector<vector<int64>> rowSum(n, vector<int64>(S, 0));
for (int r = 0; r < n; ++r) {
for (int j = 0; j < S; ++j) {
int s = masks[j];
int64 sum = 0;
for (int c = 0; c < m; ++c) {
if (s & (1 << c)) sum += a[r][c];
}
rowSum[r][j] = sum;
}
}
const int64 NEG = (int64)-4e18;
vector<int64> dpPrev(S, NEG), dpCurr(S, NEG);
// 第一行
for (int j = 0; j < S; ++j) dpPrev[j] = rowSum[0][j];
// 后续各行
for (int r = 1; r < n; ++r) {
fill(dpCurr.begin(), dpCurr.end(), NEG);
for (int i = 0; i < S; ++i) {
if (dpPrev[i] == NEG) continue;
for (int j : compat[i]) {
dpCurr[j] = max(dpCurr[j], dpPrev[i] + rowSum[r][j]);
}
}
dpPrev.swap(dpCurr);
}
int64 ans = 0;
for (int j = 0; j < S; ++j) ans = max(ans, dpPrev[j]);
cout << ans << "\n";
}
return 0;
}