题目
简要题意:
给定一棵 n n n 个节点的树,初始时 1 1 1 号节点为红色,其余为蓝色。
要求支持如下操作:
-
将一个节点变为红色。
-
询问节点 u u u 到最近红色节点的距离。
共 q q q 次操作。
1 ≤ n , q ≤ 1 0 5 1 \leq n, q \leq 10^5 1≤n,q≤105。
分析:
非常 典 的一道题。
我们首先考虑一种 修改 O ( n ) O(n) O(n),查询 O ( 1 ) O(1) O(1) 的算法:每次改变一个点的颜色就把它放进队列里跑一遍 bfs,去更新其它点到红点的最小值。
接着我们考虑一种 修改 O ( 1 ) O(1) O(1),查询 O ( n ) O(n) O(n) 的算法:每次 O ( 1 ) O(1) O(1) 标记一个点是否为红色。然后每次查询枚举红色的点并计算距离,时间复杂度是 O ( n l o g 2 n ) O(nlog_2n) O(nlog2n) 的。
我们考虑如何平衡这两种算法。
因为 bfs 可以在 O ( n ) O(n) O(n) 的复杂度内跑 多个终点的最短路 ,因此我们可以将红点储存起来一起跑bfs。所以可以对操作进行分块。
设块长为 S S S,我们每一次从当前块到下一块时,我们把当前块的 所有染红操作的点 放进队列里面跑 bfs 更新 其它点的 d i s dis dis 值。然后对于当前块的询问,我们扫块内的所有操作,如果为 1 1 1 操作,那么我们 O ( l o g 2 n ) O(log_2n) O(log2n) 的复杂度内查出询问点和修改点的距离并与 d i s dis dis 数组取 m i n min min 即可。
时间复杂度是 O ( q S × n + q × S × l o g 2 n ) O(\frac{q}{S} \times n + q \times S \times log_2n) O(Sq×n+q×S×log2n) 的,当 S = n l o g 2 n S = \sqrt{\frac{n}{log_2n}} S=log2nn 时复杂度最小,为 O ( q n l o g 2 n ) O(q\sqrt{nlog_2n}) O(qnlog2n )。
CODE:
csharp
#include<bits/stdc++.h>// 好题
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
typedef pair< int, int > PII;
int n, Q, dis[N], blo, op, x, bl, u, v, dep[N], fa[N][25];
inline int read(){
int x = 0, f = 1; char c = getchar();
while(!isdigit(c)){if(c == '-') f = -1; c = getchar();}
while(isdigit(c)){x = (x << 1) + (x << 3) + (c ^ 48); c = getchar();}
return x * f;
}
vector< int > E[N];
vector< PII > vec[N];
queue< int > q;
void dfs(int x, int fat){
dep[x] = dep[fat] + 1; fa[x][0] = fat;
for(int i = 1; i <= 20; i++) fa[x][i] = fa[fa[x][i - 1]][i - 1];
for(auto v : E[x]){
if(v == fat) continue;
dfs(v, x);
}
}
int LCA(int x, int y){
if(dep[x] < dep[y]) swap(x, y);
for(int i = 20; i >= 0; i--){
if(dep[fa[x][i]] >= dep[y]) x = fa[x][i];
}
if(x == y) return x;
for(int i = 20; i >= 0; i--){
if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
}
return fa[x][0];
}
void bfs(){
while(!q.empty()){
int u = q.front(); q.pop();
for(auto v : E[u]){
if(dis[v] > dis[u] + 1){
dis[v] = dis[u] + 1;
q.push(v);
}
}
}
}
int main(){
memset(dis, 0x3f, sizeof dis);
n = read(), Q = read();
for(int i = 1; i < n; i++){
u = read(), v = read();
E[u].pb(v); E[v].pb(u);
}
dfs(1, 0);
blo = max(1, (int)sqrt(1.0 * n / log2(n)));
for(int i = 1; i <= Q; i++){
op = read(), x = read();
bl = (i - 1) / blo + 1;
vec[bl].pb(make_pair(op, x));
}
dis[1] = 0;
q.push(1);
bfs();
for(int i = 1; i <= bl; i++){
for(int j = 0; j < vec[i].size(); j++){
int op = vec[i][j].first, x = vec[i][j].second;
if(op == 1) dis[x] = 0, q.push(x);
else{
int y = dis[x];
for(int k = 0; k < j; k++){
if(vec[i][k].first == 1) y = min(y, dep[x] + dep[vec[i][k].second] - 2 * dep[LCA(vec[i][k].second, x)]);
}
printf("%d\n", y);
}
}
bfs();
}
return 0;
}