题目描述
Farmer John 有一个大农场,农场上有 N 个谷仓(1≤N≤105),其中一些已经涂色,另一些尚未涂色。Farmer John 想要为这些剩余的谷仓涂色,使得所有谷仓都被涂色,但他只有三种可用的油漆颜色。此外,他的获奖奶牛 Bessie 如果发现两个直接相连的谷仓颜色相同,会感到困惑,因此他希望确保这种情况不会发生。
保证 N 个谷仓之间的连接不会形成任何"环"。也就是说,任意两个谷仓之间最多只有一条连接路径。
Farmer John 有多少种方式可以为剩余的未涂色谷仓涂色?
输入格式
第一行包含两个整数 N 和 K(0≤K≤N),分别表示农场上的谷仓数量和已经涂色的谷仓数量。
接下来的 N−1 行每行包含两个整数 x 和 y(1≤x,y≤N,x=y),描述直接连接谷仓 x 和 y 的路径。
接下来的 K 行每行包含两个整数 b 和 c(1≤b≤N, 1≤c≤3),表示谷仓 b 已经被涂成颜色 c。
输出格式
计算为剩余谷仓涂色的有效方式数量,模 109+7,要求任何两个直接相连的谷仓颜色不同。
输入输出样例
输入 #1
4 1
1 2
1 3
1 4
4 3
输出 #1
8
思路:
依旧树形DP。
状态:dp[i][0]代表i涂第1种颜色的方法数,dp[i][1]代表i涂第2种颜色的方法数,dp[i][2]代表i涂第3种颜色的方法数。
状态转移方程:
dp[u][0]=dp[u][0]*(dp[v][2]+dp[v][1])%mod;
dp[u][1]=dp[u][1]*(dp[v][2]+dp[v][0])%mod;
dp[u][2]=dp[u][2]*(dp[v][1]+dp[v][0])%mod;
代码:
cpp
#include <bits/stdc++.h>
#include <bits/c++config.h>
#include <ostream>
#include <istream>
#include <algorithm>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <string>
#include <math.h>
#include <time.h>
#include <ctime>
#include <cstdlib>
#define ll long long
#define ull unsigned long long
#define db double
#define st string
#define ch char
#define bo bool
#define s1 27
#define s2 205
#define s3 2005
#define s4 20005
#define s5 200005
#define s6 2000005
#define s7 20000005
using namespace std;
ll mod=1000000007;
ll n,m,dp[s5][3],a[s5];
//dp[i][0]代表i涂第1种颜色,dp[i][1]代表i涂第2种颜色,dp[i][2]代表i涂第3种颜色。
vector<int> g[s5];
void dfs(int u,int fa){
if(a[u]==0){
for(int i=0;i<g[u].size();i++){
int v=g[u][i];
if(v==fa) continue;
dfs(v,u);
//dp[u][0]
dp[u][0]=dp[u][0]*(dp[v][2]+dp[v][1])%mod;
//dp[u][1]
dp[u][1]=dp[u][1]*(dp[v][2]+dp[v][0])%mod;
//dp[u][2]
dp[u][2]=dp[u][2]*(dp[v][1]+dp[v][0])%mod;
}
}
else if(a[u]!=0){
if(a[u]==1){
dp[u][1]=0;
dp[u][2]=0;
for(int i=0;i<g[u].size();i++){
int v=g[u][i];
if(v==fa) continue;
dfs(v,u);
dp[u][0]=dp[u][0]*(dp[v][2]+dp[v][1])%mod;
}
}
else if(a[u]==2){
dp[u][2]=0;
dp[u][0]=0;
for(int i=0;i<g[u].size();i++){
int v=g[u][i];
if(v==fa) continue;
dfs(v,u);
dp[u][1]=dp[u][1]*(dp[v][2]+dp[v][0])%mod;
}
}
else if(a[u]==3){
dp[u][1]=0;
dp[u][0]=0;
for(int i=0;i<g[u].size();i++){
int v=g[u][i];
if(v==fa) continue;
dfs(v,u);
dp[u][2]=dp[u][2]*(dp[v][1]+dp[v][0])%mod;
}
}
}
return ;
}
signed main(){
cin>>n>>m;
for(int i=0;i<=100005;i++){
dp[i][1]=1;
dp[i][0]=1;
dp[i][2]=1;
}
for(int i=1;i<n;i++){
int k,c;
cin>>k>>c;
g[c].push_back(k);
g[k].push_back(c);
}
for(int i=1;i<=m;i++){
int b,c;
cin>>b>>c;
a[b]=c;
}
dfs(1,0);
cout<<(dp[1][0]+dp[1][1]%mod+dp[1][2]%mod)%mod<<endl;
return 0;
}