LUOGU P2048 [NOI2010] 超级钢琴(贪心+堆)

原题链接:[NOI2010] 超级钢琴


题目大意:


给出一个长度为 n n n 的数组,且 a i a_{i} ai 可正可负,再给出三个数字 k , L , R k,L,R k,L,R 。

定义每个子数组的价值为其所有元素的和,你需要找到 k k k 个连续的子数组(可重叠但不可重复),且满足长度在 [ L , R ] [L,R] [L,R] 内,问你最后这 k k k 个子数组的价值总和最大为多少。

解题思路:


乍一看似乎没有什么思路,因为我们要找的所有区间几乎是 O ( n 2 ) O(n^{2}) O(n2) 级别的。

先考虑一下 k = 1 k=1 k=1 怎么做,可以想到的是,做一个前缀和,枚举 r r r 的同时找到一个对 r r r 而言最小的 l l l ,那么答案取 max ⁡ { S u m [ r ] − S u m [ l − 1 ] } \max\{Sum[r]-Sum[l-1]\} max{Sum[r]−Sum[l−1]} 即可。

(其中前缀和数组为 S u m Sum Sum ,为了简化下文写成 S S S 数组)

同理 k = 2 , 3 , . . . , n 2 k=2,3,...,n^{2} k=2,3,...,n2 时候也是一样的,只需要再给每个 r r r 对应的 l l l 打上一个标记表示用过了,查找时候跳过就好,但这么做时间空间复杂度都爆炸了。

找到最大这一个操作,显而易见我们可以想到线段树,堆,之类的数据结构操作结合贪心或者 D P DP DP 来找,转化到这题来说,我们可以发现一个细节。

即当我们在枚举一个 r r r 的同时,我们知道合法的 l l l 是有个范围的。

假设题中给的 L , R L,R L,R 为 [ 1 , n ] [1,n] [1,n] ,那么对于一个确定的 r r r 而言,我们能找到的 l l l 的范围只能是 [ 1 , r ] [1,r] [1,r] 。

我们要对 r r r 在 [ 1 , r ] [1,r] [1,r] 中找到一个 l l l 使得其满足 max ⁡ { S [ r ] − S [ l − 1 ] } \max\{S[r]-S[l-1]\} max{S[r]−S[l−1]},那么也就是在找 min ⁡ { S [ l − 1 ] } \min\{S[l-1]\} min{S[l−1]},即就是对 S [ l − 1 ] S[l-1] S[l−1] 做一个最小值的 R M Q RMQ RMQ 即可。

那么我们要怎么高效地去找有哪些区间是合法的呢?

其实我们只需要知道这几个值即可:

当前位置 r r r, l l l 的可选区间 [ x , y ] [x,y] [x,y],以及一个该区间价值 a n s = S [ r ] − min ⁡ { S [ l − 1 ] } ans=S[r]-\min\{S[l-1]\} ans=S[r]−min{S[l−1]} ,下面说说怎么做。

我们枚举每一个 r r r,同时做 R M Q RMQ RMQ 在合法区间 [ x , y ] [x,y] [x,y] 中找到最小的 S [ l − 1 ] S[l-1] S[l−1],那么我们此时就可以得到 a n s = S [ r ] − S [ l − 1 ] ans=S[r]-S[l-1] ans=S[r]−S[l−1],将其包装成一个结构体 [ a n s , x , y , l , r ] [ans,x,y,l,r] [ans,x,y,l,r] 扔进堆里面,按照 a n s ans ans 键值的大根堆做排序(大的在堆顶)。

这里解释一下 [ a n s , x , y , l , r ] [ans,x,y,l,r] [ans,x,y,l,r] 的含义, a n s ans ans 即当前定下一个 r r r 且只允许在区间 [ x , y ] [x,y] [x,y] 找 l l l 所能获得的最大价值 S [ r ] − S [ l − 1 ] S[r]-S[l-1] S[r]−S[l−1] 。

要时刻注意注意我们的思路是定下一个找另一个,即定 r r r 找 l l l,我们的解法都是从这一点出发的。

当 k = 1 k=1 k=1 时,显然答案就是堆顶的 a n s ans ans 。

当 k = 2 k=2 k=2 时,第一个答案肯定是堆顶的答案,我们设堆顶的为 r r r,考虑第二个答案在哪个 r ′ r' r′ 取得。

可能是在 r r r 的位置,找到另一个 l ′ l' l′ 使得 a n s = S [ r ] − S [ l ′ − 1 ] ans=S[r]-S[l'-1] ans=S[r]−S[l′−1] 最大,但也有可能是另一个 r ′ r' r′ 的 a n s ans ans,和当前的 r r r 无关,所以我们要把 r r r 重新扔到堆里面,让堆来帮我们执行一次贪心的操作,具体而言:

我们将 [ a n s , x , y , l , r ] [ans,x,y,l,r] [ans,x,y,l,r] 变成 [ a n s 1 , x , l − 1 , A , r ] [ans1,x,l-1,A,r] [ans1,x,l−1,A,r] 和 [ a n s 2 , l + 1 , r , B , r ] [ans2,l+1,r,B,r] [ans2,l+1,r,B,r] ,然后再将这两个扔到堆里即可,即我们把原先的合法区间拆成了 [ x , l − 1 ] [x,l-1] [x,l−1] 和 [ l + 1 , y ] [l+1,y] [l+1,y] 这两个区间(当然得考虑一下合法性)。

那么我们 r r r 在 [ x , l − 1 ] [x,l-1] [x,l−1] 的最优值为 a n s 1 = S [ r ] − S [ A − 1 ] ans1 = S[r]-S[A-1] ans1=S[r]−S[A−1],同理 r r r 在 [ l + 1 , y ] [l+1,y] [l+1,y] 的最优值为 a n s 2 = S [ r ] − S [ B − 1 ] ans2 = S[r]-S[B-1] ans2=S[r]−S[B−1]。

这样, k ≥ 3 k \ge 3 k≥3 以及之后的做法就显而易见了。

将一个区间拆分成两份,同时用 S [ r ] S[r] S[r] 在这两个区间分别做 R M Q RMQ RMQ 找到 l l l,再将其扔进堆里面。

这样做,我们的原本一个区间只会被最多分成两个区间,被分成的区间总数最多就是 O ( k ) O(k) O(k) 个,单次取出删除的复杂度是 O ( log ⁡ n ) O(\log n) O(logn) 的,可以通过。

代码中用的是 S T ST ST 表做 R M Q RMQ RMQ,本质上线段树之类的也行,但是复杂度要多上一个 O ( log ⁡ n ) O(\log n) O(logn) 且常数稍大。

时间复杂度: O ( k log ⁡ n ) O(k \log n) O(klogn)

cpp 复制代码
#include <bits/stdc++.h>

using i64 = long long;

//ST表板子
template <class Ty, const int logn>
struct SparseTable {
    std::vector<std::array<Ty, logn>> info;

    SparseTable(const std::vector<Ty>& A) { init(A); }
    void init(const std::vector<Ty>& A) {
        int n = A.size() - 1;
        info.assign(A.size(), std::array<Ty, logn>{});
        for (int i = 1; i <= n; ++i) {
            info[i][0] = A[i];
        }
        for (int j = 1; j <= logn; ++j) {
            for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
                info[i][j] = merge(info[i][j - 1], info[i + (1 << j - 1)][j - 1]);
            }
        }
    }
    Ty Query(int l, int r) {
        int j = std::__lg(r - l + 1);
        return merge(info[l][j], info[r - (1 << j) + 1][j]);
    };
    constexpr Ty merge(const Ty& a, const Ty& b) {
        return std::min(a, b);
    }
};

//ST表维护 S 和下标 p
struct Info {
    int sum, p;
    bool operator<(const Info& rhs) const {
        return (sum == rhs.sum ? p < rhs.p : sum < rhs.sum);
    }
};

void solve() {
    int n, k, L, R;
    std::cin >> n >> k >> L >> R;

    std::vector<Info> A(n + 1);
    std::vector<int> s(n + 1);
    for (int i = 1; i <= n; ++i) {
        std::cin >> s[i];
        s[i] += s[i - 1];
        A[i] = {s[i - 1], i};
    }

    SparseTable<Info, 20> ST(A);

    std::priority_queue<std::array<int, 5>> heap;

	//枚举每个 r 找 l,但是要注意满足题目中说的子数组长度限制 [L, R]
    for (int i = L; i <= n; ++i) {
        auto [S, p] = ST.Query(std::max(1, i - R + 1), std::max(1, i - L + 1));
        heap.push({s[i] - S, std::max(1, i - R + 1), std::max(1, i - L + 1), p, i});
    }

    i64 ans = 0;
    for (int T = 0; T < k; ++T) {
    	//对应 [ans, x, y, l, r]
        auto [S, QL, QR, p, i] = heap.top();
        heap.pop();
        ans += S;
        //找 [x, l - 1]
        if (QL <= p - 1) {
            auto [V, x] = ST.Query(QL, p - 1);
            heap.push({s[i] - V, QL, p - 1, x, i});
        }
        //找 [l + 1, y]
        if (p + 1 <= QR) {
            auto [V, x] = ST.Query(p + 1, QR);
            heap.push({s[i] - V, p + 1, QR, x, i});
        }
    }

    std::cout << ans << "\n";
}

signed main() {

    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int t = 1;
    //std::cin >> t;
    while (t--) {
        solve();
    }

    return 0;
}
相关推荐
weixin_4866811411 分钟前
C++系列-STL中find相关的算法
java·c++·算法
码了三年又三年27 分钟前
ArrayList、LinkedList和Vector的区别
开发语言·c++·链表
月夕花晨37429 分钟前
C++学习笔记(14)
c++·笔记·学习
我是真爱学JAVA37 分钟前
第四章 类和对象 课后训练(1)
java·开发语言·算法
Qiuner1 小时前
【机器学习】分类与回归——掌握两大核心算法的区别与应用
算法·机器学习·分类
oufoc1 小时前
第J1周:ResNet-50算法实战与解析
神经网络·算法·tensorflow
金博客2 小时前
QT使用相机拍照
c++·qt
轩源源2 小时前
函数模板(初阶)
数据结构
想拿大厂offer2 小时前
【数据结构】第八节:链式二叉树
c语言·数据结构
Youkiup2 小时前
【重构数组,形成排列的最大长度】
算法