洗牌算法 & 蓄水池抽样算法
大约 4 分钟
描述洗牌算法和随机抽样算法的底层原理和实现
一、Knuth洗牌算法(Fisher-Yates Shuffle)
算法目标:将1个 数组 或 列表 随机打乱,以等可能概率生成 数组 或 列表 的1个 随机排列。
算法流程:
- 对于1个含n个无重复元素的 数组 或 列表,对于 [0, n -1] 范围内的每个 下标为i 的元素,从下标范围 [i, n-1] 中,随机选出1个 下标为k 的元素,与 下标为i 的元素交换。
- 遍历 数组 或 列表,对 数组 或 列表 中的每个元素执行 步骤1。
- 数组遍历完成,即完成 “洗牌”。
算法 正确性证明:
- 首先,根据排列组合,对于1个含n个无重复元素的 数组 或 列表,其总排列数为n x (n - 1) x (n - 2) x … x 1 = n!,而算法目标就是从这 n! 种排列组合中,随机选出1个排列。
- 根据Knuth算法流程,
- 对于下标0,则选择范围是 [0, n-1],故有n种随机选法
- 对于下标1,则选择范围是 [1 , n-1],有n-1种随机选法
- …
- 对于下标n-1,则选择范围是 [n-1, n-1],有1种随机选法
- 因此,Knuth洗牌算法可以从 n! 种选择(排列)中,随机选出1种排列,达成了算法目标。
LC题源:LC384-打乱数组
算法实现(java):
class Solution {
private int[] nums;
public Solution(int[] nums) {
this.nums = nums;
}
public int[] reset() {
return nums;
}
public int[] shuffle() {
Random rand = new Random();
// 深拷贝
int[] ans = nums.clone();
// 对于下标i, 从 下标范围[i, n-1] 中,随机选出1个元素与nums[i]交换
// 故,会随机产生 n x (n-1) x (n-2)x...x1 = n!种组合
for (int i = 0; i < ans.length; i++) {
int k = i + rand.nextInt(ans.length - i);
swap(ans, i, k);
}
return ans;
}
private void swap(int[] arr, int i, int j) {
int temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
}
二、蓄水池抽样(Reservoir Sampling)算法
算法目标:从未知容量大小(N ≥ k)的样本集中,以等概率选出k个样本。
算法流程:
- 将样本集的 前k个样本(下标范围为**[0, k-1]**的元素) 放到 ”蓄水池” 中。
- 从 第 i 个元素(i > k-1) 开始,每次随机生成1个范围在 [0, i] 间的 随机数 j
- 若 j ≤ k-1,则将 下标为i 的元素 与 下标为j 的元素 交换。
- 若 j > k-1,不执行操作。
算法正确性证明:
对于 第i个元素(i > k - 1),它被选入蓄水池的概率是 k/i;对于已在蓄水池中的元素,它被选中的概率同样是k/i;
故当遍历至 第N-1个 元素时,选中样本集中各个元素的概率均为 k/N,达成了算法目标。
算法时间复杂度:O(N)
LC题源:LC382-链表随机结点
算法实现(java):
import java.util.Random;
public class ReservoirSampling {
// 从数据流中随机抽取k个样本
public static int[] reservoirSample(int[] stream, int k) {
int[] reservoir = new int[k];
Random random = new Random();
// 将前k个元素放入蓄水池
for (int i = 0; i < k; i++) {
reservoir[i] = stream[i];
}
// 处理剩余的元素
for (int i = k; i < stream.length; i++) {
// 生成一个范围在[0, i]之间的随机数
// 每个元素(未在蓄水池 & 已在蓄水池中的元素)被选中的概率都是 k/i(i分之k)
int j = random.nextInt(i + 1);
// 如果随机数j落在蓄水池的范围内,则替换蓄水池中的第j个元素
if (j < k) {
reservoir[j] = stream[i];
}
}
return reservoir;
}
public static void main(String[] args) {
int[] stream = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
int k = 5;
int[] result = reservoirSample(stream, k);
System.out.println("Sampled elements:");
for (int value : result) {
System.out.print(value + " ");
}
}
}