pytorch nn.NLLLoss和nn.CrossEntropyLoss函数区别

nn.CrossEntropyLoss(交叉熵损失函数) 和nn.NLLLoss (负对数似然损失函数)的区别

  1. 输入格式

    • nn.CrossEntropyLoss:直接接受未归一化的 logits 作为输入,并在内部自动应用 log_softmax 来计算对数概率。
    • nn.NLLLoss:接受对数概率 (log-probabilities)作为输入,也就是说,输入需要先通过 log_softmax处理。
  2. 计算流程

    • nn.CrossEntropyLoss 的计算流程是:
      1. 先对 logits 应用 softmax,将其转换为概率分布。
      2. 再对概率分布取对数,变为对数概率(log-probabilities)。
      3. 最后,对真实类别对应的对数概率取负值,得到损失。
    • nn.NLLLoss 的计算流程是:
      1. 直接使用对数概率作为输入。
      2. 对真实类别对应的对数概率取负值,得到损失。

代码示例

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

logits = torch.tensor([[2.0, 1.0, 0.1]])  # 未归一化的 logits
target = torch.tensor([0])  # 真实标签

# 1. 使用 nn.CrossEntropyLoss
loss_fn_ce = nn.CrossEntropyLoss()
loss_ce = loss_fn_ce(logits, target)  # nn.CrossEntropyLoss 自动进行 log_softmax
print("CrossEntropyLoss:", loss_ce.item())

# 2. 使用 nn.NLLLoss
log_probs = F.log_softmax(logits, dim=1)  # 先手动进行 log_softmax
loss_fn_nll = nn.NLLLoss()
loss_nll = loss_fn_nll(log_probs, target)  # 直接传入对数概率
print("NLLLoss:", loss_nll.item())

在这个例子中,nn.CrossEntropyLossnn.NLLLoss 的最终损失值是相同的,都是 0.4170。

nn.NLLLoss 的优势

虽然在大多数场景下,使用 nn.CrossEntropyLoss 更方便(因为它直接接受 logits),但 nn.NLLLoss 也有它的优势和特定的应用场景:

  1. 灵活性

    • nn.NLLLoss 允许用户直接传入经过 log_softmax 处理的对数概率。这在某些需要自定义概率分布或概率结构的任务中是有用的,用户可以手动处理 log_softmax,甚至对其进行进一步的修改和调整。
  2. 与自定义模型结构兼容

    • 当模型的输出不是传统的 logits 而是已经计算好对数概率的复杂结构时,nn.NLLLoss 更适合,因为它直接接受对数概率,不再需要依赖 CrossEntropyLoss 的内部处理。
  3. 分离 log_softmaxNLLLoss 计算

    • 在一些场景下,我们可能希望将 log_softmax 的计算和损失函数的计算分离,以便在不同的地方使用对数概率。例如,在序列生成任务中,可能需要在生成过程中反复使用 log_softmax 计算对数概率,而不是每次都重新计算。

选择何时使用

  • 使用 nn.CrossEntropyLoss
    • 大多数情况下,我们的网络输出的是 logits(未归一化的分数),并且希望简化代码,那么 nn.CrossEntropyLoss 是更方便的选择,因为它可以直接处理 logits。
  • 使用 nn.NLLLoss
    • 当你的模型或任务需要自定义对数概率的计算过程,或者你需要对 log_softmax 进行额外操作,nn.NLLLoss 更加灵活,可以帮助你处理已经是对数概率的输出。
相关推荐
谦行7 分钟前
工欲善其事,必先利其器—— PyTorch 深度学习基础操作
pytorch·深度学习·ai编程
逢生博客33 分钟前
使用 Python 项目管理工具 uv 快速创建 MCP 服务(Cherry Studio、Trae 添加 MCP 服务)
python·sqlite·uv·deepseek·trae·cherry studio·mcp服务
xwz小王子36 分钟前
Nature Communications 面向形状可编程磁性软材料的数据驱动设计方法—基于随机设计探索与神经网络的协同优化框架
深度学习
堕落似梦40 分钟前
Pydantic增强SQLALchemy序列化(FastAPI直接输出SQLALchemy查询集)
python
白熊18843 分钟前
【计算机视觉】CV实战项目 - 基于YOLOv5的人脸检测与关键点定位系统深度解析
人工智能·yolo·计算机视觉
nenchoumi31191 小时前
VLA 论文精读(十六)FP3: A 3D Foundation Policy for Robotic Manipulation
论文阅读·人工智能·笔记·学习·vln
后端小肥肠1 小时前
文案号搞钱潜规则:日入四位数的Coze工作流我跑通了
人工智能·coze
LCHub低代码社区1 小时前
钧瓷产业原始创新的许昌共识:技术破壁·产业再造·生态重构(一)
大数据·人工智能·维格云·ai智能体·ai自动化·大禹智库·钧瓷码
-曾牛1 小时前
Spring AI 快速入门:从环境搭建到核心组件集成
java·人工智能·spring·ai·大模型·spring ai·开发环境搭建
阿川20151 小时前
云智融合普惠大模型AI,政务服务重构数智化路径
人工智能·华为云·政务·deepseek