题干
给定一棵 n n n 个节点的树,节点编号为 1 ∼ n 1∼n 1∼n。每个节点都被染成了黑色(用 1 1 1 表示)或白色(用 0 0 0 表示)。从黑色节点无法到达白色节点,反之亦然。因此,两个同色节点相互可达的前提是,两个同色节点之间的路径中不含另一种颜色的节点。
我们希望将树中的所有节点都染成同一种颜色(全黑或全白均可)。为此,你可以采用我们指定的染色操作。每次操作可以选择一个节点 v v v,并改变节点 v v v 以及其所有可达同色节点的颜色(黑变白、白变黑)。
例如,在下图中,点 1 1 1 和点 2 , 3 , 8 , 9 2,3,8,9 2,3,8,9 之间相互可达,但是点 1 1 1 和点 6 6 6 之间相互不可达(被点 5 5 5 挡住了),因此,如果选择点 1 1 1 进行染色操作,会将点 1 , 2 , 3 , 8 , 9 1,2,3,8,9 1,2,3,8,9 全部染黑。

请你计算,为了达成目标,至少需要进行多少次染色操作。
输入
第一行包含整数 n n n。
第二行包含 n n n 个整数 c 1 , c 2 , ⋯ , c n c_1,c_2,\cdots,c_n c1,c2,⋯,cn,其中 c i c_i ci 为节点 i i i 的颜色( 1 1 1 表示黑, 0 0 0 表示白)。
接下来 n − 1 n−1 n−1 行,每行包含两个整数 u i , v i u_i,v_i ui,vi,表示节点 u i u_i ui 和节点 v i v_i vi 之间存在一条边。
输出
一个整数,表示所需的最少染色操作次数。
思路
题目的考点是并查集 +树的直径。
并查集
既然相同颜色的,互相联通的点可以在一次操作内进行染色操作(即全部由黑变白,或者由白变黑),那么不如将他们视作一个点。通过并查集可以实现这一点。
找直径
缩点后形成的树 T T T 中,相邻结点具有不同颜色。考虑树 T T T 的直径 D = max u , v ∈ T d ( u , v ) D=\max_{u,v\in T}d(u,v) D=maxu,v∈Td(u,v),则必有一条长为 D D D 的路 p p p。
- 当 D D D 为奇数时,这条路由 D − 1 2 \frac{D-1}{2} 2D−1 个黑点(或者白点)和 D + 1 2 \frac{D+1}{2} 2D+1 个白点(或者黑点)组成,最少通过 D − 1 2 = ⌊ D / 2 ⌋ \frac{D-1}{2}=\lfloor D/2\rfloor 2D−1=⌊D/2⌋ 次操作将黑点全部换成白点(或者白点全部换成黑点)即可将这条路变成同一种颜色。
- 当 D D D 为偶数时,这条路由 D / 2 D/2 D/2 个黑点和 D / 2 D/2 D/2 个白点(或者黑点)组成,最少通过 D / 2 = ⌊ D / 2 ⌋ D/2=\lfloor D/2\rfloor D/2=⌊D/2⌋ 次操作将黑点全部换成白点(或者白点全部换成黑点)即可将这条路变成同一种颜色。
因此,如果要将树 T T T 变为一种颜色,至少要将这条路 p p p 变成一种颜色,次数 a n s ≥ ⌊ D / 2 ⌋ \mathrm{ans}\geq \lfloor D/2\rfloor ans≥⌊D/2⌋。此外,我们还能知道,从这条路的中心出发,通过 ⌊ D / 2 ⌋ \lfloor D/2\rfloor ⌊D/2⌋ 次操作,还可以将 T − p T-p T−p (即路外其他结点)也转变为一种颜色。因为如果做不到,说明我们在找直径的时候就找错了。
所以答案就是 a n s = ⌊ D / 2 ⌋ \mathrm{ans}=\lfloor D/2\rfloor ans=⌊D/2⌋。对于一个数 T T T 而言,它的直径为 D = f ( T ) D=f(T) D=f(T),其中
f ( T ) = max ( 1 + d T 1 + d T 2 , f ( T 1 ) , f ( T 2 ) ) f(T)=\max(1+d_{T_1}+d_{T_2},f(T_1),f(T_2)) f(T)=max(1+dT1+dT2,f(T1),f(T2))
d T d_T dT 表示树 T T T 的深度,而 T 1 , T 2 T_1,T_2 T1,T2 表示 T T T 最深的两个子树。
Code
cpp
# include <iostream>
# include <cstring>
# include <vector>
using namespace std;
int n,dad[200005],dp[200005],uu[200005],vv[200005];
bool c[200005];
vector<int> nex[200005];
int getdad(int node){
if(dad[node] == node)
return node;
return dad[node] = getdad(dad[node]);
}
int getdp(int node,int fa){
int mdp = 0;
for(int &x : nex[node])
if(x != fa)
mdp = max(mdp,getdp(x,node));
return dp[node] = mdp + 1;
}
int maxdist(int node,int fa){
int dp1 = -1,dp2 = -1,ans = 1;
for(int &x : nex[node])
if(x != fa){
ans = max(ans,maxdist(x,node));
if(dp[x] > dp1){
if(dp1 < dp2) dp1 = dp[x];
else dp2 = dp[x];
}
else dp2 = max(dp2,dp[x]);
}
if(dp1 == -1 && dp2 == -1) return ans;
if(dp1 == -1) return max(ans,1 + dp2);
else if(dp2 == -1) return max(ans,1 + dp1);
return max(ans,1 + dp1 + dp2);
}
int main(){
int u,v;
cin >> n;
for(int i = 1;i <= n;i++)
cin >> c[i];
for(int i = 1;i < n;i++){
cin >> u >> v;
if(dad[u] && dad[v]){
if(c[u] == c[v]) dad[getdad(v)] = getdad(u);
}
else if(dad[u]){
dad[v] = c[u] == c[v]?getdad(u):v;
}
else if(dad[v]){
dad[u] = c[u] == c[v]?getdad(v):u;
}
else{
dad[u] = u;
dad[v] = c[u] == c[v]?u:v;
}
uu[i] = u;
vv[i] = v;
}
for(int i = 1;i < n;i++){
u = getdad(uu[i]);
v = getdad(vv[i]);
if(c[u] != c[v]){
nex[u].push_back(v);
nex[v].push_back(u);
}
}
getdp(getdad(1),0);
return cout << maxdist(getdad(1),0) / 2,0;
}