题目描述
给定一个棵由编号为 1~n 这 n 个节点构成的树,及 m 个关键节点,树的第 i 条双向边连接 ui 和 vi ,长度为 wi 。你可以从树上任何一个节点出发开始游走,到任何一个节点结束,但整个游走过程中必须经过给定的 m 个关键节点。
请问,应当如何设计路线,能够使得整个游走过程中,经过的路径长度最短。
输入格式
输入第一行,两个正整数 n,m
接下来 n−1 行,每行 3 个正整数,其中第 i+1 行分别表示第 i 条边的 ui,vi,wi。
输入最后行,m 个正整数,分别表示 m 个关键节点的编号 ci。
输出格式
输出共一行,表示所求最短长度。
数据范围
1≤m≤n≤1e5,1 ≤ ui,vi,ci ≤n, 1≤w ≤1e4
样例数据
输入:
5 3
1 2 2
2 3 4
2 4 5
1 5 1
3 4 5
输出:
15
说明:
3-->2-->1-->5-->1-->2-->4
4 2 1 1 2 5
分析
树形dp + 换根
- 首先考虑对于起点固定,应该怎么做。
假设起点为st,假设m个关键节点分布在st的k个子树内,则我们将只经过其中1个子树一次,然后其余k-1个子树的每条边经过两次。
一个很显然的贪心是,为了使总路程最短,我们应该满足 只经过1次的子树的路程和 最大。
可以dfs来求
其中f[u]代表从u节点访问u子树的所有关键点然后再返回u节点需要的总路程(即所有路径均为两倍),
mx[i][0]和mx[i][1]分别表示i的子树的最大值和次大值。(至于为什么还需要维护次大值,后续会讲到)
cpp
void dfs(int u,int fa){
if(flag[u]) cnt[u] = 1;
for(int i = h[u];i != -1;i = ne[i]){
int j = e[i],val = w[i];
if(j != fa){
dfs(j,u);
if(cnt[j]){
cnt[u] += cnt[j];
f[u] += f[j] + 2 * val;
int tmp = mx[j][0] + val;
if(tmp > mx[u][0]){
mx[u][1] = mx[u][0],id[u][1] = id[u][0];
mx[u][0] = tmp,id[u][0] = j;
}else if(tmp > mx[u][1]){
mx[u][1] = tmp,id[u][1] = j;
}
}
}
}
}
最终结果即为 f[st] - mx[st][0]
- 然后考虑起点不固定
可以采用换根dp来写
对于 u → j u \rightarrow j u→j 这条边,根从 u u u 转移到 j j j ,
若 j 子树内关键节点数为0,则 f[j] = f[u] + 2 * val
;
若 j 子树内关键节点数为m,则 f[j] 不变;
否则, f[j] = f[u]
。
若 j 子树内关键节点数为m,则还需要更新mx[j][0]和mx[j][1]的值。(具体细节见代码)
cpp
void dp(int u,int fa){
res = min(res,f[u] - mx[u][0]);
for(int i = h[u];i != -1;i = ne[i]){
int j = e[i],val = w[i];
if(j != fa){
if(!cnt[j]) f[j] = f[u] + 2 * val;
else if(cnt[j] == m) f[j] = f[j];
else f[j] = f[u];
if(cnt[j] != m){
int tmp;
if(id[u][0] == j) tmp = mx[u][1] + val;
else tmp = mx[u][0] + val;
if(tmp > mx[j][0]){
mx[j][1] = mx[j][0],id[j][1] = id[j][0];
mx[j][0] = tmp,id[j][0] = u;
}else if(tmp > mx[j][1]){
mx[j][1] = tmp,id[j][1] = u;
}
}
dp(j,u);
}
}
}
完整代码
cpp
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10;
int n,m;
bool flag[N];
int h[N],e[N * 2],w[N * 2],ne[N * 2],idx;
LL cnt[N],f[N],mx[N][2],id[N][2],res;
void add(int a,int b,int c){
e[idx] = b,w[idx] = c,ne[idx] = h[a],h[a] = idx++;
}
void dfs(int u,int fa){
if(flag[u]) cnt[u] = 1;
for(int i = h[u];i != -1;i = ne[i]){
int j = e[i],val = w[i];
if(j != fa){
dfs(j,u);
if(cnt[j]){
cnt[u] += cnt[j];
f[u] += f[j] + 2 * val;
int tmp = mx[j][0] + val;
if(tmp > mx[u][0]){
mx[u][1] = mx[u][0],id[u][1] = id[u][0];
mx[u][0] = tmp,id[u][0] = j;
}else if(tmp > mx[u][1]){
mx[u][1] = tmp,id[u][1] = j;
}
}
}
}
}
void dp(int u,int fa){
res = min(res,f[u] - mx[u][0]);
for(int i = h[u];i != -1;i = ne[i]){
int j = e[i],val = w[i];
if(j != fa){
if(!cnt[j]) f[j] = f[u] + 2 * val;
else if(cnt[j] == m) f[j] = f[j];
else f[j] = f[u];
if(cnt[j] != m){
int tmp;
if(id[u][0] == j) tmp = mx[u][1] + val;
else tmp = mx[u][0] + val;
if(tmp > mx[j][0]){
mx[j][1] = mx[j][0],id[j][1] = id[j][0];
mx[j][0] = tmp,id[j][0] = u;
}else if(tmp > mx[j][1]){
mx[j][1] = tmp,id[j][1] = u;
}
}
dp(j,u);
}
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
memset(h,-1,sizeof h);
cin >> n >> m;
for(int i = 0,u,v,w;i < n - 1;i++){
cin >> u >> v >> w;
add(u,v,w),add(v,u,w);
}
for(int i = 1,j;i <= m;i++){
cin >> j;
flag[j] = true;
}
dfs(1,-1);
res = f[1] - mx[1][0];
dp(1,-1);
cout << res;
return 0;
}