pytorch nn.NLLLoss和nn.CrossEntropyLoss函数区别

nn.CrossEntropyLoss(交叉熵损失函数) 和nn.NLLLoss (负对数似然损失函数)的区别

  1. 输入格式

    • nn.CrossEntropyLoss:直接接受未归一化的 logits 作为输入,并在内部自动应用 log_softmax 来计算对数概率。
    • nn.NLLLoss:接受对数概率 (log-probabilities)作为输入,也就是说,输入需要先通过 log_softmax处理。
  2. 计算流程

    • nn.CrossEntropyLoss 的计算流程是:
      1. 先对 logits 应用 softmax,将其转换为概率分布。
      2. 再对概率分布取对数,变为对数概率(log-probabilities)。
      3. 最后,对真实类别对应的对数概率取负值,得到损失。
    • nn.NLLLoss 的计算流程是:
      1. 直接使用对数概率作为输入。
      2. 对真实类别对应的对数概率取负值,得到损失。

代码示例

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

logits = torch.tensor([[2.0, 1.0, 0.1]])  # 未归一化的 logits
target = torch.tensor([0])  # 真实标签

# 1. 使用 nn.CrossEntropyLoss
loss_fn_ce = nn.CrossEntropyLoss()
loss_ce = loss_fn_ce(logits, target)  # nn.CrossEntropyLoss 自动进行 log_softmax
print("CrossEntropyLoss:", loss_ce.item())

# 2. 使用 nn.NLLLoss
log_probs = F.log_softmax(logits, dim=1)  # 先手动进行 log_softmax
loss_fn_nll = nn.NLLLoss()
loss_nll = loss_fn_nll(log_probs, target)  # 直接传入对数概率
print("NLLLoss:", loss_nll.item())

在这个例子中,nn.CrossEntropyLossnn.NLLLoss 的最终损失值是相同的,都是 0.4170。

nn.NLLLoss 的优势

虽然在大多数场景下,使用 nn.CrossEntropyLoss 更方便(因为它直接接受 logits),但 nn.NLLLoss 也有它的优势和特定的应用场景:

  1. 灵活性

    • nn.NLLLoss 允许用户直接传入经过 log_softmax 处理的对数概率。这在某些需要自定义概率分布或概率结构的任务中是有用的,用户可以手动处理 log_softmax,甚至对其进行进一步的修改和调整。
  2. 与自定义模型结构兼容

    • 当模型的输出不是传统的 logits 而是已经计算好对数概率的复杂结构时,nn.NLLLoss 更适合,因为它直接接受对数概率,不再需要依赖 CrossEntropyLoss 的内部处理。
  3. 分离 log_softmaxNLLLoss 计算

    • 在一些场景下,我们可能希望将 log_softmax 的计算和损失函数的计算分离,以便在不同的地方使用对数概率。例如,在序列生成任务中,可能需要在生成过程中反复使用 log_softmax 计算对数概率,而不是每次都重新计算。

选择何时使用

  • 使用 nn.CrossEntropyLoss
    • 大多数情况下,我们的网络输出的是 logits(未归一化的分数),并且希望简化代码,那么 nn.CrossEntropyLoss 是更方便的选择,因为它可以直接处理 logits。
  • 使用 nn.NLLLoss
    • 当你的模型或任务需要自定义对数概率的计算过程,或者你需要对 log_softmax 进行额外操作,nn.NLLLoss 更加灵活,可以帮助你处理已经是对数概率的输出。
相关推荐
数据科学作家31 分钟前
学数据分析必囤!数据分析必看!清华社9本书覆盖Stata/SPSS/Python全阶段学习路径
人工智能·python·机器学习·数据分析·统计·stata·spss
HXQ_晴天2 小时前
CASToR 生成的文件进行转换
python
CV缝合救星2 小时前
【Arxiv 2025 预发行论文】重磅突破!STAR-DSSA 模块横空出世:显著性+拓扑双重加持,小目标、大场景统统拿下!
人工智能·深度学习·计算机视觉·目标跟踪·即插即用模块
java1234_小锋3 小时前
Scikit-learn Python机器学习 - 特征预处理 - 标准化 (Standardization):StandardScaler
python·机器学习·scikit-learn
Python×CATIA工业智造3 小时前
Python带状态生成器完全指南:从基础到高并发系统设计
python·pycharm
向qian看_-_3 小时前
Linux 使用pip报错(error: externally-managed-environment )解决方案
linux·python·pip
Nicole-----4 小时前
Python - Union联合类型注解
开发语言·python
TDengine (老段)4 小时前
从 ETL 到 Agentic AI:工业数据管理变革与 TDengine IDMP 的治理之道
数据库·数据仓库·人工智能·物联网·时序数据库·etl·tdengine
蓝桉8024 小时前
如何进行神经网络的模型训练(视频代码中的知识点记录)
人工智能·深度学习·神经网络
星期天要睡觉5 小时前
深度学习——数据增强(Data Augmentation)
人工智能·深度学习