跳至主要內容

洗牌算法 & 蓄水池抽样算法

大约 4 分钟algorithm抽样随机排列

描述洗牌算法和随机抽样算法的底层原理和实现

一、Knuth洗牌算法(Fisher-Yates Shuffle)

算法目标:将1个 数组 或 列表 随机打乱,以等可能概率生成 数组 或 列表 的1个 随机排列
算法流程

  1. 对于1个含n个无重复元素的 数组 或 列表,对于 [0, n -1] 范围内的每个 下标为i 的元素,从下标范围 [i, n-1] 中,随机选出1个 下标为k 的元素,与 下标为i 的元素交换。
  2. 遍历 数组 或 列表,对 数组 或 列表 中的每个元素执行 步骤1
  3. 数组遍历完成,即完成 “洗牌”。

算法 正确性证明

  1. 首先,根据排列组合,对于1个含n个无重复元素的 数组 或 列表,其总排列数为n x (n - 1) x (n - 2) x … x 1 = n!,而算法目标就是从这 n! 种排列组合中,随机选出1个排列。
  2. 根据Knuth算法流程,
    1. 对于下标0,则选择范围是 [0, n-1],故有n种随机选法
    2. 对于下标1,则选择范围是 [1 , n-1],有n-1种随机选法
    3. 对于下标n-1,则选择范围是 [n-1, n-1],有1种随机选法
  3. 因此,Knuth洗牌算法可以从 n! 种选择(排列)中,随机选出1种排列,达成了算法目标。
Knuth洗牌算法,对于 含重复元素 的数组/列表同样适用,上文中以“n个无重复元素”表述,主要是为了论证 Knuth洗牌算法 的正确性 & 有效性。

LC题源LC384-打乱数组open in new window
算法实现(java):

class Solution {
    private int[] nums;
    public Solution(int[] nums) {
        this.nums = nums;
    }
    
    public int[] reset() {
        return nums;
    }
    
    public int[] shuffle() {
        Random rand = new Random();
        // 深拷贝
        int[] ans = nums.clone();
        // 对于下标i, 从 下标范围[i, n-1] 中,随机选出1个元素与nums[i]交换
        // 故,会随机产生 n x (n-1) x (n-2)x...x1 = n!种组合
        for (int i = 0; i < ans.length; i++) {
            int k = i + rand.nextInt(ans.length - i);
            swap(ans, i, k);
        }
        return ans;
    }

    private void swap(int[] arr, int i, int j) {
        int temp = arr[i];
        arr[i] = arr[j];
        arr[j] = temp;
    }
}

二、蓄水池抽样(Reservoir Sampling)算法

算法目标:从未知容量大小(N ≥ k)的样本集中,以等概率选出k个样本。

算法流程

  1. 将样本集的 前k个样本(下标范围为**[0, k-1]**的元素) 放到 ”蓄水池” 中。
  2. 第 i 个元素(i > k-1) 开始,每次随机生成1个范围在 [0, i] 间的 随机数 j
    1. j ≤ k-1,则将 下标为i 的元素 与 下标为j 的元素 交换
    2. j > k-1,不执行操作。

算法正确性证明
对于 第i个元素(i > k - 1),它被选入蓄水池的概率是 k/i;对于已在蓄水池中的元素,它被选中的概率同样是k/i
故当遍历至 第N-1个 元素时,选中样本集中各个元素的概率均为 k/N,达成了算法目标。

算法时间复杂度:O(N)
LC题源LC382-链表随机结点open in new window
算法实现(java):

import java.util.Random;

public class ReservoirSampling {
    // 从数据流中随机抽取k个样本
    public static int[] reservoirSample(int[] stream, int k) {
        int[] reservoir = new int[k];
        Random random = new Random();

        // 将前k个元素放入蓄水池
        for (int i = 0; i < k; i++) {
            reservoir[i] = stream[i];
        }

        // 处理剩余的元素
        for (int i = k; i < stream.length; i++) {
            // 生成一个范围在[0, i]之间的随机数
            // 每个元素(未在蓄水池 & 已在蓄水池中的元素)被选中的概率都是 k/i(i分之k)
            int j = random.nextInt(i + 1);
						
            // 如果随机数j落在蓄水池的范围内,则替换蓄水池中的第j个元素
            if (j < k) {
                reservoir[j] = stream[i];
            }
        }
        return reservoir;
    }
		
		
    public static void main(String[] args) {
        int[] stream = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
        int k = 5;
        int[] result = reservoirSample(stream, k);

        System.out.println("Sampled elements:");
        for (int value : result) {
            System.out.print(value + " ");
        }
    }
}