题目传送门
思路
暴力思路就是以每个查询节点 u u u 为根,然后 d f s dfs dfs 来看是否有深度为 d d d 的节点。
暴力里面蕴含着换根 d p dp dp 的思想,所以考虑换根。
首先可以以 1 1 1 为根 d f s dfs dfs 算出每个节点的深度,然后以每个节点的 d f s dfs dfs 序 为下标来建立一个 维护深度区间最大值 m x p mx_p mxp 的线段树,并记录最大值的对应节点 i d p id_p idp。
然后再写一个 d f s dfs dfs 用来换根求解答案:
假如当前节点是 u u u,它的一个子节点为 v v v,现在根从 u u u 换到 v v v,那么整棵树中以 v v v 为根节点子树中的节点到根的距离都会 − 1 -1 −1,其他节点都会 + 1 +1 +1,这个用线段树维护就行。
然后考虑求解答案:
- 如果此时线段树中深度最大的节点的深度依旧小于 d d d,那么答案就是 − 1 -1 −1;
- 假设距离 u u u 最远的节点是 v v v,他们在原树(即以 1 1 1 为根的树)上的最近公共祖先是 l c a lca lca。如果 d i s ( u , l c a ) ≥ d dis(u, lca) \geq d dis(u,lca)≥d,那么直接从 u u u 开始跳倍增跳到距离为 d d d 的祖先,这就是答案;不然就从 v v v 开始跳距离为 d d d 的祖先。
时间复杂度 O ( ( q + n ) × l o g ( n ) ) O((q + n) \times log(n)) O((q+n)×log(n))。
代码
cpp
#include <bits/stdc++.h>
#define mkpr make_pair
#define fir first
#define sec second
#define il inline
using namespace std;
typedef pair<int, int> pii;
typedef unsigned long long ull;
typedef long long ll;
typedef long double ld;
typedef double db;
const int maxn = 2e5 + 7;
const int inf = 0x3f3f3f3f;
int n, Q;
int h[maxn], ecnt;
struct edge {int v, nxt;} e[maxn << 1];
#define addedge(u, v) (e[++ecnt] = edge{v, h[u]}, h[u] = ecnt)
vector<pii> que[maxn];
int ans[maxn];
int L[maxn], R[maxn], rfc[maxn], dfnCnt;
int fa[maxn][22], dep[maxn];
void dfs1(int u) {
rfc[L[u] = ++dfnCnt] = u;
for (int i = 1; i <= 20; ++i)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
if (fa[u][0]) dep[u] = dep[fa[u][0]] + 1;
for (int i = h[u]; i; i = e[i].nxt)
if (e[i].v != fa[u][0])
fa[e[i].v][0] = u, dfs1(e[i].v);
R[u] = dfnCnt;
}
#define ls (p << 1)
#define rs (p << 1 | 1)
#define mid ((l + r) >> 1)
struct SegmentTree {
int mx[maxn << 2], id[maxn << 2], tg[maxn << 2];
il void pushup(int p) {
mx[p] = max(mx[ls], mx[rs]);
if (mx[ls] > mx[rs]) id[p] = id[ls];
else id[p] = id[rs];
}
il void pushdown(int p) {
if (tg[p]) {
mx[ls] += tg[p], tg[ls] += tg[p];
mx[rs] += tg[p], tg[rs] += tg[p];
tg[p] = 0;
}
}
void build(int p, int l, int r) {
if (l == r) {mx[p] = dep[rfc[l]], id[p] = rfc[l]; return ;}
build(ls, l, mid), build(rs, mid + 1, r), pushup(p);
}
void mdf(int p, int l, int r, int ql, int qr, int v) {
if (ql > qr) return ;
if (qr < l || r < ql) return ;
if (ql <= l && r <= qr) {mx[p] += v, tg[p] += v; return ;}
pushdown(p);
mdf(ls, l, mid, ql, qr, v), mdf(rs, mid + 1, r, ql, qr, v);
pushup(p);
}
pii ask(int p, int l, int r, int ql, int qr) {
if (qr < l || r < ql) return mkpr(-inf, 0);
if (ql <= l && r <= qr) return mkpr(mx[p], id[p]);
pushdown(p);
pii l_res = ask(ls, l, mid, ql, qr);
pii r_res = ask(rs, mid + 1, r, ql, qr);
if (l_res.fir > r_res.fir) return l_res;
return r_res;
}
void print(int p, int l, int r) {
pushdown(p);
printf("p:%d, l:%d, r:%d, mx:%d, id:%d\n", p, l, r, mx[p], id[p]);
if (l == r) return ;
print(ls, l, mid), print(rs, mid + 1, r);
}
} sgt;
int LCA(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
int dif = dep[x] - dep[y];
for (int i = 20; i >= 0; --i)
if (dif & (1 << i)) x = fa[x][i];
if (x == y) return x;
for (int i = 20; i >= 0; --i)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
int jump(int x, int d) {
for (int i = 20; i >= 0; --i)
if (d & (1 << i)) x = fa[x][i];
return x;
}
#define dis(u, v, lca) (dep[u] + dep[v] - 2 * dep[lca])
void dfs2(int u) {
for (pii q : que[u]) {
pii res = sgt.ask(1, 1, n, 1, n);
int qry_d = q.fir, id = q.sec;
int v = res.sec, mxd = res.fir;
int lca = LCA(u, v);
if (mxd < qry_d) {
ans[id] = -1;
continue;
}
if (dis(u, lca, lca) >= qry_d) {
ans[id] = jump(u, qry_d);
continue;
}
ans[id] = jump(v, dis(u, v, lca) - qry_d);
}
for (int i = h[u]; i; i = e[i].nxt) {
int v = e[i].v;
if (v == fa[u][0]) continue;
sgt.mdf(1, 1, n, L[v], R[v], -1);
sgt.mdf(1, 1, n, 1, L[v] - 1, 1);
sgt.mdf(1, 1, n, R[v] + 1, n, 1);
dfs2(v);
sgt.mdf(1, 1, n, L[v], R[v], 1);
sgt.mdf(1, 1, n, 1, L[v] - 1, -1);
sgt.mdf(1, 1, n, R[v] + 1, n, -1);
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i < n; ++i) {
int u, v; scanf("%d%d", &u, &v);
addedge(u, v), addedge(v, u);
}
scanf("%d", &Q);
for (int i = 1; i <= Q; ++i) {
int u, d;
scanf("%d%d", &u, &d);
que[u].push_back(mkpr(d, i));
}
dfs1(1);
sgt.build(1, 1, n);
dfs2(1);
for (int i = 1; i <= Q; ++i) printf("%d\n", ans[i]);
return 0;
}