知识详解
核心思想:在树上做DP=DFS+状态转移
通常以节点为状态,父结节点的值依赖于子节点的值
使用后序遍历(先处理孩子,在处理父亲)
基本框架

经典问题




例题

cs
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define MAX 100005
typedef long long ll;
const ll INF=1e18;
//邻接表存储图
typedef struct AdjNode
{
int v;//邻居节点编号
struct AdjNode *next;//下一个邻居
}AdjNode;
typedef struct AdjList
{
AdjNode *head;//链表头指针
}AdjList;
AdjList graph[MAX];//图的邻接表数组
ll val[MAX];
ll dp[MAX];
ll maxsum=0;
void add(int u,int v)
{
AdjNode *newnode=(AdjNode*)malloc(sizeof(AdjNode));
newnode->v=v;
newnode->next=graph[u].head;
graph[u].head=newnode;
newnode=(AdjNode*)malloc(sizeof(AdjNode));
newnode->v=u;
newnode->next=graph[v].head;
graph[v].head=newnode;
}
void dfs(int u,int parent)
{
dp[u]=val[u];//初始化为节点自己的权值
AdjNode *p=graph[u].head;
while(p!=NULL)
{
int v=p->v;
if(v!=parent)//避免回到父节点
{
dfs(v,u);//先递归处理字节点
if(dp[v]>0)//如果子节点的贡献为正
{
dp[u]+=dp[v];//就累加到当前节点
}
}
p=p->next;
}
if(dp[u]>maxsum) maxsum=dp[u];//更新全局最大值
}
void free_graph(int n)
{
for(int i=1;i<=n;i++)
{
AdjNode *p=graph[i].head;
while(p!=NULL)
{
AdjNode *t=p;
p=p->next;
free(t);
}
graph[i].head=NULL;
}
}
int main(int argc, char *argv[])
{
int n;
scanf("%d",&n);
memset(graph,0,sizeof(graph));
for(int i=1;i<=n;i++)
{
scanf("%lld",&val[i]);
}
for(int i=0;i<n-1;i++)
{
int u,v;
scanf("%d %d",&u,&v);
add(u,v);
}
dfs(1,-1);
printf("%lld",maxsum);
free_graph(n);
return 0;
}

cs
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#define MAX 100010
int head[MAX],e[MAX],next[MAX],idx;
//head[u]存储节点u的第一条边的索引(头指针)
//e[i]存储第i条边的目标节点
//next[i]存储第i条边的下一条边的索引
//idx边的计数器,每添加一条边就+1
void add(int a,int b)
{
e[idx]=b;//这条边指向b
next[idx]=head[a];//新边的next指向原来的第一条边
head[a]=idx++;//更新头指针为当前边,并让idx自增
}
int dfs(int u)
{
int hmax=0,cnt=0;
for(int i=head[u];i!=-1;i=next[i])
{
int j=e[i];
int child_depth=dfs(j);
if(child_depth>hmax) hmax=child_depth;
cnt++;
}
return hmax+cnt;
}
int main()
{
int n;
scanf("%d",&n);
memset(head,-1,sizeof(head));
idx=0;
for(int i=2;i<=n;i++)
{
int p;
scanf("%d",&p);
add(p,i);
}
printf("%d",dfs(1));
return 0;
}

cs
#include <stdio.h>
#include <stdlib.h>
#define MAX 100010
typedef long long ll;
int n;
int head[MAX],to[MAX],next_edge[MAX],weight[MAX],edge_cnt;
ll dist[MAX];
int farthest_node;
ll max_dist;
void add_edge(int u,int v,int w)
{
to[edge_cnt]=v;
weight[edge_cnt]=w;
next_edge[edge_cnt]=head[u];
head[u]=edge_cnt++;
}
void dfs(int u,int parent,ll d)
{
dist[u]=d;//记录起点到u的距离
if(d>max_dist)//更新最远距离
{
max_dist=d;
farthest_node=u;//记录最远节点
}
for(int i=head[u];i!=-1;i=next_edge[i])
{
int v=to[i];
int w=weight[i];
if(v==parent)continue;//不往回走
dfs(v,u,d+w);
}
}
int main(int argc, char *argv[])
{
scanf("%d",&n);
memset(head,-1,sizeof(head));
edge_cnt=0;
for(int i=0;i<n-1;i++)
{
int p,q,d;
scanf("%d %d %d",&p,&q,&d);
add_edge(p,q,d);//添加正向边
add_edge(q,p,d);//添加反向边(无边图)
}
//第一次DFS:从1出发找最远点
//找直径端点
max_dist=-1;
dfs(1,-1,0);
int u=farthest_node;
//第二次DFS:从u出发找最远点v
//找直径长度
max_dist=-1;
dfs(u,-1,0);
ll diameter=max_dist;
ll cost=diameter*(diameter+21)/2;
printf("%lld",cost);
return 0;
}

在树上选择若干节点,不能同时选相邻节点(父子关系),最大化节点权值和。
cs
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define MAX 100010
typedef long long ll;
int head[MAX],to[MAX],next_edge[MAX],edge_cnt;
ll a[MAX];
ll dp[MAX][2];
void add_edge(int u,int v)
{
to[edge_cnt]=v;
next_edge[edge_cnt]=head[u];
head[u]=edge_cnt++;
}
void dfs(int u,int parent)
{
dp[u][0]=0;
dp[u][1]=a[u];
for(int i=head[u];i!=-1;i=next_edge[i])
{
int v=to[i];
if(v==parent) continue;
dfs(v,u);
dp[u][0]+=(dp[v][0]>dp[v][1]?dp[v][0]:dp[v][1]);
dp[u][1]+=dp[v][0];
}
}
int main(int argc, char *argv[])
{
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
}
memset(head,-1,sizeof(head));
edge_cnt=0;
for(int i=0;i<n-1;i++)
{
int u,v;
scanf("%d %d",&u,&v);
add_edge(u,v);
add_edge(v,u);
}
dfs(1,-1);
ll ans=dp[1][0]>dp[1][1]?dp[1][0]:dp[1][1];
printf("%lld",ans);
return 0;
}

cs
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#define MAXN 100010
#define MAXM 200020
int head[MAXN], to[MAXM], next_edge[MAXM], edge_cnt;
int parent[MAXN]; // 记录父节点,用于找根
int size[MAXN]; // 子树大小(包括自己)
int n, m;
void add_edge(int u, int v) {
to[edge_cnt] = v;
next_edge[edge_cnt] = head[u];
head[u] = edge_cnt++;
}
void dfs(int u) {
size[u] = 1;
for (int i = head[u]; i != -1; i = next_edge[i]) {
int v = to[i];
dfs(v);
size[u] += size[v];
}
}
typedef struct {
int id;
int cnt; // 手下人数
} Person;
int cmp(const void *a, const void *b) {
Person *pa = (Person*)a;
Person *pb = (Person*)b;
if (pa->cnt != pb->cnt) {
return pb->cnt - pa->cnt; // 手下人数多的在前
}
return pa->id - pb->id; // 人数相同,id 小的在前
}
int main() {
scanf("%d %d", &n, &m);
memset(head, -1, sizeof(head));
memset(parent, 0, sizeof(parent));
edge_cnt = 0;
for (int i = 0; i < n - 1; i++) {
int l, r;
scanf("%d %d", &l, &r);
add_edge(r, l); // r 是 l 的上级
parent[l] = r; // 记录 l 的父节点
}
// 找根(没有父节点的节点)
int root = -1;
for (int i = 1; i <= n; i++) {
if (parent[i] == 0) {
root = i;
break;
}
}
// 计算子树大小
dfs(root);
// 计算每个人的手下人数
Person *people = (Person*)malloc(n * sizeof(Person));
for (int i = 1; i <= n; i++) {
people[i-1].id = i;
people[i-1].cnt = size[i] - 1; // 手下人数 = 子树大小 - 1
}
// 排序
qsort(people, n, sizeof(Person), cmp);
// 找小明的排名
int rank = 0;
for (int i = 0; i < n; i++) {
if (people[i].id == m) {
rank = i + 1;
break;
}
}
printf("%d\n", rank);
free(people);
return 0;
}