[蓝桥杯 2015 省 B] 生命之树
题目描述
在 X 森林里,上帝创建了生命之树。
他给每棵树的每个节点(叶子也称为一个节点)上,都标了一个整数,代表这个点的和谐值。
上帝要在这棵树内选出一个节点集合 S S S(允许为空集),使得对于 S S S 中的任意两个点 a , b a,b a,b,都存在一个点列 a , v 1 , v 2 , ⋯ , v k , b {a,v_1,v_2, \cdots ,v_k,b} a,v1,v2,⋯,vk,b 使得这个点列中的每个点都是 S S S 里面的元素,且序列中相邻两个点间有一条边相连。
在这个前提下,上帝要使得 S S S 中的点所对应的整数的和尽量大。
这个最大的和就是上帝给生命之树的评分。
经过 atm 的努力,他已经知道了上帝给每棵树上每个节点上的整数。但是由于 atm 不擅长计算,他不知道怎样有效的求评分。他需要你为他写一个程序来计算一棵树的分数。
输入格式
第一行一个整数 n n n 表示这棵树有 n n n 个节点。
第二行 n n n 个整数,依次表示每个节点的评分。
接下来 n − 1 n-1 n−1 行,每行 2 2 2 个整数 u , v u,v u,v,表示存在一条 u u u 到 v v v 的边。由于这是一棵树,所以是不存在环的。
输出格式
输出一行一个数,表示上帝给这棵树的分数。
样例 #1
样例输入 #1
5
1 -2 -3 4 5
4 2
3 1
1 2
2 5
样例输出 #1
8
提示
对于 30 % 30\% 30% 的数据, n ≤ 10 n \le 10 n≤10。
对于 100 % 100\% 100% 的数据, 0 < n ≤ 1 0 5 , 0<n \le 10^5, 0<n≤105, 每个节点的评分的绝对值不超过 1 0 6 10^6 106。
时限 3 秒, 256M。
蓝桥杯 2015 省赛 B 组 J 题。
思路
首先,定义一些常量和全局变量。其中,N
是节点的最大数量,w[N]
是存储每个节点的评分,dp[N]
是存储每个节点的最大分数,g[N]
是一个向量数组,存储图的邻接表。
dfs
函数是深度优先搜索的实现。在这个函数中,首先将当前节点的评分赋值给 dp[x]
,然后遍历当前节点的所有邻居节点,如果邻居节点不是父节点,则将邻居节点的最大分数加到 dp[x]
上。
状态转移方程为:
d p [ x ] = w [ x ] + ∑ i ∈ g [ x ] , i ≠ f a max ( d f s ( i , x ) , 0 ) dp[x] = w[x] + \sum_{i \in g[x], i \neq fa} \max(dfs(i, x), 0) dp[x]=w[x]+i∈g[x],i=fa∑max(dfs(i,x),0)
其中, d p [ x ] dp[x] dp[x] 表示以 x x x 为根的子树中,按照题目要求选取节点后,能得到的最大和谐值。 w [ x ] w[x] w[x] 是节点 x x x 的权值, g [ x ] g[x] g[x] 是节点 x x x 的所有子节点, f a fa fa 是节点 x x x 的父节点。
以 x x x 为根的子树中,能得到的最大和谐值,等于节点 x x x 自身的权值,加上它所有子节点中,以子节点为根的子树能得到的最大和谐值之和。这里的最大和谐值是非负的,如果以某个子节点为根的子树的最大和谐值是负数,那么就不选取这个子树。
在 main
函数中,首先读取节点的数量 n
,然后读取每个节点的评分,接着读取每条边的两个节点,将这两个节点分别添加到对方的邻居列表中。然后调用 dfs
函数计算每个节点的最大分数,最后遍历 dp
数组,找出最大的分数并输出。
AC代码
cpp
#include <algorithm>
#include <iostream>
#include <vector>
#define mp make_pair
#define AUTHOR "HEX9CF"
using namespace std;
using ll = long long;
const int N = 1e5 + 7;
const int INF = 0x3f3f3f3f;
const ll MOD = 1e9 + 7;
int w[N];
ll dp[N];
vector<int> g[N];
ll dfs(int x, int fa) {
dp[x] = w[x];
// cout << x << " ";
for (const auto i : g[x]) {
if (i == fa) {
continue;
}
dp[x] += max(dfs(i, x), 0LL);
}
return dp[x];
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n;
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> w[i];
}
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0);
ll ans = 0;
for (int i = 1; i <= n; i++) {
ans = max(ans, dp[i]);
}
cout << ans << endl;
return 0;
}