题解 - 树上游走(二)(上海月赛2024.7甲组T1)

题目描述

给定一个棵由编号为 1~n 这 n 个节点构成的树,及 m 个关键节点,树的第 i 条双向边连接 ui 和 vi ,长度为 wi 。你可以从树上任何一个节点出发开始游走,到任何一个节点结束,但整个游走过程中必须经过给定的 m 个关键节点。

请问,应当如何设计路线,能够使得整个游走过程中,经过的路径长度最短。

输入格式

输入第一行,两个正整数 n,m

接下来 n−1 行,每行 3 个正整数,其中第 i+1 行分别表示第 i 条边的 ui,vi,wi。

输入最后行,m 个正整数,分别表示 m 个关键节点的编号 ci。

输出格式

输出共一行,表示所求最短长度。

数据范围

1≤m≤n≤1e5,1​ ≤ ui,vi,ci ≤n, 1≤w ≤1e4

样例数据

输入:

5 3

1 2 2

2 3 4

2 4 5

1 5 1

3 4 5

输出:

15

说明:

3-->2-->1-->5-->1-->2-->4

4 2 1 1 2 5

分析

树形dp + 换根

  1. 首先考虑对于起点固定,应该怎么做。

假设起点为st,假设m个关键节点分布在st的k个子树内,则我们将只经过其中1个子树一次,然后其余k-1个子树的每条边经过两次。

一个很显然的贪心是,为了使总路程最短,我们应该满足 只经过1次的子树的路程和 最大。

可以dfs来求

其中f[u]代表从u节点访问u子树的所有关键点然后再返回u节点需要的总路程(即所有路径均为两倍),

mx[i][0]和mx[i][1]分别表示i的子树的最大值和次大值。(至于为什么还需要维护次大值,后续会讲到)

cpp 复制代码
void dfs(int u,int fa){
    if(flag[u]) cnt[u] = 1;

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];
        if(j != fa){
            dfs(j,u);
            
            if(cnt[j]){
                cnt[u] += cnt[j];
                f[u] += f[j] + 2 * val;

                int tmp = mx[j][0] + val;
                if(tmp > mx[u][0]){
                    mx[u][1] = mx[u][0],id[u][1] = id[u][0];
                    mx[u][0] = tmp,id[u][0] = j;
                }else if(tmp > mx[u][1]){
                    mx[u][1] = tmp,id[u][1] = j;
                }
            } 
        }
    }
}

最终结果即为 f[st] - mx[st][0]

  1. 然后考虑起点不固定

可以采用换根dp来写

对于 u → j u \rightarrow j u→j 这条边,根从 u u u 转移到 j j j ,

若 j 子树内关键节点数为0,则 f[j] = f[u] + 2 * val

若 j 子树内关键节点数为m,则 f[j] 不变;

否则, f[j] = f[u]

若 j 子树内关键节点数为m,则还需要更新mx[j][0]和mx[j][1]的值。(具体细节见代码)

cpp 复制代码
void dp(int u,int fa){
    res = min(res,f[u] - mx[u][0]);

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];

        if(j != fa){
            if(!cnt[j]) f[j] = f[u] + 2 * val;
            else if(cnt[j] == m) f[j] = f[j];
            else f[j] = f[u];

            if(cnt[j] != m){
                int tmp;
                if(id[u][0] == j) tmp = mx[u][1] + val;
                else tmp = mx[u][0] + val;

                if(tmp > mx[j][0]){
                    mx[j][1] = mx[j][0],id[j][1] = id[j][0];
                    mx[j][0] = tmp,id[j][0] = u;
                }else if(tmp > mx[j][1]){
                    mx[j][1] = tmp,id[j][1] = u;
                }
           }

            dp(j,u);
        }
    }
}

完整代码

cpp 复制代码
#include<bits/stdc++.h>

using namespace std;

typedef long long LL;

const int N = 1e5 + 10;

int n,m;
bool flag[N];
int h[N],e[N * 2],w[N * 2],ne[N * 2],idx;
LL cnt[N],f[N],mx[N][2],id[N][2],res;

void add(int a,int b,int c){
    e[idx] = b,w[idx] = c,ne[idx] = h[a],h[a] = idx++;
}

void dfs(int u,int fa){
    if(flag[u]) cnt[u] = 1;

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];
        if(j != fa){
            dfs(j,u);
            
            if(cnt[j]){
                cnt[u] += cnt[j];
                f[u] += f[j] + 2 * val;

                int tmp = mx[j][0] + val;
                if(tmp > mx[u][0]){
                    mx[u][1] = mx[u][0],id[u][1] = id[u][0];
                    mx[u][0] = tmp,id[u][0] = j;
                }else if(tmp > mx[u][1]){
                    mx[u][1] = tmp,id[u][1] = j;
                }
            } 
        }
    }
}

void dp(int u,int fa){
    res = min(res,f[u] - mx[u][0]);

    for(int i = h[u];i != -1;i = ne[i]){
        int j = e[i],val = w[i];

        if(j != fa){
            if(!cnt[j]) f[j] = f[u] + 2 * val;
            else if(cnt[j] == m) f[j] = f[j];
            else f[j] = f[u];

            if(cnt[j] != m){
                int tmp;
                if(id[u][0] == j) tmp = mx[u][1] + val;
                else tmp = mx[u][0] + val;

                if(tmp > mx[j][0]){
                    mx[j][1] = mx[j][0],id[j][1] = id[j][0];
                    mx[j][0] = tmp,id[j][0] = u;
                }else if(tmp > mx[j][1]){
                    mx[j][1] = tmp,id[j][1] = u;
                }
           }

            dp(j,u);
        }
    }
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);

    memset(h,-1,sizeof h);

    cin >> n >> m;
    for(int i = 0,u,v,w;i < n - 1;i++){
        cin >> u >> v >> w;
        add(u,v,w),add(v,u,w);
    }
    for(int i = 1,j;i <= m;i++){
        cin >> j;
        flag[j] = true;
    }

    dfs(1,-1);

    res = f[1] - mx[1][0];

    dp(1,-1);

    cout << res;

    return 0;
}
相关推荐
s09071369 分钟前
Xilinx FPGA使用 FIR IP 核做匹配滤波时如何减少DSP使用量
算法·fpga开发·xilinx·ip core·fir滤波
老马啸西风11 分钟前
成熟企业级技术平台-10-跳板机 / 堡垒机(Bastion Host)详解
人工智能·深度学习·算法·职场和发展
子夜江寒12 分钟前
逻辑回归简介
算法·机器学习·逻辑回归
软件算法开发26 分钟前
基于ACO蚁群优化算法的多车辆含时间窗VRPTW问题求解matlab仿真
算法·matlab·aco·vrptw·蚁群优化·多车辆·时间窗
another heaven37 分钟前
【软考 磁盘磁道访问时间】总容量等相关案例题型
linux·网络·算法·磁盘·磁道
tap.AI38 分钟前
理解FSRS算法:一个现代间隔重复调度器的技术解析
算法
摇滚侠39 分钟前
面试实战 问题三十三 Spring 事务常用注解
数据库·spring·面试
老马啸西风1 小时前
成熟企业级技术平台-09-加密机 / 密钥管理服务 KMSS(Key Management & Security Service)
人工智能·深度学习·算法·职场和发展
cooldream20091 小时前
当代 C++ 的三大技术支柱:资源管理、泛型编程与模块化体系的成熟演进
开发语言·c++
while(1){yan}1 小时前
网络基础知识
java·网络·青少年编程·面试·电脑常识