《算法竞赛·快冲300题 》将于2024年出版,是《算法竞赛》的辅助练习册。
所有题目放在自建的OJ New Online Judge。
用C/C++、Java、Python三种语言给出代码,以中低档题为主,适合入门、进阶。
文章目录
" 质数拼图游戏 " ,链接: http://oj.ecustacm.cn/problem.php?id=1814
题目描述
【题目描述】 给定两个nn的矩阵A和B,记C=A B(此处为矩阵乘法),存在m次询问。
每次询问C中一个子矩阵中所有数字之和。
每次询问给定a,b,c,d四个数字,表示所求子矩阵为第a行第b列到第c行第d列的子矩阵。
【输入格式】 输入第一行为n和m(1≤n≤2000,m≤50000)。
接下来n行,每行n个数字表示矩阵A。
再接下来n行,每行n个数字表示矩阵B。矩阵中每个数字不超过100。
接下来m行,每行4个数字a,b,c,d表示询问的子矩阵,(1≤a,b,c,d≤n)。
本题输入数据量大,建议使用快速读入。
【输出格式】 对于每组询问,输出一行,包含一个数字表示答案。
【输入样例】
c
3 2
1 9 8
3 2 0
1 8 3
9 8 4
0 5 15
1 9 6
1 1 3 3
2 3 1 2
【输出样例】
c
661
388
题解
如果只要求询问一个给定矩阵的子矩阵数字之和,是一个很直白的前缀和应用。
为快速得到一个矩阵的任意子矩阵的和,可以用"二维前缀和"。定义二维数组s\[\]\[\], s i j sij sij表示子矩阵 1 , 1 i , j 1, 1 ~ i, j 1,1 i,j的和。预计算出s\[\]\[\]后,可以快速计算出任意的子矩阵和。如下图所示,阴影子矩阵 i 1 , j 1 i 2 , j 2 i_1, j_1 ~ i_2, j_2 i1,j1 i2,j2的和等于:
s i 2 j 2 − s i 2 j 1 − 1 − s i 1 − 1 j 2 + s i 1 − 1 j 1 − 1 si_2j_2 - si_2j_1-1 - si_1-1j_2 + si_1-1j_1-1 si2j2−si2j1−1−si1−1j2+si1−1j1−1
其中 s i 1 − 1 j 1 − 1 si_1-1 j_1-1 si1−1j1−1被减了2次,需要加回来1次。
用上述公式查询一次子矩阵和,计算量仅为O(1)。

预计算一个矩阵A的所有s\[\]\[\],计算量为 n 2 n^2 n2。代码这样写:
cpp
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
cin >> A[i][j], s[i][j] = s[i-1][j]+s[i][j-1]-s[i-1][j-1]+A[i][j];//预计算s[][]
本题如果用上述方法,需要先求矩阵乘法C=A*B。但是矩阵乘法的计算量是 n 3 n^3 n3,而n≤2000,肯定超时,所以必须避免直接计算矩阵乘法。
下面分析矩阵乘法C=A*B的计算过程,看能不能利用前缀和,从而减少计算量。下图画出了矩阵乘法的细节,求C中子矩阵(a, b) ~ (c, d)的和。

(1)计算C的第b列的区间和,即 C a b + C a + 1 b + . . . + C c b Cab + Ca+1b + ... + Ccb Cab+Ca+1b+...+Ccb。
先看C中第b列标'*'的 C a b Cab Cab的计算过程,它等于A第a行乘以B第b列:
C a b = A a 1 × B 1 b + A a 2 × B 2 b + . . . + A a n × B n b Cab = Aa1×B1b + Aa2×B2b + ... + Aan×Bnb Cab=Aa1×B1b+Aa2×B2b+...+Aan×Bnb
同理,C中第b列的其他坐标的计算过程是:
C a + 1 b = A a + 1 1 × B 1 b + A a + 1 2 × B 2 b + . . . + A a + 1 n × B n b Ca+1b = Aa+11×B1b + Aa+12×B2b + ...+ Aa+1n×Bnb Ca+1b=Aa+11×B1b+Aa+12×B2b+...+Aa+1n×Bnb
...
C c b = A c 1 × B 1 b + A c 2 × B 2 b + . . . + A c n × B n b Ccb = Ac1×B1b + Ac2×B2b + ... +Acn×Bnb Ccb=Ac1×B1b+Ac2×B2b+...+Acn×Bnb (式3-1)
把(式3-1)上下相加得C的子矩阵第b列的区间和:
C a b + C a + 1 b + . . . + C c b Cab + Ca+1b + ... + Ccb Cab+Ca+1b+...+Ccb
= ( A a 1 + A a + 1 1 + . . . + A c 1 ) × B 1 b + = (Aa1+Aa+11 + ... + Ac1)×B1b + =(Aa1+Aa+11+...+Ac1)×B1b+
( A a 2 + A a + 1 2 + . . . + A c 2 ) × B 2 b + (Aa2+Aa+12 + ... + Ac2)×B2b + (Aa2+Aa+12+...+Ac2)×B2b+
...
( A a n + A a + 1 n + . . . + A c n ) × B n b (Aan+Aa+1n + ... + Acn)×Bnb (Aan+Aa+1n+...+Acn)×Bnb (式3-2)
式中的 Aa1+Aa+11+...+Ac1正好是A的第1列的区间和,Aa2+Aa+12+...+Ac2是第2列的区间和,...,等等。
记s1\[\]j为A的第j列的前缀和,有:
A a 1 + A a + 1 1 + . . . + A c 1 = s 1 c 1 − s 1 a − 1 1 Aa1+Aa+11+...+Ac1 = s1c1 - s1a-11 Aa1+Aa+11+...+Ac1=s1c1−s1a−11
A a 2 + A a + 1 2 + . . . + A c 2 = s 1 c 2 − s 1 a − 1 2 Aa2+Aa+12+...+Ac2 = s1c2 - s1a-12 Aa2+Aa+12+...+Ac2=s1c2−s1a−12
...
A a n + A a + 1 n + . . . + A c n = s 1 c n − s 1 a − 1 n Aan+Aa+1n + ... + Acn=s1cn - s1a-1n Aan+Aa+1n+...+Acn=s1cn−s1a−1n
则C的子矩阵第b列的区间和(式3-2)简化为:
C a c b Ca\~cb Ca cb
= ( s 1 c 1 − s 1 a − 1 1 ) × B 1 b + ( s 1 c 2 − s 1 a − 1 2 ) × B 2 b + . . . + s 1 c n − s 1 a − 1 n × B n b = (s1c1 - s1a-11)×B1b + (s1c2 - s1a-12)×B2b + ...+s1cn - s1a-1n×Bnb =(s1c1−s1a−11)×B1b+(s1c2−s1a−12)×B2b+...+s1cn−s1a−1n×Bnb
(2)计算C的子矩阵的和,即把C的第b列、b+1列、...、d列相加。根据(1)的讨论,有:
C a c b + C a c b + 1 + . . . + C a c d Ca\~cb + Ca\~cb+1 + ... + Ca\~cd Ca cb+Ca cb+1+...+Ca cd
= ( s 1 c 1 − s 1 a − 1 1 ) × B 1 b + ( s 1 c 2 − s 1 a − 1 2 ) × B 2 b + . . . = (s1c1 - s1a-11)×B1b + (s1c2 - s1a-12)×B2b + ... =(s1c1−s1a−11)×B1b+(s1c2−s1a−12)×B2b+...
( s 1 c 1 − s 1 a − 1 1 ) × B 1 b + 1 + ( s 1 c 2 − s 1 a − 1 2 ) × B 2 b + 1 + . . . (s1c1 - s1a-11)×B1b+1 + (s1c2 - s1a-12)×B2b+1 + ... (s1c1−s1a−11)×B1b+1+(s1c2−s1a−12)×B2b+1+...
( s 1 c 1 − s 1 a − 1 1 ) × B 1 b + 2 + ( s 1 c 2 − s 1 a − 1 2 ) × B 2 b + 2 + . . . (s1c1 - s1a-11)×B1b+2 + (s1c2 - s1a-12)×B2b+2 + ... (s1c1−s1a−11)×B1b+2+(s1c2−s1a−12)×B2b+2+...
...
( s 1 c 1 − s 1 a − 1 1 ) × B 1 d + ( s 1 c 2 − s 1 a − 1 2 ) × B 2 d + . . . (s1c1 - s1a-11)×B1d + (s1c2 - s1a-12)×B2d + ... (s1c1−s1a−11)×B1d+(s1c2−s1a−12)×B2d+... (式3-3)
把(式3-3)上下相加,得:
C a c b + C a c b + 1 + . . . + C a c d Ca\~cb + Ca\~cb+1 + ... + Ca\~cd Ca cb+Ca cb+1+...+Ca cd
= ( s 1 c 1 − s 1 a − 1 1 ) × ( B 1 b + B 1 b + 1 + . . . + B 1 d ) + = (s1c1 - s1a-11) × (B1b+B1b+1 + ...+ B1d) + =(s1c1−s1a−11)×(B1b+B1b+1+...+B1d)+
( s 1 c 2 − s 1 a − 1 2 ) × ( B 2 b + B 2 b + 1 + . . . + B 2 d ) + (s1c2 - s1a-12) × (B2b+B2b+1 + ... + B2d) + (s1c2−s1a−12)×(B2b+B2b+1+...+B2d)+
...
( s 1 c n − s 1 a − 1 n ) × ( B n b + B n b + 1 + . . . + B n d ) (s1cn - s1a-1n) × (Bnb+Bnb+1 + ... + Bnd) (s1cn−s1a−1n)×(Bnb+Bnb+1+...+Bnd) (式3-4)
记s2i\[\]为B的第i行的前缀和,有:
B 1 b + B 1 b + 1 + . . . + B 1 d = s 2 1 d − s 2 1 b − 1 B1b+B1b+1 + ...+ B1d = s21d - s21b-1 B1b+B1b+1+...+B1d=s21d−s21b−1
B 2 b + B 2 b + 1 + . . . + B 2 d = s 2 2 d − s 2 2 b − 1 B2b+B2b+1 + ...+ B2d = s22d - s22b-1 B2b+B2b+1+...+B2d=s22d−s22b−1
...
B n b + B n b + 1 + . . . + B n d = s 2 n d − s 2 n b − 1 Bnb+Bnb+1 + ... + Bnd=s2nd - s2nb-1 Bnb+Bnb+1+...+Bnd=s2nd−s2nb−1
则(式3-4)改写为:
C a c b + C a c b + 1 + . . . + C a c d Ca\~cb + Ca\~cb+1 + ... + Ca\~cd Ca cb+Ca cb+1+...+Ca cd
= ( s 1 c 1 − s 1 a − 1 1 ) × ( s 2 1 d − s 2 1 b − 1 ) + = (s1c1 - s1a-11) × ( s21d - s21b-1) + =(s1c1−s1a−11)×(s21d−s21b−1)+
( s 1 c 2 − s 1 a − 1 2 ) × ( s 2 2 d − s 2 2 b − 1 ) + (s1c2 - s1a-12) × ( s22d - s22b-1) + (s1c2−s1a−12)×(s22d−s22b−1)+
...
( s 1 c n − s 1 a − 1 n ) × ( s 2 n d − s 2 n b − 1 ) (s1cn - s1a-1n) × ( s2nd - s2nb-1) (s1cn−s1a−1n)×(s2nd−s2nb−1)
这是最后的式子,每一行是两个区间和的乘法,共n行,有n次乘法计算。
总计算量是多少?(1)预计算s1\[\]\[\]和s2\[\]\[\],是 O ( n 2 ) O(n^2) O(n2)的;(2)查询m次子矩阵和,每次有n次乘法计算,是 O ( m n ) O(mn) O(mn)的;(3)总计算量等于 O ( n 2 ) + O ( m n ) O(n^2) + O(mn) O(n2)+O(mn),刚好通过测试。
【重点】 前缀和,矩阵计算 。
C++代码
题目中提到"本题输入数据量大,建议使用快速读入"。
C++的标准输入输出函数是cin/cout、scanf/printf,在默认情况下,cin/cout比scanf/printf慢得多。在需要大量输入输出的场合,一般用scanf、printf就可以。如果还要提高速度,输入用getchar(),输出用putchar(),它们更快。
自己写一个快读函数read(),用到getchar()。getchar()的功能是读1 byte的数据,按char类型读入。下面代码中的read()是整数输入的快读模板,用getchar()读入每个字符,然后转成数字。例如输入"245",用getchar()分3次读入'2'、'4'、'5',然后组合成数字"345"。注意可能有负数,所以需要判断'-'号。
自己写一个快写函数write(),用到putchar()。putchar()的功能是输出一个字符,当需要输出一个数时,把它的每一位转成字符,然后用putchar()输出。
cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2010;
int n,m,a,b,c,d;
int A[N][N],B[N][N],s1[N][N],s2[N][N];
inline int read(int &x) { //快读int型整数。如果需要读long long,把int改成long long
x = 0;
int w = 1;//w:判断正负号
char ch = 0;
while (ch < '0' || ch > '9') { //读字符
if (ch == '-') w = -1; //这是一个负整数数
ch = getchar(); //读一个字符
}
while (ch >= '0' && ch <= '9') { //读数字
x = x * 10 + (ch - '0');
ch = getchar();
}
return x = x * w;
}
void write(ll x) { //快写long long型整数
if (x < 0) { // 判断正负。如果是负数,输出负号
putchar('-');
x = -x; //记得把负数变正,方便下面输出数字
}
if (x > 9) write(x / 10); // 递归,将除最后一位外的其他部分放到递归中输出
putchar(x % 10 + '0'); // 已经输出(递归)完 x 末位前的所有数字,输出末位
}
ll query(int a,int b,int c,int d){ //C=A*B,计算C的子矩阵和
ll ans = 0;
for(int k=1;k<=n;k++){
ll ans1 = s1[c][k] - s1[a-1][k];
ll ans2 = s2[k][d] - s2[k][b-1];
ans += ans1*ans2;
}
return ans;
}
int main(){
read(n),read(m);
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
read(A[i][j]), s1[i][j] = s1[i-1][j]+A[i][j]; //输入A。s1[][j]是第j列的前缀和
//read(A[i][j])等于scanf("%d",&A[i][j])或cin>>A[i][j]
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
read(B[i][j]), s2[i][j] = s2[i][j-1]+B[i][j]; //输入B。s2[i][]是第i行的前缀和
while(m--){
read(a),read(b),read(c),read(d); //等于scanf("%d%d%d%d",&a,&b,&c,&d);
if(a > c) swap(a, c); //可能存在a>c、b>d的情况
if(b > d) swap(b, d);
ll ans = query(a,b,c,d);
write(ans); putchar('\n'); //等于printf("%lld\n",query(a,b,c,d));
}
return 0;
}
Java代码
java
import java.util.*;
import java.io.*;
class Main {
static FastReader scanner = new FastReader();
static int N = 2010;
static int n, m, a, b, c, d;
static int[][] A = new int[N][N], B = new int[N][N], s1 = new int[N][N], s2 = new int[N][N];
public static void main(String[] args) throws IOException {
n = scanner.nextInt();
m = scanner.nextInt();
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++) {
A[i][j] = scanner.nextInt();;
s1[i][j] = s1[i - 1][j] + A[i][j];
}
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++) {
B[i][j] = scanner.nextInt();;
s2[i][j] = s2[i][j - 1] + B[i][j];
}
while (m-- > 0) {
a = scanner.nextInt();
b = scanner.nextInt();
c = scanner.nextInt();
d = scanner.nextInt();
if (a > c) { int temp = a; a = c; c = temp; }
if (b > d) { int temp = b; b = d; d = temp; }
long ans = query(a, b, c, d);
System.out.println(ans);
}
}
static long query(int a, int b, int c, int d) {
long ans = 0;
for (int k = 1; k <= n; k++) {
long ans1 = s1[c][k] - s1[a - 1][k];
long ans2 = s2[k][d] - s2[k][b - 1];
ans += ans1 * ans2;
}
return ans;
}
static class FastReader {
BufferedReader br;
StringTokenizer st;
public FastReader() {
br = new BufferedReader(new InputStreamReader(System.in));
}
String next() {
while (st == null || !st.hasMoreElements()) {
try {st = new StringTokenizer(br.readLine());}
catch (IOException e) { e.printStackTrace(); }
}
return st.nextToken();
}
int nextInt() { return Integer.parseInt(next()); }
long nextLong() { return Long.parseLong(next()); }
double nextDouble() { return Double.parseDouble(next()); }
String nextLine() {
String str = "";
try { str = br.readLine(); }
catch (IOException e) { e.printStackTrace(); }
return str;
}
}
}
Python代码
python
#pypy
import sys
input = sys.stdin.readline
def query(a, b, c, d):
ans = 0
for k in range(1, n+1):
ans1 = s1[c][k] - s1[a-1][k]
ans2 = s2[k][d] - s2[k][b-1]
ans += ans1 * ans2
return ans
n, m = list(map(int, input().split()))
N = n+1
s1 = [[0] * N for _ in range(N)]
s2 = [[0] * N for _ in range(N)]
A = [0] * N
B = [0] * N
for i in range(1, n+1):
A[i] = [0] + list(map(int, input().split()))
for j in range(1, n+1): s1[i][j] = s1[i-1][j] + A[i][j]
for i in range(1, n+1):
B[i] = [0] + list(map(int, input().split()))
for j in range(1, n+1): s2[i][j] = s2[i][j-1] + B[i][j]
for _ in range(m):
a, b, c, d = list(map(int, input().split()))
if a > c: a, c = c, a
if b > d: b, d = d, b
sys.stdout.write(str(query(a, b, c, d)) + '\n')