原理
树套树,顾名思义就是在树里套一个树,这里的树是一种代称,它指的是可以高效的查询和修改的数据结构,比如分块,树状数组,线段树,平衡树,那么既然我们有了这些高效的数据结构我们为什么还要再往里面套一个呢,难道是因为不够高效吗?肯定不是。主要是因为无法解决问题,大多数情况下是因为维度不够,(想象一下,不管是分块,树状数组,线段树,平衡树,他们管理的都是一个一维的数组,也就是区间问题),如果问题是二维的,我们就需要升维,那自然而然,往一维里套一个一维不就是二维的吗。
题目
U644377 平面点对 - 洛谷
https://www.luogu.com.cn/problem/U644377
AC代码
cpp
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 10;
const int inf = 1e5; // 坐标范围 [-1e5, 1e5]
struct Segt1 { // 内层线段树
// tr[u]存储点的个数, ls[u]存储左儿子,rs[u]存储右儿子
int idx, tr[maxn * 100], ls[maxn * 100], rs[maxn * 100];
void push_up(int u) {
tr[u] = tr[ls[u]] + tr[rs[u]];
}
// v = 1 插入;v = -1 删除
void add(int y, int v, int l, int r, int& u) {
if (!u) u = ++idx;
if (l == r) {
tr[u] += v;
return;
}
int mid = l + r >> 1;
(y <= mid) ? add(y, v, l, mid, ls[u]) : add(y, v, mid + 1, r, rs[u]);
push_up(u);
}
// 查询有多少个 y 在 [y1, y2] 中
int query(int y1, int y2, int l, int r, int u) {
if (y1 <= l && r <= y2)
return tr[u];
int res = 0, mid = l + r >> 1;
if (y1 <= mid) res += query(y1, y2, l, mid, ls[u]);
if (y2 > mid) res += query(y1, y2, mid + 1, r, rs[u]);
return res;
}
} segt1;
int rt[maxn];// 外层改为树状数组
const int offset = 100001; // 用于处理负数坐标
const int max_range = 200005;
#define lowbit(x) x & -x
void bit_add(int x, int y, int v) {
x += offset; // 移位,确保 x > 0
for (; x < max_range; x += lowbit(x))
segt1.add(y, v, -inf, inf, rt[x]);
}
int bit_query(int x, int y1, int y2) {
x += offset;
int res = 0;
for (; x > 0; x -= lowbit(x))
res += segt1.query(y1, y2, -inf, inf, rt[x]);
return res;
}
int px[50005], py[50005];// 原始点坐标存储
int main() {
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d%d", &px[i], &py[i]);
bit_add(px[i], py[i], 1);
}
int q;
scanf("%d", &q);
while (q--) {
int op;
scanf("%d", &op);
if (op == 1) {
int p, nx, ny;
scanf("%d%d%d", &p, &nx, &ny);
bit_add(px[p], py[p], -1); // 删旧点
px[p] = nx; py[p] = ny;
bit_add(px[p], py[p], 1); // 加新点
}
else {
int x1, x2, y1, y2;
scanf("%d%d%d%d", &x1, &x2, &y1, &y2);
printf("%d\n", bit_query(x2, y1, y2) - bit_query(x1 - 1, y1, y2));
}
}
return 0;
}
逐步分析
我们可以很显然的发现,普通的线段树或别的一些数据结构无法处理这样的问题,问题在于他要处理的是两个独立的问题,也就是二维,那我们就要分析一下这道题的两个维度都是什么,这道题非常的直白,这两个维度分别是 x1≤x≤x2 和 y1≤y≤y2 ,也就是说,当我们在修改或查找时,我们要先修改或查找符合条件的 x ,接着修改或查找符合条件的 y。
选择用什么套
那么接下来就是数据结构选择的问题了,到底是什么套什么呢?先说结论:通常来说当二维问题只涉及到单点修改时我们使用树状数组套线段树,当问题涉及到区间修改时我们使用线段树套线段树,为什么呢?
内层的选择
内层的树占用了远比外层的树更多的空间(就像是一个年级有十个班但是一个年级总共有几百人一样),我们需要尽可能的优化内层的空间,而线段树的动态开点就完美的符合这一条件,(而且线段树的实现相对简单,主要是用的多,自然熟练,不像平衡树)。
|------------|-----------------|--------------------------|
| 特性 | 静态线段树 (Static) | 动态开点线段树 (Dynamic) |
| 理论空间 | 4N | O(MlogN) |
| N=10^9 场景 | 需要约 4×10^9 个节点 | 如果 M=105,仅需约 3×10^6 个节点 |
| 内存占用 | 约 15 GB (无法运行) | 约 48 MB (轻松运行) |
| 空间分布 | 预先开辟一整块连续内存 | 随用随开,按需分配 |
外层的选择
外层主要起的是向导的作用,它既不需要区间修改,又不需要很好的空间优化,那我们自然就可以效率优异,实现简单的树状数组了。
注:用树状数组时记得索引一定要大于 0 ,一定不能等于 0,否则就死循环了。
下面同样是一道树状数组套线段树的题目,我们先分析问题再转换问题就可以解决了,虽然逻辑不如这道题清晰,但是实现反而比这道题简单,大家可以参考一下:
参考题目
P6514 [QkOI#R1] Quark and Strings - 洛谷
https://www.luogu.com.cn/problem/P6514
AC代码
cpp
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;
int n, Q;
struct Seg { // 内层线段树
// tr[u]:点的个数, ls[u]左儿子,rs[u]右儿子
int idx, tr[maxn * 200], ls[maxn * 200], rs[maxn * 200];
void push_up(int u) {
tr[u] = tr[ls[u]] + tr[rs[u]];
}
// 插入一个 R
void add(int R, int l, int r, int& u) {
if (!u) u = ++idx;
if (l == r) {
tr[u]++;
return;
}
int mid = l + r >> 1;
(R <= mid) ? add(R, l, mid, ls[u]) : add(R, mid + 1, r, rs[u]);
push_up(u);
}
// 查询有多少个 y 在 [L, R] = [y1, y2]
int query(int L, int R, int l, int r, int u) {
if (L <= l && r <= R)
return tr[u];
int res = 0, mid = l + r >> 1;
if (L <= mid) res += query(L, R, l, mid, ls[u]);
if (R > mid) res += query(L, R, mid + 1, r, rs[u]);
return res;
}
} seg1;
int rt[maxn]; // 树状数组每个节点对应的内层线段树根节点
#define lowbit(x) x & -x
// 树状数组添加:在外层索引为 L 的位置,向其对应的内层线段树插入 R
void bit_add(int L, int R) {
for (; L <= n; L += lowbit(L)) {
seg1.add(R, 1, n, rt[L]);
}
}
// 树状数组查询:统计外层索引在 [1, L] 范围内,且内层索引在 [v, n] 范围内的点数
int bit_query(int L, int v) {
int res = 0;
for (; L > 0; L -= lowbit(L)) {
res += seg1.query(v, n, 1, n, rt[L]);
}
return res;
}
int main() {
if (scanf("%d%d", &n, &Q) == EOF) return 0;
while (Q--) {
int op, l, r;
scanf("%d%d%d", &op, &l, &r);
if (op == 1)
bit_add(l, r);
else
printf("%d\n", bit_query(l, r));
}
return 0;
}
