F. Strange Memory

思路:
题目要求计算树上所有满足 a i ⊕ a j = a lca ( i , j ) a_i \oplus a_j = a_{\text{lca}(i,j)} ai⊕aj=alca(i,j) 的点对 ( i , j ) (i,j) (i,j) 的 i ⊕ j i \oplus j i⊕j 值之和。直接枚举所有点对的时间复杂度为 O ( n 2 ) O(n^2) O(n2),显然会超时。因此我们需要一种更高效的算法来解决这个问题。
核心思路:树上启发式合并(DSU on Tree)
我们采用树上启发式合并算法来解决这个问题,其主要思想如下:
-
预处理重儿子:
- 首先进行一次 DFS,计算每个节点的子树大小并标记重儿子(子树最大的儿子)
-
分治处理:
- 对于每个节点 u,先递归处理其所有轻儿子,并清除它们的影响
- 然后处理重儿子,保留其影响
- 将当前节点 u 加入集合
- 遍历每个轻儿子子树:
- 先统计该子树与已有集合(重儿子子树和已处理的轻儿子)形成的点对贡献
- 再将该子树加入集合
-
贡献统计:
- 使用映射
map<int, vector<int>> M
存储权值到节点列表的映射 - 对于当前节点 u(作为 LCA),统计满足 a i ⊕ a j = a u a_i \oplus a_j = a_u ai⊕aj=au 的点对
- 具体实现时,对于轻儿子子树中的节点 x,查找 M 中键为 a u ⊕ a x a_u \oplus a_x au⊕ax 的节点列表,计算 x ⊕ y x \oplus y x⊕y 并累加
- 使用映射
-
时间复杂度 : O ( n log 2 n ) O(n \log^2 n) O(nlog2n) , 通过重用重儿子子树信息避免重复计算 O ( n log n ) O(n \log n) O(nlogn) , map操作是 O ( log n ) O(\log n) O(logn) 的。
树上启发式合并是解决子树统计问题的强大工具,通过重用重儿子子树信息显著降低时间复杂度。本题中,我们利用该算法高效地统计了满足特定条件的点对异或和,避免了 O ( n 2 ) O(n^2) O(n2) 的暴力枚举。
代码:
cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define endl '\n'
#define int long long
#define pb push_back
#define pii pair<int, int>
#define FU(i, a, b) for (int i = (a); i <= (b); ++i)
#define FD(i, a, b) for (int i = (a); i >= (b); --i)
const int MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const int maxn = 1e6 + 5, MAXN = maxn;
int n;
int ans = 0;
int a[maxn];
vector<int> ed[maxn];
int hc[maxn];
int cs[maxn];
map<int, vector<int>> M;
int tans = 0;
void predfs(int x, int f) {
cs[x] = 1;
int ms = 0;
for (int e : ed[x]) {
if (e == f)
continue;
predfs(e, x);
cs[x] += cs[e];
if (cs[e] > ms) {
ms = cs[e];
hc[x] = e;
}
}
}
void add(int x, int f, int rt, bool mg) {
if (mg) { // 合并
M[a[x]].pb(x);
} else { // 统计
for (int e : M[a[x] ^ a[rt]]) {
ans += e ^ x;
}
}
for (int e : ed[x]) {
if (e == f)
continue;
add(e, x, rt, mg);
}
}
void dfs(int x, int f, int k) {
for (int e : ed[x]) {
if (e == f || e == hc[x])
continue;
dfs(e, x, 0);
}
if (hc[x]) {
dfs(hc[x], x, 1);
}
M[a[x]].pb(x);
for (int e : ed[x]) {
if (e == f || e == hc[x])
continue;
add(e, x, x, 0);
add(e, x, x, 1);
}
if (!k) {
M.clear();
}
}
void solve() {
ans = 0;
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
ed[u].pb(v);
ed[v].pb(u);
}
predfs(1, 0);
dfs(1, 0, 0);
cout << ans << endl;
}
signed main() {
#ifndef ONLINE_JUDGE
freopen("../in.txt", "r", stdin);
#endif
cin.tie(0)->ios::sync_with_stdio(0);
int T = 1;
// cin >> T;
while (T--) {
solve();
}
return 0;
}