Alias O(1)采样算法

参考

算法解析

  1. 假设游戏抽奖: 1星:80权重, 2星:60权重, 3星:40权重, 4星:20权重。
  2. 根据权重值,计算他们出现的概率。得到概率数组:distribution,同时概率对应的值数组为:values
    1
    2
    3
    let weights = [80, 60, 40, 20] // 权重数组
    let distribution = [0.4, 0.3, 0.2, 0.1] // 概率数组
    let values = [1, 2, 3, 4] // 对应值数组
  3. 根据概率数组,制表得prob,alisa 数组。
  4. 根据数组随机

代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
/*****************************
*@file: AliasMethod
*@author: 陈吕唐
*@desc: Alias 采样算法 初始化O(n) 调用O(1)
*@date: 2025-03-21 17:44
*****************************/
export default class AliasMethod {


/**
* 概率数组 存放第i列事件的概率
*/
private prob: number[];

/**
* 别名数组 存放第i列不是事件i的另外一个事件的标号
*/
private alias: number[];


private values: any[];
/**
*
* @param distribution 概率分布权重数组
* @param values 对应数值数组
*/
constructor(distribution: number[], values: any[] = []) {
if (values.length <= 0) {
values = distribution;
}


let n = distribution.length;

this.prob = new Array(n);
this.alias = new Array(n);
this.values = values;

/**
* 权重和
*/
let sum = distribution.reduce((a, b) => a + b);

/**
* 概率数组
*/
let normalizedDistribution = distribution.map(p => p / sum);

/**
* 柱状表
*/
normalizedDistribution = normalizedDistribution.map(p => p * n);

/**
* 概率小于1的下标
*/
let small: number[] = [];

/**
* 概率大等于1的下标
*/
let large: number[] = [];

/**
* 初始化下标数组
*/
for (let i = 0; i < normalizedDistribution.length; i++) {
if (normalizedDistribution[i] < 1) {
small.push(i);
} else {
large.push(i);
}
}

//制表
while (small.length > 0 && large.length > 0) {
/**
* 概率小于1的下标
*/
let smallIndex = small.pop();
/**
* 概率大于等于1的下标
*/
let largeIndex = large.pop();

this.prob[smallIndex] = normalizedDistribution[smallIndex];
this.alias[smallIndex] = largeIndex;

normalizedDistribution[largeIndex] = normalizedDistribution[largeIndex] + normalizedDistribution[smallIndex] - 1;

if (normalizedDistribution[largeIndex] < 1) {
small.push(largeIndex);
}

if (normalizedDistribution[largeIndex] >= 1) {
large.push(largeIndex);
}
}

while (large.length > 0) {
this.prob[large.pop()] = 1;
}

while (small.length > 0) {
this.prob[small.pop()] = 1;
}
}

/**
*
* @returns 随机值
*/
public sample(): number {
let n = this.prob.length;
let idx = Math.floor(Math.random() * n);
return Math.random() < this.prob[idx] ? this.values[idx] : this.values[this.alias[idx]];
}
}