原题链接:[NOI2010] 超级钢琴
题目大意:
给出一个长度为 n n n 的数组,且 a i a_{i} ai 可正可负,再给出三个数字 k , L , R k,L,R k,L,R 。
定义每个子数组的价值为其所有元素的和,你需要找到 k k k 个连续的子数组(可重叠但不可重复),且满足长度在 [ L , R ] [L,R] [L,R] 内,问你最后这 k k k 个子数组的价值总和最大为多少。
解题思路:
乍一看似乎没有什么思路,因为我们要找的所有区间几乎是 O ( n 2 ) O(n^{2}) O(n2) 级别的。
先考虑一下 k = 1 k=1 k=1 怎么做,可以想到的是,做一个前缀和,枚举 r r r 的同时找到一个对 r r r 而言最小的 l l l ,那么答案取 max { S u m [ r ] − S u m [ l − 1 ] } \max\{Sum[r]-Sum[l-1]\} max{Sum[r]−Sum[l−1]} 即可。
(其中前缀和数组为 S u m Sum Sum ,为了简化下文写成 S S S 数组)
同理 k = 2 , 3 , . . . , n 2 k=2,3,...,n^{2} k=2,3,...,n2 时候也是一样的,只需要再给每个 r r r 对应的 l l l 打上一个标记表示用过了,查找时候跳过就好,但这么做时间空间复杂度都爆炸了。
找到最大这一个操作,显而易见我们可以想到线段树,堆,之类的数据结构操作结合贪心或者 D P DP DP 来找,转化到这题来说,我们可以发现一个细节。
即当我们在枚举一个 r r r 的同时,我们知道合法的 l l l 是有个范围的。
假设题中给的 L , R L,R L,R 为 [ 1 , n ] [1,n] [1,n] ,那么对于一个确定的 r r r 而言,我们能找到的 l l l 的范围只能是 [ 1 , r ] [1,r] [1,r] 。
我们要对 r r r 在 [ 1 , r ] [1,r] [1,r] 中找到一个 l l l 使得其满足 max { S [ r ] − S [ l − 1 ] } \max\{S[r]-S[l-1]\} max{S[r]−S[l−1]},那么也就是在找 min { S [ l − 1 ] } \min\{S[l-1]\} min{S[l−1]},即就是对 S [ l − 1 ] S[l-1] S[l−1] 做一个最小值的 R M Q RMQ RMQ 即可。
那么我们要怎么高效地去找有哪些区间是合法的呢?
其实我们只需要知道这几个值即可:
当前位置 r r r, l l l 的可选区间 [ x , y ] [x,y] [x,y],以及一个该区间价值 a n s = S [ r ] − min { S [ l − 1 ] } ans=S[r]-\min\{S[l-1]\} ans=S[r]−min{S[l−1]} ,下面说说怎么做。
我们枚举每一个 r r r,同时做 R M Q RMQ RMQ 在合法区间 [ x , y ] [x,y] [x,y] 中找到最小的 S [ l − 1 ] S[l-1] S[l−1],那么我们此时就可以得到 a n s = S [ r ] − S [ l − 1 ] ans=S[r]-S[l-1] ans=S[r]−S[l−1],将其包装成一个结构体 [ a n s , x , y , l , r ] [ans,x,y,l,r] [ans,x,y,l,r] 扔进堆里面,按照 a n s ans ans 键值的大根堆做排序(大的在堆顶)。
这里解释一下 [ a n s , x , y , l , r ] [ans,x,y,l,r] [ans,x,y,l,r] 的含义, a n s ans ans 即当前定下一个 r r r 且只允许在区间 [ x , y ] [x,y] [x,y] 找 l l l 所能获得的最大价值 S [ r ] − S [ l − 1 ] S[r]-S[l-1] S[r]−S[l−1] 。
要时刻注意注意我们的思路是定下一个找另一个,即定 r r r 找 l l l,我们的解法都是从这一点出发的。
当 k = 1 k=1 k=1 时,显然答案就是堆顶的 a n s ans ans 。
当 k = 2 k=2 k=2 时,第一个答案肯定是堆顶的答案,我们设堆顶的为 r r r,考虑第二个答案在哪个 r ′ r' r′ 取得。
可能是在 r r r 的位置,找到另一个 l ′ l' l′ 使得 a n s = S [ r ] − S [ l ′ − 1 ] ans=S[r]-S[l'-1] ans=S[r]−S[l′−1] 最大,但也有可能是另一个 r ′ r' r′ 的 a n s ans ans,和当前的 r r r 无关,所以我们要把 r r r 重新扔到堆里面,让堆来帮我们执行一次贪心的操作,具体而言:
我们将 [ a n s , x , y , l , r ] [ans,x,y,l,r] [ans,x,y,l,r] 变成 [ a n s 1 , x , l − 1 , A , r ] [ans1,x,l-1,A,r] [ans1,x,l−1,A,r] 和 [ a n s 2 , l + 1 , r , B , r ] [ans2,l+1,r,B,r] [ans2,l+1,r,B,r] ,然后再将这两个扔到堆里即可,即我们把原先的合法区间拆成了 [ x , l − 1 ] [x,l-1] [x,l−1] 和 [ l + 1 , y ] [l+1,y] [l+1,y] 这两个区间(当然得考虑一下合法性)。
那么我们 r r r 在 [ x , l − 1 ] [x,l-1] [x,l−1] 的最优值为 a n s 1 = S [ r ] − S [ A − 1 ] ans1 = S[r]-S[A-1] ans1=S[r]−S[A−1],同理 r r r 在 [ l + 1 , y ] [l+1,y] [l+1,y] 的最优值为 a n s 2 = S [ r ] − S [ B − 1 ] ans2 = S[r]-S[B-1] ans2=S[r]−S[B−1]。
这样, k ≥ 3 k \ge 3 k≥3 以及之后的做法就显而易见了。
将一个区间拆分成两份,同时用 S [ r ] S[r] S[r] 在这两个区间分别做 R M Q RMQ RMQ 找到 l l l,再将其扔进堆里面。
这样做,我们的原本一个区间只会被最多分成两个区间,被分成的区间总数最多就是 O ( k ) O(k) O(k) 个,单次取出删除的复杂度是 O ( log n ) O(\log n) O(logn) 的,可以通过。
代码中用的是 S T ST ST 表做 R M Q RMQ RMQ,本质上线段树之类的也行,但是复杂度要多上一个 O ( log n ) O(\log n) O(logn) 且常数稍大。
时间复杂度: O ( k log n ) O(k \log n) O(klogn)
cpp
#include <bits/stdc++.h>
using i64 = long long;
//ST表板子
template <class Ty, const int logn>
struct SparseTable {
std::vector<std::array<Ty, logn>> info;
SparseTable(const std::vector<Ty>& A) { init(A); }
void init(const std::vector<Ty>& A) {
int n = A.size() - 1;
info.assign(A.size(), std::array<Ty, logn>{});
for (int i = 1; i <= n; ++i) {
info[i][0] = A[i];
}
for (int j = 1; j <= logn; ++j) {
for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
info[i][j] = merge(info[i][j - 1], info[i + (1 << j - 1)][j - 1]);
}
}
}
Ty Query(int l, int r) {
int j = std::__lg(r - l + 1);
return merge(info[l][j], info[r - (1 << j) + 1][j]);
};
constexpr Ty merge(const Ty& a, const Ty& b) {
return std::min(a, b);
}
};
//ST表维护 S 和下标 p
struct Info {
int sum, p;
bool operator<(const Info& rhs) const {
return (sum == rhs.sum ? p < rhs.p : sum < rhs.sum);
}
};
void solve() {
int n, k, L, R;
std::cin >> n >> k >> L >> R;
std::vector<Info> A(n + 1);
std::vector<int> s(n + 1);
for (int i = 1; i <= n; ++i) {
std::cin >> s[i];
s[i] += s[i - 1];
A[i] = {s[i - 1], i};
}
SparseTable<Info, 20> ST(A);
std::priority_queue<std::array<int, 5>> heap;
//枚举每个 r 找 l,但是要注意满足题目中说的子数组长度限制 [L, R]
for (int i = L; i <= n; ++i) {
auto [S, p] = ST.Query(std::max(1, i - R + 1), std::max(1, i - L + 1));
heap.push({s[i] - S, std::max(1, i - R + 1), std::max(1, i - L + 1), p, i});
}
i64 ans = 0;
for (int T = 0; T < k; ++T) {
//对应 [ans, x, y, l, r]
auto [S, QL, QR, p, i] = heap.top();
heap.pop();
ans += S;
//找 [x, l - 1]
if (QL <= p - 1) {
auto [V, x] = ST.Query(QL, p - 1);
heap.push({s[i] - V, QL, p - 1, x, i});
}
//找 [l + 1, y]
if (p + 1 <= QR) {
auto [V, x] = ST.Query(p + 1, QR);
heap.push({s[i] - V, p + 1, QR, x, i});
}
}
std::cout << ans << "\n";
}
signed main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t = 1;
//std::cin >> t;
while (t--) {
solve();
}
return 0;
}