题目大意
有一个有 n n n个点 m m m条边的图,第 i i i条边有一个边权 c i c_i ci。
给定 K K K,有 q q q次询问,每次给出一个 x x x,如果 c i ⊕ x < K c_i\oplus x<K ci⊕x<K,则这条边存在,否则不存在。对于每个询问,输出互相连通的点对个数(即有多少 1 ≤ i < j ≤ n 1\leq i<j\leq n 1≤i<j≤n使得 i , j i,j i,j连通)。
1 ≤ n , m ≤ 1 0 5 , 0 ≤ c i , x , K < 2 30 1\leq n,m\leq 10^5,0\leq c_i,x,K<2^{30} 1≤n,m≤105,0≤ci,x,K<230
时间限制 2000 m s 2000ms 2000ms,空间限制 512 M B 512MB 512MB。
题解
我们可以想到,比较 c i ⊕ x c_i\oplus x ci⊕x和 K K K,如果 c i ⊕ x c_i\oplus x ci⊕x和 K K K的前 t − 1 t-1 t−1位都是相同的:
- 当 K K K的第 t t t位为 1 1 1时, c i ⊕ x c_i\oplus x ci⊕x的第 t t t位为 0 0 0就一定可以取到,否则继续比较
- 当 K K K的第 t t t位为 0 0 0时, c i ⊕ x c_i\oplus x ci⊕x的第 t t t位为 0 0 0才能满足条件,还要继续往下比较
那我们把 c i ⊕ K c_i\oplus K ci⊕K存储到字典树中,判断哪些 x x x可以使边 i i i存在。设当前放到了第 t t t位, c i ⊕ K c_i\oplus K ci⊕K的第 t t t位为 p p p:
- 如果 K K K的第 t t t位为 1 1 1,则当 x x x的前 t − 1 t-1 t−1位和 c i ⊕ K c_i\oplus K ci⊕K的前 t − 1 t-1 t−1位相同且 x ⊕ c i x\oplus c_i x⊕ci的第 t t t位为 0 0 0时(即 x x x的第 t t t位与 c i c_i ci的第 t t t位相同)一定可以取到,在对应子树的根节点加上这条边,然后按 p p p继续往下遍历
- 如果 K K K的第 t t t位为 0 0 0,则继续按 p p p继续往下遍历
然后,我们遍历这棵字典树,来求出每个叶子节点的答案。对于每个点,先将其能取到的边用并查集维护并计算贡献,然后遍历其子树,求完子树中叶子节点的答案之后再将这个点在图上连的边删除,所以要用带删并查集。带删并查集中要用栈来维护加入了哪些边,因为不能路径压缩,所以要用 dsu on tree \text{dsu on tree} dsu on tree来保证每次查找的时间复杂度是 O ( log n ) O(\log n) O(logn)的。删边时不断删去栈顶的边并将栈顶弹出,直到这个点所有在先前加入栈中的边都被删完。
这样的话,每个叶子节点的答案都算出来了,查询就是 O ( 1 ) O(1) O(1)的了。
时间复杂度为 O ( n log V log n + m ) O(n\log V\log n+m) O(nlogVlogn+m),其中 V V V表示 c i , x , K c_i,x,K ci,x,K的值域。
可以参考代码帮助理解。
code
cpp
#include<bits/stdc++.h>
using namespace std;
const int N=100000;
int n,m,q,K,tot=1,tp=0,fa[N+5],siz[N+5];
int wh[N+5],ch[32*N+5][2];
long long now=0,ans[32*N+5];
pair<int,int>st[N+5];
vector<int>v[32*N+5][2];
struct node{
int x,y,w;
}w[N+5];
void pt(int w,int id){
int q,vq=1;
for(int i=30;i>=0;i--){
if((K>>i)&1){
v[vq][(w>>i)&1].push_back(id);
}
q=((K^w)>>i)&1;
if(!ch[vq][q]) ch[vq][q]=++tot;
vq=ch[vq][q];
}
}
int gt(int x){
int q,vq=1;
for(int i=30;i>=0;i--){
q=(x>>i)&1;
if(!ch[vq][q]) ch[vq][q]=++tot;
vq=ch[vq][q];
}
return vq;
}
int find(int ff){
if(fa[ff]!=ff) return find(fa[ff]);
return ff;
}
long long gts(int x){
return 1ll*x*(x-1)/2;
}
void merge(int x,int y){
int v1=find(x),v2=find(y);
if(v1==v2) return;
if(siz[v1]<siz[v2]) swap(v1,v2);
now-=gts(siz[v1])+gts(siz[v2]);
fa[v2]=v1;
siz[v1]+=siz[v2];
st[++tp]={v1,v2};
now+=gts(siz[v1]);
}
void del(int x,int y){
now-=gts(siz[x]);
siz[x]-=siz[y];
fa[y]=y;
now+=gts(siz[x])+gts(siz[y]);
}
void stpop(int tmp){
while(tp>tmp){
del(st[tp].first,st[tp].second);
--tp;
}
}
void solve(int u){
ans[u]=now;
if(ch[u][0]){
int tmp=tp;
for(int p:v[u][0]){
merge(w[p].x,w[p].y);
}
solve(ch[u][0]);
stpop(tmp);
}
if(ch[u][1]){
int tmp=tp;
for(int p:v[u][1]){
merge(w[p].x,w[p].y);
}
solve(ch[u][1]);
stpop(tmp);
}
}
int main()
{
// freopen("xor.in","r",stdin);
// freopen("xor.out","w",stdout);
scanf("%d%d%d%d",&n,&m,&q,&K);
for(int i=1;i<=m;i++){
scanf("%d%d%d",&w[i].x,&w[i].y,&w[i].w);
pt(w[i].w,i);
}
for(int i=1;i<=n;i++){
fa[i]=i;siz[i]=1;
}
for(int i=1,x;i<=q;i++){
scanf("%d",&x);
wh[i]=gt(x);
}
solve(1);
for(int i=1;i<=q;i++){
printf("%lld\n",ans[wh[i]]);
}
return 0;
}