0.前言
去年 10 月底学了一遍,当前再次学习并记录。
wqs 二分,也称带权二分,一般用于优化决策单调性优化 DP。
1.wqs 二分
对于此类问题:给定 \(n\) 个物品,现要求将其分为 \(m\) 段,每段存在相应的代价 \(w\),求代价最值。
如果没有限制段数,可以考虑使用决策单调性优化或者斜优来做到 \(\mathcal{O(n \log n)}\) 或 \(\mathcal{O(n)}\),再加上一维段数的限制,我们可以得到一个 \(\mathcal{O(mc)}\)(\(c = n\log n\) 或 \(n\))的算法,即 \(f_{i,cnt}\) 表示前 \(i\) 个物品,分了 \(cnt\) 段的代价最值。那么有如下转移:
\[f_{i,cnt} = f_{j,cnt - 1} + w(j,i) \]
但当 \(\mathcal{O(nm)}\) 过大,上述做法会 TLE,此时需要考虑使用 wqs 二分来优化 DP。
首先假设我们讨论的情形是,求最小值。将 \((i,g_i = f_{n,i})\) 当作点全部放在平面直角坐标系中,并顺次连接,假设得到的是下凸包。
该函数图像就表示了限制分为 \(i\) 段时的代价最小值,得到的图像为下凸包,该条件等价于 \(\forall i \in [2,m - 1], g_{i - 1} - g_{i} \ge g_{i} - g_{i + 1} \Leftrightarrow \Delta \downarrow\)。
画出该下凸包。图中标红的点 C 即为要求的点 \((m,g_{m})\)。

拿一条斜率为 \(k\) 直线去截我们需要的答案点,所二分的即为斜率 \(k\) ,如果能够求出所截到的点 \((p,g_p)\),就能够通过比较 \(p\) 与 \(m\),考虑斜率是需要增大还是减小,再调整二分区间。
如果我们观察该斜率的直线与凸包上每个点的相交后的直线,会发现第一个截到的点的截距最小(上凸包则为截距最大),观察 \(b = y - kx\) 的式子,问题可以等价于求 \(\min\limits_{i = 1}^m \{f_i - ki\}\)。
相当于每次分段的时候将贡献 \(- k\),等价于做无段数限制的 DP。 这就是 wqs 二分的核心。注意并非每次都是将贡献 \(-k\),需要考虑该题求的是最大代价还是最小代价,需要向着与要求最值相反的方向来进行加减,这样方可起到限制段数的作用。
在做 DP 的时候可以记录转移次数,即得到所截点的横坐标 \(p\)。但我们二分完斜率之后,需要将答案斜率再 check 一遍得到答案的 \(f_n\),最终答案需要加上 \(m \times k\)。
2.特别注意
I.多点共线
在题目中,通常会遇到多点共线的情况,如图,所求点为 D,但是点 C,E 也同样会被截到。

在决策单调性优化 DP一文中我提到过取决策点时,通常需要钦定取最小决策点还是最大决策点,就我个人的写法而言,取的是最小决策点,此时会取到共线点的左端点,所以此时在 check 的时候我应该判断转移次数 \(g_n \le m\)(此处的 \(g_n\) 指转移次数,并非上文的 \(f_{n,m}\))。如果写 \(g_n \ge m\),那么在共线的时候并不会记录答案,而是直接改变斜率继续二分,此时可能我们就不能再二分到答案斜率了。
关于该点详见帖子。
II.斜率
若贡献均为整数,则我们二分斜率也为整数,如果贡献为小数,才会采用小数二分,一般情况下整数贡献不会采用小数二分,时间限制可能不能接受。
III.内层 DP 为斜优
此时需要极度注意细节。
思考你在出队头队尾时所取的点是最小决策点还是最大决策点,即取不取等,该点需与 check 的判断条件相对应,与 I. 中所叙述的问题类似。如果取的是最小决策点那么出队时判断不应该取等。
3.例题
I.P6246 [IOI 2000] 邮局 加强版 加强版
此题需要严格注意细节才能通过。
做该题前可以先把 \(\mathcal{O(pn\log n)}\) 的做法写出来,即P4767 [IOI 2000] 邮局 加强版。
cpp
#include <bits/stdc++.h>
// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------")
using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
template<class T> il void read(T &x) {
x = 0; T f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
x *= f;
}
template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
template<class T> il void print(T x) {
if (x < 0) ptc('-'), x = -x;
if (x > 9) print(x / 10); ptc(x % 10 + '0');
}
template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
template<class T,class T_> il void chmax(T &x,T_ y) { x = max(x,y); }
template<class T,class T_> il void chmin(T &x,T_ y) { x = min(x,y); }
template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
T res = 1; while (b) {
if (b & 1) res = res * a % p;
a = a * a % p; b >>= 1;
} return res;
}
template<class T> il T gcd(T a,T b) { if (!b) return a; return gcd(b,a % b); }
template<class T,class T_> il void exgcd(T a, T b, T_ &x, T_ &y) {
if (b == 0) { x = 1; y = 0; return; }
exgcd(b,a % b,y,x); y -= a / b * x; return ;
}
template<class T,class T_> il T getinv(T x,T_ p) { T inv,y; exgcd(x,(T)p,inv,y); inv = (inv + p) % p; return inv; }
} using namespace szhqwq;
const int N = 3010,inf = 1e9,mod = 998244353;
const ull base = 131,base_ = 233;
const ll inff = 1e18;
int n,p;
int a[N],f[N][310],s[N];
int q[N],hh = 1,tt;
il int calc(int j,int i) {
int mid = i + j >> 1;
int dis = s[i] - s[mid] - (i - mid) * a[mid] + (mid - j + 1) * a[mid] - (s[mid] - s[j - 1]);
return dis;
}
il int check(int j,int i,int c) {
if (f[j][c - 1] + calc(j + 1,n) <= f[i][c - 1] + calc(i + 1,n)) return n + 1; // 注意 <=
int l = i,r = n,ret = -1;
while (l <= r) {
int mid = l + r >> 1;
if (f[j][c - 1] + calc(j + 1,mid) > f[i][c - 1] + calc(i + 1,mid)) r = mid - 1,ret = mid; // 注意 >,两处地方共同构成了取最小决策点的写法
else l = mid + 1;
}
return ret;
}
il void solve() {
//------------code------------
read(n,p);
rep(i,1,n) read(a[i]),s[i] = s[i - 1] + a[i];
memset(f,0x3f,sizeof f);
sort(a + 1,a + n + 1);
// cerr << calc(1,4) << endl;
f[0][0] = 0;
rep(cnt,1,p) {
hh = 1; tt = 0;
rep(i,1,n) {
// cerr << hh << " " << tt << endl;
while (hh <= tt && check(q[hh - 1],q[hh],cnt) <= i) ++ hh;
int j = q[hh - 1];
f[i][cnt] = f[j][cnt - 1] + calc(j + 1,i);
while (hh <= tt && check(q[tt - 1],q[tt],cnt) >= check(q[tt],i,cnt)) -- tt;
q[++ tt] = i;
}
}
// rep(i,1,n) {
// rep(j,1,p) cerr << f[i][j] << " ";
// cerr << '\n';
// }
write(f[n][p],'\n');
return ;
}
il void init() {
return ;
}
signed main() {
// init();
int _ = 1;
// read(_);
while (_ -- ) solve();
return 0;
}
观察 \((i,f_{n,i})\) 构成的函数图像,感性猜测该题为斜率为负的下凸包,当然可以严谨证明,大多数情况下我们通常仅进行猜测。所以在取最小决策点的情况下,如果 \(g_n \le m\),则增大斜率 \(l \gets mid + 1,ret \gets mid\),反之 \(r \gets mid - 1\)。因为该题要求最小值,所以在 check 里 DP 转移中要限制其的段数,故我们需要 - k,等价于每次转移会多加上一个正整数。
最后答案再加上 \(p \times k\) 即可。
cpp
#include <bits/stdc++.h>
// #define int long long
#define ll long long
#define ull unsigned long long
#define db double
#define ld long double
#define rep(i,l,r) for (int i = (int)(l); i <= (int)(r); ++ i )
#define rep1(i,l,r) for (int i = (int)(l); i >= (int)(r); -- i )
#define il inline
#define fst first
#define snd second
#define ptc putchar
#define Yes ptc('Y'),ptc('e'),ptc('s'),puts("")
#define No ptc('N'),ptc('o'),puts("")
#define YES ptc('Y'),ptc('E'),ptc('S'),puts("")
#define NO ptc('N'),ptc('O'),puts("")
#define pb emplace_back
#define sz(x) (int)(x.size())
#define all(x) x.begin(),x.end()
#define get(x) ((x - 1) / len + 1)
#define debug() puts("------------")
using namespace std;
typedef pair<int,int> PII;
typedef pair<int,PII> PIII;
typedef pair<ll,ll> PLL;
namespace szhqwq {
template<class T> il void read(T &x) {
x = 0; T f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
x *= f;
}
template<class T,class... Args> il void read(T &x,Args &...x_) { read(x); read(x_...); }
template<class T> il void print(T x) {
if (x < 0) ptc('-'), x = -x;
if (x > 9) print(x / 10); ptc(x % 10 + '0');
}
template<class T,class T_> il void write(T x,T_ ch) { print(x); ptc(ch); }
template<class T,class T_> il void chmax(T &x,T_ y) { x = max(x,y); }
template<class T,class T_> il void chmin(T &x,T_ y) { x = min(x,y); }
template<class T,class T_,class T__> il T qmi(T a,T_ b,T__ p) {
T res = 1; while (b) {
if (b & 1) res = res * a % p;
a = a * a % p; b >>= 1;
} return res;
}
template<class T> il T gcd(T a,T b) { if (!b) return a; return gcd(b,a % b); }
template<class T,class T_> il void exgcd(T a, T b, T_ &x, T_ &y) {
if (b == 0) { x = 1; y = 0; return; }
exgcd(b,a % b,y,x); y -= a / b * x; return ;
}
template<class T,class T_> il T getinv(T x,T_ p) { T inv,y; exgcd(x,(T)p,inv,y); inv = (inv + p) % p; return inv; }
} using namespace szhqwq;
const int N = 5e5 + 10,inf = 1e9,mod = 998244353;
const ull base = 131,base_ = 233;
const ll inff = 1e18;
int n,p;
ll a[N],f[N],s[N],g[N];
int q[N],hh = 1,tt;
il ll calc(int j,int i) {
int mid = i + j >> 1;
ll dis = s[i] - s[mid] - (i - mid) * a[mid] + (mid - j + 1) * a[mid] - (s[mid] - s[j - 1]);
return dis;
}
il int check(int j,int i) {
if (f[j] + calc(j + 1,n) <= f[i] + calc(i + 1,n)) return n + 1;
int l = i,r = n,ret = -1;
while (l <= r) {
int mid = l + r >> 1;
if (f[j] + calc(j + 1,mid) > f[i] + calc(i + 1,mid)) r = mid - 1,ret = mid;
else l = mid + 1;
}
return ret;
}
il bool check__(int val) {
hh = 1; tt = 0; f[0] = g[0] = 0;
rep(i,1,n) {
// cerr << hh << " " << tt << endl;
while (hh <= tt && check(q[hh - 1],q[hh]) <= i) ++ hh;
int j = q[hh - 1];
f[i] = f[j] + calc(j + 1,i) - val; g[i] = g[j] + 1;
while (hh <= tt && check(q[tt - 1],q[tt]) >= check(q[tt],i)) -- tt;
q[++ tt] = i;
}
return g[n] <= p;
}
il void solve() {
//------------code------------
read(n,p);
rep(i,1,n) read(a[i]);
sort(a + 1,a + n + 1);
rep(i,1,n) s[i] = s[i - 1] + a[i];
int l = -1e7,r = 0,ret = 0;
while (l <= r) {
int mid = l + r >> 1;
if (check__(mid)) l = mid + 1,ret = mid;
else r = mid - 1;
}
check__(ret);
write(f[n] + p * ret,'\n');
return ;
}
il void init() {
return ;
}
signed main() {
// init();
int _ = 1;
// read(_);
while (_ -- ) solve();
return 0;
}
4.练习题
直接借用他人的题单。Link
5.参考资料
【学习笔记】WQS二分详解及常见理解误区解释 - ikrvxt
作者水平有限,如有错误请指出。