AT_abc401_f [ABC401F] Add One Edge 3
洛谷题目传送门
atcoder题目传送门
题目描述
给定两棵树:
- 树 1 包含 N 1 N_1 N1 个顶点,编号为 1 1 1 到 N 1 N_1 N1
- 树 2 包含 N 2 N_2 N2 个顶点,编号为 1 1 1 到 N 2 N_2 N2
树 1 的第 i i i 条边双向连接顶点 u 1 , i u_{1,i} u1,i 和 v 1 , i v_{1,i} v1,i,树 2 的第 i i i 条边双向连接顶点 u 2 , i u_{2,i} u2,i 和 v 2 , i v_{2,i} v2,i。
如果在树 1 的顶点 i i i 和树 2 的顶点 j j j 之间添加一条双向边,将得到一棵新的树。定义这棵新树的直径为 f ( i , j ) f(i,j) f(i,j)。
请计算 ∑ i = 1 N 1 ∑ j = 1 N 2 f ( i , j ) \displaystyle\sum_{i=1}^{N_1}\sum_{j=1}^{N_2} f(i,j) i=1∑N1j=1∑N2f(i,j) 的值。
其中:
- 两顶点之间的距离定义为它们之间最短路径的边数
- 树的直径定义为所有顶点对之间距离的最大值
输入格式
输入通过标准输入给出,格式如下:
N 1 N_1 N1
u 1 , 1 u_{1,1} u1,1 v 1 , 1 v_{1,1} v1,1
⋮ \vdots ⋮
u 1 , N 1 − 1 u_{1,N_1-1} u1,N1−1 v 1 , N 1 − 1 v_{1,N_1-1} v1,N1−1
N 2 N_2 N2
u 2 , 1 u_{2,1} u2,1 v 2 , 1 v_{2,1} v2,1
⋮ \vdots ⋮
u 2 , N 2 − 1 u_{2,N_2-1} u2,N2−1 v 2 , N 2 − 1 v_{2,N_2-1} v2,N2−1
输出格式
输出计算结果。
输入输出样例 #1
输入 #1
3
1 3
1 2
3
1 2
3 1
输出 #1
39
输入输出样例 #2
输入 #2
7
5 6
1 3
5 7
4 5
1 6
1 2
5
5 3
2 4
2 3
5 1
输出 #2
267
说明/提示
约束条件
- 1 ≤ N 1 , N 2 ≤ 2 × 10 5 1 \leq N_1, N_2 \leq 2 \times 10^5 1≤N1,N2≤2×105
- 1 ≤ u 1 , i , v 1 , i ≤ N 1 1 \leq u_{1,i}, v_{1,i} \leq N_1 1≤u1,i,v1,i≤N1
- 1 ≤ u 2 , i , v 2 , i ≤ N 2 1 \leq u_{2,i}, v_{2,i} \leq N_2 1≤u2,i,v2,i≤N2
- 输入的两张图都是树
- 输入的所有数值均为整数
样例解释 1
例如,当连接树 1 的顶点 2 和树 2 的顶点 3 时,得到的新树直径为 5,因此 f ( 2 , 3 ) = 5 f(2,3)=5 f(2,3)=5。所有 f ( i , j ) f(i,j) f(i,j) 的总和为 39。
思路详解
直径
首先我们要搞懂什么是直径,树的直径即为树中的最远点对之间的路径。
直径还有一个性质:距离每个点最远的点一定是直径的一个端点,证明如下:
我们把图抽象一下,变成下面的样子:
考虑使用反证法,黑线为到一个点的最长距离,红线为直径。若1+2>1+4,则2>4,则3+2>3+4,则红线不为直径,矛盾。
那我们想要求出直径就很简单了,先随便找一个1点求出哪个点距离这个点最远,这个点即为直径的一个端点。再从这个端点找另一个端点。就可以求出直径了。
题目分析
思考如何在 i , j i,j i,j之间连边了如何求出 f i , j f_{i,j} fi,j,我们定义 d i s i dis_{i} disi为距离 i i i最远的点的距离,这个距离我们可以从直径的端点跑2次 b f s bfs bfs求得。所以 f i , j = m a x ( d 1 , d 2 , d i s i + d i s j + 1 ) f_{i,j}=max(d_{1},d_{2},dis_{i}+dis_{j}+1) fi,j=max(d1,d2,disi+disj+1), d 1 , d 2 d_{1},d_{2} d1,d2为第1/2棵树得直径。
但是,我们发现枚举 i , j i,j i,j依然会超时,那怎么办呢?我们显然只能枚举一个,考虑对于 i i i,如何快速求出 ∑ f i , j \sum f_{i,j} ∑fi,j。我们发现,当 d i s j < = m a x ( d 1 , d 2 ) − d i s i − 1 dis_{j}<=max(d_{1},d_{2})-dis_{i}-1 disj<=max(d1,d2)−disi−1时,答案都为 m a x ( d 1 , d 2 ) max(d_{1},d_{2}) max(d1,d2),大于则答案为 ∑ d i s i + d i s j + 1 \sum dis_{i}+dis_{j}+1 ∑disi+disj+1即 ( 1 + d i s i ) ∗ c n t + ∑ j = p o s n d i s j (1+dis_{i})*cnt+\sum {j=pos}^{n} dis{j} (1+disi)∗cnt+∑j=posndisj, c n t cnt cnt为大于的个数, p o s pos pos为第一个大于的位置。那很显然了,我们对 d i s j dis_{j} disj进行排序,然后用upper_bound查找即可。
思路分析
我们再来理清一次过程:
- 首先求出两棵树的直径。
- 然后从,每个直径端点去更新每个点的 d i s dis dis。
- 对第二棵树的 d i s dis dis进行排序,并求前缀和。
- 最后枚举第一棵树的 i i i,求解答案。
code
cpp
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const ll N=2e5+5;
ll n[2];
vector<ll>e[2][N];
ll p[5];
ll dis[N],mx[4][N],ne[2][N];
void dfs1(ll u,ll fa,ll j){//找直径端点
dis[u]=dis[fa]+1;
for(ll v:e[j][u]){
if(v==fa)continue;
dfs1(v,u,j);
}
}
void find(ll j){
memset(dis,0,sizeof(dis));
//由直径的定义可得距离每一个点最远的点一定是直径的一个端点
dfs1(1,0,j);//先随便找一个点求出一个端点
ll d=0,id=0;
for(ll i=1;i<=n[j];i++){
if(dis[i]>d){
d=dis[i];id=i;
}
}
p[(j+1)*2-1]=id;
memset(dis,0,sizeof(dis));
dfs1(id,0,j);//再寻找另一个端点
d=0,id=0;
for(ll i=1;i<=n[j];i++){
if(dis[i]>d){
d=dis[i];id=i;
}
}
p[(j+1)*2]=id;
}
void dfs2(ll u,ll fa,ll j,ll k){
mx[k][u]=mx[k][fa]+1;
for(ll v:e[j][u]){
if(v==fa)continue;
dfs2(v,u,j,k);
}
}
void co(ll j){//求出距离每个点最远的距离
mx[j*2][0]=-1;//到第一个端点的距离
dfs2(p[(j+1)*2-1],0,j,j*2);
mx[(j+1)*2-1][0]=-1;//到第二个端点的距离
dfs2(p[(j+1)*2],0,j,(j+1)*2-1);
for(ll i=1;i<=n[j];i++)ne[j][i]=max(mx[j*2][i],mx[(j+1)*2-1][i]);//求出最远距离
}
ll sum[N];
int main(){
//显然对于每个i,j,f(i,j)=max(max(d1,d2),ne[i]+ne[j]+1)
//其中d1,d2分别为树1,2的直径,ne[i]为离i最远的距离
//我们先将ne[j]排序,再枚举i,对于ne[i]+ne[j]+1<=max(d1,d2)部分的和即为几个max(d1,d2)
//大于部分的和为n*(ne[i]+1)+sum{ne[j]}
cin>>n[0];
for(ll i=1;i<=n[0]-1;i++){
ll x,y;
cin>>x>>y;
e[0][x].push_back(y);
e[0][y].push_back(x);
}//图0
cin>>n[1];
for(ll i=1;i<=n[1]-1;i++){
ll x,y;
cin>>x>>y;
e[1][x].push_back(y);
e[1][y].push_back(x);
}//图1
for(ll i=0;i<=1;i++)find(i);//先找每一个直径的端点
for(ll i=0;i<=1;i++)co(i);
ll d1=max(ne[0][p[1]],ne[0][p[2]]),d2=max(ne[1][p[3]],ne[1][p[4]]);
sort(ne[1]+1,ne[1]+1+n[1]);//将第二棵树排序
for(ll i=1;i<=n[1];i++)sum[i]=sum[i-1]+ne[1][i];//求前缀和
ll ans=0;
for(ll i=1;i<=n[0];i++){
ll pos=upper_bound(ne[1]+1,ne[1]+1+n[1],max(d1,d2)-ne[0][i]-1)-(ne[1]+1);
//找有几个小于max(d1,d2)的
ans+=pos*max(d1,d2)+(n[1]-pos)*(ne[0][i]+1)+sum[n[1]]-sum[pos];
}
cout<<ans;
return 0;
}