题目描述
帅帅经常跟同学玩一个矩阵取数游戏:对于一个给定的 n×mn \times mn×m 的矩阵,矩阵中的每个元素 ai,ja_{i,j}ai,j 均为非负整数。游戏规则如下:
- 每次取数时须从每行各取走一个元素,共 nnn 个。经过 mmm 次后取完矩阵内所有元素;
- 每次取走的各个元素只能是该元素所在行的行首或行尾;
- 每次取数都有一个得分值,为每行取数的得分之和,每行取数的得分 = 被取走的元素值 ×2i\times 2^i×2i,其中 iii 表示第 iii 次取数(从 111 开始编号);
- 游戏结束总得分为 mmm 次取数得分之和。
帅帅想请你帮忙写一个程序,对于任意矩阵,可以求出取数后的最大得分。
输入格式
输入文件包括 n+1n+1n+1 行:
第一行为两个用空格隔开的整数 nnn 和 mmm。
第 2∼n+12\sim n+12∼n+1 行为 n×mn \times mn×m 矩阵,其中每行有 mmm 个用单个空格隔开的非负整数。
输出格式
输出文件仅包含 111 行,为一个整数,即输入矩阵取数后的最大得分。
输入输出样例 #1
输入 #1
2 3
1 2 3
3 4 2
输出 #1
82
说明/提示
【数据范围】
对于 60%60\%60% 的数据,满足 1≤n,m≤301\le n,m\le 301≤n,m≤30,答案不超过 101610^{16}1016。
对于 100%100\%100% 的数据,满足 1≤n,m≤801\le n,m\le 801≤n,m≤80,0≤ai,j≤10000\le a_{i,j}\le10000≤ai,j≤1000。
【题目来源】
NOIP 2007 提高第三题。
答案
C++
cpp
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
int main() {
int n, m;
cin >> n >> m;
vector<vector<ll>> matrix(n, vector<ll>(m));
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
cin >> matrix[i][j];
}
}
vector<ll> pow2(m + 1);
pow2[0] = 1;
for (int i = 1; i <= m; i++) {
pow2[i] = pow2[i - 1] * 2;
}
ll total = 0;
for (int i = 0; i < n; i++) {
vector<vector<ll>> dp(m, vector<ll>(m, 0));
for (int len = 0; len < m; len++) {
for (int l = 0; l + len < m; l++) {
int r = l + len;
int k = m - len;
if (l == r) {
dp[l][r] = matrix[i][l] * pow2[k];
} else {
dp[l][r] = max(
matrix[i][l] * pow2[k] + dp[l + 1][r],
matrix[i][r] * pow2[k] + dp[l][r - 1]
);
}
}
}
total += dp[0][m - 1];
}
cout << total << endl;
return 0;
}
C
c
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
typedef long long ll;
ll max(ll a, ll b) {
return a > b ? a : b;
}
int main() {
int n, m;
scanf("%d %d", &n, &m);
ll** matrix = (ll**)malloc(n * sizeof(ll*));
for (int i = 0; i < n; i++) {
matrix[i] = (ll*)malloc(m * sizeof(ll));
for (int j = 0; j < m; j++) {
scanf("%lld", &matrix[i][j]);
}
}
ll* pow2 = (ll*)malloc((m + 1) * sizeof(ll));
pow2[0] = 1;
for (int i = 1; i <= m; i++) {
pow2[i] = pow2[i - 1] * 2;
}
ll total = 0;
ll** dp = (ll**)malloc(m * sizeof(ll*));
for (int i = 0; i < m; i++) {
dp[i] = (ll*)malloc(m * sizeof(ll));
}
for (int i = 0; i < n; i++) {
memset(dp, 0, m * m * sizeof(ll));
for (int len = 0; len < m; len++) {
for (int l = 0; l + len < m; l++) {
int r = l + len;
int k = m - len;
if (l == r) {
dp[l][r] = matrix[i][l] * pow2[k];
} else {
dp[l][r] = max(
matrix[i][l] * pow2[k] + dp[l + 1][r],
matrix[i][r] * pow2[k] + dp[l][r - 1]
);
}
}
}
total += dp[0][m - 1];
}
printf("%lld\n", total);
for (int i = 0; i < n; i++) {
free(matrix[i]);
}
free(matrix);
free(pow2);
for (int i = 0; i < m; i++) {
free(dp[i]);
}
free(dp);
return 0;
}
Python
python
n, m = map(int, input().split())
matrix = []
for _ in range(n):
row = list(map(int, input().split()))
matrix.append(row)
pow2 = [1] * (m + 1)
for i in range(1, m + 1):
pow2[i] = pow2[i-1] * 2
total = 0
for i in range(n):
dp = [[0] * m for _ in range(m)]
for length in range(m):
for l in range(m - length):
r = l + length
k = m - length
if l == r:
dp[l][r] = matrix[i][l] * pow2[k]
else:
dp[l][r] = max(
matrix[i][l] * pow2[k] + dp[l+1][r],
matrix[i][r] * pow2[k] + dp[l][r-1]
)
total += dp[0][m-1]
print(total)
Java
java
import java.util.Scanner;
public class MatrixGame {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt();
int m = scanner.nextInt();
long[][] matrix = new long[n][m];
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
matrix[i][j] = scanner.nextLong();
}
}
long[] pow2 = new long[m + 1];
pow2[0] = 1;
for (int i = 1; i <= m; i++) {
pow2[i] = pow2[i - 1] * 2;
}
long total = 0;
long[][] dp = new long[m][m];
for (int i = 0; i < n; i++) {
for (int len = 0; len < m; len++) {
for (int l = 0; l + len < m; l++) {
int r = l + len;
int k = m - len;
if (l == r) {
dp[l][r] = matrix[i][l] * pow2[k];
} else {
dp[l][r] = Math.max(
matrix[i][l] * pow2[k] + dp[l + 1][r],
matrix[i][r] * pow2[k] + dp[l][r - 1]
);
}
}
}
total += dp[0][m - 1];
for (int a = 0; a < m; a++) {
for (int b = 0; b < m; b++) {
dp[a][b] = 0;
}
}
}
System.out.println(total);
scanner.close();
}
}
Go
go
package main
import (
"fmt"
)
func max(a, b int64) int64 {
if a > b {
return a
}
return b
}
func main() {
var n, m int
fmt.Scan(&n, &m)
matrix := make([][]int64, n)
for i := range matrix {
matrix[i] = make([]int64, m)
for j := range matrix[i] {
fmt.Scan(&matrix[i][j])
}
}
pow2 := make([]int64, m+1)
pow2[0] = 1
for i := 1; i <= m; i++ {
pow2[i] = pow2[i-1] * 2
}
var total int64 = 0
dp := make([][]int64, m)
for i := range dp {
dp[i] = make([]int64, m)
}
for i := 0; i < n; i++ {
for len := 0; len < m; len++ {
for l := 0; l+len < m; l++ {
r := l + len
k := m - len
if l == r {
dp[l][r] = matrix[i][l] * pow2[k]
} else {
dp[l][r] = max(
matrix[i][l]*pow2[k]+dp[l+1][r],
matrix[i][r]*pow2[k]+dp[l][r-1],
)
}
}
}
total += dp[0][m-1]
for a := 0; a < m; a++ {
for b := 0; b < m; b++ {
dp[a][b] = 0
}
}
}
fmt.Println(total)
}