P1005 [NOIP 2007 提高组] 矩阵取数游戏 题解
题意
给定一个 n × m n\times m n×m 的矩阵,可以进行 m m m 次取数,每次只能取走每一行的第一个数或最后一个数。设第 i i i 次取出的数是 b 1 b_1 b1, b 2 b_2 b2, ⋯ \cdots ⋯, b n b_n bn,则第 i i i 次取数的得分是 ∑ j = 1 n b j × 2 i \sum_{j=1}^{n}b_j\times 2^i ∑j=1nbj×2i。问取完 m m m 次后的每次取数得分之和最大是多少?
思路
开始看着题目,可能感觉没有什么头绪。但不难发现,其实行与行之间不会相互影响,所以可以将每一行在输入时单独处理,这样可以把 a a a 数组压成一维。同时,每行取数之后总是剩余一个区间,所以令人很自然的联想到区间 dp。
设 d p i , j dp_{i,j} dpi,j 表示剩下 [ i , j ] [i,j] [i,j] 这个区间时的最大得分。
接下来开始推方程。
[ i , j ] [i,j] [i,j] 区间只能从 [ i − 1 , j ] [i-1,j] [i−1,j] 或 [ i , j + 1 ] [i,j+1] [i,j+1] 区间转移而来。同时,可以得到在转移到 [ i , j ] [i,j] [i,j] 区间时是第 m − j + i − 1 m-j+i-1 m−j+i−1 次。
设在第 k k k 行转移,便能得到转移方程:
d p i , j = max ( d p i − 1 , j + a k , i − 1 × 2 m − j + i − 1 , d p i , j + 1 + a k , j + 1 × 2 m − j + i − 1 ) dp_{i,j}=\max(dp_{i-1,j}+a_{k,i-1}\times 2^{m-j+i-1},dp_{i,j+1}+a_{k,j+1}\times 2^{m-j+i-1}) dpi,j=max(dpi−1,j+ak,i−1×2m−j+i−1,dpi,j+1+ak,j+1×2m−j+i−1)
此时只需要把 2 1 2^1 21 到 2 m 2^m 2m 都用快速幂预处理出来便可将转移的复杂度压缩到 O ( 1 ) O(1) O(1)。(用位运算的话似乎会 wa 后四个点)
由 dp 数组的定义可知长度为一的区间不会被计算,所以最终答案为:
∑ i = 1 n max 1 ≤ j ≤ m ( d p j , j + a i , j × 2 m ) \sum_{i=1}^{n}\max\limits_{1\le j\le m}(dp_{j,j}+a_{i,j}\times 2^m) i=1∑n1≤j≤mmax(dpj,j+ai,j×2m)
注意到此题的转移过程是从大区间到小区间的,所以我们的枚举顺序要改一下。即 i i i 从小到大, j j j 从大到小。
时间复杂度: O ( n m 2 ) O(nm^2) O(nm2)。
代码
数据范围会爆 long long
,可以写高精,但我这里用了 __int128
。
代码如下:
cpp
#include<bits/stdc++.h>
#define int __int128
#define inf 1ll << 62
#define max(a , b) a > b ? a : b
#define min(a , b) a < b ? a : b
using namespace std;
const int MAXN = 85;
int n , m;
int a[MAXN] , poww[MAXN];
int dp[MAXN][MAXN] , ans;
inline int read() {
char c = getchar();
int x = 0 , s = 1;
while(c < '0' || c > '9') {
if(c == '-')
s = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
return x * s;
}
inline int ksm(int base , int x) {
int result = 1;
while(x) {
if(x & 1)
result *= base;
x >>= 1;
base *= base;
}
return result;
}
void write(int x) {
if(x < 0) {
putchar('-');
x = -x;
}
if(x >= 10)
write(x / 10);
putchar(x % 10 + '0');
return;
}
signed main() {
n = read();
m = read();
for(register int i = 1;i <= m;i ++)
poww[i] = ksm(2 , i);
while(n --) {
for(register int i = 1;i <= m;i ++)
a[i] = read();
memset(dp , 0 , sizeof(dp));
for(register int i = 1;i <= m;i ++)
for(register int j = m;j >= i;j --)
dp[i][j] = max(dp[i - 1][j] + a[i - 1] * poww[m - j + i - 1] , dp[i][j + 1] + a[j + 1] * poww[m - j + i - 1]);
int maxn = -inf;
for(register int i = 1;i <= m;i ++)
maxn = max(maxn , dp[i][i] + a[i] * poww[m]);
ans += maxn;
}
write(ans);
return 0;
}