一、torch.distributions 核心概念
torch.distributions 是 PyTorch 专门用于概率分布建模的模块,核心作用是:
- 生成服从特定概率分布的随机样本(如正态分布、均匀分布、二项分布等);
- 计算概率相关指标(如概率密度/质量函数 PDF/PMF、对数概率、累积分布函数 CDF 等);
- 支持自动微分(与 PyTorch 张量无缝衔接,适合深度学习中的概率建模、强化学习、变分推断等场景)。
核心术语说明:
- PDF(概率密度函数):针对连续分布(如正态分布),描述某一取值的概率密度;
- PMF(概率质量函数):针对离散分布(如二项分布),描述某一取值的概率;
- 样本采样:从分布中生成符合该分布规律的随机数。
二、常用分布及案例
下面通过 连续分布 和 离散分布 两类场景,结合代码+可视化,直观展示核心用法。
前置准备
python
import torch
import matplotlib.pyplot as plt
torch.manual_seed(42) # 固定随机种子,结果可复现
案例1:连续分布 - 正态分布(Normal Distribution)
最常用的连续分布,也叫高斯分布,参数为 loc(均值)、scale(标准差)。
python
# 1. 定义正态分布:均值=0,标准差=1(标准正态分布)
normal_dist = torch.distributions.Normal(loc=0.0, scale=1.0)
# 2. 核心操作
## (1)采样:生成1000个服从该分布的样本
samples = normal_dist.sample(sample_shape=(1000,)) # shape: (1000,)
## (2)计算PDF:样本对应的概率密度
pdf_vals = normal_dist.log_prob(samples).exp() # log_prob是对数PDF,exp还原为PDF
# 3. 可视化:展示样本分布+PDF曲线
plt.figure(figsize=(10, 6))
# 绘制样本直方图(反映实际分布)
plt.hist(samples.numpy(), bins=30, density=True, alpha=0.7, label="样本直方图", color="#1f77b4")
# 绘制PDF曲线(理论分布)
x = torch.linspace(-4, 4, 1000)
pdf_curve = normal_dist.log_prob(x).exp()
plt.plot(x.numpy(), pdf_curve.numpy(), color="red", linewidth=2, label="PDF曲线")
plt.xlabel("取值")
plt.ylabel("概率密度")
plt.title("标准正态分布(Normal):样本+PDF")
plt.legend()
plt.grid(alpha=0.3)
plt.show()
# 输出关键统计量验证
print("样本均值:", samples.mean().item()) # 接近0
print("样本标准差:", samples.std().item()) # 接近1
效果说明:
- 直方图呈现"钟形",与红色PDF曲线高度吻合;
- 样本均值≈0、标准差≈1,符合标准正态分布特征。
案例2:连续分布 - 均匀分布(Uniform Distribution)
取值在 low 和 high 之间均匀分布,参数为 low(下限)、high(上限)。
python
# 1. 定义均匀分布:取值范围[0, 10]
uniform_dist = torch.distributions.Uniform(low=0.0, high=10.0)
# 2. 核心操作
samples = uniform_dist.sample(sample_shape=(10000,)) # 生成10000个样本(越多越均匀)
pdf_vals = uniform_dist.log_prob(samples).exp()
# 3. 可视化
plt.figure(figsize=(10, 6))
plt.hist(samples.numpy(), bins=20, density=True, alpha=0.7, label="样本直方图", color="#ff7f0e")
# 绘制PDF曲线(均匀分布的PDF是常数:1/(high-low))
x = torch.linspace(-1, 11, 1000)
pdf_curve = uniform_dist.log_prob(x).exp()
plt.plot(x.numpy(), pdf_curve.numpy(), color="red", linewidth=2, label="PDF曲线")
plt.xlabel("取值")
plt.ylabel("概率密度")
plt.title("均匀分布(Uniform):[0,10]")
plt.legend()
plt.grid(alpha=0.3)
plt.show()
print("样本均值:", samples.mean().item()) # 接近5((0+10)/2)
print("均匀分布PDF值:", pdf_vals[0].item()) # 0.1(1/(10-0))
效果说明:
- 直方图各区间高度接近,体现"均匀";
- PDF曲线在[0,10]内为常数0.1,区间外为0。
案例3:离散分布 - 二项分布(Binomial Distribution)
描述 n 次独立伯努利试验中成功的次数,参数为 total_count(试验次数n)、probs(单次成功概率p)。
python
# 1. 定义二项分布:10次试验,单次成功概率0.5
binomial_dist = torch.distributions.Binomial(total_count=10, probs=0.5)
# 2. 核心操作
samples = binomial_dist.sample(sample_shape=(10000,)) # 生成10000个样本(成功次数)
pmf_vals = binomial_dist.log_prob(samples).exp() # 离散分布用PMF(概率质量)
# 3. 可视化:离散分布用柱状图
plt.figure(figsize=(10, 6))
# 统计每个成功次数的频数(归一化为概率)
counts = torch.bincount(samples.int()) / len(samples)
plt.bar(range(11), counts.numpy(), alpha=0.7, label="样本概率", color="#2ca02c")
# 绘制理论PMF
x = torch.arange(0, 11)
pmf_curve = binomial_dist.log_prob(x).exp()
plt.plot(x.numpy(), pmf_curve.numpy(), color="red", marker="o", linestyle="--", label="PMF曲线")
plt.xlabel("成功次数")
plt.ylabel("概率")
plt.title("二项分布(Binomial):n=10, p=0.5")
plt.xticks(range(11))
plt.legend()
plt.grid(alpha=0.3)
plt.show()
print("样本均值:", samples.mean().item()) # 接近5(n*p=10*0.5)
效果说明:
- 成功次数集中在5附近,符合p=0.5的二项分布特征;
- 样本均值≈5,与理论值(n*p)一致。
案例4:离散分布 - 类别分布(Categorical Distribution)
描述离散类别取值的概率(如分类任务的标签分布),参数为 probs(各类别概率)或 logits(未归一化的对数概率)。
python
# 1. 定义类别分布:3个类别,概率分别为[0.2, 0.5, 0.3]
probs = torch.tensor([0.2, 0.5, 0.3])
categorical_dist = torch.distributions.Categorical(probs=probs)
# 2. 核心操作
samples = categorical_dist.sample(sample_shape=(10000,)) # 生成10000个样本(0/1/2)
log_prob_vals = categorical_dist.log_prob(samples) # 对数PMF
# 3. 可视化
plt.figure(figsize=(8, 5))
counts = torch.bincount(samples) / len(samples)
plt.bar(["类别0", "类别1", "类别2"], counts.numpy(), alpha=0.7, color=["#1f77b4", "#ff7f0e", "#2ca02c"])
plt.axhline(y=0.2, color="red", linestyle="--", label="类别0理论概率")
plt.axhline(y=0.5, color="orange", linestyle="--", label="类别1理论概率")
plt.axhline(y=0.3, color="green", linestyle="--", label="类别2理论概率")
plt.ylabel("概率")
plt.title("类别分布(Categorical):[0.2, 0.5, 0.3]")
plt.legend()
plt.grid(alpha=0.3)
plt.show()
print("各类别样本占比:", counts.numpy()) # 接近[0.2, 0.5, 0.3]
效果说明:
- 类别1的样本占比最高(≈0.5),类别0最低(≈0.2),与设定概率一致;
- 是分类任务中"标签分布""预测分布"的核心建模工具。
案例5:实战场景 - 分布的KL散度计算
KL散度(Kullback-Leibler Divergence)衡量两个分布的差异,torch.distributions 内置 kl_divergence 函数。
python
# 定义两个正态分布
dist1 = torch.distributions.Normal(loc=0.0, scale=1.0) # 标准正态
dist2 = torch.distributions.Normal(loc=1.0, scale=2.0) # 均值1,标准差2
# 计算KL散度(dist1 || dist2)
kl_div = torch.distributions.kl_divergence(dist1, dist2)
print("KL散度(dist1 || dist2):", kl_div.item()) # 输出约0.8132
# 物理意义:值越大,两个分布差异越大;值为0则分布完全相同
三、核心API总结
| 操作 | 方法 | 说明 |
|---|---|---|
| 定义分布 | dist = DistributionClass(参数) |
如 Normal(loc, scale)、Binomial(total_count, probs) |
| 采样 | dist.sample(sample_shape) |
生成指定形状的样本,支持批量 |
| 概率计算 | dist.log_prob(x) |
计算x的对数PDF/PMF(数值更稳定) |
| 概率计算 | dist.log_prob(x).exp() |
还原为原始PDF/PMF值 |
| 分布差异 | kl_divergence(dist1, dist2) |
计算两个分布的KL散度 |
总结
- 核心作用 :
torch.distributions是PyTorch中概率建模的核心模块,支持连续/离散分布的采样、概率计算,且兼容自动微分; - 关键用法 :
- 连续分布(Normal/Uniform):关注PDF和样本分布的连续性;
- 离散分布(Binomial/Categorical):关注PMF和类别/次数的离散性;
- 实战价值:广泛用于深度学习的概率模型(如VAE)、强化学习(动作采样)、分类任务(标签分布)等场景,是连接"概率"和"张量计算"的桥梁。