20260113-np.random.multinomial 与 torch.multinomial

np.random.multinomial vs torch.multinomial

np.random.multinomial 与 torch.multinomial 详解

一、多项分布基础

多项分布是二项分布的推广,描述的是在进行 n 次独立试验时,每个 k 个可能结果出现次数的概率分布。其概率质量函数为:

P(X1=x1,...,Xk=xk)=n!x1!...xk!⋅p1x1...pkxkP(X_1=x_1,...,X_k=x_k) = \frac{n!}{x_1!...x_k!} \cdot p_1^{x_1}...p_k^{x_k}P(X1=x1,...,Xk=xk)=x1!...xk!n!⋅p1x1...pkxk

其中:

  • n:试验总次数
  • p :各结果的概率向量(∑pi=1\sum p_i = 1∑pi=1)
  • x :各结果出现次数的向量(∑xi=n\sum x_i = n∑xi=n)

二、np.random.multinomial(NumPy)

1. 原理

从多项分布中一次性抽取多个独立样本 ,返回的是计数结果(各类别出现的次数)。

2. 函数签名

python 复制代码
np.random.multinomial(n, pvals, size=None)

3. 参数详解

参数 类型 说明
n int 每次试验的抽取次数(如掷骰子次数)
pvals array_like 概率向量,长度 k 表示类别数,总和必须为 1
size int/tuple 输出形状,表示进行多少次独立的 n 次试验

4. 返回值

  • 形状size 形状 + (len(pvals),)
  • 含义 :最后一个维度表示 每个类别被抽中的次数

5. 示例

python 复制代码
import numpy as np

# 定义一个6面骰子的概率分布(不均匀)
probs = [0.1, 0.2, 0.3, 0.2, 0.15, 0.05]

# 掷骰子10次,记录每个面出现的次数
result = np.random.multinomial(10, probs)
print(result)  
# 可能输出: array([0, 3, 4, 2, 1, 0])  # 6个面的出现次数,总和为10

# 进行5次独立的实验,每次掷骰子10次
results = np.random.multinomial(10, probs, size=5)
print(results)
# 可能输出:
# [[1 2 3 2 2 0]
#  [0 1 5 2 2 0]
#  [0 3 4 1 1 1]
#  [3 1 2 2 1 1]
#  [0 2 4 3 0 1]]

三、torch.multinomial(PyTorch)

1. 原理

从类别分布中进行多次有放回抽样 ,返回的是样本的索引(类别编号),而非计数。

2. 函数签名

python 复制代码
torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None)

3. 参数详解

参数 类型 说明
input Tensor 概率矩阵,形状 (N, C)(C,)
num_samples int 每个分布要抽取的样本数
replacement bool 是否允许重复抽样(有放回/无放回)
generator torch.Generator 可选的随机数生成器

4. 返回值

  • 形状 :与 input 相同的前导维度 + (num_samples,)
  • 含义 :抽取的样本类别 索引

5. 示例

python 复制代码
import torch

# 单个概率分布
probs = torch.tensor([0.1, 0.2, 0.3, 0.2, 0.15, 0.05])
samples = torch.multinomial(probs, num_samples=10, replacement=True)
print(samples)  
# 可能输出: tensor([2, 1, 2, 2, 3, 2, 4, 1, 2, 1])  # 类别索引

# 批量概率分布 (3个不同的分布)
batch_probs = torch.tensor([
    [0.1, 0.2, 0.3, 0.4],  # 分布1
    [0.5, 0.3, 0.2, 0.0],  # 分布2
    [0.0, 0.0, 0.5, 0.5]   # 分布3
])

# 每个分布抽取5个样本
batch_samples = torch.multinomial(batch_probs, num_samples=5, replacement=True)
print(batch_samples)
# 可能输出:
# tensor([[2, 3, 3, 2, 1],
#         [0, 0, 1, 0, 2],
#         [2, 3, 2, 3, 2]])

# 无放回抽样 (num_samples不能超过类别数)
samples_no_replace = torch.multinomial(probs, num_samples=3, replacement=False)
print(samples_no_replace)  # 输出如: tensor([2, 1, 3]),不会重复

四、核心区别对比

特性 np.random.multinomial torch.multinomial
返回内容 计数:各类别被抽中的次数 索引:被抽中类别的编号
抽样方式 隐含有放回 明确指定 replacement
输入格式 一维概率向量 支持批量输入(可多维)
典型用途 统计模拟、文档生成 强化学习采样、批量决策
梯度支持 不支持 支持(可集成在计算图中)
性能 CPU计算 GPU加速(支持CUDA)

五、注意事项

  1. 概率归一化 :两个函数都要求概率总和为1(在容差范围内,但是概率传输时并不严格要求总和为1,内部会自己处理
  2. 数值稳定性:极小的概率值可能导致采样偏差,建议设置最小阈值
  3. 随机种子 :NumPy用 np.random.seed(),PyTorch用 torch.manual_seed()
  4. 无放回限制torch.multinomialreplacement=False 时,num_samples 不能超过类别数
相关推荐
你好潘先生5 小时前
别再记命令了,用 yeero do 说句人话就能跑脚本,而且不烧 token
服务器·python·命令行
Agent_大师5 小时前
WebSocket 行情重连成功,K线缺口不会自动消失
python
荣码5 小时前
LLM结构化输出:让AI返回JSON而不是废话,我踩了4个坑
java·python
copyer_xyf5 小时前
FastAPI 如何连接 MySQL
后端·python
apocelipes19 小时前
常用编程语言和库的正则表达式性能对比
c语言·c++·python·性能优化·golang·开发工具和环境
用户83562907805121 小时前
使用 Python 在 PDF 中创建与管理书签
后端·python
MeixianAgent1 天前
Python 回测数据入口怎么验?历史 K 线入库前先做 5 个检查
后端·python
咕白m6251 天前
用 Python 实现一键批量查找与替换 Excel 数据
后端·python
SelectDB2 天前
Apache Doris Python UDF:让 SQL 直接调用 Python 生态,支撑 Agent 时代复杂业务逻辑
大数据·数据库·python
荣码2 天前
GraphRAG:普通RAG只能回答"点"的问题,我踩了4个坑才搞懂
java·python