题解 - 树上游走(二)(上海月赛2024.7甲组T1)

题目描述

给定一个棵由编号为 1~n 这 n 个节点构成的树,及 m 个关键节点,树的第 i 条双向边连接 ui 和 vi ,长度为 wi 。你可以从树上任何一个节点出发开始游走,到任何一个节点结束,但整个游走过程中必须经过给定的 m 个关键节点。

请问,应当如何设计路线,能够使得整个游走过程中,经过的路径长度最短。

输入格式

输入第一行,两个正整数 n,m

接下来 n−1 行,每行 3 个正整数,其中第 i+1 行分别表示第 i 条边的 ui,vi,wi。

输入最后行,m 个正整数,分别表示 m 个关键节点的编号 ci。

输出格式

输出共一行,表示所求最短长度。

数据范围

1≤m≤n≤1e5,1​ ≤ ui,vi,ci ≤n, 1≤w ≤1e4

样例数据

输入:

5 3

1 2 2

2 3 4

2 4 5

1 5 1

3 4 5

输出:

15

说明:

3-->2-->1-->5-->1-->2-->4

4 2 1 1 2 5

分析

树形dp + 换根

  1. 首先考虑对于起点固定,应该怎么做。

假设起点为st,假设m个关键节点分布在st的k个子树内,则我们将只经过其中1个子树一次,然后其余k-1个子树的每条边经过两次。

一个很显然的贪心是,为了使总路程最短,我们应该满足 只经过1次的子树的路程和 最大。

可以dfs来求

其中f[u]代表从u节点访问u子树的所有关键点然后再返回u节点需要的总路程(即所有路径均为两倍),

mx[i][0]和mx[i][1]分别表示i的子树的最大值和次大值。(至于为什么还需要维护次大值,后续会讲到)

cpp 复制代码
void dfs(int u,int fa){
    if(flag[u]) cnt[u] = 1;

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];
        if(j != fa){
            dfs(j,u);
            
            if(cnt[j]){
                cnt[u] += cnt[j];
                f[u] += f[j] + 2 * val;

                int tmp = mx[j][0] + val;
                if(tmp > mx[u][0]){
                    mx[u][1] = mx[u][0],id[u][1] = id[u][0];
                    mx[u][0] = tmp,id[u][0] = j;
                }else if(tmp > mx[u][1]){
                    mx[u][1] = tmp,id[u][1] = j;
                }
            } 
        }
    }
}

最终结果即为 f[st] - mx[st][0]

  1. 然后考虑起点不固定

可以采用换根dp来写

对于 u → j u \rightarrow j u→j 这条边,根从 u u u 转移到 j j j ,

若 j 子树内关键节点数为0,则 f[j] = f[u] + 2 * val

若 j 子树内关键节点数为m,则 f[j] 不变;

否则, f[j] = f[u]

若 j 子树内关键节点数为m,则还需要更新mx[j][0]和mx[j][1]的值。(具体细节见代码)

cpp 复制代码
void dp(int u,int fa){
    res = min(res,f[u] - mx[u][0]);

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];

        if(j != fa){
            if(!cnt[j]) f[j] = f[u] + 2 * val;
            else if(cnt[j] == m) f[j] = f[j];
            else f[j] = f[u];

            if(cnt[j] != m){
                int tmp;
                if(id[u][0] == j) tmp = mx[u][1] + val;
                else tmp = mx[u][0] + val;

                if(tmp > mx[j][0]){
                    mx[j][1] = mx[j][0],id[j][1] = id[j][0];
                    mx[j][0] = tmp,id[j][0] = u;
                }else if(tmp > mx[j][1]){
                    mx[j][1] = tmp,id[j][1] = u;
                }
           }

            dp(j,u);
        }
    }
}

完整代码

cpp 复制代码
#include<bits/stdc++.h>

using namespace std;

typedef long long LL;

const int N = 1e5 + 10;

int n,m;
bool flag[N];
int h[N],e[N * 2],w[N * 2],ne[N * 2],idx;
LL cnt[N],f[N],mx[N][2],id[N][2],res;

void add(int a,int b,int c){
    e[idx] = b,w[idx] = c,ne[idx] = h[a],h[a] = idx++;
}

void dfs(int u,int fa){
    if(flag[u]) cnt[u] = 1;

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];
        if(j != fa){
            dfs(j,u);
            
            if(cnt[j]){
                cnt[u] += cnt[j];
                f[u] += f[j] + 2 * val;

                int tmp = mx[j][0] + val;
                if(tmp > mx[u][0]){
                    mx[u][1] = mx[u][0],id[u][1] = id[u][0];
                    mx[u][0] = tmp,id[u][0] = j;
                }else if(tmp > mx[u][1]){
                    mx[u][1] = tmp,id[u][1] = j;
                }
            } 
        }
    }
}

void dp(int u,int fa){
    res = min(res,f[u] - mx[u][0]);

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];

        if(j != fa){
            if(!cnt[j]) f[j] = f[u] + 2 * val;
            else if(cnt[j] == m) f[j] = f[j];
            else f[j] = f[u];

            if(cnt[j] != m){
                int tmp;
                if(id[u][0] == j) tmp = mx[u][1] + val;
                else tmp = mx[u][0] + val;

                if(tmp > mx[j][0]){
                    mx[j][1] = mx[j][0],id[j][1] = id[j][0];
                    mx[j][0] = tmp,id[j][0] = u;
                }else if(tmp > mx[j][1]){
                    mx[j][1] = tmp,id[j][1] = u;
                }
           }

            dp(j,u);
        }
    }
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);

    memset(h,-1,sizeof h);

    cin >> n >> m;
    for(int i = 0,u,v,w;i < n - 1;i++){
        cin >> u >> v >> w;
        add(u,v,w),add(v,u,w);
    }
    for(int i = 1,j;i <= m;i++){
        cin >> j;
        flag[j] = true;
    }

    dfs(1,-1);

    res = f[1] - mx[1][0];

    dp(1,-1);

    cout << res;

    return 0;
}
相关推荐
tinker在coding14 分钟前
Coding Caprice - Linked-List 1
算法·leetcode
LCG元3 小时前
【面试问题】JIT 是什么?和 JVM 什么关系?
面试·职场和发展
唐诺4 小时前
几种广泛使用的 C++ 编译器
c++·编译器
XH华5 小时前
初识C语言之二维数组(下)
c语言·算法
南宫生5 小时前
力扣-图论-17【算法学习day.67】
java·学习·算法·leetcode·图论
不想当程序猿_5 小时前
【蓝桥杯每日一题】求和——前缀和
算法·前缀和·蓝桥杯
落魄君子5 小时前
GA-BP分类-遗传算法(Genetic Algorithm)和反向传播算法(Backpropagation)
算法·分类·数据挖掘
冷眼看人间恩怨5 小时前
【Qt笔记】QDockWidget控件详解
c++·笔记·qt·qdockwidget
菜鸡中的奋斗鸡→挣扎鸡5 小时前
滑动窗口 + 算法复习
数据结构·算法
红龙创客6 小时前
某狐畅游24校招-C++开发岗笔试(单选题)
开发语言·c++