解析 PyTorch 中的 torch.multinomial 函数

深入解析 PyTorch 中的 torch.multinomial 函数

torch.multinomial 函数的用途:

python 复制代码
def generate_step(self, seq, temperature=1.0):
    logits = self.forward(torch.tensor([seq], dtype=torch.long).to(device))[:, -1, :]
    probs = F.softmax(logits / temperature, dim=-1)
    return torch.multinomial(probs, 1).item()

想知道 torch.multinomial 是什么,它的作用是什么?本文将以中文博客的形式,详细介绍 torch.multinomial 函数的定义、功能、使用场景,并通过示例和代码解析其工作原理,帮助读者全面理解这一工具。


1. torch.multinomial 是什么?
定义

torch.multinomial 是 PyTorch 中一个用于多项式分布采样的函数。它根据给定的概率分布,从一组离散选项中随机抽取样本。简单来说,它就像一个"加权随机选择器",概率越高的事项越有可能被选中。

官方文档

根据 PyTorch 官方文档,torch.multinomial 的签名如下:

python 复制代码
torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None)
  • input:输入张量,包含每个选项的概率或权重(非负数,不必归一化)。
  • num_samples:要抽取的样本数。
  • replacement :是否放回采样(True 表示可重复,False 表示不重复)。
  • generator:可选的随机数生成器,用于控制随机性。
  • out:可选的输出张量。
返回值
  • 返回一个张量,包含采样得到的索引(整数),形状取决于 inputnum_samples

2. torch.multinomial 的工作原理
多项式分布

torch.multinomial 基于多项式分布(Multinomial Distribution) ,这是离散概率分布的一种,用于描述在 ( n n n ) 次独立试验中,每个类别出现次数的分布。给定 ( k k k ) 个类别,每个类别有概率 ( p i p_i pi )(满足 ( ∑ p i = 1 \sum p_i = 1 ∑pi=1)),multinomial 从中采样,返回类别索引。

执行过程
  1. 输入处理

    • input 是一个张量,其值表示每个类别的权重或概率。
    • 如果值未归一化(如 [1.0, 2.0, 3.0]),函数会自动将其视为相对权重,归一化为概率(这里相当于 [1/6, 2/6, 3/6])。
  2. 采样

    • 根据权重,随机选择索引。
    • replacement=False,每个索引只能被选中一次;若 True,可重复选中。
  3. 输出

    • 返回采样得到的索引,类型为 torch.long
数学示例
  • 输入:[0.2, 0.5, 0.3](已归一化)。
  • 采样:torch.multinomial([0.2, 0.5, 0.3], 1)
  • 可能输出:
    • 0(概率20%),
    • 1(概率50%),
    • 2(概率30%)。

3. 在代码中的具体应用
上下文解析

generate_step 方法中:

python 复制代码
logits = self.forward(torch.tensor([seq], dtype=torch.long).to(device))[:, -1, :]  # [vocab_size]
probs = F.softmax(logits / temperature, dim=-1)  # [vocab_size]
return torch.multinomial(probs, 1).item()
  • logits :模型输出的原始得分,形状 [vocab_size](例如 [0.1, 0.6, 0.3])。
  • probs :通过 softmax 转换为概率,形状仍为 [vocab_size](例如 [0.231, 0.524, 0.245])。
  • temperature:调整概率分布的平滑度(值越大,越均匀;值越小,越尖锐)。
  • torch.multinomial(probs, 1) :从概率分布中采样1个索引,返回形状 [1] 的张量。
  • .item():将单值张量转为Python标量。
作用
  • 随机生成下一个token :在生成任务(如语言模型)中,multinomial 根据概率分布选择下一个词的索引,而不是简单取最大值(贪婪解码),增加了输出的多样性。
  • 示例
    • probs = [0.1, 0.6, 0.3]
    • torch.multinomial(probs, 1) 可能返回 [1](概率最高),也可能返回 [2][0]

4. 代码示例与模拟
简单示例
python 复制代码
import torch

# 定义概率分布
probs = torch.tensor([0.1, 0.6, 0.3])

# 采样1次
sample = torch.multinomial(probs, 1)
print(f"Sampled index: {sample.item()}")  # 可能是 0, 1, 2

# 采样多次
samples = torch.multinomial(probs, 3, replacement=True)
print(f"Multiple samples: {samples}")  # 可能如 [1, 2, 1]
模拟生成过程

假设词汇表 [0, 1, 2] 表示 [我, 喜欢, 是]

  • 输入probs = [0.1, 0.6, 0.3]

  • 采样torch.multinomial(probs, 1)

  • 结果

    • 60% 概率返回 1("喜欢"),
    • 30% 概率返回 2("是"),
    • 10% 概率返回 0("我")。
  • 运行多次

    python 复制代码
    for _ in range(5):
        print(torch.multinomial(probs, 1).item())

    输出可能为 [1, 1, 2, 1, 0],反映概率分布。

与温度结合
python 复制代码
logits = torch.tensor([0.1, 0.6, 0.3])
for temp in [0.5, 1.0, 2.0]:
    probs = F.softmax(logits / temp, dim=-1)
    sample = torch.multinomial(probs, 1).item()
    print(f"Temperature {temp}: probs={probs.tolist()}, sample={sample}")
  • 输出示例

    复制代码
    Temperature 0.5: probs=[0.204, 0.553, 0.243], sample=1
    Temperature 1.0: probs=[0.231, 0.524, 0.245], sample=1
    Temperature 2.0: probs=[0.282, 0.424, 0.294], sample=2
  • 观察:温度越低,分布越尖锐(偏向高概率项);温度越高,分布越平滑(更随机)。


5. multinomial 的使用场景
生成任务
  • 语言模型:生成下一个词时,增加多样性(如本文代码)。
  • 对话系统:避免重复生成最高概率的回答。
其他领域
  • 强化学习:从动作概率分布中采样。
  • 数据增强:根据权重随机选择样本。
argmax 的对比
  • argmax :取最大概率索引,确定性输出(如 1)。
  • multinomial:随机采样,引入多样性,适合探索性任务。

6. 注意事项
  • 输入要求input 必须非负,若全为0会报错。
  • 归一化:无需手动归一化,函数内部处理。
  • 性能:小规模采样效率高,大规模需考虑替代方法。

7. 总结

torch.multinomial 是一个强大的工具,用于从概率分布中随机采样索引。在生成任务中,它通过 probsnum_samples 参数,从词汇表中选择下一个token,结合温度参数(如 temperature)控制随机性。示例展示了其在语言生成中的作用:从概率 [0.1, 0.6, 0.3] 中采样,可能生成多样化的结果。希望这篇博客让你彻底理解 multinomial 的本质!如需更多探讨,欢迎提问。

后记

2025年3月2日14点27分于上海,在grok3大模型辅助下完成。

相关推荐
天天扭码14 分钟前
从图片到语音:我是如何用两大模型API打造沉浸式英语学习工具的
前端·人工智能·github
张彦峰ZYF1 小时前
从检索到生成:RAG 如何重构大模型的知识边界?
人工智能·ai·aigc
刘海东刘海东1 小时前
结构型智能科技的关键可行性——信息型智能向结构型智能的转变(修改提纲)
人工智能·算法·机器学习
2301_805054561 小时前
Python训练营打卡Day59(2025.7.3)
开发语言·python
**梯度已爆炸**1 小时前
NLP文本预处理
人工智能·深度学习·nlp
uncle_ll1 小时前
李宏毅NLP-8-语音模型
人工智能·自然语言处理·语音识别·语音模型·lm
Liudef061 小时前
FLUX.1-Kontext 高效训练 LoRA:释放大语言模型定制化潜能的完整指南
人工智能·语言模型·自然语言处理·ai作画·aigc
静心问道1 小时前
大型语言模型中的自动化思维链提示
人工智能·语言模型·大模型
万千思绪1 小时前
【PyCharm 2025.1.2配置debug】
ide·python·pycharm
众链网络2 小时前
你的Prompt还有很大提升
人工智能·prompt·ai写作·ai工具·ai智能体