给你三个整数 a ,b 和 n ,请你返回 (a XOR x) * (b XOR x) 的 最大值 且 x 需要满足 0 <= x < 2n。
由于答案可能会很大,返回它对 109 + 7 取余 后的结果。
注意,XOR 是按位异或操作。
示例 1:
输入:a = 12, b = 5, n = 4
输出:98
解释:当 x = 2 时,(a XOR x) = 14 且 (b XOR x) = 7 。所以,(a XOR x) * (b XOR x) = 98 。
98 是所有满足 0 <= x < 2n 中 (a XOR x) * (b XOR x) 的最大值。
示例 2:
输入:a = 6, b = 7 , n = 5
输出:930
解释:当 x = 25 时,(a XOR x) = 31 且 (b XOR x) = 30 。所以,(a XOR x) * (b XOR x) = 930 。
930 是所有满足 0 <= x < 2n 中 (a XOR x) * (b XOR x) 的最大值。
示例 3:
输入:a = 1, b = 6, n = 3
输出:12
解释: 当 x = 5 时,(a XOR x) = 4 且 (b XOR x) = 3 。所以,(a XOR x) * (b XOR x) = 12 。
12 是所有满足 0 <= x < 2n 中 (a XOR x) * (b XOR x) 的最大值。
提示:
0 <= a, b < 250
0 <= n <= 50
java
class Solution {
public int maximumXorProduct(long a, long b, int n) {
if(a < b) {
a = a ^ b;
b = a ^ b;
a = a ^ b;
}
long MOD = (int)1e9+7;
long mask = (1L<<n)-1; // n个1
long ax = a & ~mask; // 没办法通过 xor x 修改的部分
long bx = b & ~mask;
a &= mask; // 保留低于n的比特位
b &= mask;
long left = a ^ b; // 一个是0 一个是1 的比特位
long one = mask ^ left; // 全为 1 或者 全为0 的比特位
ax |= one; // 异或结果一定是1先加到结果
bx |= one;
if(left > 0 && ax == bx) {
// left 的最高位给ax 其余给bx
long high_bit = 1L << (63 - Long.numberOfLeadingZeros(left));
ax |= high_bit;
left ^= high_bit;
}
bx |= left;
return (int) (ax % MOD * (bx % MOD) % MOD);
}
}