系列文章:
《PyTorch 基础学习》文章索引
介绍
在深度学习中,损失函数(Loss Function)用于评估模型预测值与实际目标之间的差异,是模型训练的核心部分。在 PyTorch 中,损失函数通过 torch.nn
模块提供,有多种不同类型的损失函数可供选择,每种损失函数都有其特定的用途和应用场景。本文将详细介绍几种常见的 PyTorch 损失函数,包括它们的用途、公式、典型应用场景以及实例代码。
1. 均方误差损失(nn.MSELoss)
用途
nn.MSELoss
主要用于回归任务中,衡量模型预测值与真实值之间的差异。
公式
MSE ( y , y ^ ) = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \text{MSE}(y, \hat{y}) = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 MSE(y,y^)=n1i=1∑n(yi−y^i)2
其中, y i y_i yi 是真实值, y ^ i \hat{y}_i y^i 是预测值, n n n 是样本的数量。
典型应用场景
- 预测连续数值,如房价预测、股票价格预测等。
实例代码
python
import torch
import torch.nn as nn
loss_fn = nn.MSELoss()
input = torch.tensor([0.0, 0.5, 1.0])
target = torch.tensor([0.0, 1.0, 1.0])
loss = loss_fn(input, target)
print(f'MSE Loss: {loss.item()}')
2. 交叉熵损失(nn.CrossEntropyLoss)
用途
nn.CrossEntropyLoss
主要用于多分类任务,用于衡量预测类别分布与真实分布之间的差异。
公式
CrossEntropy ( y , y ^ ) = − 1 n ∑ i = 1 n ∑ c = 1 C y i , c log ( y ^ i , c ) \text{CrossEntropy}(y, \hat{y}) = - \frac{1}{n} \sum_{i=1}^{n} \sum_{c=1}^{C} y_{i,c} \log(\hat{y}_{i,c}) CrossEntropy(y,y^)=−n1i=1∑nc=1∑Cyi,clog(y^i,c)
其中, y i , c y_{i,c} yi,c 是样本 i i i 属于类别 c c c 的真实概率, y ^ i , c \hat{y}_{i,c} y^i,c 是预测的概率, C C C 是类别数。
典型应用场景
- 图像分类任务,如手写数字识别(MNIST)、图像分类(CIFAR-10)等。
实例代码
python
import torch
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
input = torch.tensor([[0.5, 1.0, 0.5], [0.3, 0.2, 0.5]])
target = torch.tensor([1, 2])
loss = loss_fn(input, target)
print(f'Cross Entropy Loss: {loss.item()}')
3. 二元交叉熵损失(nn.BCELoss)
用途
nn.BCELoss
主要用于二分类任务,用于衡量预测概率与真实标签之间的差异。
公式
BCE ( y , y ^ ) = − 1 n ∑ i = 1 n [ y i log ( y ^ i ) + ( 1 − y i ) log ( 1 − y ^ i ) ] \text{BCE}(y, \hat{y}) = - \frac{1}{n} \sum_{i=1}^{n} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right] BCE(y,y^)=−n1i=1∑n[yilog(y^i)+(1−yi)log(1−y^i)]
其中, y i y_i yi 是二分类的真实标签(0 或 1), y ^ i \hat{y}_i y^i 是预测的概率值。
典型应用场景
- 二分类任务,如垃圾邮件检测、二元图像分类等。
实例代码
python
import torch
import torch.nn as nn
loss_fn = nn.BCELoss()
input = torch.tensor([0.1, 0.9])
target = torch.tensor([0.0, 1.0])
loss = loss_fn(input, target)
print(f'BCE Loss: {loss.item()}')
4. 带有Logits的二元交叉熵损失(nn.BCEWithLogitsLoss)
用途
nn.BCEWithLogitsLoss
用于二分类任务,但与 nn.BCELoss
不同的是,输入值未经过 Sigmoid
函数。该函数在内部结合了 Sigmoid
和 BCELoss
,可以更稳定地计算损失。
公式
BCEWithLogits ( y , z ) = 1 n ∑ i = 1 n [ max ( z i , 0 ) − z i ⋅ y i + log ( 1 + e − ∣ z i ∣ ) ] \text{BCEWithLogits}(y, z) = \frac{1}{n} \sum_{i=1}^{n} \left[ \max(z_i, 0) - z_i \cdot y_i + \log\left(1 + e^{-|z_i|}\right) \right] BCEWithLogits(y,z)=n1i=1∑n[max(zi,0)−zi⋅yi+log(1+e−∣zi∣)]
其中, z i z_i zi 是模型输出的未经过 Sigmoid 函数的值, y i y_i yi 是真实标签(0 或 1)。
典型应用场景
- 二分类任务,特别是在需要直接处理 logits(未归一化预测值)时,如文本分类中的情感分析。
实例代码
python
import torch
import torch.nn as nn
loss_fn = nn.BCEWithLogitsLoss()
input = torch.tensor([0.1, 0.9])
target = torch.tensor([0.0, 1.0])
loss = loss_fn(input, target)
print(f'BCE With Logits Loss: {loss.item()}')
5. 边际排名损失(nn.MarginRankingLoss)
用途
nn.MarginRankingLoss
主要用于学习排序任务,评估两个输入值的相对差异是否符合目标标签的排序。
公式
MarginRanking ( x 1 , x 2 , y ) = max ( 0 , − y ⋅ ( x 1 − x 2 ) + margin ) \text{MarginRanking}(x_1, x_2, y) = \max(0, -y \cdot (x_1 - x_2) + \text{margin}) MarginRanking(x1,x2,y)=max(0,−y⋅(x1−x2)+margin)
其中, x 1 x_1 x1 和 x 2 x_2 x2 是两个输入值, y y y 是目标标签(+1 或 -1),margin
是定义的边距值。
典型应用场景
- 信息检索和推荐系统,评估文档或物品的相关性排序。
实例代码
python
import torch
import torch.nn as nn
loss_fn = nn.MarginRankingLoss(margin=1.0)
input1 = torch.tensor([0.8])
input2 = torch.tensor([0.3])
target = torch.tensor([1.0]) # 目标:input1应大于input2
loss = loss_fn(input1, input2, target)
print(f'Margin Ranking Loss: {loss.item()}')
6. 铰链嵌入损失(nn.HingeEmbeddingLoss)
用途
nn.HingeEmbeddingLoss
常用于支持向量机(SVM)以及其他涉及到嵌入学习的任务。
公式
HingeEmbedding ( y , y ^ ) = 1 n ∑ i = 1 n [ max ( 0 , 1 − y i ⋅ y ^ i ) ] \text{HingeEmbedding}(y, \hat{y}) = \frac{1}{n} \sum_{i=1}^{n} \left[ \max(0, 1 - y_i \cdot \hat{y}_i) \right] HingeEmbedding(y,y^)=n1i=1∑n[max(0,1−yi⋅y^i)]
其中, y i y_i yi 是目标标签(+1 或 -1), y ^ i \hat{y}_i y^i 是预测值。
典型应用场景
- 用于二分类任务,特别是支持向量机模型中。
实例代码
python
import torch
import torch.nn as nn
loss_fn = nn.HingeEmbeddingLoss()
input = torch.tensor([0.8, -0.5])
target = torch.tensor([1, -1]) # 标签是1或-1
loss = loss_fn(input, target)
print(f'Hinge Embedding Loss: {loss.item()}')
总结
损失函数在深度学习模型训练中起着至关重要的作用。通过选择合适的损失函数,可以有效引导模型优化目标,从而提高预测的准确性和可靠性。本文介绍了几种常见的 PyTorch 损失函数,包括它们的用途、公式、典型应用场景以及示例代码,希望对你在深度学习模型的构建和训练中有所帮助。