题解 - 树上游走(二)(上海月赛2024.7甲组T1)

题目描述

给定一个棵由编号为 1~n 这 n 个节点构成的树,及 m 个关键节点,树的第 i 条双向边连接 ui 和 vi ,长度为 wi 。你可以从树上任何一个节点出发开始游走,到任何一个节点结束,但整个游走过程中必须经过给定的 m 个关键节点。

请问,应当如何设计路线,能够使得整个游走过程中,经过的路径长度最短。

输入格式

输入第一行,两个正整数 n,m

接下来 n−1 行,每行 3 个正整数,其中第 i+1 行分别表示第 i 条边的 ui,vi,wi。

输入最后行,m 个正整数,分别表示 m 个关键节点的编号 ci。

输出格式

输出共一行,表示所求最短长度。

数据范围

1≤m≤n≤1e5,1​ ≤ ui,vi,ci ≤n, 1≤w ≤1e4

样例数据

输入:

5 3

1 2 2

2 3 4

2 4 5

1 5 1

3 4 5

输出:

15

说明:

3-->2-->1-->5-->1-->2-->4

4 2 1 1 2 5

分析

树形dp + 换根

  1. 首先考虑对于起点固定,应该怎么做。

假设起点为st,假设m个关键节点分布在st的k个子树内,则我们将只经过其中1个子树一次,然后其余k-1个子树的每条边经过两次。

一个很显然的贪心是,为了使总路程最短,我们应该满足 只经过1次的子树的路程和 最大。

可以dfs来求

其中f[u]代表从u节点访问u子树的所有关键点然后再返回u节点需要的总路程(即所有路径均为两倍),

mx[i][0]和mx[i][1]分别表示i的子树的最大值和次大值。(至于为什么还需要维护次大值,后续会讲到)

cpp 复制代码
void dfs(int u,int fa){
    if(flag[u]) cnt[u] = 1;

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];
        if(j != fa){
            dfs(j,u);
            
            if(cnt[j]){
                cnt[u] += cnt[j];
                f[u] += f[j] + 2 * val;

                int tmp = mx[j][0] + val;
                if(tmp > mx[u][0]){
                    mx[u][1] = mx[u][0],id[u][1] = id[u][0];
                    mx[u][0] = tmp,id[u][0] = j;
                }else if(tmp > mx[u][1]){
                    mx[u][1] = tmp,id[u][1] = j;
                }
            } 
        }
    }
}

最终结果即为 f[st] - mx[st][0]

  1. 然后考虑起点不固定

可以采用换根dp来写

对于 u → j u \rightarrow j u→j 这条边,根从 u u u 转移到 j j j ,

若 j 子树内关键节点数为0,则 f[j] = f[u] + 2 * val

若 j 子树内关键节点数为m,则 f[j] 不变;

否则, f[j] = f[u]

若 j 子树内关键节点数为m,则还需要更新mx[j][0]和mx[j][1]的值。(具体细节见代码)

cpp 复制代码
void dp(int u,int fa){
    res = min(res,f[u] - mx[u][0]);

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];

        if(j != fa){
            if(!cnt[j]) f[j] = f[u] + 2 * val;
            else if(cnt[j] == m) f[j] = f[j];
            else f[j] = f[u];

            if(cnt[j] != m){
                int tmp;
                if(id[u][0] == j) tmp = mx[u][1] + val;
                else tmp = mx[u][0] + val;

                if(tmp > mx[j][0]){
                    mx[j][1] = mx[j][0],id[j][1] = id[j][0];
                    mx[j][0] = tmp,id[j][0] = u;
                }else if(tmp > mx[j][1]){
                    mx[j][1] = tmp,id[j][1] = u;
                }
           }

            dp(j,u);
        }
    }
}

完整代码

cpp 复制代码
#include<bits/stdc++.h>

using namespace std;

typedef long long LL;

const int N = 1e5 + 10;

int n,m;
bool flag[N];
int h[N],e[N * 2],w[N * 2],ne[N * 2],idx;
LL cnt[N],f[N],mx[N][2],id[N][2],res;

void add(int a,int b,int c){
    e[idx] = b,w[idx] = c,ne[idx] = h[a],h[a] = idx++;
}

void dfs(int u,int fa){
    if(flag[u]) cnt[u] = 1;

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];
        if(j != fa){
            dfs(j,u);
            
            if(cnt[j]){
                cnt[u] += cnt[j];
                f[u] += f[j] + 2 * val;

                int tmp = mx[j][0] + val;
                if(tmp > mx[u][0]){
                    mx[u][1] = mx[u][0],id[u][1] = id[u][0];
                    mx[u][0] = tmp,id[u][0] = j;
                }else if(tmp > mx[u][1]){
                    mx[u][1] = tmp,id[u][1] = j;
                }
            } 
        }
    }
}

void dp(int u,int fa){
    res = min(res,f[u] - mx[u][0]);

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];

        if(j != fa){
            if(!cnt[j]) f[j] = f[u] + 2 * val;
            else if(cnt[j] == m) f[j] = f[j];
            else f[j] = f[u];

            if(cnt[j] != m){
                int tmp;
                if(id[u][0] == j) tmp = mx[u][1] + val;
                else tmp = mx[u][0] + val;

                if(tmp > mx[j][0]){
                    mx[j][1] = mx[j][0],id[j][1] = id[j][0];
                    mx[j][0] = tmp,id[j][0] = u;
                }else if(tmp > mx[j][1]){
                    mx[j][1] = tmp,id[j][1] = u;
                }
           }

            dp(j,u);
        }
    }
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);

    memset(h,-1,sizeof h);

    cin >> n >> m;
    for(int i = 0,u,v,w;i < n - 1;i++){
        cin >> u >> v >> w;
        add(u,v,w),add(v,u,w);
    }
    for(int i = 1,j;i <= m;i++){
        cin >> j;
        flag[j] = true;
    }

    dfs(1,-1);

    res = f[1] - mx[1][0];

    dp(1,-1);

    cout << res;

    return 0;
}
相关推荐
葛小白11 小时前
C#进阶12:C#全局路径规划算法_Dijkstra
算法·c#·dijkstra算法
前端小L1 小时前
图论专题(五):图遍历的“终极考验”——深度「克隆图」
数据结构·算法·深度优先·图论·宽度优先
Sailing1 小时前
🔥 React 高频 useEffect 导致页面崩溃的真实案例:从根因排查到彻底优化
前端·react.js·面试
byte轻骑兵2 小时前
【安全函数】C语言安全字符串函数详解:告别缓冲区溢出的噩梦
c语言·安全·面试
CoovallyAIHub2 小时前
超越像素的视觉:亚像素边缘检测原理、方法与实战
深度学习·算法·计算机视觉
CoovallyAIHub2 小时前
中科大西工大提出RSKT-Seg:精度速度双提升,开放词汇分割不再难
深度学习·算法·计算机视觉
gugugu.2 小时前
算法:位运算类型题目练习与总结
算法
百***97642 小时前
【语义分割】12个主流算法架构介绍、数据集推荐、总结、挑战和未来发展
算法·架构
代码不停2 小时前
Java分治算法题目练习(快速/归并排序)
java·数据结构·算法
bubiyoushang8882 小时前
基于MATLAB的马尔科夫链蒙特卡洛(MCMC)模拟实现方法
人工智能·算法·matlab