PyTorch 之 torch.distributions.Categorical 详解

PyTorch 之 torch.distributions.Categorical 详解

PyTorch 之 torch.distributions.Categorical 详解

在深度学习的诸多任务中,我们常常需要处理离散概率分布,比如在自然语言处理中对词汇表中的单词进行采样,或者在强化学习中从策略网络输出的动作概率分布中选择动作。PyTorch 提供了 torch.distributions.Categorical 类,方便我们高效地创建和操作离散分类分布。本文将深入讲解这个类的用法,帮助你在实际项目中更好地利用它。

一、创建分类分布

(一)基本语法

torch.distributions.Categorical 的基本语法是:

复制代码
torch.distributions.Categorical(probs=None, logits=None)

其中:

  • probs:一个张量,表示每个类别的概率。它的值应该非负,并且所有元素的和为 1。例如,torch.tensor([0.1, 0.2, 0.3, 0.4]) 表示有四个类别,它们的概率分别是 0.1、0.2、0.3 和 0.4。
  • logits:一个张量,表示每个类别的未归一化对数概率。系统会自动将其转换为概率值。比如,torch.tensor([1.0, 2.0, 3.0, 4.0]) 会被处理成相应的概率分布。

(二)示例

python 复制代码
import torch

# 使用 probs 参数
probs = torch.tensor([0.1, 0.2, 0.3, 0.4])
categorical_dist = torch.distributions.Categorical(probs)

# 使用 logits 参数
logits = torch.tensor([1.0, 2.0, 3.0, 4.0])
categorical_dist_logits = torch.distributions.Categorical(logits=logits)

二、采样

(一)方法

使用 sample 方法进行采样,语法是:

复制代码
sample(sample_shape=torch.Size())

其中,sample_shape 是一个元组,用于指定采样的样本数量和形状。默认为空,表示采样一个样本。

(二)示例

python 复制代码
# 采样一个样本
sample = categorical_dist.sample()
print(sample)  # 输出一个类别索引,比如 tensor(3)

# 采样多个样本
samples = categorical_dist.sample((5,))
print(samples)  # 输出一个形状为 [5] 的张量,包含 5 个类别索引,如 tensor([2, 0, 3, 1, 3])

三、计算概率

(一)方法

借助 prob 方法计算概率,语法如下:

复制代码
prob(value)

这里,value 是一个张量,表示类别索引,取值范围为 [0, num_categories - 1]

(二)示例

python 复制代码
# 计算单个值的概率
prob_value = categorical_dist.prob(torch.tensor(2))
print(prob_value)  # 输出类别索引为 2 的概率值,如 tensor(0.3)

# 计算多个值的概率
prob_values = categorical_dist.prob(torch.tensor([0, 1, 2, 3]))
print(prob_values)  # 输出一个形状为 [4] 的张量,包含每个类别索引对应的概率值,如 tensor([0.1, 0.2, 0.3, 0.4])

四、计算对数概率

(一)方法

调用 log_prob 方法计算对数概率,语法是:

复制代码
log_prob(value)

参数 value 的含义和 prob 方法中的相同。

(二)示例

python 复制代码
# 计算单个值的对数概率
log_prob_value = categorical_dist.log_prob(torch.tensor(1))
print(log_prob_value)  # 输出类别索引为 1 的对数概率值,比如 tensor(-1.6094)

# 计算多个值的对数概率
log_prob_values = categorical_dist.log_prob(torch.tensor([0, 1, 2, 3]))
print(log_prob_values)  # 输出一个形状为 [4] 的张量,包含每个类别索引对应的对数概率值,如 tensor([-2.3026, -1.6094, -1.2039, -0.9163])

五、其他方法

(一)计算熵

使用 entropy 方法计算分类分布的熵,熵反映了分布的不确定性。值越大,表示不确定性越高。示例代码如下:

python 复制代码
# 计算熵
entropy_value = categorical_dist.entropy()
print(entropy_value)  # 输出一个值,如 tensor(1.3777)

(二)枚举支持集

通过 enumerate_support 方法枚举分布的支持集,即所有可能的类别索引。示例如下:

python 复制代码
# 枚举支持集
support = categorical_dist.enumerate_support()
print(support)  # 输出一个形状为 [4] 的张量,如 tensor([0, 1, 2, 3])

(三)获取均值和方差

可以直接访问 meanvariance 属性,分别获取分布的均值和方差。示例:

python 复制代码
# 获取均值和方差
mean_value = categorical_dist.mean
variance_value = categorical_dist.variance
print(mean_value, variance_value)  # 输出类似 tensor(2.3000) tensor(1.8100)

六、实际应用场景

(一)强化学习中的策略选择

在强化学习里,策略网络常输出动作的概率分布。此时,可以利用 torch.distributions.Categorical 来创建这个分布,然后通过采样来选择动作。例如:

python 复制代码
import torch
import torch.nn as nn

# 假设策略网络的输出是 logits
class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(4, 128)  # 输入状态维度为 4
        self.fc2 = nn.Linear(128, 2)  # 输出动作 logits,假设有 2 个动作

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

policy_net = PolicyNetwork()
state = torch.tensor([1.0, 2.0, 3.0, 4.0])  # 当前状态
logits = policy_net(state)
action_dist = torch.distributions.Categorical(logits=logits)
action = action_dist.sample()  # 采样得到动作
print(action)  # 输出动作索引,如 tensor(1)

(二)自然语言处理中的单词预测

在语言模型中,模型会预测下一个单词的概率分布。使用 torch.distributions.Categorical 可以方便地处理这个分布,比如进行采样生成文本。例如:

python 复制代码
import torch
import torch.nn as nn

# 假设语言模型的输出是单词的概率
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(LanguageModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.fc(x[:, -1, :])  # 取最后一个时间步的输出预测下一个单词
        return x

vocab_size = 10000  # 词汇表大小
embedding_dim = 128
hidden_dim = 256
lm = LanguageModel(vocab_size, embedding_dim, hidden_dim)
input_word_indices = torch.tensor([1, 2, 3, 4])  # 输入单词的索引序列
probs = lm(input_word_indices)
word_dist = torch.distributions.Categorical(probs=probs)
next_word = word_dist.sample()  # 采样得到下一个单词的索引
print(next_word)  # 输出单词索引,如 tensor(125)

torch.distributions.Categorical 类是 PyTorch 中处理离散概率分布的有力工具。它丰富的功能使得在涉及到分类数据的概率操作时变得简单高效。掌握这个类的用法,能让你在强化学习、自然语言处理等诸多领域更加得心应手地构建和训练模型。建议你在实际项目中多加练习,深入理解其原理和应用场景。

相关推荐
爱分享的阿Q18 小时前
STM32现代化AI开发环境搭建:从Keil到VSCode+AI的范式转移
人工智能·vscode·stm32
LJ979511118 小时前
媒体发布新武器:Infoseek融媒体平台使用指南
大数据·人工智能
科技小花18 小时前
AI重塑数据治理:2026年核心方案评估与场景适配
大数据·人工智能·云原生·ai原生
Canace18 小时前
使用大模型来维护知识库
前端·人工智能
Ricky111zzz18 小时前
leetcode学python记录1
python·算法·leetcode·职场和发展
乐鑫科技 Espressif19 小时前
使用 MCP 服务器,把乐鑫文档接入 AI 工作流
人工智能·ai·esp32·乐鑫科技
云烟成雨TD19 小时前
Spring AI Alibaba 1.x 系列【5】ReactAgent 构建器深度源码解析
java·人工智能·spring
语戚19 小时前
Stable Diffusion 入门:架构、空间与生成流程概览
人工智能·ai·stable diffusion·aigc·模型
代码青铜19 小时前
如何用 Zion 实现 AI 图片分析与电商文案自动生成流程
大数据·人工智能
俊哥V19 小时前
每日 AI 研究简报 · 2026-04-08
人工智能·ai