加权随机采样算法
加权随机采样算法是一种常见的算法,常出现于负载均衡,随机抽奖等方面。
这篇blog就来总结一下这些加权随机算法。
框架定义
Java
public interface Select {
/**
* return the index of the selected element
*/
int select();
}
public interface RandomSelect extends Select {
/**
* value is the weight of the element
*/
void init(int[] nums);
}
扩展数组法
由简入难
扩展数组法将权重转化为重复元素集合
假设weights={6, 4, 1, 1}
拓展数组如下图所示
在select时只需随机从数组中选择一个即可
优势:
- 采样时间复杂度O(1)
- 无需复杂计算逻辑
局限:
- 内存消耗与总权重线性相关
- 不适用大权重值场景
Java
/**
* W : sum of nums
* n : nums size
* init time : O(W)
* select time : O(1)
* memory : O(W)
*
* @author qiu
*/
public class ExpandArrayRandomSelect implements RandomSelect {
private int[] expandedList;
@Override
public void init(int[] nums) {
int sum = 0;
for (int num : nums) {
sum += num;
}
this.expandedList = new int[sum];
int index = 0;
for (int i = 0; i < nums.length; i++) {
for (int j = 0; j < nums[i]; j++) {
expandedList[index++] = i;
}
}
}
@Override
public int select() {
ThreadLocalRandom random = ThreadLocalRandom.current();
return expandedList[random.nextInt(expandedList.length)];
}
}
前缀和二分法
计算累积权重数组 生成[0,总权重)的随机数 二分查找定位区间
优势:
- 支持动态更新权重(修改后重建prefixSum)
- 适合中等规模数据集(万级元素)
局限:
- 需要维护有序结构
Java
/**
* n : nums.size
* init time : O(n)
* select time : O(log_n)
* memory : O(n)
*
* @author qiu
*/
public class PreSumRandomSelect implements RandomSelect {
private int sum;
private int[] preSum;
@Override
public void init(int[] nums) {
// 前缀和
preSum = new int[nums.length];
preSum[0] = nums[0];
for (int i = 1; i < preSum.length; i++) {
preSum[i] = preSum[i - 1] + nums[i];
}
sum = preSum[preSum.length - 1];
}
@Override
public int select() {
ThreadLocalRandom random = ThreadLocalRandom.current();
int randomNum = random.nextInt(sum) + 1;
int left = 0, right = preSum.length - 1;
while (left < right) {
int mid = left + (right - left) / 2;
if (preSum[mid] == randomNum) {
return mid;
} else if (preSum[mid] < randomNum) {
left = mid + 1;
} else {
right = mid;
}
}
return left;
}
}
二维数组法
举个例子 weights={6, 4, 1, 1}
- 取 x=weights.length=4
- y = max(weights) = 6
- 构造一个高6,宽4 的矩形,如下图所示
- 先随机选定下标,再随机选定高度,如果是玫红色区域重选。
优势:
- 引出下文别名法(AliasMethod)
局限:
- select 慢
Java
/**
* n : nums.size
* init time : O(n)
* select best time : O(1)
* select worst time : O(n)
* memory : O(n)
*/
public class RectangleRandomSelect implements RandomSelect {
int[] nums;
int max;
@Override
public void init(int[] nums) {
this.nums = nums;
for (int num : nums) {
max = Math.max(max, num);
}
}
@Override
public int select() {
int n = nums.length;
ThreadLocalRandom random = ThreadLocalRandom.current();
int randomIndex = random.nextInt(n);
int num = nums[randomIndex];
int randomNum = random.nextInt(max)+1;
while (num < randomNum) {
randomIndex = random.nextInt(n);
num = nums[randomIndex];
randomNum = random.nextInt(max)+1;
}
return randomIndex;
}
}
别名法
矩阵法,在选中玫红色时会重选,导致时间复杂度激增,如何优化呢?
别名法通过构建概率表和别名表,将权重分布转化为等概率二维坐标系的映射
优势:
- 采样时间复杂度O(1)
局限:
- 预处理需要O(n)时间
- 权重更新需要重建结构
Java
/**
* n : nums.size
* init time : O(n)
* select time : O(1)
* memory : O(n)
*
* @author qiu
*/
public class AliasMethodRandomSelect implements RandomSelect{
int sum;
int[] weights;
int[] alias;
@Override
public void init(int[] nums) {
VoseAliasInit(nums);
}
private void VoseAliasInit(int[] nums) {
int n = nums.length;
weights = new int[n];
alias = new int[n];
Deque<Integer> small = new ArrayDeque<>();
Deque<Integer> large = new ArrayDeque<>();
sum = 0;
for (int num : nums) {
sum += num;
}
for(int i=0; i<n; i++) {
weights[i] = nums[i] * n;
if(weights[i] < sum) small.add(i);
else large.add(i);
}
while(!small.isEmpty() && !large.isEmpty()) {
int l = small.pop();
int g = large.pop();
alias[l] = g;
weights[g] = (weights[g] + weights[l]) - sum;
if(weights[g] < sum) small.add(g);
else large.add(g);
}
while (!small.isEmpty()) {
int l = small.removeFirst();
weights[l] = sum;
}
while (!large.isEmpty()) {
int g = large.removeFirst();
weights[g] = sum;
}
}
@Override
public int select() {
ThreadLocalRandom localRandom = ThreadLocalRandom.current();
int i = localRandom.nextInt(weights.length);
return (localRandom.nextInt(sum) < weights[i]) ? i : alias[i];
}
}
单元测试与基准测试
单元测试
Java
class RandomSelectTest {
private static final int ITERATIONS = 1000000;
private int[] weights;
private RandomSelect[] implementations;
@BeforeEach
void setUp() {
weights = new int[]{4, 2, 1, 3, 50}; // 测试权重数组
implementations = new RandomSelect[]{
new AliasMethodRandomSelect(),
new ExpandArrayRandomSelect(),
new PreSumRandomSelect(),
new RectangleRandomSelect()
};
// 初始化所有实现
for (RandomSelect impl : implementations) {
impl.init(weights);
}
}
@Test
void testDistribution() {
for (RandomSelect impl : implementations) {
Map<Integer, Integer> distribution = new HashMap<>();
// 进行大量随机选择
for (int i = 0; i < ITERATIONS; i++) {
int selected = impl.select();
distribution.merge(selected, 1, Integer::sum);
}
// 验证分布
int sum = 0;
for (int weight : weights) {
sum += weight;
}
System.out.println(distribution);
for (int i = 0; i < weights.length; i++) {
double expectedRatio = (double) weights[i] / sum;
double actualRatio = (double) distribution.getOrDefault(i, 0) / ITERATIONS;
// 允许5%的误差
assertEquals(expectedRatio, actualRatio, 0.05,
String.format("%s distribution test failed for index %d",
impl.getClass().getSimpleName(), i));
}
}
}
}
基准测试
Java
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Benchmark)
@Fork(1)
@Warmup(iterations = 3)
@Measurement(iterations = 5)
public class RandomSelectBenchmark {
private AliasMethodRandomSelect aliasMethod;
private ExpandArrayRandomSelect expandArray;
private PreSumRandomSelect preSum;
private RectangleRandomSelect rectangle;
private int[] weights;
@Setup
public void setup() {
// 90% 4% 3% 3‰ 5‰ 2‰ 2%
weights = new int[1000];/*{9000, 400, 300, 30, 50, 20, 200}*/
for (int i = 0; i < weights.length; i++) {
weights[i] = i + 1;
}
aliasMethod = new AliasMethodRandomSelect();
expandArray = new ExpandArrayRandomSelect();
preSum = new PreSumRandomSelect();
rectangle = new RectangleRandomSelect();
aliasMethod.init(weights);
expandArray.init(weights);
preSum.init(weights);
rectangle.init(weights);
}
@Benchmark
public int aliasMethodSelect() {
return aliasMethod.select();
}
@Benchmark
public int expandArraySelect() {
return expandArray.select();
}
@Benchmark
public int preSumSelect() {
return preSum.select();
}
// @Benchmark
// 矩形算法太慢了,不测试
public int rectangleSelect() {
return rectangle.select();
}
public static void main(String[] args) throws RunnerException {
Options opt = new OptionsBuilder()
.include(RandomSelectBenchmark.class.getSimpleName())
.build();
new Runner(opt).run();
}
}
weights.size=7 weights.sum=1000
bash
Benchmark Mode Cnt Score Error Units
RandomSelectBenchmark.aliasMethodSelect avgt 5 6.551 ± 0.375 ns/op
RandomSelectBenchmark.expandArraySelect avgt 5 2.749 ± 0.085 ns/op
RandomSelectBenchmark.preSumSelect avgt 5 5.199 ± 0.232 ns/op
RandomSelectBenchmark.rectangleSelect avgt 5 33.411 ± 0.333 ns/op
rectangleSelect淘汰
weights.size=1000 weights.sum=250250
bash
Benchmark Mode Cnt Score Error Units
RandomSelectBenchmark.aliasMethodSelect avgt 5 9.014 ± 0.153 ns/op
RandomSelectBenchmark.expandArraySelect avgt 5 3.371 ± 0.139 ns/op
RandomSelectBenchmark.preSumSelect avgt 5 46.309 ± 0.589 ns/op
我们可以看到expandArraySelect
无论参数大小都非常快 但是expandArraySelect
空间复杂度是 O(W) , W:sum(weights)
我们来计算一下需要使用多少内存
- 如果W=10000 , 万分之一级别,内存消耗为40kb
- 如果W=1000000 ,百万分之一级别,内存消耗为4MB
- 再高的需求几乎没有吧?这对于Java 后端项目来说内存消耗完全可以接受
- 如果W=100000000 ,亿分之一级别,内存消耗为400MB(这不太能接受)
所以我们在选型时,百万级别一下使用拓展数组法,百万级别以上使用别名法。
为什么拓展数组法比别名法快?
为什么拓展数组法比别名法快?他们不都是O(1)吗
Alias Method 的 select() 逻辑
步骤:
- 生成 两个随机数
- 条件判断:根据随机数与 weights[i] 的比较结果,决定返回 i 或 alias[i]。
关键操作:
- 两次 nextInt() 调用(涉及随机数生成的算术运算)。
- 两次数组访问(weights[i] 和 alias[i])。
- 一次条件分支
ExpandArray 的 select() 逻辑
步骤:
- 生成 一个随机数:random.nextInt(expandedList.length)。
- 直接返回:通过随机索引从扩展数组中取值。
关键操作:
- 一次 nextInt() 调用。
- 一次数组访问(连续内存访问,缓存友好)