原题:link,点击这里喵。
题意:P1484 种树。
将军今天在种树,他在一条直线上挖了 n n n 个坑。这 n n n 个坑都可以种树,但为了保证每一棵树都有充足的养料,将军不会在相邻的两个坑中种树。而且由于将军的树种不够,他至多会种 k k k 棵树。将军有某种神能力,能预知自己在某个坑种树的获利会是多少(可能为负),请你帮助他计算出他的最大获利。
对于 100 % 100\% 100% 的数据, 1 ≤ n ≤ 300000 1 \le n\leq 300000 1≤n≤300000, 1 ≤ k ≤ n 2 1 \le k\leq \dfrac{n}{2} 1≤k≤2n,在一个地方种树获利的绝对值在 1 0 6 10^6 106 以内。
解法:WQS 二分。
首先,如果没有 k k k 的限制,我们可以很轻松的推出 d p dp dp 方程式,设 d p i dp_i dpi 为种到第 i i i 个坑的最大收益,我们有: d p i = max ( d p i − 2 + v i , d p i − 1 ) dp_i=\max(dp_{i-2}+v_i,dp_{i-1}) dpi=max(dpi−2+vi,dpi−1)
注意自行判断边界条件即可。
运用套路二分惩罚(斜率) c c c,把每一个坑的收益减去 c c c,了解此时所选的数量。不过有一点很特别,不同于以往的恰好 ,题目中是至多。
这该怎么办?
我们定义一次惩罚为 c c c 的操作为 solve ( c ) \operatorname{solve}(c) solve(c)。
在程序开始时,我们先执行一次 solve ( 0 ) \operatorname{solve}(0) solve(0),如果其返回值告诉我们它选择了 c n t cnt cnt 个树坑:
- c n t < = k cnt<=k cnt<=k,函数极值的横坐标小于等于 k k k,意味着可以直接输出答案,因为右侧的斜率 c < = 0 c<=0 c<=0。
- c n t > k cnt>k cnt>k,归约到正常的 WQS 二分问题,此时函数在 k k k 限制下的极值的横坐标为 k k k,所以套用恰好模型。
可以参照下面的图片参照理解。
为了防止有相同的斜率产生,我们还需要统计 solve ( 0 ) \operatorname{solve}(0) solve(0) 中, v − c = 0 v-c=0 v−c=0 的 v v v 个数(这个是在 d p dp dp 过程中计入的 0 0 0 的个数),记为 _ c n t \_cnt _cnt,因为 0 0 0 不对 d p dp dp 的值产生变化,所以我们直接在统计答案时加上其个数并与 k k k 取 min \min min 即可。
代码 time。
cpp
#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
typedef long long lnt;
using namespace std;
const int N = 2e6 + 10;
struct DATA { // 这样写方便也省事
int cnt, _cnt;
lnt v;
DATA() = default;
DATA(int cntp, int _cntp, lnt _v) : cnt(cntp), _cnt(_cntp) , v(_v){}
bool operator<(const DATA &a) const { return v < a.v; }
DATA operator+(const DATA &a) const {
return DATA(cnt + a.cnt, _cnt + a._cnt, v + a.v);
}
};
vector<DATA> dp, d, v;
int n, k;
DATA solve(int c) {
v = d;
for (int i = 0; i < n; ++i) {
v[i].v -= c;
if (v[i].v < 0) v[i].cnt = 0, v[i].v = 0;
else if(v[i].v == 0) v[i]._cnt = 1, v[i].cnt = 0;
dp[i] = {0, 0, 0};
}
DATA ans = DATA(0, 0, 0);
for (int i = 0; i < n; ++i) {
dp[i] = (i - 2 >= 0 ? dp[i - 2] : DATA(0, 0, 0)) + v[i];
if (i) dp[i] = max(max(dp[i], dp[i - 1]), DATA(0, 0, 0));
ans = max(ans, dp[i]);
}
return ans;
}
int main() { //
scanf("%d%d", &n, &k);
dp.resize(n), d.resize(n);
for (int i = 0; i < n; ++i) {
scanf("%lld", &d[i].v);
d[i].cnt = 1;
}
DATA g = solve(0);
if (g.cnt <= k) {
printf("%lld\n", g.v);
return 0;
}
int l = -1, r = 1e6 + 10;
lnt ans = 0;
while (l < r - 1) {
int mid = (l + r) >> 1;
DATA g = solve(mid);
if (g.cnt <= k) {
ans = max(ans, (lnt)min(g.cnt + g._cnt, k) * mid + g.v); // 注意这里 !
r = mid;
} else if (g.cnt > k) {
l = mid;
}
}
printf("%lld\n", ans);
return 0;
}