torch.distribution函数详解

一、torch.distributions 核心概念

torch.distributions 是 PyTorch 专门用于概率分布建模的模块,核心作用是:

  1. 生成服从特定概率分布的随机样本(如正态分布、均匀分布、二项分布等);
  2. 计算概率相关指标(如概率密度/质量函数 PDF/PMF、对数概率、累积分布函数 CDF 等);
  3. 支持自动微分(与 PyTorch 张量无缝衔接,适合深度学习中的概率建模、强化学习、变分推断等场景)。

核心术语说明

  • PDF(概率密度函数):针对连续分布(如正态分布),描述某一取值的概率密度;
  • PMF(概率质量函数):针对离散分布(如二项分布),描述某一取值的概率;
  • 样本采样:从分布中生成符合该分布规律的随机数。

二、常用分布及案例

下面通过 连续分布离散分布 两类场景,结合代码+可视化,直观展示核心用法。

前置准备
python 复制代码
import torch
import matplotlib.pyplot as plt
torch.manual_seed(42)  # 固定随机种子,结果可复现

案例1:连续分布 - 正态分布(Normal Distribution)

最常用的连续分布,也叫高斯分布,参数为 loc(均值)、scale(标准差)。

python 复制代码
# 1. 定义正态分布:均值=0,标准差=1(标准正态分布)
normal_dist = torch.distributions.Normal(loc=0.0, scale=1.0)

# 2. 核心操作
## (1)采样:生成1000个服从该分布的样本
samples = normal_dist.sample(sample_shape=(1000,))  # shape: (1000,)

## (2)计算PDF:样本对应的概率密度
pdf_vals = normal_dist.log_prob(samples).exp()  # log_prob是对数PDF,exp还原为PDF

# 3. 可视化:展示样本分布+PDF曲线
plt.figure(figsize=(10, 6))
# 绘制样本直方图(反映实际分布)
plt.hist(samples.numpy(), bins=30, density=True, alpha=0.7, label="样本直方图", color="#1f77b4")
# 绘制PDF曲线(理论分布)
x = torch.linspace(-4, 4, 1000)
pdf_curve = normal_dist.log_prob(x).exp()
plt.plot(x.numpy(), pdf_curve.numpy(), color="red", linewidth=2, label="PDF曲线")

plt.xlabel("取值")
plt.ylabel("概率密度")
plt.title("标准正态分布(Normal):样本+PDF")
plt.legend()
plt.grid(alpha=0.3)
plt.show()

# 输出关键统计量验证
print("样本均值:", samples.mean().item())  # 接近0
print("样本标准差:", samples.std().item())  # 接近1

效果说明

  • 直方图呈现"钟形",与红色PDF曲线高度吻合;
  • 样本均值≈0、标准差≈1,符合标准正态分布特征。

案例2:连续分布 - 均匀分布(Uniform Distribution)

取值在 lowhigh 之间均匀分布,参数为 low(下限)、high(上限)。

python 复制代码
# 1. 定义均匀分布:取值范围[0, 10]
uniform_dist = torch.distributions.Uniform(low=0.0, high=10.0)

# 2. 核心操作
samples = uniform_dist.sample(sample_shape=(10000,))  # 生成10000个样本(越多越均匀)
pdf_vals = uniform_dist.log_prob(samples).exp()

# 3. 可视化
plt.figure(figsize=(10, 6))
plt.hist(samples.numpy(), bins=20, density=True, alpha=0.7, label="样本直方图", color="#ff7f0e")
# 绘制PDF曲线(均匀分布的PDF是常数:1/(high-low))
x = torch.linspace(-1, 11, 1000)
pdf_curve = uniform_dist.log_prob(x).exp()
plt.plot(x.numpy(), pdf_curve.numpy(), color="red", linewidth=2, label="PDF曲线")

plt.xlabel("取值")
plt.ylabel("概率密度")
plt.title("均匀分布(Uniform):[0,10]")
plt.legend()
plt.grid(alpha=0.3)
plt.show()

print("样本均值:", samples.mean().item())  # 接近5((0+10)/2)
print("均匀分布PDF值:", pdf_vals[0].item())  # 0.1(1/(10-0))

效果说明

  • 直方图各区间高度接近,体现"均匀";
  • PDF曲线在[0,10]内为常数0.1,区间外为0。

案例3:离散分布 - 二项分布(Binomial Distribution)

描述 n 次独立伯努利试验中成功的次数,参数为 total_count(试验次数n)、probs(单次成功概率p)。

python 复制代码
# 1. 定义二项分布:10次试验,单次成功概率0.5
binomial_dist = torch.distributions.Binomial(total_count=10, probs=0.5)

# 2. 核心操作
samples = binomial_dist.sample(sample_shape=(10000,))  # 生成10000个样本(成功次数)
pmf_vals = binomial_dist.log_prob(samples).exp()  # 离散分布用PMF(概率质量)

# 3. 可视化:离散分布用柱状图
plt.figure(figsize=(10, 6))
# 统计每个成功次数的频数(归一化为概率)
counts = torch.bincount(samples.int()) / len(samples)
plt.bar(range(11), counts.numpy(), alpha=0.7, label="样本概率", color="#2ca02c")
# 绘制理论PMF
x = torch.arange(0, 11)
pmf_curve = binomial_dist.log_prob(x).exp()
plt.plot(x.numpy(), pmf_curve.numpy(), color="red", marker="o", linestyle="--", label="PMF曲线")

plt.xlabel("成功次数")
plt.ylabel("概率")
plt.title("二项分布(Binomial):n=10, p=0.5")
plt.xticks(range(11))
plt.legend()
plt.grid(alpha=0.3)
plt.show()

print("样本均值:", samples.mean().item())  # 接近5(n*p=10*0.5)

效果说明

  • 成功次数集中在5附近,符合p=0.5的二项分布特征;
  • 样本均值≈5,与理论值(n*p)一致。

案例4:离散分布 - 类别分布(Categorical Distribution)

描述离散类别取值的概率(如分类任务的标签分布),参数为 probs(各类别概率)或 logits(未归一化的对数概率)。

python 复制代码
# 1. 定义类别分布:3个类别,概率分别为[0.2, 0.5, 0.3]
probs = torch.tensor([0.2, 0.5, 0.3])
categorical_dist = torch.distributions.Categorical(probs=probs)

# 2. 核心操作
samples = categorical_dist.sample(sample_shape=(10000,))  # 生成10000个样本(0/1/2)
log_prob_vals = categorical_dist.log_prob(samples)  # 对数PMF

# 3. 可视化
plt.figure(figsize=(8, 5))
counts = torch.bincount(samples) / len(samples)
plt.bar(["类别0", "类别1", "类别2"], counts.numpy(), alpha=0.7, color=["#1f77b4", "#ff7f0e", "#2ca02c"])
plt.axhline(y=0.2, color="red", linestyle="--", label="类别0理论概率")
plt.axhline(y=0.5, color="orange", linestyle="--", label="类别1理论概率")
plt.axhline(y=0.3, color="green", linestyle="--", label="类别2理论概率")

plt.ylabel("概率")
plt.title("类别分布(Categorical):[0.2, 0.5, 0.3]")
plt.legend()
plt.grid(alpha=0.3)
plt.show()

print("各类别样本占比:", counts.numpy())  # 接近[0.2, 0.5, 0.3]

效果说明

  • 类别1的样本占比最高(≈0.5),类别0最低(≈0.2),与设定概率一致;
  • 是分类任务中"标签分布""预测分布"的核心建模工具。

案例5:实战场景 - 分布的KL散度计算

KL散度(Kullback-Leibler Divergence)衡量两个分布的差异,torch.distributions 内置 kl_divergence 函数。

python 复制代码
# 定义两个正态分布
dist1 = torch.distributions.Normal(loc=0.0, scale=1.0)  # 标准正态
dist2 = torch.distributions.Normal(loc=1.0, scale=2.0)  # 均值1,标准差2

# 计算KL散度(dist1 || dist2)
kl_div = torch.distributions.kl_divergence(dist1, dist2)
print("KL散度(dist1 || dist2):", kl_div.item())  # 输出约0.8132

# 物理意义:值越大,两个分布差异越大;值为0则分布完全相同

三、核心API总结

操作 方法 说明
定义分布 dist = DistributionClass(参数) Normal(loc, scale)Binomial(total_count, probs)
采样 dist.sample(sample_shape) 生成指定形状的样本,支持批量
概率计算 dist.log_prob(x) 计算x的对数PDF/PMF(数值更稳定)
概率计算 dist.log_prob(x).exp() 还原为原始PDF/PMF值
分布差异 kl_divergence(dist1, dist2) 计算两个分布的KL散度

总结

  1. 核心作用torch.distributions 是PyTorch中概率建模的核心模块,支持连续/离散分布的采样、概率计算,且兼容自动微分;
  2. 关键用法
    • 连续分布(Normal/Uniform):关注PDF和样本分布的连续性;
    • 离散分布(Binomial/Categorical):关注PMF和类别/次数的离散性;
  3. 实战价值:广泛用于深度学习的概率模型(如VAE)、强化学习(动作采样)、分类任务(标签分布)等场景,是连接"概率"和"张量计算"的桥梁。
相关推荐
退休钓鱼选手5 小时前
[ Pytorch教程 ] 神经网络的基本骨架 torch.nn -Neural Network
pytorch·深度学习·神经网络
DeniuHe5 小时前
用 PyTorch 库创建了一个随机张量,并演示了多种张量取整和分解操作
pytorch
Network_Engineer10 小时前
从零手写LSTM:从门控原理到PyTorch源码级实现
人工智能·pytorch·lstm
多恩Stone12 小时前
【3D-AICG 系列-1】Trellis v1 和 Trellis v2 的区别和改进
人工智能·pytorch·python·算法·3d·aigc
2501_9011478314 小时前
PyTorch DDP官方文档学习笔记(核心干货版)
pytorch·笔记·学习·算法·面试
铁手飞鹰14 小时前
[深度学习]常用的库与操作
人工智能·pytorch·python·深度学习·numpy·scikit-learn·matplotlib
青春不朽51214 小时前
PyTorch 入门指南:深度学习的瑞士军刀
人工智能·pytorch·深度学习
DeniuHe15 小时前
Pytorch中统计学相关的函数
pytorch·python·深度学习
林深现海1 天前
【刘二大人】PyTorch深度学习实践笔记 —— 第四集:反向传播(凝练版)
pytorch·python·numpy