加权随机采样算法

加权随机采样算法

加权随机采样算法是一种常见的算法,常出现于负载均衡,随机抽奖等方面。

这篇blog就来总结一下这些加权随机算法。

框架定义

Java 复制代码
public interface Select {
    /**
     * return the index of the selected element
     */
    int select();
}
public interface RandomSelect extends Select {
    /**
     * value is the weight of the element
     */
    void init(int[] nums);
}

扩展数组法

由简入难

扩展数组法将权重转化为重复元素集合

假设weights={6, 4, 1, 1}

拓展数组如下图所示

在select时只需随机从数组中选择一个即可

优势:

  • 采样时间复杂度O(1)
  • 无需复杂计算逻辑

局限:

  • 内存消耗与总权重线性相关
  • 不适用大权重值场景
Java 复制代码
/**
 * W : sum of nums
 * n : nums size
 * init time : O(W)
 * select time : O(1)
 * memory : O(W)
 *
 * @author qiu
 */
public class ExpandArrayRandomSelect implements RandomSelect {
    private int[] expandedList;

    @Override
    public void init(int[] nums) {
        int sum = 0;
        for (int num : nums) {
            sum += num;
        }
        this.expandedList = new int[sum];
        int index = 0;
        for (int i = 0; i < nums.length; i++) {
            for (int j = 0; j < nums[i]; j++) {
                expandedList[index++] = i;
            }
        }
    }

    @Override
    public int select() {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        return expandedList[random.nextInt(expandedList.length)];
    }
}

前缀和二分法

计算累积权重数组 生成[0,总权重)的随机数 二分查找定位区间

优势:

  • 支持动态更新权重(修改后重建prefixSum)
  • 适合中等规模数据集(万级元素)

局限:

  • 需要维护有序结构
Java 复制代码
/**
 * n : nums.size
 * init time : O(n)
 * select time : O(log_n)
 * memory : O(n)
 *
 * @author qiu
 */
public class PreSumRandomSelect implements RandomSelect {
    private int sum;
    private int[] preSum;

    @Override
    public void init(int[] nums) {
        // 前缀和
        preSum = new int[nums.length];
        preSum[0] = nums[0];
        for (int i = 1; i < preSum.length; i++) {
            preSum[i] = preSum[i - 1] + nums[i];
        }
        sum = preSum[preSum.length - 1];
    }

    @Override
    public int select() {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        int randomNum = random.nextInt(sum) + 1;
        int left = 0, right = preSum.length - 1;
        while (left < right) {
            int mid = left + (right - left) / 2;
            if (preSum[mid] == randomNum) {
                return mid;
            } else if (preSum[mid] < randomNum) {
                left = mid + 1;
            } else {
                right = mid;
            }
        }
        return left;
    }
}

二维数组法

举个例子 weights={6, 4, 1, 1}

  • 取 x=weights.length=4
  • y = max(weights) = 6
  • 构造一个高6,宽4 的矩形,如下图所示
  • 先随机选定下标,再随机选定高度,如果是玫红色区域重选。

优势:

  • 引出下文别名法(AliasMethod)

局限:

  • select 慢
Java 复制代码
/**
 * n : nums.size
 * init time : O(n)
 * select best  time : O(1)
 * select worst time : O(n)
 * memory : O(n)
 */
public class RectangleRandomSelect implements RandomSelect {
    int[] nums;
    int max;

    @Override
    public void init(int[] nums) {
        this.nums = nums;
        for (int num : nums) {
            max = Math.max(max, num);
        }
    }

    @Override
    public int select() {
        int n = nums.length;
        ThreadLocalRandom random = ThreadLocalRandom.current();
        int randomIndex = random.nextInt(n);
        int num = nums[randomIndex];
        int randomNum = random.nextInt(max)+1;
        while (num < randomNum) {
            randomIndex = random.nextInt(n);
            num = nums[randomIndex];
            randomNum = random.nextInt(max)+1;
        }
        return randomIndex;
    }
}

别名法

矩阵法,在选中玫红色时会重选,导致时间复杂度激增,如何优化呢?

别名法通过构建概率表和别名表,将权重分布转化为等概率二维坐标系的映射

优势:

  • 采样时间复杂度O(1)

局限:

  • 预处理需要O(n)时间
  • 权重更新需要重建结构
Java 复制代码
/**
 * n : nums.size
 * init time : O(n)
 * select time : O(1)
 * memory : O(n)
 *
 * @author qiu
 */
public class AliasMethodRandomSelect implements RandomSelect{
    int sum;
    int[] weights;
    int[] alias;

    @Override
    public void init(int[] nums) {
        VoseAliasInit(nums);
    }

    private void VoseAliasInit(int[] nums) {
        int n = nums.length;
        weights = new int[n];
        alias = new int[n];

        Deque<Integer> small = new ArrayDeque<>();
        Deque<Integer> large = new ArrayDeque<>();
        sum = 0;
        for (int num : nums) {
            sum += num;
        }

        for(int i=0; i<n; i++) {
            weights[i] = nums[i] * n;
            if(weights[i] < sum) small.add(i);
            else large.add(i);
        }

        while(!small.isEmpty() && !large.isEmpty()) {
            int l = small.pop();
            int g = large.pop();
            alias[l] = g;
            weights[g] = (weights[g] + weights[l]) - sum;
            if(weights[g] < sum) small.add(g);
            else large.add(g);
        }

        while (!small.isEmpty()) {
            int l = small.removeFirst();
            weights[l] = sum;
        }
        while (!large.isEmpty()) {
            int g = large.removeFirst();
            weights[g] = sum;
        }
    }

    @Override
    public int select() {
        ThreadLocalRandom localRandom = ThreadLocalRandom.current();
        int i = localRandom.nextInt(weights.length);
        return (localRandom.nextInt(sum) < weights[i]) ? i : alias[i];
    }
}

单元测试与基准测试

单元测试

Java 复制代码
class RandomSelectTest {
    private static final int ITERATIONS = 1000000;
    private int[] weights;
    private RandomSelect[] implementations;

    @BeforeEach
    void setUp() {
        weights = new int[]{4, 2, 1, 3, 50}; // 测试权重数组
        implementations = new RandomSelect[]{
                new AliasMethodRandomSelect(),
                new ExpandArrayRandomSelect(),
                new PreSumRandomSelect(),
                new RectangleRandomSelect()
        };

        // 初始化所有实现
        for (RandomSelect impl : implementations) {
            impl.init(weights);
        }
    }

    @Test
    void testDistribution() {
        for (RandomSelect impl : implementations) {
            Map<Integer, Integer> distribution = new HashMap<>();

            // 进行大量随机选择
            for (int i = 0; i < ITERATIONS; i++) {
                int selected = impl.select();
                distribution.merge(selected, 1, Integer::sum);
            }

            // 验证分布
            int sum = 0;
            for (int weight : weights) {
                sum += weight;
            }
            System.out.println(distribution);
            for (int i = 0; i < weights.length; i++) {
                double expectedRatio = (double) weights[i] / sum;
                double actualRatio = (double) distribution.getOrDefault(i, 0) / ITERATIONS;
                // 允许5%的误差
                assertEquals(expectedRatio, actualRatio, 0.05,
                        String.format("%s distribution test failed for index %d",
                                impl.getClass().getSimpleName(), i));
            }
        }
    }
}

基准测试

Java 复制代码
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Benchmark)
@Fork(1)
@Warmup(iterations = 3)
@Measurement(iterations = 5)
public class RandomSelectBenchmark {
    private AliasMethodRandomSelect aliasMethod;
    private ExpandArrayRandomSelect expandArray;
    private PreSumRandomSelect preSum;
    private RectangleRandomSelect rectangle;
    private int[] weights;

    @Setup
    public void setup() {
        // 90% 4% 3% 3‰ 5‰ 2‰ 2%
        weights = new int[1000];/*{9000, 400, 300, 30, 50, 20, 200}*/
        for (int i = 0; i < weights.length; i++) {
            weights[i] = i + 1;
        }

        aliasMethod = new AliasMethodRandomSelect();
        expandArray = new ExpandArrayRandomSelect();
        preSum = new PreSumRandomSelect();
        rectangle = new RectangleRandomSelect();

        aliasMethod.init(weights);
        expandArray.init(weights);
        preSum.init(weights);
        rectangle.init(weights);
    }

    @Benchmark
    public int aliasMethodSelect() {
        return aliasMethod.select();
    }

    @Benchmark
    public int expandArraySelect() {
        return expandArray.select();
    }

    @Benchmark
    public int preSumSelect() {
        return preSum.select();
    }

    // @Benchmark
    // 矩形算法太慢了,不测试
    public int rectangleSelect() {
        return rectangle.select();
    }

    public static void main(String[] args) throws RunnerException {
        Options opt = new OptionsBuilder()
                .include(RandomSelectBenchmark.class.getSimpleName())
                .build();
        new Runner(opt).run();
    }
}

weights.size=7 weights.sum=1000

bash 复制代码
Benchmark                                Mode  Cnt   Score   Error  Units
RandomSelectBenchmark.aliasMethodSelect  avgt    5   6.551 ± 0.375  ns/op
RandomSelectBenchmark.expandArraySelect  avgt    5   2.749 ± 0.085  ns/op
RandomSelectBenchmark.preSumSelect       avgt    5   5.199 ± 0.232  ns/op
RandomSelectBenchmark.rectangleSelect    avgt    5  33.411 ± 0.333  ns/op

rectangleSelect淘汰

weights.size=1000 weights.sum=250250

bash 复制代码
Benchmark                                Mode  Cnt   Score   Error  Units
RandomSelectBenchmark.aliasMethodSelect  avgt    5   9.014 ± 0.153  ns/op
RandomSelectBenchmark.expandArraySelect  avgt    5   3.371 ± 0.139  ns/op
RandomSelectBenchmark.preSumSelect       avgt    5  46.309 ± 0.589  ns/op

我们可以看到expandArraySelect无论参数大小都非常快 但是expandArraySelect空间复杂度是 O(W) , W:sum(weights)

我们来计算一下需要使用多少内存

  • 如果W=10000 , 万分之一级别,内存消耗为40kb
  • 如果W=1000000 ,百万分之一级别,内存消耗为4MB
  • 再高的需求几乎没有吧?这对于Java 后端项目来说内存消耗完全可以接受
  • 如果W=100000000 ,亿分之一级别,内存消耗为400MB(这不太能接受)

所以我们在选型时,百万级别一下使用拓展数组法,百万级别以上使用别名法。

为什么拓展数组法比别名法快?

为什么拓展数组法比别名法快?他们不都是O(1)吗

Alias Method 的 select() 逻辑

步骤:

  • 生成 两个随机数
  • 条件判断:根据随机数与 weights[i] 的比较结果,决定返回 i 或 alias[i]。

关键操作:

  • 两次 nextInt() 调用(涉及随机数生成的算术运算)。
  • 两次数组访问(weights[i] 和 alias[i])。
  • 一次条件分支

ExpandArray 的 select() 逻辑

步骤:

  • 生成 一个随机数:random.nextInt(expandedList.length)。
  • 直接返回:通过随机索引从扩展数组中取值。

关键操作:

  • 一次 nextInt() 调用。
  • 一次数组访问(连续内存访问,缓存友好)
相关推荐
m0_7482548814 分钟前
SpringBoot整合MQTT最详细版(亲测有效)
java·spring boot·后端
uhakadotcom20 分钟前
Kubernetes入门指南:从基础到实践
后端·面试·github
Monika Zhang20 分钟前
Maven 简介及其核心概念
java·maven
用户10005229303928 分钟前
Django DRF API 单元测试完整方案(基于 `TestCase`)
后端
时光呢35 分钟前
JAVA泛型的作用
java·windows·python
<但凡.1 小时前
C++修炼:内存管理
c++·算法
Asthenia04121 小时前
Redis面试复盘:从连接到扩容与数据定位的极致详解(含Java RedisTemplate交互)
后端
tpoog1 小时前
[贪心算法]买卖股票的最佳时机 && 买卖股票的最佳时机Ⅱ && K次取反后最大化的数组和 && 按身高排序 && 优势洗牌(田忌赛马)
算法·贪心算法
不7夜宵1 小时前
dockerSDK-Go语言实现
开发语言·后端·golang
uhakadotcom1 小时前
Scikit-learn 安装和使用教程
后端·面试·github