题解
70pts
树上 01 背包。
Code
cpp
#include <bits/stdc++.h>
#define int long long
using namespace std;
inline int read()
{
int x = 0,p = 1;
char c = getchar();
while(c > '9' || c < '0') p = c == '-' ? -1 : 1,c = getchar();
while(c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48),c = getchar();
return p * x;
}
template<typename _Tp> inline void _write(_Tp x) { if(x > 9) _write(x / 10); putchar(x % 10 | 48); }
template<typename _Tp> inline void write(_Tp x) { if(x < 0) putchar('-'),x = -x; _write(x); }
template<typename _Tp> inline void writeln(_Tp x) { write(x); puts(""); }
const int N = 1e5 + 10;
int n,Q,v[N],w[N],f[5010][5010];
struct query
{
int u,L;
} q[N];
signed main()
{
n = read();
for(int i = 1;i <= n;i++) v[i] = read(),w[i] = read();
Q = read();
int maxL = 0;
for(int i = 1;i <= Q;i++)
q[i].u = read(),q[i].L = read(),maxL = max(maxL,q[i].L);
f[1][w[1]] = v[1];
for(int i = 2;i <= n;i++)
{
int k = i >> 1;
for(int j = 1;j < w[i];j++)
f[i][j] = f[k][j];
for(int j = w[i];j <= maxL;j++)
f[i][j] = max(f[k][j],f[k][j - w[i]] + v[i]);
}
for(int i = 1;i <= Q;i++)
{
int ans = 0;
for(int j = 1;j <= q[i].L;j++)
ans = max(ans,f[q[i].u][j]);
writeln(ans);
}
return 0;
}
100 pts
可以发现前 n \sqrt{n} n 个数被访问的次数比较多,所以考虑预计算这些数的 01 背包的 dp 值。
剩下的点用 dfs 暴力枚举即可,但是统计答案时要特别注意。
AC Code
cpp
#include <bits/stdc++.h>
#define int long long
using namespace std;
inline int read()
{
int x = 0,p = 1;
char c = getchar();
while(c > '9' || c < '0') p = c == '-' ? -1 : 1,c = getchar();
while(c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48),c = getchar();
return p * x;
}
template<typename _Tp> inline void _write(_Tp x) { if(x > 9) _write(x / 10); putchar(x % 10 | 48); }
template<typename _Tp> inline void write(_Tp x) { if(x < 0) putchar('-'),x = -x; _write(x); }
template<typename _Tp> inline void writeln(_Tp x) { write(x); puts(""); }
const int N = 1e5 + 10;
int n,Q,v[N],w[N],f[320][N];
int _sqrt;
int ans = 0;
struct query
{
int u,L;
} q[N];
void dfs(int u,int L,int value = 0,int weight = 0)
{
if(L - weight < 0) return;
if(u <= _sqrt)
{
ans = max(ans,value + f[u][L - weight]);
return;
}
dfs(u >> 1,L,value + v[u],weight + w[u]);
dfs(u >> 1,L,value,weight);
}
signed main()
{
n = read();
for(int i = 1;i <= n;i++) v[i] = read(),w[i] = read();
Q = read();
int maxL = 0;
for(int i = 1;i <= Q;i++)
q[i].u = read(),q[i].L = read(),maxL = max(maxL,q[i].L);
for(int i = w[1];i <= maxL;i++) f[1][i] = v[1];
_sqrt = sqrt(n);
for(int i = 2;i <= _sqrt;i++)
{
int k = i >> 1;
for(int j = 1;j < w[i];j++)
f[i][j] = f[k][j];
for(int j = w[i];j <= maxL;j++)
f[i][j] = max(f[k][j],f[k][j - w[i]] + v[i]);
}
for(int i = 1;i <= Q;i++)
{
ans = 0;
if(q[i].u <= _sqrt)
ans = max(ans,f[q[i].u][q[i].L]);
else
dfs(q[i].u,q[i].L);
writeln(ans);
}
return 0;
}