解析 PyTorch 中的 torch.multinomial 函数

深入解析 PyTorch 中的 torch.multinomial 函数

torch.multinomial 函数的用途:

python 复制代码
def generate_step(self, seq, temperature=1.0):
    logits = self.forward(torch.tensor([seq], dtype=torch.long).to(device))[:, -1, :]
    probs = F.softmax(logits / temperature, dim=-1)
    return torch.multinomial(probs, 1).item()

想知道 torch.multinomial 是什么,它的作用是什么?本文将以中文博客的形式,详细介绍 torch.multinomial 函数的定义、功能、使用场景,并通过示例和代码解析其工作原理,帮助读者全面理解这一工具。


1. torch.multinomial 是什么?
定义

torch.multinomial 是 PyTorch 中一个用于多项式分布采样的函数。它根据给定的概率分布,从一组离散选项中随机抽取样本。简单来说,它就像一个"加权随机选择器",概率越高的事项越有可能被选中。

官方文档

根据 PyTorch 官方文档,torch.multinomial 的签名如下:

python 复制代码
torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None)
  • input:输入张量,包含每个选项的概率或权重(非负数,不必归一化)。
  • num_samples:要抽取的样本数。
  • replacement :是否放回采样(True 表示可重复,False 表示不重复)。
  • generator:可选的随机数生成器,用于控制随机性。
  • out:可选的输出张量。
返回值
  • 返回一个张量,包含采样得到的索引(整数),形状取决于 inputnum_samples

2. torch.multinomial 的工作原理
多项式分布

torch.multinomial 基于多项式分布(Multinomial Distribution) ,这是离散概率分布的一种,用于描述在 ( n n n ) 次独立试验中,每个类别出现次数的分布。给定 ( k k k ) 个类别,每个类别有概率 ( p i p_i pi )(满足 ( ∑ p i = 1 \sum p_i = 1 ∑pi=1)),multinomial 从中采样,返回类别索引。

执行过程
  1. 输入处理

    • input 是一个张量,其值表示每个类别的权重或概率。
    • 如果值未归一化(如 [1.0, 2.0, 3.0]),函数会自动将其视为相对权重,归一化为概率(这里相当于 [1/6, 2/6, 3/6])。
  2. 采样

    • 根据权重,随机选择索引。
    • replacement=False,每个索引只能被选中一次;若 True,可重复选中。
  3. 输出

    • 返回采样得到的索引,类型为 torch.long
数学示例
  • 输入:[0.2, 0.5, 0.3](已归一化)。
  • 采样:torch.multinomial([0.2, 0.5, 0.3], 1)
  • 可能输出:
    • 0(概率20%),
    • 1(概率50%),
    • 2(概率30%)。

3. 在代码中的具体应用
上下文解析

generate_step 方法中:

python 复制代码
logits = self.forward(torch.tensor([seq], dtype=torch.long).to(device))[:, -1, :]  # [vocab_size]
probs = F.softmax(logits / temperature, dim=-1)  # [vocab_size]
return torch.multinomial(probs, 1).item()
  • logits :模型输出的原始得分,形状 [vocab_size](例如 [0.1, 0.6, 0.3])。
  • probs :通过 softmax 转换为概率,形状仍为 [vocab_size](例如 [0.231, 0.524, 0.245])。
  • temperature:调整概率分布的平滑度(值越大,越均匀;值越小,越尖锐)。
  • torch.multinomial(probs, 1) :从概率分布中采样1个索引,返回形状 [1] 的张量。
  • .item():将单值张量转为Python标量。
作用
  • 随机生成下一个token :在生成任务(如语言模型)中,multinomial 根据概率分布选择下一个词的索引,而不是简单取最大值(贪婪解码),增加了输出的多样性。
  • 示例
    • probs = [0.1, 0.6, 0.3]
    • torch.multinomial(probs, 1) 可能返回 [1](概率最高),也可能返回 [2][0]

4. 代码示例与模拟
简单示例
python 复制代码
import torch

# 定义概率分布
probs = torch.tensor([0.1, 0.6, 0.3])

# 采样1次
sample = torch.multinomial(probs, 1)
print(f"Sampled index: {sample.item()}")  # 可能是 0, 1, 2

# 采样多次
samples = torch.multinomial(probs, 3, replacement=True)
print(f"Multiple samples: {samples}")  # 可能如 [1, 2, 1]
模拟生成过程

假设词汇表 [0, 1, 2] 表示 [我, 喜欢, 是]

  • 输入probs = [0.1, 0.6, 0.3]

  • 采样torch.multinomial(probs, 1)

  • 结果

    • 60% 概率返回 1("喜欢"),
    • 30% 概率返回 2("是"),
    • 10% 概率返回 0("我")。
  • 运行多次

    python 复制代码
    for _ in range(5):
        print(torch.multinomial(probs, 1).item())

    输出可能为 [1, 1, 2, 1, 0],反映概率分布。

与温度结合
python 复制代码
logits = torch.tensor([0.1, 0.6, 0.3])
for temp in [0.5, 1.0, 2.0]:
    probs = F.softmax(logits / temp, dim=-1)
    sample = torch.multinomial(probs, 1).item()
    print(f"Temperature {temp}: probs={probs.tolist()}, sample={sample}")
  • 输出示例

    复制代码
    Temperature 0.5: probs=[0.204, 0.553, 0.243], sample=1
    Temperature 1.0: probs=[0.231, 0.524, 0.245], sample=1
    Temperature 2.0: probs=[0.282, 0.424, 0.294], sample=2
  • 观察:温度越低,分布越尖锐(偏向高概率项);温度越高,分布越平滑(更随机)。


5. multinomial 的使用场景
生成任务
  • 语言模型:生成下一个词时,增加多样性(如本文代码)。
  • 对话系统:避免重复生成最高概率的回答。
其他领域
  • 强化学习:从动作概率分布中采样。
  • 数据增强:根据权重随机选择样本。
argmax 的对比
  • argmax :取最大概率索引,确定性输出(如 1)。
  • multinomial:随机采样,引入多样性,适合探索性任务。

6. 注意事项
  • 输入要求input 必须非负,若全为0会报错。
  • 归一化:无需手动归一化,函数内部处理。
  • 性能:小规模采样效率高,大规模需考虑替代方法。

7. 总结

torch.multinomial 是一个强大的工具,用于从概率分布中随机采样索引。在生成任务中,它通过 probsnum_samples 参数,从词汇表中选择下一个token,结合温度参数(如 temperature)控制随机性。示例展示了其在语言生成中的作用:从概率 [0.1, 0.6, 0.3] 中采样,可能生成多样化的结果。希望这篇博客让你彻底理解 multinomial 的本质!如需更多探讨,欢迎提问。

后记

2025年3月2日14点27分于上海,在grok3大模型辅助下完成。

相关推荐
apcipot_rain1 小时前
【应用密码学】实验五 公钥密码2——ECC
前端·数据库·python
yu4106211 小时前
2025年中期大语言模型实力深度剖析
人工智能·语言模型·自然语言处理
小彭律师1 小时前
门禁人脸识别系统详细技术文档
笔记·python
鸿业远图科技2 小时前
分式注记种表达方式arcgis
python·arcgis
别让别人觉得你做不到3 小时前
Python(1) 做一个随机数的游戏
python
feng995203 小时前
技术伦理双轨认证如何重构AI工程师能力评估体系——基于AAIA框架的技术解析与行业实证研究
人工智能·aaif·aaia·iaaai
2301_776681654 小时前
【用「概率思维」重新理解生活】
开发语言·人工智能·自然语言处理
蜡笔小新..4 小时前
从零开始:用PyTorch构建CIFAR-10图像分类模型达到接近1的准确率
人工智能·pytorch·机器学习·分类·cifar-10
小彭律师4 小时前
人脸识别门禁系统技术文档
python
富唯智能4 小时前
转运机器人可以绕障吗?
人工智能·智能机器人·转运机器人