torch.distribution函数详解

一、torch.distributions 核心概念

torch.distributions 是 PyTorch 专门用于概率分布建模的模块,核心作用是:

  1. 生成服从特定概率分布的随机样本(如正态分布、均匀分布、二项分布等);
  2. 计算概率相关指标(如概率密度/质量函数 PDF/PMF、对数概率、累积分布函数 CDF 等);
  3. 支持自动微分(与 PyTorch 张量无缝衔接,适合深度学习中的概率建模、强化学习、变分推断等场景)。

核心术语说明

  • PDF(概率密度函数):针对连续分布(如正态分布),描述某一取值的概率密度;
  • PMF(概率质量函数):针对离散分布(如二项分布),描述某一取值的概率;
  • 样本采样:从分布中生成符合该分布规律的随机数。

二、常用分布及案例

下面通过 连续分布离散分布 两类场景,结合代码+可视化,直观展示核心用法。

前置准备
python 复制代码
import torch
import matplotlib.pyplot as plt
torch.manual_seed(42)  # 固定随机种子,结果可复现

案例1:连续分布 - 正态分布(Normal Distribution)

最常用的连续分布,也叫高斯分布,参数为 loc(均值)、scale(标准差)。

python 复制代码
# 1. 定义正态分布:均值=0,标准差=1(标准正态分布)
normal_dist = torch.distributions.Normal(loc=0.0, scale=1.0)

# 2. 核心操作
## (1)采样:生成1000个服从该分布的样本
samples = normal_dist.sample(sample_shape=(1000,))  # shape: (1000,)

## (2)计算PDF:样本对应的概率密度
pdf_vals = normal_dist.log_prob(samples).exp()  # log_prob是对数PDF,exp还原为PDF

# 3. 可视化:展示样本分布+PDF曲线
plt.figure(figsize=(10, 6))
# 绘制样本直方图(反映实际分布)
plt.hist(samples.numpy(), bins=30, density=True, alpha=0.7, label="样本直方图", color="#1f77b4")
# 绘制PDF曲线(理论分布)
x = torch.linspace(-4, 4, 1000)
pdf_curve = normal_dist.log_prob(x).exp()
plt.plot(x.numpy(), pdf_curve.numpy(), color="red", linewidth=2, label="PDF曲线")

plt.xlabel("取值")
plt.ylabel("概率密度")
plt.title("标准正态分布(Normal):样本+PDF")
plt.legend()
plt.grid(alpha=0.3)
plt.show()

# 输出关键统计量验证
print("样本均值:", samples.mean().item())  # 接近0
print("样本标准差:", samples.std().item())  # 接近1

效果说明

  • 直方图呈现"钟形",与红色PDF曲线高度吻合;
  • 样本均值≈0、标准差≈1,符合标准正态分布特征。

案例2:连续分布 - 均匀分布(Uniform Distribution)

取值在 lowhigh 之间均匀分布,参数为 low(下限)、high(上限)。

python 复制代码
# 1. 定义均匀分布:取值范围[0, 10]
uniform_dist = torch.distributions.Uniform(low=0.0, high=10.0)

# 2. 核心操作
samples = uniform_dist.sample(sample_shape=(10000,))  # 生成10000个样本(越多越均匀)
pdf_vals = uniform_dist.log_prob(samples).exp()

# 3. 可视化
plt.figure(figsize=(10, 6))
plt.hist(samples.numpy(), bins=20, density=True, alpha=0.7, label="样本直方图", color="#ff7f0e")
# 绘制PDF曲线(均匀分布的PDF是常数:1/(high-low))
x = torch.linspace(-1, 11, 1000)
pdf_curve = uniform_dist.log_prob(x).exp()
plt.plot(x.numpy(), pdf_curve.numpy(), color="red", linewidth=2, label="PDF曲线")

plt.xlabel("取值")
plt.ylabel("概率密度")
plt.title("均匀分布(Uniform):[0,10]")
plt.legend()
plt.grid(alpha=0.3)
plt.show()

print("样本均值:", samples.mean().item())  # 接近5((0+10)/2)
print("均匀分布PDF值:", pdf_vals[0].item())  # 0.1(1/(10-0))

效果说明

  • 直方图各区间高度接近,体现"均匀";
  • PDF曲线在[0,10]内为常数0.1,区间外为0。

案例3:离散分布 - 二项分布(Binomial Distribution)

描述 n 次独立伯努利试验中成功的次数,参数为 total_count(试验次数n)、probs(单次成功概率p)。

python 复制代码
# 1. 定义二项分布:10次试验,单次成功概率0.5
binomial_dist = torch.distributions.Binomial(total_count=10, probs=0.5)

# 2. 核心操作
samples = binomial_dist.sample(sample_shape=(10000,))  # 生成10000个样本(成功次数)
pmf_vals = binomial_dist.log_prob(samples).exp()  # 离散分布用PMF(概率质量)

# 3. 可视化:离散分布用柱状图
plt.figure(figsize=(10, 6))
# 统计每个成功次数的频数(归一化为概率)
counts = torch.bincount(samples.int()) / len(samples)
plt.bar(range(11), counts.numpy(), alpha=0.7, label="样本概率", color="#2ca02c")
# 绘制理论PMF
x = torch.arange(0, 11)
pmf_curve = binomial_dist.log_prob(x).exp()
plt.plot(x.numpy(), pmf_curve.numpy(), color="red", marker="o", linestyle="--", label="PMF曲线")

plt.xlabel("成功次数")
plt.ylabel("概率")
plt.title("二项分布(Binomial):n=10, p=0.5")
plt.xticks(range(11))
plt.legend()
plt.grid(alpha=0.3)
plt.show()

print("样本均值:", samples.mean().item())  # 接近5(n*p=10*0.5)

效果说明

  • 成功次数集中在5附近,符合p=0.5的二项分布特征;
  • 样本均值≈5,与理论值(n*p)一致。

案例4:离散分布 - 类别分布(Categorical Distribution)

描述离散类别取值的概率(如分类任务的标签分布),参数为 probs(各类别概率)或 logits(未归一化的对数概率)。

python 复制代码
# 1. 定义类别分布:3个类别,概率分别为[0.2, 0.5, 0.3]
probs = torch.tensor([0.2, 0.5, 0.3])
categorical_dist = torch.distributions.Categorical(probs=probs)

# 2. 核心操作
samples = categorical_dist.sample(sample_shape=(10000,))  # 生成10000个样本(0/1/2)
log_prob_vals = categorical_dist.log_prob(samples)  # 对数PMF

# 3. 可视化
plt.figure(figsize=(8, 5))
counts = torch.bincount(samples) / len(samples)
plt.bar(["类别0", "类别1", "类别2"], counts.numpy(), alpha=0.7, color=["#1f77b4", "#ff7f0e", "#2ca02c"])
plt.axhline(y=0.2, color="red", linestyle="--", label="类别0理论概率")
plt.axhline(y=0.5, color="orange", linestyle="--", label="类别1理论概率")
plt.axhline(y=0.3, color="green", linestyle="--", label="类别2理论概率")

plt.ylabel("概率")
plt.title("类别分布(Categorical):[0.2, 0.5, 0.3]")
plt.legend()
plt.grid(alpha=0.3)
plt.show()

print("各类别样本占比:", counts.numpy())  # 接近[0.2, 0.5, 0.3]

效果说明

  • 类别1的样本占比最高(≈0.5),类别0最低(≈0.2),与设定概率一致;
  • 是分类任务中"标签分布""预测分布"的核心建模工具。

案例5:实战场景 - 分布的KL散度计算

KL散度(Kullback-Leibler Divergence)衡量两个分布的差异,torch.distributions 内置 kl_divergence 函数。

python 复制代码
# 定义两个正态分布
dist1 = torch.distributions.Normal(loc=0.0, scale=1.0)  # 标准正态
dist2 = torch.distributions.Normal(loc=1.0, scale=2.0)  # 均值1,标准差2

# 计算KL散度(dist1 || dist2)
kl_div = torch.distributions.kl_divergence(dist1, dist2)
print("KL散度(dist1 || dist2):", kl_div.item())  # 输出约0.8132

# 物理意义:值越大,两个分布差异越大;值为0则分布完全相同

三、核心API总结

操作 方法 说明
定义分布 dist = DistributionClass(参数) Normal(loc, scale)Binomial(total_count, probs)
采样 dist.sample(sample_shape) 生成指定形状的样本,支持批量
概率计算 dist.log_prob(x) 计算x的对数PDF/PMF(数值更稳定)
概率计算 dist.log_prob(x).exp() 还原为原始PDF/PMF值
分布差异 kl_divergence(dist1, dist2) 计算两个分布的KL散度

总结

  1. 核心作用torch.distributions 是PyTorch中概率建模的核心模块,支持连续/离散分布的采样、概率计算,且兼容自动微分;
  2. 关键用法
    • 连续分布(Normal/Uniform):关注PDF和样本分布的连续性;
    • 离散分布(Binomial/Categorical):关注PMF和类别/次数的离散性;
  3. 实战价值:广泛用于深度学习的概率模型(如VAE)、强化学习(动作采样)、分类任务(标签分布)等场景,是连接"概率"和"张量计算"的桥梁。
相关推荐
盼小辉丶2 天前
PyTorch实战(30)——使用TorchScript和ONNX导出通用PyTorch模型
人工智能·pytorch·深度学习·模型部署
封奚泽优2 天前
使用mmdetection项目进行训练记录
pytorch·python·cuda·mmdetection·mmcv
tony3652 天前
pytorch分布式训练解释
人工智能·pytorch·分布式
weixin_贾2 天前
深度学习基础理论与 PyTorch 实战 —— 从传统机器学习到前沿模型全攻略
pytorch·深度学习·机器学习
大连好光景2 天前
PyTorch深度学习----优化器
pytorch·深度学习·学习
多恩Stone3 天前
【3D-AICG 系列-11】Trellis 2 的 Shape VAE 训练流程梳理
人工智能·pytorch·算法·3d·aigc
隔壁大炮4 天前
08. PyTorch_张量基本创建方式
人工智能·pytorch·python
隔壁大炮4 天前
07. PyTorch框架简介
人工智能·pytorch·python
大鹏的NLP博客4 天前
Rust + PyTorch 实现 BGE 向量检索系统
人工智能·pytorch·rust
勾股导航6 天前
蚁群优化算法
人工智能·pytorch·python