nn.CrossEntropyLoss
(交叉熵损失函数) 和nn.NLLLoss
(负对数似然损失函数)的区别
-
输入格式:
nn.CrossEntropyLoss
:直接接受未归一化的 logits 作为输入,并在内部自动应用log_softmax
来计算对数概率。nn.NLLLoss
:接受对数概率 (log-probabilities)作为输入,也就是说,输入需要先通过log_softmax
处理。
-
计算流程:
nn.CrossEntropyLoss
的计算流程是:- 先对 logits 应用
softmax
,将其转换为概率分布。 - 再对概率分布取对数,变为对数概率(log-probabilities)。
- 最后,对真实类别对应的对数概率取负值,得到损失。
- 先对 logits 应用
nn.NLLLoss
的计算流程是:- 直接使用对数概率作为输入。
- 对真实类别对应的对数概率取负值,得到损失。
代码示例
import torch
import torch.nn as nn
import torch.nn.functional as F
logits = torch.tensor([[2.0, 1.0, 0.1]]) # 未归一化的 logits
target = torch.tensor([0]) # 真实标签
# 1. 使用 nn.CrossEntropyLoss
loss_fn_ce = nn.CrossEntropyLoss()
loss_ce = loss_fn_ce(logits, target) # nn.CrossEntropyLoss 自动进行 log_softmax
print("CrossEntropyLoss:", loss_ce.item())
# 2. 使用 nn.NLLLoss
log_probs = F.log_softmax(logits, dim=1) # 先手动进行 log_softmax
loss_fn_nll = nn.NLLLoss()
loss_nll = loss_fn_nll(log_probs, target) # 直接传入对数概率
print("NLLLoss:", loss_nll.item())
在这个例子中,nn.CrossEntropyLoss
和 nn.NLLLoss
的最终损失值是相同的,都是 0.4170。
nn.NLLLoss
的优势
虽然在大多数场景下,使用 nn.CrossEntropyLoss
更方便(因为它直接接受 logits),但 nn.NLLLoss
也有它的优势和特定的应用场景:
-
灵活性:
nn.NLLLoss
允许用户直接传入经过log_softmax
处理的对数概率。这在某些需要自定义概率分布或概率结构的任务中是有用的,用户可以手动处理log_softmax
,甚至对其进行进一步的修改和调整。
-
与自定义模型结构兼容:
- 当模型的输出不是传统的 logits 而是已经计算好对数概率的复杂结构时,
nn.NLLLoss
更适合,因为它直接接受对数概率,不再需要依赖CrossEntropyLoss
的内部处理。
- 当模型的输出不是传统的 logits 而是已经计算好对数概率的复杂结构时,
-
分离
log_softmax
和NLLLoss
计算:- 在一些场景下,我们可能希望将
log_softmax
的计算和损失函数的计算分离,以便在不同的地方使用对数概率。例如,在序列生成任务中,可能需要在生成过程中反复使用log_softmax
计算对数概率,而不是每次都重新计算。
- 在一些场景下,我们可能希望将
选择何时使用
- 使用
nn.CrossEntropyLoss
:- 大多数情况下,我们的网络输出的是 logits(未归一化的分数),并且希望简化代码,那么
nn.CrossEntropyLoss
是更方便的选择,因为它可以直接处理 logits。
- 大多数情况下,我们的网络输出的是 logits(未归一化的分数),并且希望简化代码,那么
- 使用
nn.NLLLoss
:- 当你的模型或任务需要自定义对数概率的计算过程,或者你需要对
log_softmax
进行额外操作,nn.NLLLoss
更加灵活,可以帮助你处理已经是对数概率的输出。
- 当你的模型或任务需要自定义对数概率的计算过程,或者你需要对