深入解析 PyTorch 中的 torch.multinomial
函数
torch.multinomial
函数的用途:
python
def generate_step(self, seq, temperature=1.0):
logits = self.forward(torch.tensor([seq], dtype=torch.long).to(device))[:, -1, :]
probs = F.softmax(logits / temperature, dim=-1)
return torch.multinomial(probs, 1).item()
想知道 torch.multinomial
是什么,它的作用是什么?本文将以中文博客的形式,详细介绍 torch.multinomial
函数的定义、功能、使用场景,并通过示例和代码解析其工作原理,帮助读者全面理解这一工具。
1. torch.multinomial
是什么?
定义
torch.multinomial
是 PyTorch 中一个用于多项式分布采样的函数。它根据给定的概率分布,从一组离散选项中随机抽取样本。简单来说,它就像一个"加权随机选择器",概率越高的事项越有可能被选中。
官方文档
根据 PyTorch 官方文档,torch.multinomial
的签名如下:
python
torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None)
input
:输入张量,包含每个选项的概率或权重(非负数,不必归一化)。num_samples
:要抽取的样本数。replacement
:是否放回采样(True
表示可重复,False
表示不重复)。generator
:可选的随机数生成器,用于控制随机性。out
:可选的输出张量。
返回值
- 返回一个张量,包含采样得到的索引(整数),形状取决于
input
和num_samples
。
2. torch.multinomial
的工作原理
多项式分布
torch.multinomial
基于多项式分布(Multinomial Distribution) ,这是离散概率分布的一种,用于描述在 ( n n n ) 次独立试验中,每个类别出现次数的分布。给定 ( k k k ) 个类别,每个类别有概率 ( p i p_i pi )(满足 ( ∑ p i = 1 \sum p_i = 1 ∑pi=1)),multinomial
从中采样,返回类别索引。
执行过程
-
输入处理:
input
是一个张量,其值表示每个类别的权重或概率。- 如果值未归一化(如
[1.0, 2.0, 3.0]
),函数会自动将其视为相对权重,归一化为概率(这里相当于[1/6, 2/6, 3/6]
)。
-
采样:
- 根据权重,随机选择索引。
- 若
replacement=False
,每个索引只能被选中一次;若True
,可重复选中。
-
输出:
- 返回采样得到的索引,类型为
torch.long
。
- 返回采样得到的索引,类型为
数学示例
- 输入:
[0.2, 0.5, 0.3]
(已归一化)。 - 采样:
torch.multinomial([0.2, 0.5, 0.3], 1)
。 - 可能输出:
0
(概率20%),1
(概率50%),2
(概率30%)。
3. 在代码中的具体应用
上下文解析
在 generate_step
方法中:
python
logits = self.forward(torch.tensor([seq], dtype=torch.long).to(device))[:, -1, :] # [vocab_size]
probs = F.softmax(logits / temperature, dim=-1) # [vocab_size]
return torch.multinomial(probs, 1).item()
logits
:模型输出的原始得分,形状[vocab_size]
(例如[0.1, 0.6, 0.3]
)。probs
:通过softmax
转换为概率,形状仍为[vocab_size]
(例如[0.231, 0.524, 0.245]
)。temperature
:调整概率分布的平滑度(值越大,越均匀;值越小,越尖锐)。torch.multinomial(probs, 1)
:从概率分布中采样1个索引,返回形状[1]
的张量。.item()
:将单值张量转为Python标量。
作用
- 随机生成下一个token :在生成任务(如语言模型)中,
multinomial
根据概率分布选择下一个词的索引,而不是简单取最大值(贪婪解码),增加了输出的多样性。 - 示例 :
probs = [0.1, 0.6, 0.3]
。torch.multinomial(probs, 1)
可能返回[1]
(概率最高),也可能返回[2]
或[0]
。
4. 代码示例与模拟
简单示例
python
import torch
# 定义概率分布
probs = torch.tensor([0.1, 0.6, 0.3])
# 采样1次
sample = torch.multinomial(probs, 1)
print(f"Sampled index: {sample.item()}") # 可能是 0, 1, 2
# 采样多次
samples = torch.multinomial(probs, 3, replacement=True)
print(f"Multiple samples: {samples}") # 可能如 [1, 2, 1]
模拟生成过程
假设词汇表 [0, 1, 2]
表示 [我, 喜欢, 是]
:
-
输入 :
probs = [0.1, 0.6, 0.3]
。 -
采样 :
torch.multinomial(probs, 1)
。 -
结果 :
- 60% 概率返回
1
("喜欢"), - 30% 概率返回
2
("是"), - 10% 概率返回
0
("我")。
- 60% 概率返回
-
运行多次 :
pythonfor _ in range(5): print(torch.multinomial(probs, 1).item())
输出可能为
[1, 1, 2, 1, 0]
,反映概率分布。
与温度结合
python
logits = torch.tensor([0.1, 0.6, 0.3])
for temp in [0.5, 1.0, 2.0]:
probs = F.softmax(logits / temp, dim=-1)
sample = torch.multinomial(probs, 1).item()
print(f"Temperature {temp}: probs={probs.tolist()}, sample={sample}")
-
输出示例 :
Temperature 0.5: probs=[0.204, 0.553, 0.243], sample=1 Temperature 1.0: probs=[0.231, 0.524, 0.245], sample=1 Temperature 2.0: probs=[0.282, 0.424, 0.294], sample=2
-
观察:温度越低,分布越尖锐(偏向高概率项);温度越高,分布越平滑(更随机)。
5. multinomial
的使用场景
生成任务
- 语言模型:生成下一个词时,增加多样性(如本文代码)。
- 对话系统:避免重复生成最高概率的回答。
其他领域
- 强化学习:从动作概率分布中采样。
- 数据增强:根据权重随机选择样本。
与 argmax
的对比
argmax
:取最大概率索引,确定性输出(如1
)。multinomial
:随机采样,引入多样性,适合探索性任务。
6. 注意事项
- 输入要求 :
input
必须非负,若全为0会报错。 - 归一化:无需手动归一化,函数内部处理。
- 性能:小规模采样效率高,大规模需考虑替代方法。
7. 总结
torch.multinomial
是一个强大的工具,用于从概率分布中随机采样索引。在生成任务中,它通过 probs
和 num_samples
参数,从词汇表中选择下一个token,结合温度参数(如 temperature
)控制随机性。示例展示了其在语言生成中的作用:从概率 [0.1, 0.6, 0.3]
中采样,可能生成多样化的结果。希望这篇博客让你彻底理解 multinomial
的本质!如需更多探讨,欢迎提问。
后记
2025年3月2日14点27分于上海,在grok3大模型辅助下完成。