题目链接
题目描述
给定一个长度为 N N N 的数列, A 1 , A 2 , A 3 , . . , A N A_1,A_2,A_3,..,A_N A1,A2,A3,..,AN,如果其中一段连续的子序列 A i , A i + 1 , . . . , A j A_i,A_{i+1},...,A_j Ai,Ai+1,...,Aj 之和是 K K K 的倍数,我们就称这个区间 [ i , j ] [i,j] [i,j] 是 K K K倍区间。
你能求出数列中总共有多少个 K K K倍区间嘛?
输入格式
第一行包含两个整数 N N N 和 K K K。
以下 N N N 行每行包含一个整数 A i A_i Ai。
输出格式
输出一个整数,代表 K K K 倍区间的数目。
输入输出样例
输入
5 2
1
2
3
4
5
输出
6
数据范围
- 1 ≤ N , K ≤ 1 0 5 1 \leq N,K \leq 10^5 1≤N,K≤105
- 1 ≤ A i ≤ 1 0 5 1 \leq A_i \leq 10^5 1≤Ai≤105
解法:前缀和 + 哈希表
我们用 s s s 表示 数列 a a a 的前 n n n 项和:
- s [ 0 ] = 0 s[0] = 0 s[0]=0
- s [ 1 ] = a [ 1 ] + 0 s[1] = a[1] + 0 s[1]=a[1]+0
- s [ 2 ] = a [ 2 ] + a [ 1 ] + 0 s[2] = a[2] + a[1] + 0 s[2]=a[2]+a[1]+0
- ...
如果区间 [ i , j ] [i,j] [i,j] 是一个 K K K倍区间,那么显然 [ i , j ] [i,j] [i,j] 的区间和 s [ j ] − s [ i − 1 ] s[j] - s[i - 1] s[j]−s[i−1]是 K K K 的倍数,也就是 ( s [ j ] − s [ i − 1 ] ) % K = 0 (s[j] - s[i -1]) \% K = 0 (s[j]−s[i−1])%K=0,即 s [ j ] % K = s [ i − 1 ] % K s[j] \% K = s[i -1] \% K s[j]%K=s[i−1]%K,也就是 s [ j ] s[j] s[j] 模 K K K 的余数 和 s [ i − 1 ] s[i-1] s[i−1] 模 K K K 的余数 相等!
那么对于每一个 s [ j ] s[j] s[j],我们只需要找到前面与 s [ j ] % K s[j] \% K s[j]%K 相同的前缀和的个数即可。
所以我们可以直接使用哈希表 m a p map map 来存储 s [ j ] % K s[j] \% K s[j]%K的出现次数。
为了不漏掉从第一个元素开始的 K K K倍区间的个数。初始时, m a p [ 0 ] = 1 map[0] = 1 map[0]=1。
时间复杂度: O ( n ) O(n) O(n)
C++代码:
cpp
#include <iostream>
#include <vector>
#include <unordered_map>
using namespace std;
using LL = long long;
int main(){
int n,k;
cin>>n>>k;
vector<int> a(n);
for(int i = 0;i < n;i++) cin>>a[i];
unordered_map<int,int> cnt;
cnt[0] = 1;
LL sum = 0, ans = 0;
for(auto x:a){
sum += x;
int t = sum % k;
ans += cnt[t];
cnt[t]++;
}
cout<<ans<<'\n';
}
Java代码:
java
import java.util.*;
import java.io.*;
public class Main {
static BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
static int n;
public static void main(String[] args) throws Exception
{
String[] str = in.readLine().split(" ");
n = Integer.parseInt(str[0]);
int k = Integer.parseInt(str[1]);
int[] a = new int[n];
for(int i = 0;i < n;i++) a[i] = Integer.parseInt(in.readLine().trim());
Map<Long,Integer> map = new HashMap<>();
map.put(0L, 1);
long sum = 0, ans = 0;
for(int i = 0;i < n;i++) {
sum += a[i];
long t = sum % k;
ans += map.getOrDefault(t, 0);
map.put(t, map.getOrDefault(t, 0) + 1);
}
System.out.println(ans);
}
}