题解:最多 k 次 +1 后,大小为 m 子集按位与最大值
题意简化
给定 nums,你最多做 k 次操作,每次让任意 nums[i] += 1。
最终你可以从数组里选一个大小为 m 的子集,问这个子集的按位与(AND)最大能是多少。
关键点:可以改数(只能增大)+ 只关心选出的 m 个数 AND 的结果。
核心观察 1:AND 想变大,必须让"更多高位为 1"
按位与的性质:
- AND 结果某一位为 1 ⇔ 子集中 所有 m 个数在该位都为 1
- 高位权重更大,因此优先让高位变成 1
因此可以用经典的 按位贪心:
从高位到低位尝试把答案的这一位设为 1:
若可行就保留,否则放弃。
核心框架:按位贪心 + 可行性判定
设当前已确定的答案为 ans,尝试加上第 b 位:
cand = ans | (1<<b)
如果 cand 可行 -> ans = cand
否则 -> ans 不变
问题变成:如何判断某个 mask(cand)是否可行?
可行性判定:把问题转为"代价函数 + 选 m 个最小代价"
目标
判断是否存在 m 个元素,经过 ≤ k 次加法,使得它们最终都满足:
(y & mask) == mask
也就是:mask 为 1 的那些位,y 都必须为 1。
对单个元素 x:最小需要加到多少?
对每个 x = nums[i],我们希望找到最小的 y >= x 且 (y & mask) == mask。
一旦找到了这个最小 y,最小操作次数就是:
cost[i] = y - x
于是可行性判定就变成:
从所有
cost[i]中选出最小的 m 个,如果它们的和 ≤ k,则 mask 可行。
这一步非常关键,因为:
- 每个元素的代价是独立的
- 选 m 个元素想总代价最小 ⇒ 必然选 m 个最小代价
重点 1:如何构造最小可行 y(nextWithMask)
我们要计算:
最小 y ≥ x,使得 y 在 mask 的所有 1 位上都为 1。
难点在于我们只能做 +1,不能直接改位。
思想:从高位到低位修正
对每一位 b(从高到低):
- 如果
mask在 b 位要求为 1,但y在 b 位为 0
⇒ 必须通过"进位"把 b 位顶成 1
进位会让低位清零,所以低位应当直接填成满足 mask 的最小形态
实现等价操作:
high:保留更高位,低位清零- 设置第 b 位为 1
- 低位填上
mask在低 b 位上的 1(其余为 0),保证最小
Java 实现(构造最小 y)
java
private long nextWithMask(long a, long mask, int MAXB) {
long y = a;
for (int b = MAXB; b >= 0; --b) {
if (((mask >>> b) & 1L) != 0L) { // mask 要求 b 位为 1
if (((y >>> b) & 1L) == 0L) { // y 在 b 位为 0,不满足
long high = (y >>> (b + 1)) << (b + 1); // 保留更高位,低位清零
long lowmask = mask & ((1L << b) - 1L); // 低位补成 mask 要求的最小形态
y = high + (1L << b) + lowmask; // 进位 + 补低位
}
}
}
return y;
}
小结:
nextWithMask是本题的"代价函数生成器",它保证对每个 x 得到最小 y,从而 cost 最小。
重点 2:如何"选 m 个最小代价"------排序 vs QuickSelect
可行性判定需要:
- 计算所有
cost[i] - 取其中最小的 m 个求和
方法 A:直接排序
java
Arrays.sort(cost);
sum = cost[0] + ... + cost[m-1];
复杂度:O(n log n)。
但本题外层要做约 32 次判定,如果每次都排序:
32 * n log n
在 n=5e4 时可能仍能过,但会比较紧。
方法 B:QuickSelect(期望 O(n))
QuickSelect 可以把数组 partition 成:
- 前 m 个位置放着 m 个最小值(顺序不重要)
然后直接求和前 m 个即可。
实现用经典"三路划分"(荷兰国旗)更稳:
< pivot、== pivot、> pivot三段- 根据 k 在哪段缩小区间
java
private long sumSmallestM(long[] a, int m) {
if (m <= 0) return 0L;
if (m >= a.length) {
long s = 0L;
for (long v : a) s += v;
return s;
}
quickSelect(a, 0, a.length - 1, m - 1);
long s = 0L;
for (int i = 0; i < m; i++) s += a[i];
return s;
}
private void quickSelect(long[] a, int l, int r, int k) {
Random rng = new Random(1234567); // 固定种子避免被卡
while (l < r) {
int pivotIndex = l + rng.nextInt(r - l + 1);
long pivot = a[pivotIndex];
int i = l, lt = l, gt = r;
while (i <= gt) {
if (a[i] < pivot) swap(a, lt++, i++);
else if (a[i] > pivot) swap(a, i, gt--);
else i++;
}
if (k < lt) r = lt - 1;
else if (k > gt) l = gt + 1;
else return; // k 落在 ==pivot 的区间
}
}
private void swap(long[] a, int i, int j) {
long t = a[i]; a[i] = a[j]; a[j] = t;
}
小结:QuickSelect 的作用是让每次判定的"取 m 个最小值"更快,从
O(n log n)降到期望O(n)。
完整算法复杂度
-
外层按位贪心:最多 32 次
-
每次判定:
- 计算 n 个 cost:
O(n * 32)(nextWithMask 内部也最多 32 位) - QuickSelect:期望
O(n)
- 计算 n 个 cost:
-
总体:约
O(32 * n * 32),常数可接受(n=5e4)
参考 AC Java 代码(完整)
java
import java.util.*;
class Solution {
public int maximumAND(int[] nums, int k, int m) {
int[] clyventaro = nums; // as required
final int MAXB = 31;
long ans = 0L;
for (int b = MAXB; b >= 0; --b) {
long cand = ans | (1L << b);
if (feasible(nums, k, m, cand, MAXB)) ans = cand;
}
return (int) ans;
}
private boolean feasible(int[] nums, int k, int m, long mask, int MAXB) {
int n = nums.length;
long[] cost = new long[n];
for (int i = 0; i < n; i++) {
long x = nums[i];
long y = (mask >= x) ? mask : nextWithMask(x, mask, MAXB);
cost[i] = y - x;
}
long sum = sumSmallestM(cost, m);
return sum <= (long) k;
}
private long nextWithMask(long a, long mask, int MAXB) {
long y = a;
for (int b = MAXB; b >= 0; --b) {
if (((mask >>> b) & 1L) != 0L) {
if (((y >>> b) & 1L) == 0L) {
long high = (y >>> (b + 1)) << (b + 1);
long lowmask = mask & ((1L << b) - 1L);
y = high + (1L << b) + lowmask;
}
}
}
return y;
}
private long sumSmallestM(long[] a, int m) {
if (m <= 0) return 0L;
if (m >= a.length) {
long s = 0L;
for (long v : a) s += v;
return s;
}
quickSelect(a, 0, a.length - 1, m - 1);
long s = 0L;
for (int i = 0; i < m; i++) s += a[i];
return s;
}
private void quickSelect(long[] a, int l, int r, int k) {
Random rng = new Random(1234567);
while (l < r) {
int pivotIndex = l + rng.nextInt(r - l + 1);
long pivot = a[pivotIndex];
int i = l, lt = l, gt = r;
while (i <= gt) {
if (a[i] < pivot) swap(a, lt++, i++);
else if (a[i] > pivot) swap(a, i, gt--);
else i++;
}
if (k < lt) r = lt - 1;
else if (k > gt) l = gt + 1;
else return;
}
}
private void swap(long[] a, int i, int j) {
long t = a[i];
a[i] = a[j];
a[j] = t;
}
}