仅用于个人学习
js归一化方法
function softmax(logits) {
// 第1步:找到 logits 数组中的最大值(对应公式里的 max(x))
const maxLogit = Math.max(...logits);
// 第2步:对每个元素,先减最大值,再求指数(对应公式分子的 exp(x_i - max(x)))
const exps = logits.map(x => Math.exp(x - maxLogit));
// 第3步:求所有指数值的总和(对应公式分母的 sum(exp(x_j - max(x))))
const sumExps = exps.reduce((a, b) => a + b, 0);
// 第4步:每个指数值除以总和,得到概率(对应公式的 分子/分母,完成归一化)
return exps.map(exp => exp / sumExps);
}
举例
我们不用长度 1000 的数组,用简单的 logits = new Float32Array([2, 4, 1]) 来演示整个流程:
- 第 1 步:找最大值
maxLogit = Math.max(2,4,1) = 4 - 第 2 步:减最大值 + 求指数
- 第 1 个元素:
Math.exp(2-4) = Math.exp(-2) ≈ 0.135 - 第 2 个元素:
Math.exp(4-4) = Math.exp(0) = 1 - 第 3 个元素:
Math.exp(1-4) = Math.exp(-3) ≈ 0.0498
所以exps ≈ [0.135, 1, 0.0498]
- 第 1 个元素:
- 第 3 步:求指数和
sumExps ≈ 0.135 + 1 + 0.0498 ≈ 1.1848 - 第 4 步:归一化求概率
- 第 1 个概率:
0.135 / 1.1848 ≈ 0.114(11.4%) - 第 2 个概率:
1 / 1.1848 ≈ 0.844(84.4%) - 第 3 个概率:
0.0498 / 1.1848 ≈ 0.042(4.2%)
最终概率数组:[0.114, 0.844, 0.042],总和≈1,符合要求。
- 第 1 个概率:
总结
这个 softmax 函数的流程可以浓缩为 4 步,每一步都和公式严格对应:
1、 找最大值(防溢出的关键)
2、 每个元素 "减最大值 + 求指数"(得到分子原始值)
3、 求所有指数值的总和(得到分母)
4、 每个指数值除以总和(归一化,得到概率)
最终输出的概率数组,既能反映原始 logits 的 "大小关系"(原始值越大,概率越高),又能满足 "0~1 区间、总和为 1" 的概率特性,方便后续的分类判断。