树上莫队
前言
学习这个前,请确保你已经学会了普通莫队
如果可以,最好把带修莫队也学了
正文
例题1
题意简述:
- 给定一棵 n n n 个节点的树,根节点为 1 1 1。每个节点上有一个颜色 c i c_i ci。 m m m 次操作。操作有一种:
u k
:询问在以 u u u 为根的子树中,出现次数 ≥ k \ge k ≥k 的颜色有多少种。
数据范围:
- 2 ≤ n ≤ 1 0 5 2\le n\le 10^5 2≤n≤105, 1 ≤ m ≤ 1 0 5 1\le m\le 10^5 1≤m≤105, 1 ≤ c i , k ≤ 1 0 5 1\le c_i,k\le 10^5 1≤ci,k≤105。
分析
看着数据范围,应该是 O ( n log n ) O(n\log n) O(nlogn)或 ( n n ) (n\sqrt n) (nn )的复杂度,可是 n log n n\log n nlogn 的复杂度不太好想,再仔细看问题,这不明显莫队吗,可这是在树上啊。
于是树上莫队就被发明了,顾名思义,就是在树上莫队。
考虑将树上的询问转换为区间询问。
我们需要用到一个东西:dfs序或欧拉序
这样就能将树上的询问转换为区间询问。
dfs序多用于求有关子树的,欧拉序多用于求有关树上路径的,如果两个都有关,想想树剖吧 {\color{Red} \text{dfs序多用于求有关子树的,欧拉序多用于求有关树上路径的,如果两个都有关,想想树剖吧} } dfs序多用于求有关子树的,欧拉序多用于求有关树上路径的,如果两个都有关,想想树剖吧
看回这道题,因为是求一整颗子树,所以我们要记录这颗子树的起始节点即根节点的dfs序和最后一个叶子节点的dfs序,剩下的正常莫队维护就可以。
注意:
- add和del操作操作的是dfs序,而不是直接的 l , r l,r l,r
Code:
cpp
#include <bits/stdc++.h>
#define int long long
#define IOS ios::sync_with_stdio(false), cin.tie(NULL), cout.tie(NULL)
#define cou(i) cout << fixed << setprecision(i)
using namespace std;
const int N = 2e5 + 1;
struct fy {
int l, r, id, k;
} q[N];
vector<int> E[N];
int n, m, a[N], b[N], ks, pos[N], cnt[N], ans, Ans[N], sum[N];
int dfn[N], las[N], beg[N], idx;
bool cmp(fy x, fy y) { return pos[x.l] == pos[y.l] ? x.r < y.r : pos[x.l] < pos[y.l]; }
inline void del(int x) {
sum[cnt[a[x]]]--;
cnt[a[x]]--;
}
inline void add(int x) {
cnt[a[x]]++;
sum[cnt[a[x]]]++;
}
void dfs(int u, int fa) {
dfn[++idx] = u;
beg[u] = idx;
for (int i = 0; i < E[u].size(); i++) {
int v = E[u][i];
if (v == fa) continue;
dfs(v, u);
}
las[u] = idx;
}
signed main() {
IOS;
cou(0);
cin >> n >> m;
ks = pow(n, 2.0 / 3.0);
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 0; i < N; i++) pos[i] = (i - 1) / ks + 1;
// sort(b + 1, b + n + 1);
// int len = unique(b + 1, b + n + 1) - b - 1;
// for (int i = 1; i <= n; i++) {
// a[i] = lower_bound(b + 1, b + n + 1, a[i]) - b;
// }
for (int i = 1, x, y; i < n; i++) cin >> x >> y, E[x].push_back(y), E[y].push_back(x);
dfs(1, 0);
for (int i = 1, x, y; i <= m; i++) {
cin >> x >> y, q[i].id = i;
q[i].l = beg[x], q[i].r = las[x], q[i].k = y;
}
sort(q + 1, q + m + 1, cmp);
int l = 1, r = 0;
for (int i = 1; i <= m; i++) {
int L = q[i].l, R = q[i].r, id = q[i].id;
while (l > L) add(dfn[--l]);
while (r < R) add(dfn[++r]);
while (l < L) del(dfn[l++]);
while (r > R) del(dfn[r--]);
Ans[id] = sum[q[i].k];
}
for (int i = 1; i <= m; i++) cout << Ans[i] << "\n";
return 0;
}
例题2
题意简述:
- 给定 n n n 个结点的树,每个结点有一种颜色。
- m m m 次询问,每次询问给出 u , v u,v u,v,回答 u , v u,v u,v 之间的路径上的结点的不同颜色数。
- 1 ≤ n ≤ 4 × 1 0 4 1\le n\le 4\times 10^4 1≤n≤4×104, 1 ≤ m ≤ 1 0 5 1\le m\le 10^5 1≤m≤105,颜色是不超过 2 × 1 0 9 2 \times 10^9 2×109 的非负整数。
分析
还是一样的套路,使用树上莫队,但注意到这里问的是路径问题,所以使用欧拉序。
剩下的就很简单了,直接套模板即可。
注意:
- 数据很大,需离散化
- 欧拉序的长度为 2 n 2n 2n,所以块长要改一下
- 需要处理lca的节点
- 处理lca时注意节点顺序
- 注意处理无效节点
呃,简单讲一下:
因为欧拉序的性质,所以容易发现如果两个节点在同一颗子树内,那么直接用入栈顺序即可,那lca的点自然不用管。
若不在,我们发现,需要一个节点出栈以后,才能扫到另一个节点,所以一个用入栈序号(假设为 y y y),一个用出栈序号(假设为 x x x点,且 z = l c a ( x , y ) z=lca(x,y) z=lca(x,y))。
那么 z z z的除含有 x , y x,y x,y节点的子树就是无用的,但也在区间之内,需要删掉,观察到他们在这个区间内出现了两次,所以若出现次数为两次,就不要统计进答案之中。
同时在这个区间中,因为 y y y还没有返回到 z z z,所以 z z z节点的答案是少算的,要加上,并在事后删除。
Code:
cpp
#include <bits/stdc++.h>
#define int long long
#define IOS ios::sync_with_stdio(false), cin.tie(NULL), cout.tie(NULL)
#define cou(i) cout << fixed << setprecision(i)
using namespace std;
const int N = 2e5 + 1;
struct fy {
int l, r, id, k;
} q[N];
vector<int> E[N];
int n, m, a[N], b[N], ks, pos[N], cnt[N], ans, Ans[N], sum[N], f[N][20];
int dfn[N], las[N], beg[N], idx;
bool cmp(fy x, fy y) {
return pos[x.l] == pos[y.l] ? (pos[x.l] & 1 ? x.r < y.r : x.r > y.r) : pos[x.l] < pos[y.l];
}
inline void del(int x) {
cnt[x]--;
if (cnt[x] == 1) ans += (++sum[a[x]] == 1);
if (cnt[x] == 0) ans -= (--sum[a[x]] == 0);
}
inline void add(int x) {
cnt[x]++;
if (cnt[x] == 1) ans += (++sum[a[x]] == 1);
if (cnt[x] == 2) ans -= (--sum[a[x]] == 0);
}
void dfs(int u, int fa) {
dfn[++idx] = u;
beg[u] = idx;
f[u][0] = fa;
for (int i = 1; i < 20; i++) f[u][i] = f[f[u][i - 1]][i - 1];
for (int i = 0; i < E[u].size(); i++) {
int v = E[u][i];
if (v == fa) continue;
dfs(v, u);
}
dfn[++idx] = u;
las[u] = idx;
}
int lca(int x, int y) {
if (x == y) return x;
if (beg[x] < beg[y]) swap(x, y);
for (int i = 19; i >= 0; i--)
if (beg[f[x][i]] > beg[y]) x = f[x][i];
return f[x][0];
}
signed main() {
IOS;
cou(0);
cin >> n >> m;
ks = pow(2 * n, 2.0 / 3.0);
for (int i = 1; i <= n; i++) cin >> a[i], b[i] = a[i];
for (int i = 0; i < N; i++) pos[i] = (i - 1) / ks + 1;
sort(b + 1, b + n + 1);
int len = unique(b + 1, b + n + 1) - b - 1;
for (int i = 1; i <= n; i++) {
a[i] = lower_bound(b + 1, b + len + 1, a[i]) - b;
}
for (int i = 1, x, y; i < n; i++) cin >> x >> y, E[x].push_back(y), E[y].push_back(x);
dfs(1, 0);
for (int i = 1, x, y, z; i <= m; i++) {
cin >> x >> y, q[i].id = i;
if (beg[y] < beg[x]) swap(x, y);
z = lca(x, y);
if (z == x) q[i].l = beg[x], q[i].r = beg[y], q[i].k = 0;
else q[i].l = las[x], q[i].r = beg[y], q[i].k = z;
}
sort(q + 1, q + m + 1, cmp);
int l = 1, r = 0;
for (int i = 1; i <= m; i++) {
int L = q[i].l, R = q[i].r, id = q[i].id;
while (l > L) add(dfn[--l]);
while (r < R) add(dfn[++r]);
while (l < L) del(dfn[l++]);
while (r > R) del(dfn[r--]);
if (q[i].k) add(q[i].k);
Ans[id] = ans;
if (q[i].k) del(q[i].k);
}
for (int i = 1; i <= m; i++) cout << Ans[i] << "\n";
return 0;
}
扩展
树上带修莫队例题
洛谷
树上莫队,但是带修改,好像也没什么好讲的,注意块长设为 ( 2 n ) 2 3 (2n)^\frac{2}{3} (2n)32
就代码比较长
部分Code:
cpp
void dfs(int u, int fa) {
dfn[++idx] = u;
beg[u] = idx;
f[u][0] = fa;
for (int i = 1; i < 20; i++) f[u][i] = f[f[u][i - 1]][i - 1];
for (int i = 0; i < E[u].size(); i++) {
int v = E[u][i];
if (v == fa) continue;
dfs(v, u);
}
dfn[++idx] = u;
las[u] = idx;
}
int lca(int x, int y) {
if (x == y) return x;
if (beg[x] < beg[y]) swap(x, y);
for (int i = 19; i >= 0; i--)
if (beg[f[x][i]] > beg[y]) x = f[x][i];
return f[x][0];
}
void modify(int x) {
int X = C[x].X, Y = C[x].Y;
if (cnt[X] == 1) {
ans -= sum[a[X]] * (sum[a[X]] - 1) / 2;
sum[a[X]]--;
ans += sum[a[X]] * (sum[a[X]] - 1) / 2;
ans -= sum[Y] * (sum[Y] - 1) / 2;
sum[Y]++;
ans += sum[Y] * (sum[Y] - 1) / 2;
}
swap(a[C[x].X], C[x].Y);
}
signed main() {
IOS;
cou(0);
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
ks = 2 * pow(n, 2.0 / 3.0);
for (int i = 0; i < N; i++) pos[i] = (i - 1) / ks + 1;
for (int i = 1, x, y; i < n; i++) {
cin >> x >> y;
x++, y++;
E[x].push_back(y), E[y].push_back(x);
}
dfs(1, 0);
for (int i = 1, op, x, y, z; i <= m; i++) {
cin >> op >> x >> y;
if (op == 1) {
x++;
C[++cid].X = x, C[cid].Y = y;
}
if (op == 2) {
x++, y++;
qid++;
q[qid].id = qid, q[qid].t = cid;
if (beg[y] < beg[x]) swap(x, y);
z = lca(x, y);
if (z == x) q[qid].l = beg[x], q[qid].r = beg[y], q[qid].k = 0;
else q[qid].l = las[x], q[qid].r = beg[y], q[qid].k = z;
}
}
sort(q + 1, q + qid + 1, cmp);
int l = 1, r = 0, T = 0;
for (int i = 1; i <= qid; i++) {
int L = q[i].l, R = q[i].r, id = q[i].id, ti = q[i].t;
while (l > L) add(dfn[--l]);
while (r < R) add(dfn[++r]);
while (l < L) del(dfn[l++]);
while (r > R) del(dfn[r--]);
while (T < ti) modify(++T);
while (T > ti) modify(T--);
if (q[i].k) add(q[i].k);
Ans[id] = ans;
if (q[i].k) del(q[i].k);
}
for (int i = 1; i <= qid; i++) cout << Ans[i] << "\n";
return 0;
}