Select from Subtrees
Problem Statement
给你一棵有根树 T T T,共有 N N N 个顶点,顶点编号依次是顶点 1 1 1、顶点 2 2 2、 ... \dots ...、顶点 N N N。顶点 1 1 1 是树 T T T 的根,且顶点 i i i( 2 ≤ i ≤ N 2 \le i \le N 2≤i≤N)的父节点为 P i P_i Pi。
此外,顶点 i i i( 1 ≤ i ≤ N 1 \le i \le N 1≤i≤N)上有 C i C_i Ci 颗糖果。所有 ( C 1 + C 2 + ⋯ + C N ) (C_1+C_2+\dots+C_N) (C1+C2+⋯+CN) 颗糖果都是彼此不同的。
高桥给了 N N N 只松鼠一些任务。具体来说,第 i i i 只松鼠( 1 ≤ i ≤ N 1 \le i \le N 1≤i≤N)的任务是:
- 从以顶点 i i i 为根的子树中,选出并收集 D i D_i Di 颗糖果。
不同的松鼠不能拿同一颗糖果。请输出所有可能的选法数量,对 998244353 998244353 998244353 取模。
注意,即使最终选出的糖果集合相同,只要拿糖果的松鼠不同,也算作不同的选法。
如果无法满足所有松鼠的要求,输出 0 0 0。
Constraints
- 2 ≤ N ≤ 2 × 10 5 2 \le N \le 2 \times 10^5 2≤N≤2×105
- 1 ≤ P i ≤ N 1 \le P_i \le N 1≤Pi≤N
- 1 ≤ C i ≤ 10 9 1 \le C_i \le 10^9 1≤Ci≤109
- 1 ≤ D i 1 \le D_i 1≤Di
- D 1 + D 2 + ⋯ + D N ≤ 10 6 D_1+D_2+\dots+D_N \le 10^6 D1+D2+⋯+DN≤106
- 所有输入均为整数。
- T T T 是以顶点 1 1 1 为根的树。
Input
输入从标准输入读取,格式如下:
N N N
P 2 P_2 P2 P 3 P_3 P3 ... \dots ... P N P_N PN
C 1 C_1 C1 C 2 C_2 C2 ... \dots ... C N C_N CN
D 1 D_1 D1 D 2 D_2 D2 ... \dots ... D N D_N DN
Output
输出所有可能的选法数量,对 998244353 998244353 998244353 取模。
Solution
从叶子到根节点求答案。记 subtree i \text{subtree}i subtreei 为以 i i i 为根的子树。对于节点 i i i,需要在 ∑ j ∈ subtree i , j ≠ i C j \sum{j \in \text{subtree}i,j \neq i} C_j ∑j∈subtreei,j=iCj 颗糖果中选择 D i D_i Di 颗,而每颗糖果不同,所以方案数即为 ( ∑ j ∈ subtree i , j ≠ i C j D i ) \binom{\sum{j \in \text{subtree}i,j \neq i} C_j}{D_i} (Di∑j∈subtreei,j=iCj) 种,再将每个节点的方案数相乘。组合数与 ∑ j ∈ subtree i , j ≠ i C j \sum{j \in \text{subtree}_i,j \neq i} C_j ∑j∈subtreei,j=iCj 可预处理。计算组合数时若运算不合法,即无解。
Code
cpp
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=2e5+5,mod=998244353;
int n;
int p[maxn],c[maxn],d[maxn];
int cs[maxn],ds[maxn];
vector<int> g[maxn];
int ans;
int quick_pow(int x,int y){
int res=1;
while(y){
if(y&1) res=res*x%mod;
x=x*x%mod,y>>=1;
}
return res;
}
void dfs(int x){
cs[x]=c[x],ds[x]=d[x];
for(auto i:g[x]){
dfs(i);
cs[x]+=cs[i],ds[x]+=ds[i];
}
}
int cpr(int x,int y){
int res=1;
for(int i=x-y+1;i<=x;i++){
res=res*(i%mod)%mod;
}
for(int i=1;i<=y;i++){
res=res*quick_pow(i%mod,mod-2)%mod;
}
return res;
}
signed main(){
ans=1;
cin>>n;
for(int i=2;i<=n;i++){
cin>>p[i];
g[p[i]].push_back(i);
}
for(int i=1;i<=n;i++){
cin>>c[i];
}
for(int i=1;i<=n;i++){
cin>>d[i];
}
dfs(1);
for(int i=1;i<=n;i++){
ans*=cpr(cs[i]-ds[i]+d[i],d[i]),ans%=mod;
}
cout<<ans<<"\n";
return 0;
}