解决的问题:从N个元素中随机选择m个元素,每个元素被选中的概率相等
1、算法概述
-
蓄水池算法:
- 解决的问题:从N个元素中随机选择m个元素,每个元素被选中的概率相等
-
概率计算详解
- 概率计算的核心在于,对于数据流中第 i 个元素,证明它最终留在蓄水池的概率是 m/N(N为数据流总长度)。证明时需区分两种情况。
- 1)对于前m个元素(i ≤ m)
- 这些元素一开始就被放入蓄水池。它们最终被保留下来,需要在后续所有替换操作中都不被换出。
- 从第 m+1个元素开始,到第 N个元素,每个元素 j都可能带来替换。
- 对于第 j个元素(j > m),它不会替换掉我们关注的第 i 个元素的概率是多少?这需要两个独立事件同时发生:
- 第 j个元素被选中(概率为 m/j)且它恰好替换了第 i个元素(概率为 1/m)。
- 所以,第 i个元素被替换的概率是 (m/j) * (1/m) = 1/j。那么,不被替换的概率就是 1 - 1/j = (j-1)/j。
- 因此,第 i个元素需要连续闯过从 j=m+1到 j=N的所有关卡,其最终被保留的概率是所有关卡概率的乘积:
- P(i最终保留) = 1 * [m/(m+1)] * [(m+1)/(m+2)] * ... * [(N-1)/N] = m/N
- 这个过程中,分子分母相继抵消,最终结果简化为 m/N。
- 2)对于第m个之后的元素(i > m)
- 这类元素需要先被选入蓄水池,然后还要在后续的替换中存活下来。
- 被选入蓄水池的概率:当处理到第 i个元素时,它被选中进入蓄水池的概率就是 m/i。
- 入池后不被后续元素替换的概率:一旦被选入,它需要承受从第 i+1到第 N个元素的考验。利用上面的逻辑,从 j=i+1到 j=N,它不被替换的概率链为:
-
i/(i+1)\] \* \[(i+1)/(i+2)\] \* ... \* \[(N-1)/N\] = i/N
- P(i最终保留) = (被选入的概率) * (入后不被替换的概率) = (m/i) * (i/N) = m/N。
-
蓄水池算法步骤:
- 1),对于i<=m,直接放入蓄水池
- 2),对于i>m,以 m/i的概率(代码中为从(0,i]生成一个数,如果小于m,代表概率为m/i)决定是否放入蓄水池,
- 若放入,则等概率随机(从(0,m]随机生成一个数)替换池中的一个元素.
-
算法特性与应用
- 公平性保证:无论数据何时到达,每个元素被抽中的概率都是 m/N,确保了抽样的绝对公平。
- 空间效率:只需维护大小为 m的蓄水池,空间复杂度是 O(m),非常适合处理海量数据流。
- 经典应用:该算法非常适合在线抽奖系统。例如,在一天内为所有登录用户进行抽奖,无需等到所有用户数据收集完毕,即可在任意时刻为当前已登录用户提供公平的中奖概率。
2、算法实现
java
/**
* 蓄水池采样器
*/
public static class ReservoirSample {
// 水池数组
private final int[] reservoir;
// 水池大小
private final int m;
// 已处理元素数量
private int count;
public ReservoirSample(int m) {
this.reservoir = new int[m];
this.m = m;
this.count = 0;
}
/**
* 随机选择(0,max]之间的一个数
*/
private int rand(int max) {
return (int) (Math.random() * max) + 1;
}
/**
* 向蓄水池添加一个元素
*/
public void add(int i) {
count++;
// i <=m 直接放入
if (i <= m) {
reservoir[count - 1] = i;
} else {
// i>m 以m/i的概率决定是否放入蓄水池
// 调用rand代表随机选择了数,如果选择的数<=m,代表要放入
if (rand(count) <= m) {
// 要放入,随机选择一个位置替换
reservoir[rand(m) - 1] = i;
}
}
}
/**
* 获取蓄水池中的元素
*/
public int[] getReservoir() {
int[] res = new int[m];
for (int i = 0; i < m; i++) {
res[i] = reservoir[i];
}
return res;
}
}
整体代码和应用测试:
java
/**
* 蓄水池算法:
* 解决的问题:从N个元素中随机选择m个元素,每个元素被选中的概率相等
* <br>
* 概率计算详解
* 概率计算的核心在于,对于数据流中第 i 个元素,证明它最终留在蓄水池的概率是 m/N(N为数据流总长度)。证明时需区分两种情况。
* 1)对于前m个元素(i ≤ m)
* 这些元素一开始就被放入蓄水池。它们最终被保留下来,需要在后续所有替换操作中都不被换出。
* 从第 m+1个元素开始,到第 N个元素,每个元素 j都可能带来替换。
* 对于第 j个元素(j > m),它不会替换掉我们关注的第 i 个元素的概率是多少?这需要两个独立事件同时发生:
* 第 j个元素被选中(概率为 m/j)且它恰好替换了第 i个元素(概率为 1/m)。
* 所以,第 i个元素被替换的概率是 (m/j) * (1/m) = 1/j。那么,不被替换的概率就是 1 - 1/j = (j-1)/j。
* 因此,第 i个元素需要连续闯过从 j=m+1到 j=N的所有关卡,其最终被保留的概率是所有关卡概率的乘积:
* P(i最终保留) = 1 * [m/(m+1)] * [(m+1)/(m+2)] * ... * [(N-1)/N] = m/N
* 这个过程中,分子分母相继抵消,最终结果简化为 m/N。
* 2)对于第m个之后的元素(i > m)
* 这类元素需要先被选入蓄水池,然后还要在后续的替换中存活下来。
* 被选入蓄水池的概率:当处理到第 i个元素时,它被选中进入蓄水池的概率就是 m/i。
* 入池后不被后续元素替换的概率:一旦被选入,它需要承受从第 i+1到第 N个元素的考验。利用上面的逻辑,从 j=i+1到 j=N,它不被替换的概率链为:
* [i/(i+1)] * [(i+1)/(i+2)] * ... * [(N-1)/N] = i/N
* 因此,第 i个元素最终被选中的概率是:
* P(i最终保留) = (被选入的概率) * (入后不被替换的概率) = (m/i) * (i/N) = m/N。
* <br>
* 蓄水池算法步骤:
* 1),对于i<=m,直接放入蓄水池
* 2),对于i>m,以 m/i的概率(代码中为从(0,i]生成一个数,如果小于m,代表概率为m/i)决定是否放入蓄水池,
* 若放入,则等概率随机(从(0,m]随机生成一个数)替换池中的一个元素.
* <br>
* 算法特性与应用
* 公平性保证:无论数据何时到达,每个元素被抽中的概率都是 m/N,确保了抽样的绝对公平。
* 空间效率:只需维护大小为 m的蓄水池,空间复杂度是 O(m),非常适合处理海量数据流。
* 经典应用:该算法非常适合在线抽奖系统。例如,在一天内为所有登录用户进行抽奖,无需等到所有用户数据收集完毕,即可在任意时刻为当前已登录用户提供公平的中奖概率。
*/
public class Reservoir {
/**
* 蓄水池采样器
*/
public static class ReservoirSample {
// 水池数组
private final int[] reservoir;
// 水池大小
private final int m;
// 已处理元素数量
private int count;
public ReservoirSample(int m) {
this.reservoir = new int[m];
this.m = m;
this.count = 0;
}
/**
* 随机选择(0,max]之间的一个数
*/
private int rand(int max) {
return (int) (Math.random() * max) + 1;
}
/**
* 向蓄水池添加一个元素
*/
public void add(int i) {
count++;
// i <=m 直接放入
if (i <= m) {
reservoir[count - 1] = i;
} else {
// i>m 以m/i的概率决定是否放入蓄水池
// 调用rand代表随机选择了数,如果选择的数<=m,代表要放入
if (rand(count) <= m) {
// 要放入,随机选择一个位置替换
reservoir[rand(m) - 1] = i;
}
}
}
/**
* 获取蓄水池中的元素
*/
public int[] getReservoir() {
int[] res = new int[m];
for (int i = 0; i < m; i++) {
res[i] = reservoir[i];
}
return res;
}
}
public static void main(String[] args) {
int all = 100;
int choose = 10;
int testTimes = 500000;
int[] counts = new int[all + 1];
// 测试testTimes次,统计每个数选择的次数
for (int i = 0; i < testTimes; i++) {
ReservoirSample box = new ReservoirSample(choose);
for (int num = 1; num <= all; num++) {
box.add(num);
}
int[] ans = box.getReservoir();
for (int j = 0; j < ans.length; j++) {
counts[ans[j]]++;
}
}
// 打印出每个数选择的次数
for (int i = 0; i < counts.length; i++) {
System.out.println(i + " times : " + counts[i]);
}
}
}
后记
个人学习总结笔记,不能保证非常详细,轻喷