20260113-np.random.multinomial 与 torch.multinomial

np.random.multinomial vs torch.multinomial

np.random.multinomial 与 torch.multinomial 详解

一、多项分布基础

多项分布是二项分布的推广,描述的是在进行 n 次独立试验时,每个 k 个可能结果出现次数的概率分布。其概率质量函数为:

P(X1=x1,...,Xk=xk)=n!x1!...xk!⋅p1x1...pkxkP(X_1=x_1,...,X_k=x_k) = \frac{n!}{x_1!...x_k!} \cdot p_1^{x_1}...p_k^{x_k}P(X1=x1,...,Xk=xk)=x1!...xk!n!⋅p1x1...pkxk

其中:

  • n:试验总次数
  • p :各结果的概率向量(∑pi=1\sum p_i = 1∑pi=1)
  • x :各结果出现次数的向量(∑xi=n\sum x_i = n∑xi=n)

二、np.random.multinomial(NumPy)

1. 原理

从多项分布中一次性抽取多个独立样本 ,返回的是计数结果(各类别出现的次数)。

2. 函数签名

python 复制代码
np.random.multinomial(n, pvals, size=None)

3. 参数详解

参数 类型 说明
n int 每次试验的抽取次数(如掷骰子次数)
pvals array_like 概率向量,长度 k 表示类别数,总和必须为 1
size int/tuple 输出形状,表示进行多少次独立的 n 次试验

4. 返回值

  • 形状size 形状 + (len(pvals),)
  • 含义 :最后一个维度表示 每个类别被抽中的次数

5. 示例

python 复制代码
import numpy as np

# 定义一个6面骰子的概率分布(不均匀)
probs = [0.1, 0.2, 0.3, 0.2, 0.15, 0.05]

# 掷骰子10次,记录每个面出现的次数
result = np.random.multinomial(10, probs)
print(result)  
# 可能输出: array([0, 3, 4, 2, 1, 0])  # 6个面的出现次数,总和为10

# 进行5次独立的实验,每次掷骰子10次
results = np.random.multinomial(10, probs, size=5)
print(results)
# 可能输出:
# [[1 2 3 2 2 0]
#  [0 1 5 2 2 0]
#  [0 3 4 1 1 1]
#  [3 1 2 2 1 1]
#  [0 2 4 3 0 1]]

三、torch.multinomial(PyTorch)

1. 原理

从类别分布中进行多次有放回抽样 ,返回的是样本的索引(类别编号),而非计数。

2. 函数签名

python 复制代码
torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None)

3. 参数详解

参数 类型 说明
input Tensor 概率矩阵,形状 (N, C)(C,)
num_samples int 每个分布要抽取的样本数
replacement bool 是否允许重复抽样(有放回/无放回)
generator torch.Generator 可选的随机数生成器

4. 返回值

  • 形状 :与 input 相同的前导维度 + (num_samples,)
  • 含义 :抽取的样本类别 索引

5. 示例

python 复制代码
import torch

# 单个概率分布
probs = torch.tensor([0.1, 0.2, 0.3, 0.2, 0.15, 0.05])
samples = torch.multinomial(probs, num_samples=10, replacement=True)
print(samples)  
# 可能输出: tensor([2, 1, 2, 2, 3, 2, 4, 1, 2, 1])  # 类别索引

# 批量概率分布 (3个不同的分布)
batch_probs = torch.tensor([
    [0.1, 0.2, 0.3, 0.4],  # 分布1
    [0.5, 0.3, 0.2, 0.0],  # 分布2
    [0.0, 0.0, 0.5, 0.5]   # 分布3
])

# 每个分布抽取5个样本
batch_samples = torch.multinomial(batch_probs, num_samples=5, replacement=True)
print(batch_samples)
# 可能输出:
# tensor([[2, 3, 3, 2, 1],
#         [0, 0, 1, 0, 2],
#         [2, 3, 2, 3, 2]])

# 无放回抽样 (num_samples不能超过类别数)
samples_no_replace = torch.multinomial(probs, num_samples=3, replacement=False)
print(samples_no_replace)  # 输出如: tensor([2, 1, 3]),不会重复

四、核心区别对比

特性 np.random.multinomial torch.multinomial
返回内容 计数:各类别被抽中的次数 索引:被抽中类别的编号
抽样方式 隐含有放回 明确指定 replacement
输入格式 一维概率向量 支持批量输入(可多维)
典型用途 统计模拟、文档生成 强化学习采样、批量决策
梯度支持 不支持 支持(可集成在计算图中)
性能 CPU计算 GPU加速(支持CUDA)

五、注意事项

  1. 概率归一化 :两个函数都要求概率总和为1(在容差范围内,但是概率传输时并不严格要求总和为1,内部会自己处理
  2. 数值稳定性:极小的概率值可能导致采样偏差,建议设置最小阈值
  3. 随机种子 :NumPy用 np.random.seed(),PyTorch用 torch.manual_seed()
  4. 无放回限制torch.multinomialreplacement=False 时,num_samples 不能超过类别数
相关推荐
智航GIS2 小时前
11.6 Pandas数据处理进阶:缺失值处理与数据类型转换完全指南
python·pandas
小希smallxi2 小时前
Java 程序调用 FFmpeg 教程
java·python·ffmpeg
学习的学习者2 小时前
CS课程项目设计22:基于Transformer的智能机器翻译算法
人工智能·python·深度学习·transformer·机器翻译
小陈phd2 小时前
langGraph从入门到精通(四)——基于LangGraph的State状态模式设计
python·microsoft·状态模式
3824278272 小时前
JS正则表达式实战:核心语法解析
开发语言·前端·javascript·python·html
Engineer邓祥浩2 小时前
设计模式学习(10) 23-8 装饰者模式
python·学习·设计模式
ybdesire3 小时前
Joern服务器启动后cpgqls-client结合python编程进行扫描
运维·服务器·python
autho3 小时前
conda
linux·python·conda
知乎的哥廷根数学学派3 小时前
基于注意力机制的多尺度脉冲神经网络旋转机械故障诊断(西储大学轴承数据,Pytorch)
人工智能·pytorch·python·深度学习·神经网络·机器学习