【PyTorch笔记 04】F.cross_entropy的使用

torch.nn.functional.cross_entropy是PyTorch中用于计算交叉熵损失的函数,非常适合用于多分类问题。这个函数结合了log_softmax操作和nll_loss(负对数似然损失)的计算,因此输入得分(即模型输出)不需要事先经过softmax处理。

下面是一个使用torch.nn.functional.cross_entropy的示例,展示了如何在一个简单的神经网络模型中应用它来计算损失:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

# 假设我们有一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)
    
    def forward(self, x):
        return self.linear(x)

# 创建模型实例
model = SimpleModel(input_size=10, num_classes=3)

# 模拟一批输入数据和标签
inputs = torch.randn(5, 10)  # 假设批大小为5,输入特征大小为10
labels = torch.tensor([0, 2, 1, 0, 2])  # 真实标签

# 模型前向传播
outputs = model(inputs)

# 计算交叉熵损失
# outputs: torch.Size([5, 3])
# labels: torch.Size([5]), 注意这个参数必须为long型的
loss = F.cross_entropy(outputs, labels)

print("Loss:", loss.item())

在这个示例中:

  • SimpleModel是一个简单的线性模型,其输出大小等于类别数。
  • 我们创建了一批输入inputs和对应的标签labels
  • 模型的输出outputs是直接传递给F.cross_entropy的,不需要额外的softmax层,因为cross_entropy内部已经处理了这部分。
  • labels应该是每个样本的类别索引形式,而不是one-hot编码。
  • F.cross_entropy计算了从模型输出到真实标签的交叉熵损失。

这种方式使得实现多分类问题的模型训练变得简单而直接。

相关推荐
嵌入式×边缘AI:打怪升级日志1 分钟前
USB设备枚举过程详解:从插入到正常工作
开发语言·数据库·笔记
学习是生活的调味剂4 分钟前
在大模型开发中,是否需要先完整学习 TensorFlow,再学 PyTorch?
pytorch·学习·tensorflow·transformers
Jerryhut4 分钟前
OpenCv总结5——图像特征——harris角点检测
人工智能·opencv·计算机视觉
笨鸟先飞的橘猫7 分钟前
mongo权威指南(第三版)学习笔记
笔记·学习
图欧学习资源库9 分钟前
人工智能领域、图欧科技、IMYAI智能助手2025年12月更新月报
人工智能·科技
技术小泽9 分钟前
java转go速成入门笔记篇(一)
java·笔记·golang
光羽隹衡10 分钟前
机器学习——贝叶斯
人工智能·机器学习
夏天是冰红茶10 分钟前
YOLO目标检测数据集扩充
人工智能·yolo·目标检测
Noushiki10 分钟前
RabbitMQ 进阶 学习笔记2
笔记·学习·rabbitmq
lpfasd12312 分钟前
Spring AI 集成国内大模型实战:千问/豆包(含多模态)+ Spring Boot 4.0.1 全攻略
人工智能·spring boot·spring