下标线段树
下标线段树即最普通的线段树,用来维护一个区间 \(L,R\) 的信息,可以发现这其实对应了数组里面的下标,故称其为下标线段树
基本原理
对于每个节点 \(l,r\),左儿子是 \(l,mid\),右儿子是 \(mid+1,r\),\(mid=(l+r)/2\)(向下取整)
我们发现除了最后一层是一个满二叉树,于是我们就可以用与二叉堆类似的方法存储
对于一个编号为 \(x\) 的节点,父亲节点 \(x/2\)(x >> 1)(向下取整),左儿子 \(2x\)(x << 1),右儿子 \(2x+1\)(x << 1 | 1)
除去最后一层点数共有:\(n+n/2+n/4+...+2+1=2n-1\) 个节点,最后一层因为是倒数第二层的两倍,所以最多有 \(2n\) 个点,所以说开线段树开 \(4n\) 个空间
如图假如需要维护一个 \(1,10\) 的区间:

基本操作
push up:由子节点算出父节点的值
build:将一段区间初始化成一棵线段树
对于线段树,我们一般使用结构体来存储
首先需要存储左儿子和右儿子,那么怎么判断需要存哪些值呢?
首先看问的是什么(求某个区间的某种属性),除此之外还需要存储一些辅助信息,即是否直接可以从两个子节点推出父节点
根据基本原理,我们可以写出以下代码
c++
void build(int u, int L, int R)
{
tr[u].L = L;
tr[u].R = R;
if (L == R) return;
int mid = L + R >> 1;
build(u << 1, L, mid); build(u << 1 | 1, mid + 1, R);
//pushup(u); 一般会写在这里
}
modify:修改单点(简单)修改区间(困难,需要懒标记)
单点修改非常简单,因为只需要修改一个点,所以从根节点不停递归到一个叶子节点上就行了,回溯的时候更新父节点的值就行了(push up 操作)
query:查询某一段区间的信息
查询的时候一般分为这几种情况:
我们使用 \(L,R\) 来表示查询的区间,\(T_L,T_R\) 来表示树中节点
这里以求某一个区间的最大值为例,每个 \(T_L,T_R\) 维护区间内的最大值
- \(L,R\supsetT_L,T_R\),如果完全包含的话,直接返回维护的最大值即可
- \(L,R\capT_L,T_R\neq \varnothing\),和左边有交集就递归左边,和右边有交集就递归右边,都有就同时递归

此为第一种情况:\(T_L\leq L\leq T_R\leq R\)
又分为两种情况:
\(L>mid\) 只递归右边
\(L\leq mid\) 既递归左边也递归右边,但是右儿子只会递归 \(1\) 次,只会对常数产生影响,所以复杂度仍为 \(O(\log n)\)

此为第二种情况:\(L\leq T_L\leq R\leq T_R\)
与上一种是类似的

此为第三种情况,又分为三种情况:
\(R\leq mid\) 只递归左边
\(L> mid\) 只递归右边
剩下的情况需要递归左边和右边,这种情况最多发生一次,之后便会变成上述两种情况。因此复杂度为 \(O(2\log n)\) 忽略常数还是 \(O(\log n)\) 的
- \(L,R\capT_L,T_R= \varnothing\),这种情况是不存在的
总结一下操作过程:
- 若 \(l,r\) 完全覆盖了当前节点所覆盖的区间,则立即回溯,并且该节点的值做为一个备选答案
- 若左子节点与 \(l,r\) 有重叠部分,则递归访问左子节点
- 若右子节点与 \(l,r\) 有重叠部分,则递归访问右子节点
push down:给一个区间加上一个数,更新到两个子节点
例题
I. JSOI2008 最大数
因为动态添加比较麻烦,所以我们直接建 \(1\sim m\) 个空间就行了(最多 \(m\) 个操作,也就是 \(m\) 个空间),每次新增一个数的时候相当于直接修改,问后 \(L\) 个数里面最大值就是问 \(n-L+1,n\)(\(n\) 在每次加数的时候也顺带 ++)
c++
#include <iostream>
using namespace std;
const int N = 200005;
struct Node
{
int l, r; //左儿子右儿子
int v; //记录的最大值
}tr[N * 4];
//建树
void build(int u, int l, int r)
{
tr[u] = { l,r };
if (l == r) return; //叶子节点
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
//根据子节点的值计算父节点的值
void pushup(int u)
{
//取最大值
tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}
//查询操作
int query(int u, int l, int r)
{
//包含在 [l,r] 中,可作为备选答案
if (tr[u].l >= l && tr[u].r <= r) return tr[u].v;
int mid = tr[u].l + tr[u].r >> 1;
int v = 0;
//左子节点有重叠
if (l <= mid) v = query(u << 1, l, r);
//右子节点有重叠
if (r > mid) v = max(v, query(u << 1 | 1, l, r));
return v;
}
//修改操作
void modify(int u, int x, int v)
{
//遍历到了叶子节点(需要修改的 x 位置)
if (tr[u].l == x && tr[u].r == x) tr[u].v = v;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u); //更新父节点的值
}
}
int main()
{
int m, p;
scanf("%d%d", &m, &p);
//建立线段树
build(1, 1, m);
int last = 0, n = 0; //记录上一个结果和线段树编号
int x; char op[2];
while (m--)
{
scanf("%s%d", op, &x);
if (*op == 'Q')
{
last = query(1, n - x + 1, n);
printf("%d\n", last);
}
else
{
modify(1, n + 1, ((long long)last + x) % p);
n++;
}
}
return 0;
}
II. 你能回答这些问题吗
首先考虑一下,线段树中存储什么,显然需要存储最大连续子段和,左右儿子
那么我们可以从两个子节点的最大连续子段和推出父节点的最大连续子段和吗,其实是不行的
因为在两个子区间中顺序是无所谓的,但到了父区间时,要求必须是连续的,所以不行,我们需要额外存储每个区间的最大前缀和和后缀和,这样就可以拼在一起求出父节点
那么可以分为三种情况:
- 完全在左子区间
- 完全在右子区间
- 由左子区间的最大后缀和右子区间的最大前缀和拼接而成
接下来再来分析如何得到最大后缀和、最大前缀和
最大前缀和分为两种情况:
- 完全在左子区间中
- 左边的和加上右子区间的最大前缀和
我们可以发现又需要记录一个 sum 信息,那么根据左子区间和右子区间的和可以推出来父节点的和,所以不需要再记录新的信息
c++
#include <iostream>
using namespace std;
const int N = 500005;
int w[N];
int n, m;
struct Node
{
int l, r;
int lmax, rmax, sum; //最大前缀和、后缀和、和
int tmax; //最大连续字段和
}tr[N * 4];
void pushup(Node& u, Node& l, Node& r)
{
u.sum = l.sum + r.sum;
u.lmax = max(l.lmax, l.sum + r.lmax);
u.rmax = max(r.rmax, r.sum + l.rmax);
u.tmax = max(max(l.tmax, r.tmax), l.rmax + r.lmax);
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
//叶子节点
if (l == r) tr[u] = { l,r,w[l],w[l],w[l],w[l] };
else
{
tr[u] = { l,r };
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
//更新父节点的值
pushup(u);
}
}
void modify(int u, int x, int v)
{
if (tr[u].l == x && tr[u].r == x) tr[u] = { x,x,v,v,v,v };
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u); //更新值
}
}
Node query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
else
{
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
Node res;
pushup(res, left, right);
return res;
}
}
int main()
{
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> w[i];
build(1, 1, n);
while (m--)
{
int k, x, y;
cin >> k >> x >> y;
if (k == 1)
{
if (x > y) swap(x, y);
cout << query(1, x, y).tmax << endl;
}
else modify(1, x, y);
}
return 0;
}
III. 区间最大公约数
对于这种区间修改的问题,一般是要使用懒标记的,但是在这个问题中只需要区间都加上一个数,所以只需要运用差分的思想,单点修改就行了
对于查询操作来说,只需要维护一个最大公约数就够了,那么修改怎么办
我们可以发现修改一个数非常容易,但是修改一个区间非常难
因为一个数的最大公约数就是它自己
那么我们是否可以用差分的思想来做:
我们知道 \((x,y,z)=(x,y-x,z-y)\)
那么对于 \(n\) 个数来言:\((a_1,a_2,...,a_n)=(a_1,a_2-a_1,...,a_n-a_{n-1})\)
(参考更相减损术)
所以我们维护 \(a\) 的差分序列 \(b_i=a_i-a_{i-1}\)。用线段树维护 \(b\) 的最大公约数
那么对于每一个查询,就相当于求出 \(gcd(a_l,gcd(b_{l+1}\sim b_r))\)
c++
#include <iostream>
using namespace std;
const int N = 500005;
typedef long long LL;
struct Node
{
int l, r;
LL d, sum; //记录最大公约数与区间和
//区间和用来记录原本 a[i] 的值,因为维护的是差分序列
}tr[N * 4];
LL w[N];
int n, m;
LL gcd(LL a, LL b)
{
return b ? gcd(b, a % b) : a;
}
void pushup(Node& u, Node& left, Node& right)
{
u.sum = left.sum + right.sum;
u.d = gcd(left.d, right.d);
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
if (l == r)
{
LL d = w[r] - w[r - 1]; //差分
tr[u] = { l,r,d,d };
}
else
{
tr[u].l = l, tr[u].r = r;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int x, LL v)
{
if (tr[u].l == x && tr[u].r == x)
{
LL sum = tr[u].sum + v;
tr[u].sum = sum;
tr[u].d = sum;
}
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (mid >= x) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
Node query(int u, int l, int r)
{
if (l > r) return { 0 }; //左端点比右端点大
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
else
{
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
Node res;
pushup(res, left, right);
return res;
}
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%lld", &w[i]);
build(1, 1, n);
int l, r;
char op[2];
while (m--)
{
scanf("%s%d%d", op, &l, &r);
if (*op == 'Q')
{
// a[l] 的值
Node left = query(1, 1, l);
Node right = query(1, l + 1, r);
// 需要取绝对值,可能存在负数
printf("%lld\n", abs(gcd(left.sum, right.d)));
}
else
{
LL d;
scanf("%lld", &d);
modify(1, l, d);
// 有可能越界进行特判
if (r + 1 <= n) modify(1, r + 1, -d);
}
}
return 0;
}
也可以再使用一个树状数组来维护 \(al\) 的值
c++
#include <iostream>
#include <cstring>
using namespace std;
typedef long long LL;
const int N = 500005;
int n, m;
LL A[N];
struct Node
{
int l, r;
LL d;
}tr[N * 4];
LL tr1[N];
int lowbit(int x)
{
return x & -x;
}
void add(int x, LL c)
{
for (int i = x; i <= n; i += lowbit(i)) tr1[i] += c;
}
LL sum(int x)
{
LL res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr1[i];
return res;
}
LL gcd(LL a, LL b)
{
return b ? gcd(b, a % b) : a;
}
void pushup(Node& u, Node& left, Node& right)
{
u.d = gcd(left.d, right.d);
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
if (l == r) tr[u] = { l, r, A[r] - A[r - 1] };
else
{
tr[u] = { l, r };
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int x, LL v)
{
if (tr[u].l == x && tr[u].r == x) tr[u].d += v;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (mid >= x) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
Node query(int u, int l, int r)
{
if (l > r) return { 0 };
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
else
{
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
Node res;
pushup(res, left, right);
return res;
}
}
int main()
{
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> A[i];
for (int i = 1; i <= n; i++) add(i, A[i] - A[i - 1]);
build(1, 1, n);
while (m--)
{
char op;
int l, r;
LL d;
cin >> op >> l >> r;
if (op == 'C')
{
cin >> d;
add(l, d), modify(1, l, d);
if (r < n) add(r + 1, -d), modify(1, r + 1, -d);
}
else cout << ((l + 1 <= r) ? llabs(gcd(sum(l), query(1, l + 1, r).d)) : sum(l)) << endl;
}
return 0;
}
IV. 3 的倍数区间
题意简述:
给定 \(n\) 个元素的序列,\(m\) 次操作
操作 1:
1 x a把数组中下标为 \(x\) 的数修改为 \(a\)操作 2:
2 L R求区间 \(L,R\) 范围内,有多少对 \((x,y)\) 满足 \(L\le x \le y\le R,(\sum^y_{i=x}A_i)\bmod 3=0\)
考虑一个大区间的答案如何由两个小子区间得来,可以发现要么在左子区间里面要么在右子区间里面,要么就是中间一段,比较像 你能回答这些问题吗 一题
所以说需要额外记录一个区间从右到左,余数分别为 \(0,1,2\) 的区间各有多少个,从左往右,余数分别为 \(0,1,2\) 的区间各有多少个
相乘便可以得到答案
维护这两个额外信息也很简单,一个区间从左到右的值要么是左子区间从左到右的值,要么是左子区间的 \(sum\) 加上右子区间对应余数的方案
请读者参考代码进行理解
c++
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 300005;
typedef long long LL;
struct Node
{
int l, r;
LL v;
int sum;
int cnt1[3], cnt2[3]; //从左到右,从右到左
}tr[N * 4];
int a[N];
void pushup(Node& u, Node& l, Node& r)
{
u.v = l.v + r.v;
u.sum = l.sum + r.sum;
for (int i = 0; i < 3; i++)
u.v += 1ll * l.cnt2[i] * r.cnt1[(3 - i) % 3];
for (int i = 0; i < 3; i++) u.cnt1[i] = u.cnt2[i] = 0;
for (int i = 0; i < 3; i++)
{
u.cnt1[i] += r.cnt1[(i - l.sum % 3 + 3) % 3];
u.cnt2[i] += l.cnt2[(i - r.sum % 3 + 3) % 3];
u.cnt1[i] += l.cnt1[i];
u.cnt2[i] += r.cnt2[i];
}
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
if (l == r)
{
tr[u] = { l, r, a[l] % 3 == 0, a[l] };
tr[u].cnt1[a[l] % 3]++;
tr[u].cnt2[a[l] % 3]++;
}
else
{
tr[u] = { l, r };
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int x, int v)
{
if (tr[u].l == x && tr[u].r == x)
{
tr[u].v = (v % 3 == 0);
tr[u].cnt1[a[x] % 3]--;
tr[u].cnt2[a[x] % 3]--;
a[x] = v;
tr[u].sum = v;
tr[u].cnt1[a[x] % 3]++;
tr[u].cnt2[a[x] % 3]++;
}
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
Node query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
else
{
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
Node res;
pushup(res, left, right);
return res;
}
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
build(1, 1, n);
for (int i = 1; i <= m; i++)
{
int op, a, b;
scanf("%d%d%d", &op, &a, &b);
if (op == 1) modify(1, a, b);
else printf("%lld\n", query(1, a, b).v);
}
return 0;
}
懒标记
如果使用 pushup 操作进行区间修改的话,最多会修改 \(n\) 个数,那么复杂度就退化为 \(O(n)\)
试想一下,若是我们需要修改一个 tr[u].l >= l && tr[u].r <= r 的区间和其子区间,但是后面查询的时候又没用到,那么就白改了,所以可以打个标记,表示这个节点被修改过但是子区间并没有更新,等查到它的时候再根据标记修改它的子区间,称作 pushdown 操作
pushdown 操作实际上来源于区间查询的思想,区间查询时 if (tr[u].l >= l && tr[u].r <= r) 及时返回,那么复杂度就是 \(O(4\log n)\)
我们以此题为例,分为两种操作:
- sum,记录当前区间的总和
- add,使用懒标记,给以当前节点为根的子树中的每一个节点加上一个数(不包含根节点)
那么查询区间时,对于一个区间实际上应该加上打的标记
为了实现这个操作,我们只需要在遍历到一个区间时,如果不是答案区间,清空标记,然后将标记传给子节点
在修改时也需要做类似的操作,举个例子:
例如需要修改区间的时候,如果之前的父区间打了一个标记意为全部加上 10,此时要对这个子区间加上 5,如果不将父区间的标记下传的话,将会导致父区间在 pushup 操作中的值被修改成两个子区间的和,原本两个子区间是要被加上 10 的,现在并没有,导致错误
凡是割裂开来无法一块修改的均需要用 pushdown 操作
c++
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 100005;
typedef long long LL;
int A[N];
struct Node
{
int l, r;
LL sum, add;
}tr[N * 4];
void pushup(Node& u, Node& l, Node& r)
{
u.sum = l.sum + r.sum;
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void pushdown(int u)
{
auto& root = tr[u], & left = tr[u << 1], & right = tr[u << 1 | 1];
if (root.add)
{
left.sum += (LL)(left.r - left.l + 1) * root.add;
left.add += root.add;
right.sum += (LL)(right.r - right.l + 1) * root.add;
right.add += root.add;
root.add = 0;
}
}
void build(int u, int l, int r)
{
if (l == r) tr[u] = { l, r, A[l], 0 };
else
{
tr[u] = { l, r };
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
Node query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
else
{
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
Node res;
pushup(res, left, right);
return res;
}
}
void modify(int u, int l, int r, int d)
{
if (tr[u].l >= l && tr[u].r <= r)
{
tr[u].sum += (LL)(tr[u].r - tr[u].l + 1) * d;
tr[u].add += d;
}
else
{
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, d);
if (r > mid) modify(u << 1 | 1, l, r, d);
pushup(u);
}
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &A[i]);
build(1, 1, n);
while (m--)
{
char op[2];
int l, r;
scanf("%s%d%d", op, &l, &r);
if (*op == 'Q') printf("%lld\n", query(1, l, r).sum);
else
{
int d;
scanf("%d", &d);
modify(1, l, r, d);
}
}
return 0;
}
权值线段树

权值线段树用来维护区间内的数的一些信息,区间范围是数据的值域
如图就是数列 \(1,1,2,3,3,5,5,5\) 的权值线段树,其用来维护一个区间内数的出现次数
例题
I. 逆序对
这道题可以用权值线段树维护区间内数的出现次数
求逆序对可以用归并排序,树状数组来解决,时间复杂度都为 \(O(n\log n)\),其实使用权值线段树也可以达到同样的复杂度
首先需要对所有数进行离散化,然后在处理 每一个数的过程中,查询得到其离散化后的结果 \(id\),将其加入到权值线段树中,进行一次单点修改,逆序对的数量即在这个数之前被插入且大于它的数的个数之和,也就是对区间 \(id+1,n\) 进行一次区间求和,最终累加结果即可
c++
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 500005;
struct Node
{
int l, r, sum;
}tr[N * 4];
vector<int> alls;
int a[N];
int find(int x)
{
int l = 0, r = alls.size() - 1;
while (l < r)
{
int mid = l + r >> 1;
if (alls[mid] >= x) r = mid;
else l = mid + 1;
}
return r + 1;
}
void pushup(Node &u, Node &l, Node &r)
{
u.sum = l.sum + r.sum;
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
if (l == r) tr[u] = { l, r, 0 };
else
{
tr[u] = { l, r };
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int x, int v)
{
if (tr[u].l == x && tr[u].r == x) tr[u].sum += v;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
Node query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
else
{
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
Node res;
pushup(res, left, right);
return res;
}
}
int main()
{
int n;
cin >> n;
build(1, 1, n);
for (int i = 1; i <= n; i++)
{
cin >> a[i];
alls.push_back(a[i]);
}
sort(alls.begin(), alls.end()); //此处离散化不能去重
int res = 0;
for (int i = 1; i <= n; i++)
{
int id = find(a[i]);
modify(1, id, 1);
if (id == n) continue; //没有数比 ai 更大
res += query(1, id + 1, n).sum;
}
cout << res;
return 0;
}
II. 普通平衡树
把除了操作 4 以外的所有数进行离散化,再用权值线段树来维护这些数的出现次数
c++
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 100005;
struct Node
{
int l, r, sum;
}tr[N * 4];
struct Q
{
int op, x;
}q[N];
int a[N];
vector<int> alls;
int find(int x)
{
return lower_bound(alls.begin(), alls.end(), x) - alls.begin() + 1;
}
void pushup(Node &u, Node &l, Node &r)
{
u.sum = l.sum + r.sum;
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
if (l == r) tr[u] = { l, r, 0 };
else
{
tr[u] = { l, r };
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int x, int v)
{
if (tr[u].l == x && tr[u].r == x) tr[u].sum += v;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
Node query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
else
{
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
Node res;
pushup(res, left, right);
return res;
}
}
//查询排名为 x 的数类似于 modify
int query(int u, int x)
{
if (tr[u].l == tr[u].r) return tr[u].l;
if (x <= tr[u << 1].sum) return query(u << 1, x);
else return query(u << 1 | 1, x - tr[u << 1].sum);
}
int main()
{
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
scanf("%d%d", &q[i].op, &q[i].x);
if (q[i].op != 4) alls.push_back(q[i].x);
}
build(1, 1, alls.size());
sort(alls.begin(), alls.end());
alls.erase(unique(alls.begin(), alls.end()), alls.end());
for (int i = 1; i <= n; i++)
{
int op = q[i].op, x = q[i].x, id;
if (op != 4) id = find(x);
if (op == 1) modify(1, find(x), 1);
if (op == 2) modify(1, find(x), -1);
if (op == 3) printf("%d\n", id > 1 ? query(1, 1, id - 1).sum + 1 : 1);
if (op == 4) printf("%d\n", alls[query(1, x) - 1]);
if (op == 5)
{
//查询比 x 小的数的排名 y
int rank = query(1, 1, id - 1).sum;
//查询排名为 y 的数是什么
printf("%d\n", alls[query(1, rank) - 1]);
}
if (op == 6)
{
//查询比 x 大的数的排名 y
int rank = query(1, 1, find(x)).sum + 1;
//查询排名为 y 的数是什么
printf("%d\n", alls[query(1, rank) - 1]);
}
}
return 0;
}
III. Rmq Problem / mex
权值线段树不仅可以用来维护区间内数的出现次数,还可以维护下标(属于是将下标线段树反过来了)
这题有一个好处,由于没有修改操作,所以可以进行离线
我们希望维护区间 \(l,r\) 中没有出现过的最小整数 \(x\)
可以通过离线将 \(r\) 这一维去掉,按照 \(r\) 将所有的询问进行排序,然后再用 \(x\) 去映射 \(l\)(权值线段树)
处理一个询问 \(l,r\) 时,先保证 \(1\sim r\) 的所有数都已经插入进线段树,然后再进行查询,本题的查询和 普通平衡树 里面的一样,都是二分的查询
举个例子,如果区间 \(0,5\) 中出现的位置均 \(\ge l\),则 \(0,5\) 都不是答案,所以我们维护下标的最小值即可
本题不需要离散化,因为一共只有 \(n\) 个数,显然答案最大只到 \(n\),当有数字 \(>n\) 时,我们将其设为 \(n\),此时 \(n\) 不可能是答案
c++
#include <iostream>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 2000005;
typedef pair<int, int> PII;
struct Node
{
int l, r;
int minn;
}tr[N * 4];
vector<PII> q[N];
int ans[N];
int a[N];
void pushup(Node& u, Node& l, Node& r)
{
u.minn = min(l.minn, r.minn);
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
tr[u] = { l, r };
if (l == r) return;
else
{
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
}
void modify(int u, int x, int v)
{
//注意一个数时这里要的是数出现的位置最大值
if (tr[u].l == x && tr[u].r == x) tr[u].minn = v;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
int query(int u, int x)
{
if (tr[u].l == tr[u].r) return tr[u].l;
if (tr[u << 1].minn < x) return query(u << 1, x);
else return query(u << 1 | 1, x);
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
if (a[i] > n) a[i] = n;
}
build(1, 0, n);
for (int i = 1; i <= m; i++)
{
int l, r;
scanf("%d%d", &l, &r);
q[r].push_back({ l, i });
}
for (int i = 1; i <= n; i++)
{
modify(1, a[i], i);
for (auto j : q[i]) ans[j.second] = query(1, j.first);
}
for (int i = 1; i <= m; i++) printf("%d\n", ans[i]);
return 0;
}
IV. MEX Queries
不考虑数据范围的情况下,分析这个问题可以发现,这个问题就是在权值线段树上进行操作
操作 1 相当于将 \(l,r\) 全部赋值为 \(1\),操作 2 相当于将 \(l,r\) 全部赋值为 \(0\),操作 3 相当于将 \(l,r\) 的值全部
c++
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 300005, M = 100005;
typedef long long LL;
struct Q
{
LL l, r;
int op;
}q[M];
struct Node
{
int l, r;
// tag_assign: -1 表示没有赋值 0, 1 表示赋值多少
// tag_rev: 0 表示没有反转,1 表示有反转
int tag_assign, tag_rev;
int sum;
}tr[N * 4];
vector<LL> vec;
int find(LL x)
{
return lower_bound(vec.begin(), vec.end(), x) + 1;
}
void pushdown(int u)
{
auto& root = tr[u], & left = tr[u << 1], & right = tr[u << 1 | 1];
if (root.tag_assign)
{
left.sum = root.tag_assign * (left.r - left.l + 1);
left.tag_assign = root.tag_assign;
}
}
void build(int u, int l, int r)
{
if (l == r) tr[u] = { l, r, -1, 0, 0 };
else
{
tr[u] = { l, r };
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
}
void modify(int u, int l, int r, int op)
{
if (tr[u].l >= l && tr[u].r <= r)
{
// 有赋值时,反转操作失效
if (op == 1)
{
tr[u].tag_rev = 0;
tr[u].sum = tr[u].r - tr[u].l + 1;
tr[u].tag_assign = 1;
}
if (op == 2)
{
tr[u].tag_rev = 0;
tr[u].sum = 0;
tr[u].tag_assign = 0;
}
if (op == 3)
{
// 反转时,赋值操作反过来
tr[u].tag_rev ^= 0; // 反转两次等于没操作
tr[u].sum = tr[u].r - tr[u].l + 1 - tr[u].sum;
tr[u].tag_assign ^= 0;
}
}
else
{
}
}
int main()
{
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
LL l, r; int op;
scanf("%lld%lld%d", &l, &r, &op);
vec.push_back(l);
vec.push_back(r);
vec.push_back(r + 1);
q[i] = { l, r, op };
}
sort(vec.begin(), vec.end());
vec.erase(unique(vec.begin(), vec.end()), vec.end());
build(1, 1, n);
}
动态开点
当线段树维护的值域太大,但是需要访问的点很少时,需要使用动态开点,下面以一道例题具体说明
例题
I. Physical Education Lessons
区间赋值问题,但是数据范围为 \(1\le n\le 10^9\),可以用离散化加标记永久化解决,也可以使用动态开点的做法
我们可以套用懒标记的思想,当需要开点的时候再进行开点,一边修改查询,一边建树,那如何知道一个节点的儿子呢?我们可以直接建立一个数组 \(lsx,rsx\) 分别表示 \(x\) 的左儿子和右儿子
本题还需要估计空间,设修改次数为 \(m\),分裂开点时,最坏情况会走到两个分支,到叶子节点,所以最坏开点数为 \(2\log n\),空间复杂度为 \(O(m\cdot 2\log n)\)
具体地,我们开了四个数组 \(sum,tag,ls,rs\),估计内存为 \(3\times 10^5\times 2\times \log 10^9\times 4\times 4\)(另一个 \(4\) 表示 int 有四个字节),总计约为 \(274MB\) 依旧超出空间限制,此时就需要将 \(2\log n\) 估计得小一些,毕竟不可能每次都是最坏情况,改成 \(50\) 即可通过本题
c++
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 300005 * 50;
int sum[N], tag[N], ls[N], rs[N];
//tag 0 工作日 1 非工作日 -1 没有修改过
int root, tot; //根节点以及节点的数量
void pushup(int u)
{
sum[u] = sum[ls[u]] + sum[rs[u]];
}
void pushdown(int u, int l, int r)
{
if (tag[u] == -1) return;
if (!ls[u]) ls[u] = ++tot;
if (!rs[u]) rs[u] = ++tot;
int mid = l + r >> 1;
sum[ls[u]] = tag[u] * (mid - l + 1);
sum[rs[u]] = tag[u] * (r - mid);
tag[ls[u]] = tag[rs[u]] = tag[u];
tag[u] = -1;
}
void modify(int& u, int l, int r, int x, int y, int d)
{
if (!u) u = ++tot;
if (l >= x && r <= y)
{
sum[u] = d * (r - l + 1);
tag[u] = d;
return;
}
pushdown(u, l, r);
int mid = l + r >> 1;
if (x <= mid) modify(ls[u], l, mid, x, y, d);
if (y > mid) modify(rs[u], mid + 1, r, x, y, d);
pushup(u);
}
int main()
{
memset(tag, -1, sizeof tag);
int n, q;
scanf("%d%d", &n, &q);
for (int i = 1; i <= q; i++)
{
int l, r, k;
scanf("%d%d%d", &l, &r, &k);
if (k == 1) modify(root, 1, n, l, r, 1);
else modify(root, 1, n, l, r, 0);
printf("%d\n", n - sum[root]);
}
return 0;
}
标记永久化
所谓标记永久化,就是修改时留下的懒标记不下传,也不删除,而是留在打标记的那个节点上。当查询经过这个点时,就加上这个节点的懒标记造成的影响
下面以一道区间修改的模板题具体说明
例题
I. 模板 线段树 1
如果用懒标记进行区间修改将会变得非常麻烦,考虑使用标记永久化
样例数据:
c++
4 4
1 1 1 1
1 1 4 1
1 2 4 2
2 2 4
2 3 3

如图先进行建树操作
当我们想对 \(1,4\) 加一时,首先修改它的和,然后打上标记 \(1\)

当我们修改 \(2,4\) 时,递归会经过 \(1,4,1,2,3,4,2,2\) 几个区间,需要修改它们的区间和,并在完全覆盖的区间上打上标记

查询时需要计算没有下传下来的懒标记,例如查询 \(3,3\),就要加上一路递归下来的 \(1,4,3,4\) 的懒标记
c++
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 100005;
typedef long long LL;
struct Node
{
int l, r;
LL sum, tag;
}tr[N * 4];
LL a[N];
void build(int u, int l, int r)
{
if (l == r) tr[u] = { l, r, a[l], 0 };
else
{
tr[u] = { l, r };
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
}
LL query(int u, int l, int r, LL s)
{
if (tr[u].l >= l && tr[u].r <= r)
return tr[u].sum + (min(tr[u].r, r) - max(tr[u].l, l) + 1) * s;
s += tr[u].tag;
LL sum = 0;
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) sum += query(u << 1, l, r, s);
if (r > mid) sum += query(u << 1 | 1, l, r, s);
return sum;
}
void modify(int u, int l, int r, int d)
{
tr[u].sum += (min(tr[u].r, r) - max(tr[u].l, l) + 1) * d;
if (tr[u].l >= l && tr[u].r <= r) tr[u].tag += d;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, d);
if (r > mid) modify(u << 1 | 1, l, r, d);
}
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);
build(1, 1, n);
for (int i = 1; i <= m; i++)
{
int op, x, y, k;
scanf("%d%d%d", &op, &x, &y);
if (op == 1)
{
scanf("%d", &k);
modify(1, x, y, k);
}
else printf("%lld\n", query(1, x, y, 0));
}
return 0;
}
II. Mayor's posters
想要知道一共有多少种不同的海报,直观上想,就是扫一遍整个序列,看有多少种不同的数
区间让我们想到了线段树,扫一遍整个序列,其实就是对整个序列的每一个位置进行一遍单点查询
我们把每个叶子节点当做如图的格子
我们按照贴海报的顺序,进行编号,编号大的说明贴的晚
每次贴一张海报,相当于对于一段区间进行了一次赋值操作,对于区间赋值的操作,我们可以不下传标记,而是在单点查询的过程中,在递归到叶子节点的路径上取最大值,就像上一题一样,将路径上的标记加起来再乘以有效长度
但这题 \(l_i\) 和 \(r_i\) 的范围很大,所以需要离散化,这题的难点就在这里
假设现在要在 \(1,35,7\) 这两个区间贴海报,离散化之后会变成 \(1,23,4\)
但原本 \(3,5\) 之间是存在距离的,离散化之后,\(1,35,7\) 两个区间可以正常表示,但是 \(3,5\) 就不行了,原因是离散化之后 \(2,3\) 是相邻的,中间没有空隙,导致错误
解决方案是:我们可以把区间 \(l,r\),\(r+1\) 的值放进离散化数组中,如果原本两个区间就是相邻的,那么就会在去重那一步去掉,否则就会存在一个数,让题目给定的两两区间的左右端点之间存在空隙
c++
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 30005, M = 10005;
typedef pair<int, int> PII;
struct Node
{
int l, r, v;
}tr[N * 4];
vector<int> alls;
vector<int> ans;
PII range[M];
void build(int u, int l, int r)
{
tr[u] = { l, r, 0 };
if (l == r) return;
else
{
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
}
int find(int x)
{
return lower_bound(alls.begin(), alls.end(), x) - alls.begin() + 1;
}
void modify(int u, int l, int r, int id)
{
if (tr[u].l >= l && tr[u].r <= r) tr[u].v = id;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, id);
if (r > mid) modify(u << 1 | 1, l, r, id);
}
}
int query(int u, int x, int s)
{
if (tr[u].l == x && tr[u].r == x)
{
s = max(s, tr[u].v);
return s;
}
s = max(s, tr[u].v);
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) return query(u << 1, x, s);
else return query(u << 1 | 1, x, s);
}
int main()
{
int t;
scanf("%d", &t);
while (t--)
{
memset(tr, 0, sizeof tr);
int n;
scanf("%d", &n);
alls.clear();
for (int i = 1; i <= n; i++)
{
int l, r;
scanf("%d%d", &l, &r);
range[i] = { l, r };
alls.push_back(l), alls.push_back(r);
alls.push_back(r + 1);
}
sort(alls.begin(), alls.end());
alls.erase(unique(alls.begin(), alls.end()), alls.end());
alls.pop_back();
build(1, 1, alls.size());
for (int i = 1; i <= n; i++)
{
int& l = range[i].first;
int& r = range[i].second;
l = find(l), r = find(r);
modify(1, l, r, i);
}
ans.clear();
for (int i = 1; i <= alls.size(); i++)
{
int x = query(1, i, 0);
if (x == 0) continue;
ans.push_back(x);
}
sort(ans.begin(), ans.end());
ans.erase(unique(ans.begin(), ans.end()), ans.end());
printf("%d\n", ans.size());
}
return 0;
}
线段树合并
通常合并动态开点的权值线段树,以下是算法流程

假设现在合并到了两棵线段树 \(x,y\) 的 \(p\) 节点:
-
如果一棵树的 \(p\) 节点为空,那么返回另一个的 \(p\) 节点(如图,\(x\) 线段树 \(1,3\) 没有左子节点,所以返回了 \(1,2\))
-
如果已经合并到两棵线段树的叶子节点,那么就把 \(y\) 在 \(p\) 节点的值加到 \(x\) 上,并将 \(p\) 节点返回
-
递归处理左子树,右子树
-
用左右子树的值更新当前节点
-
将 \(p\) 节点返回
例题
I. Vani有约会 雨天的尾巴
将 \(x\to y\) 的路径上都发放一袋救济粮,使用树上差分进行操作
用权值线段树维护,范围是救济粮种类的值域,维护 \(l,r\) 范围中,最多的救济粮的类型以及数量
将每一个点都开一棵线段树,修改时,使用动态开点防止爆空间
由于每次需要更改四棵线段树,所以总的时间复杂度为 \(O(4\cdot m\log n)\)
空间复杂度方面,每次最坏 \(2\log n\),\(m\) 次操作,每次修改 \(4\) 棵 ,总的空间复杂度为 \(O(m\cdot 8\log n)\),计算得知可以通过
树上差分中,我们最后需要求出子树和,这道题中,我们需要将所有的线段树进行合并
我们依旧按照计算子树和一样的顺序,自底向上地进行合并
c++
#include <iostream>
#include <cstring>
#include <queue>
#include <algorithm>
using namespace std;
const int N = 100005, M = N * 2;
int sum[N * 140], type[N * 140], ls[N * 140], rs[N * 140];
int root[N], tot;
int ans[N];
int depth[N], fa[N][17];
int h[N], ne[M], e[M], idx;
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void pushup(int u)
{
if (sum[ls[u]] >= sum[rs[u]])
{
sum[u] = sum[ls[u]];
type[u] = type[ls[u]];
}
else
{
sum[u] = sum[rs[u]];
type[u] = type[rs[u]];
}
}
void modify(int& u, int l, int r, int p, int k)
{
if (!u) u = ++tot;
if (l == r)
{
sum[u] += k;
type[u] = p;
return;
}
int mid = l + r >> 1;
if (p <= mid) modify(ls[u], l, mid, p, k);
else modify(rs[u], mid + 1, r, p, k);
pushup(u);
}
void bfs(int root)
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[root] = 1;
queue<int> q;
q.push(root);
while (!q.empty())
{
int t = q.front();
q.pop();
for (int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if (depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
q.push(j);
fa[j][0] = t;
for (int k = 1; k <= 16; k++)
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}
int lca(int a, int b)
{
if (depth[a] < depth[b]) swap(a, b);
for (int k = 16; k >= 0; k--)
if (depth[fa[a][k]] >= depth[b])
a = fa[a][k];
if (a == b) return b;
for (int k = 16; k >= 0; k--)
if (fa[a][k] != fa[b][k])
a = fa[a][k], b = fa[b][k];
return fa[a][0];
}
int merge(int x, int y, int l, int r)
{
if (!x || !y) return x + y;
if (l == r)
{
sum[x] += sum[y];
return x;
}
int mid = l + r >> 1;
ls[x] = merge(ls[x], ls[y], l, mid);
rs[x] = merge(rs[x], rs[y], mid + 1, r);
pushup(x);
return x;
}
void dfs(int u, int fa)
{
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa) continue;
dfs(j, u);
root[u] = merge(root[u], root[j], 1, N);
}
ans[u] = sum[root[u]] ? type[root[u]] : 0;
}
int main()
{
memset(h, -1, sizeof h);
int n, m;
scanf("%d%d", &n, &m);
for (int i = 2; i <= n; i++)
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
bfs(1);
for (int i = 1; i <= m; i++)
{
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
modify(root[x], 1, N, z, 1);
modify(root[y], 1, N, z, 1);
int p = lca(x, y);
modify(root[p], 1, N, z, -1);
modify(root[fa[p][0]], 1, N, z, -1);
}
dfs(1, -1);
for (int i = 1; i <= n; i++) printf("%d\n", ans[i]);
return 0;
}
可持久化线段树
可持久化线段树可以支持查询历史区间信息
P3834 【模板】可持久化线段树 2 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
看这样的一个问题,如果使用权值线段树,很难做到查询某一个区间数出现次数为 \(k\) 的数
此时,我们按照每个数在序列中的出现次数,划分一个个历史阶段,使用可持久化线段树
如图,我们发现每次修改都最多只会影响到 \(\log n+1\) 个节点

于是一些没被修改过的点如果重新开一棵就浪费了,我们将被修改的点合并到前一个阶段的线段树中去

如图,可持久化线段树使用动态开点的方式进行存储,原因是对于一个区间存在多个历史版本不好确定子节点究竟是什么
对于每个节点存储其左右儿子的编号,对于每个历史版本,还需要保存根节点的编号,就相当于我们每次询问时找到了一个历史版本的入口
内存分析:一开始会建立一棵空树,需要 \(2n-1\) 个节点。有 \(n\) 次插入对应不同的历史版本,每次最多新增 \(\log n+1\) 个节点
c++
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 200005;
struct Node
{
int l, r;
int cnt;
}tr[N * 4 + N * 18];
vector<int> alls;
int a[N];
int root[N], idx;
int find(int x)
{
return lower_bound(alls.begin(), alls.end(), x) - alls.begin();
}
//建立空树
int build(int l, int r)
{
int p = ++idx;
if (l == r) return p;
int mid = l + r >> 1;
tr[p].l = build(l, mid), tr[p].r = build(mid + 1, r);
return p;
}
//对于一条链上的所有的节点需要新开一个
//因为一次修改会影响到它们
int insert(int p, int l, int r, int x)
{
int q = ++idx; //复制一个新节点
tr[q] = tr[p];
if (l == r)
{
tr[q].cnt++;
return q;
}
int mid = l + r >> 1;
if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x);
else tr[q].r = insert(tr[p].r, mid + 1, r, x);
tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt;
return q;
}
//q 保存了 1 ~ r 的答案
//p 是用来剔除掉 1 ~ l - 1 对 q 答案的影响
int query(int q, int p, int l, int r, int k)
{
if (l == r) return r;
int cnt = tr[tr[q].l].cnt - tr[tr[p].l].cnt; //减去 1 ~ l - 1 出现的所有数
int mid = l + r >> 1;
if (k <= cnt) return query(tr[q].l, tr[p].l, l, mid, k);
else return query(tr[q].r, tr[p].r, mid + 1, r, k - cnt);
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
alls.push_back(a[i]);
}
sort(alls.begin(), alls.end());
alls.erase(unique(alls.begin(), alls.end()), alls.end());
root[0] = build(0, alls.size() - 1);
for (int i = 1; i <= n; i++)
root[i] = insert(root[i - 1], 0, alls.size() - 1, find(a[i]));
while (m--)
{
int l, r, k;
scanf("%d%d%d", &l, &r, &k);
printf("%d\n", alls[query(root[r], root[l - 1], 0, alls.size() - 1, k)]);
}
return 0;
}
树上主席树
区间主席树是按照区间将所有的数值信息丢到主席树里面去,查询时 \(rootr\) 提供了 \(1\sim r\) 时刻的信息,还要消除 \(1\sim l -1\) 时刻信息的影响
与前缀和是非常类似的,前缀和存在树上前缀和的操作,这启发我们主席树是不是也可以在树上维护呢?
类比树上前缀和 \(sx+sy-slca-sfa\[lca]\),消除了 \(lca\) 和 \(falca\) 对答案的影响,同理我们用主席树也可以这样做
P2633 Count on a tree - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
c++
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <cstdio>
using namespace std;
const int N = 100005, M = N * 2;
struct Node
{
int l, r;
int cnt;
}tr[N * 4 + N * 18];
int a[N];
vector<int> alls;
int find(int x)
{
return lower_bound(alls.begin(), alls.end(), x) - alls.begin();
}
int root[N], tot;
int fa[N][17], depth[N];
int h[N], ne[M], e[M], idx;
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
int build(int l, int r)
{
int p = ++tot;
if (l == r) return p;
int mid = l + r >> 1;
tr[p].l = build(l, mid), tr[p].r = build(mid + 1, r);
return p;
}
int insert(int p, int l, int r, int x)
{
int q = ++tot;
tr[q] = tr[p];
if (l == r)
{
tr[q].cnt++;
return q;
}
int mid = l + r >> 1;
if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x);
else tr[q].r = insert(tr[p].r, mid + 1, r, x);
tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt;
return q;
}
int query(int q1, int q2, int p1, int p2, int l, int r, int k)
{
if (l == r) return r;
int cnt = tr[tr[q1].l].cnt + tr[tr[q2].l].cnt - tr[tr[p1].l].cnt - tr[tr[p2].l].cnt;
int mid = l + r >> 1;
if (k <= cnt) return query(tr[q1].l, tr[q2].l, tr[p1].l, tr[p2].l, l, mid, k);
else return query(tr[q1].r, tr[q2].r, tr[p1].r, tr[p2].r, mid + 1, r, k - cnt);
}
int lca(int a, int b)
{
if (depth[a] < depth[b]) swap(a, b);
for (int k = 16; k >= 0; k--)
if (depth[fa[a][k]] >= depth[b])
a = fa[a][k];
if (a == b) return a;
for (int k = 16; k >= 0; k--)
if (fa[a][k] != fa[b][k])
a = fa[a][k], b = fa[b][k];
return fa[a][0];
}
void dfs(int u, int father)
{
root[u] = insert(root[father], 0, alls.size() - 1, find(a[u]));
fa[u][0] = father;
depth[u] = depth[father] + 1;
for (int j = 1; j <= 16; j++)
fa[u][j] = fa[fa[u][j - 1]][j - 1];
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == father) continue;
dfs(j, u);
}
}
int query_path(int u, int v, int k)
{
int p = lca(u, v);
int res = query(root[u], root[v], root[p], root[fa[p][0]], 0, alls.size() - 1, k);
return alls[res];
}
int main()
{
memset(h, -1, sizeof h);
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
alls.push_back(a[i]);
}
sort(alls.begin(), alls.end());
alls.erase(unique(alls.begin(), alls.end()), alls.end());
for (int i = 2; i <= n; i++)
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
root[0] = build(0, alls.size() - 1);
dfs(1, 0);
int last = 0;
for (int i = 1; i <= m; i++)
{
int u, v, k;
scanf("%d%d%d", &u, &v, &k);
u ^= last;
last = query_path(u, v, k);
printf("%d\n", last);
}
return 0;
}