重链剖分

重链剖分


定义

重链剖分是树链剖分的一种解决方式,通过树链剖分将树划分为多个连续部分方便执行操作

我们进行如下定义:

  • son[u] 表示 u 结点的重儿子。对于重儿子, 我们定义在 u 的所有子树中最大的子节点,且重儿子只有一个
  • sz[u] 表示 u 结点的子树大小, 包括结点 u 本身
  • dep[u] 表示 u 结点的深度
  • dfn[u] 表示 u 结点所表示的 dfs
  • top[u] 表示 u 结点所在链的顶点
  • fa[u] 表示 u 结点的父节点

我们对一棵树进行如下划分,对于每个结点 u ,将其重儿子称为重子结点,而由重儿子所组成的一条链成为重链,其余节点成为轻儿子,特别的,重子节点一定会组成一条链,因为一个节点只存在一个重儿子,所以任何重子结点不存在两个及以上的重子节点。

通过以上分化后,我们可以通过 dfs 对每个结点定义其 dfs 序,对于这个遍历过程,我们采用重儿子优先的策略,优先向重儿子进行递归,可以保证重链的 dfs 序是连续的,并且对于任意一结点 u ,保证其子树内的 dfs 序大于 u ,并且不超过 dfn[u] + sz[u] - 1

总结就是:

  • 重儿子所组成的一定是一条链

  • 每个节点至多有一个重儿子,满足每个结点均在一条重链上

  • 重链的 dfs 序是连续的

  • 子树 udfs 序范围是 [dfn[u], dfn[u] + sz[u] - 1]


应用

通过以上定义,我们可以通过线段树快速对一个子树内的所有元素进行统一修改,因为一个子树内的所有节点的 dfs 序一定连续

同时,我们还可以快速查询树上最近公共祖先 LCA ,我们可以知道,树上任意一个点 u 都必然存在于一个重链之上,所以两个点只要不断跳到其所在链的顶点,再不断上跳,就能快速找到最近公共祖先;还可以对树上的一条路径进行操作

示例

以洛谷题目 P3384 【模板】重链剖分 为例

初始化

dfs1 :计算每个子树的大小,以及每个节点的深度,父亲,以及重儿子

cpp 复制代码
void dfs1(ll u, ll father) {
	fa[u] = father;
	dep[u] = dep[father] + 1;
	sz[u] = 1;
	for (auto v : g[u]) {
		if (v == father) continue;
		dfs1(v, u);
		sz[u] += sz[v];
		if (sz[v] > sz[son[u]]) son[u] = v;
	}
}

dfs2:计算每个节点的 dfs 序,以及每个链的链顶

cpp 复制代码
void dfs2(ll u, ll topf) {
	top[u] = topf;
	dfn[u] = ++timer;
	rk[timer] = u;//rk在这里是一个逆数组,用于标记每个dfn所对应的结点,用于后续操作
	if (son[u]) dfs2(son[u], topf);
	for (auto v : g[u]) {
		if (v != fa[u] && v != son[u])
			dfs2(v, v);
	}
}

线段树模板:

cpp 复制代码
//线段树维护的是dfs序对应的数组 a[rk[i]],因此树上的子树或路径问题可以转化为数组的区间问题。
struct SegmentTree {
	void push_up(ll p) {
		tree[p] = (tree[p << 1] + tree[p << 1 | 1]) % mod;
	}
	void push_down(ll p, ll l, ll r) {
		if (lz[p]) {
			lz[p << 1] = (lz[p << 1] + lz[p]) % mod;
			lz[p << 1 | 1] = (lz[p << 1 | 1] + lz[p]) % mod;
			ll mid = (l + r) >> 1;
			tree[p << 1] = (tree[p << 1] + lz[p] * (mid - l + 1) % mod) % mod;
			tree[p << 1 | 1] = (tree[p << 1 | 1] + lz[p] * (r - mid) % mod) % mod;
			lz[p] = 0;
		}
	}

	void build(ll p, ll l, ll r) {
		if (l == r) {
			tree[p] = a[rk[l]] % mod;
			return;
		}
		ll mid = (l + r) >> 1;
		build(p << 1, l, mid);
		build(p << 1 | 1, mid + 1, r);
		push_up(p);
	}

	void add(ll p, ll l, ll r, ll L, ll R, ll k) {
		if (L <= l && r <= R) {
			tree[p] = (tree[p] + k*(r - l + 1) % mod) % mod;
			lz[p] = (lz[p] + k) % mod;
			return;
		}
		push_down(p, l, r);
		ll mid = (l + r) >> 1;
		if (L <= mid)add(p << 1, l, mid, L, R, k);
		if (R > mid)add(p << 1 | 1, mid + 1, r, L, R, k);
		push_up(p);
	}

	ll search(ll p, ll l, ll r, ll L, ll R) {
		if (L <= l && r <= R) {
			return tree[p];
		}
		push_down(p, l, r);
		ll mid = (l + r) >> 1;
		ll res = 0;
		if (L <= mid)res = (res + search(p << 1, l, mid, L, R)) % mod;
		if (R > mid)res = (res + search(p << 1 | 1, mid + 1, r, L, R)) % mod;
		return res % mod;
	}
} st;

修改路径

cpp 复制代码
void update(ll u, ll v, ll k) {
	while (top[u] != top[v]) {//不断向上跳链,跳到链顶就修改一次,直到两个节点在同一个链上
		if (dep[top[u]] < dep[top[v]])swap(u, v);
		st.add(1, 1, n, dfn[top[u]], dfn[u], k);
		u = fa[top[u]];
	}
	if (dep[u] > dep[v])swap(u, v);
	st.add(1, 1, n, dfn[u], dfn[v], k);
}

查询路径

cpp 复制代码
ll query(ll u, ll v) {//与修改一样,不断向上跳即可
	ll res = 0;
	while (top[u] != top[v]) {
		if (dep[top[u]] < dep[top[v]])swap(u, v);
		res = (res + st.search(1, 1, n, dfn[top[u]], dfn[u])) % mod;
		u = fa[top[u]];
	}
	if (dep[u] > dep[v])swap(u, v);
	res = (res + st.search(1, 1, n, dfn[u], dfn[v])) % mod;
	return res % mod;
}

完整代码:

cpp 复制代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e5+10;

ll n, m, R, mod, deep, timer;
ll a[N], dfn[N], sz[N], son[N], dep[N], fa[N], rk[N], top[N], tree[N << 2], lz[N << 2];

vector<ll> g[N];

void dfs1(ll u, ll father) {
	fa[u] = father;
	dep[u] = dep[father] + 1;
	sz[u] = 1;
	for (auto v : g[u]) {
		if (v == father) continue;
		dfs1(v, u);
		sz[u] += sz[v];
		if (sz[v] > sz[son[u]]) son[u] = v;
	}
}

void dfs2(ll u, ll topf) {
	top[u] = topf;
	dfn[u] = ++timer;
	rk[timer] = u;
	if (son[u]) dfs2(son[u], topf);
	for (auto v : g[u]) {
		if (v != fa[u] && v != son[u])
			dfs2(v, v);
	}
}


struct SegmentTree {
	void push_up(ll p) {
		tree[p] = (tree[p << 1] + tree[p << 1 | 1]) % mod;
	}
	void push_down(ll p, ll l, ll r) {
		if (lz[p]) {
			lz[p << 1] = (lz[p << 1] + lz[p]) % mod;
			lz[p << 1 | 1] = (lz[p << 1 | 1] + lz[p]) % mod;
			ll mid = (l + r) >> 1;
			tree[p << 1] = (tree[p << 1] + lz[p] * (mid - l + 1) % mod) % mod;
			tree[p << 1 | 1] = (tree[p << 1 | 1] + lz[p] * (r - mid) % mod) % mod;
			lz[p] = 0;
		}
	}

	void build(ll p, ll l, ll r) {
		if (l == r) {
			tree[p] = a[rk[l]] % mod;
			return;
		}
		ll mid = (l + r) >> 1;
		build(p << 1, l, mid);
		build(p << 1 | 1, mid + 1, r);
		push_up(p);
	}

	void add(ll p, ll l, ll r, ll L, ll R, ll k) {
		if (L <= l && r <= R) {
			tree[p] = (tree[p] + k*(r - l + 1) % mod) % mod;
			lz[p] = (lz[p] + k) % mod;
			return;
		}
		push_down(p, l, r);
		ll mid = (l + r) >> 1;
		if (L <= mid)add(p << 1, l, mid, L, R, k);
		if (R > mid)add(p << 1 | 1, mid + 1, r, L, R, k);
		push_up(p);
	}

	ll search(ll p, ll l, ll r, ll L, ll R) {
		if (L <= l && r <= R) {
			return tree[p];
		}
		push_down(p, l, r);
		ll mid = (l + r) >> 1;
		ll res = 0;
		if (L <= mid)res = (res + search(p << 1, l, mid, L, R)) % mod;
		if (R > mid)res = (res + search(p << 1 | 1, mid + 1, r, L, R)) % mod;
		return res % mod;
	}
} st;

void update(ll u, ll v, ll k) {
	while (top[u] != top[v]) {
		if (dep[top[u]] < dep[top[v]])swap(u, v);
		st.add(1, 1, n, dfn[top[u]], dfn[u], k);
		u = fa[top[u]];
	}
	if (dep[u] > dep[v])swap(u, v);
	st.add(1, 1, n, dfn[u], dfn[v], k);
}

ll query(ll u, ll v) {
	ll res = 0;
	while (top[u] != top[v]) {
		if (dep[top[u]] < dep[top[v]])swap(u, v);
		res = (res + st.search(1, 1, n, dfn[top[u]], dfn[u])) % mod;
		u = fa[top[u]];
	}
	if (dep[u] > dep[v])swap(u, v);
	res = (res + st.search(1, 1, n, dfn[u], dfn[v])) % mod;
	return res % mod;
}

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);
	cin >> n >> m >> R >> mod;
	for (int i = 1; i <= n; i++) {
		cin >> a[i];
		a[i] %= mod;
	}
	for (int i = 1; i < n; i++) {
		ll u, v;
		cin >> u >> v;
		g[u].push_back(v);
		g[v].push_back(u);
	}

	dfs1(R, 0);
	dfs2(R, R);
	st.build(1, 1, n);

	for (int i = 1; i <= m; i++) {
		ll op, x, y, z;
		cin >> op;
		if (op == 1) {
			cin >> x >> y >> z;
			update(x, y, z);
		}
		if (op == 2) {
			cin >> x >> y;
			cout << query(x, y) << '\n';
		}
		if (op == 3) {
			cin >> x >> y;
			st.add(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, y);
		}
		if (op == 4) {
			cin >> x;
			cout << st.search(1, 1, n, dfn[x], dfn[x] + sz[x] - 1) << '\n';
		}
	}
	return 0;
}

时间复杂度的简单证明

树链剖分整体时间复杂度为 O(N log_2 N)

其中 dfs 时间复杂度为 \(O(N)\) ,其余查询操作均为 \(O(log_2 N)\)

其中简单讲一下为什么树上路径的查询与修改是 \(O(log_2 N)\)

我们考虑树剖每次的跳跃可以理解为从链底直接跳到链顶,而为了使其跳跃次数尽可能地多,只能让其子树尽可能的多,最极限的状态是就是这是一颗满二叉树,但未什么不在某一个位置深度更高一点呢?因为如果某一个地方的深度过深,那么这里一定是一条重链,那么跳越次数就又会锐减,所以很明显的,树剖的查询与修改操作最极限的时间复杂度为 \(O(log_2 N)\) 并且常数极小,效率极高