【PyTorch笔记 04】F.cross_entropy的使用

torch.nn.functional.cross_entropy是PyTorch中用于计算交叉熵损失的函数,非常适合用于多分类问题。这个函数结合了log_softmax操作和nll_loss(负对数似然损失)的计算,因此输入得分(即模型输出)不需要事先经过softmax处理。

下面是一个使用torch.nn.functional.cross_entropy的示例,展示了如何在一个简单的神经网络模型中应用它来计算损失:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

# 假设我们有一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)
    
    def forward(self, x):
        return self.linear(x)

# 创建模型实例
model = SimpleModel(input_size=10, num_classes=3)

# 模拟一批输入数据和标签
inputs = torch.randn(5, 10)  # 假设批大小为5,输入特征大小为10
labels = torch.tensor([0, 2, 1, 0, 2])  # 真实标签

# 模型前向传播
outputs = model(inputs)

# 计算交叉熵损失
# outputs: torch.Size([5, 3])
# labels: torch.Size([5]), 注意这个参数必须为long型的
loss = F.cross_entropy(outputs, labels)

print("Loss:", loss.item())

在这个示例中:

  • SimpleModel是一个简单的线性模型,其输出大小等于类别数。
  • 我们创建了一批输入inputs和对应的标签labels
  • 模型的输出outputs是直接传递给F.cross_entropy的,不需要额外的softmax层,因为cross_entropy内部已经处理了这部分。
  • labels应该是每个样本的类别索引形式,而不是one-hot编码。
  • F.cross_entropy计算了从模型输出到真实标签的交叉熵损失。

这种方式使得实现多分类问题的模型训练变得简单而直接。

相关推荐
likerhood6 分钟前
3. pytorch中数据集加载和处理
人工智能·pytorch·python
Robot侠7 分钟前
ROS1从入门到精通 10:URDF机器人建模(从零构建机器人模型)
人工智能·机器人·ros·机器人操作系统·urdf机器人建模
haiyu_y8 分钟前
Day 46 TensorBoard 使用介绍
人工智能·深度学习·神经网络
阿里云大数据AI技术13 分钟前
DataWorks 又又又升级了,这次我们通过 Arrow 列存格式让数据同步速度提升10倍!
大数据·人工智能
做科研的周师兄14 分钟前
中国土壤有机质数据集
人工智能·算法·机器学习·分类·数据挖掘
IT一氪15 分钟前
一款 AI 驱动的 Word 文档翻译工具
人工智能·word
lovingsoft18 分钟前
Vibe coding 氛围编程
人工智能
百***074524 分钟前
GPT-Image-1.5 极速接入全流程及关键要点
人工智能·gpt·计算机视觉
yiersansiwu123d38 分钟前
AI二创的版权迷局与健康生态构建之道
人工智能
Narrastory44 分钟前
拆解指数加权平均:5 分钟看懂机器学习的 “数据平滑神器”
人工智能·机器学习