PyTorch 基础学习(15)- 损失函数

系列文章:
《PyTorch 基础学习》文章索引

介绍

在深度学习中,损失函数(Loss Function)用于评估模型预测值与实际目标之间的差异,是模型训练的核心部分。在 PyTorch 中,损失函数通过 torch.nn 模块提供,有多种不同类型的损失函数可供选择,每种损失函数都有其特定的用途和应用场景。本文将详细介绍几种常见的 PyTorch 损失函数,包括它们的用途、公式、典型应用场景以及实例代码。

1. 均方误差损失(nn.MSELoss)

用途

nn.MSELoss 主要用于回归任务中,衡量模型预测值与真实值之间的差异。

公式

MSE ( y , y ^ ) = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \text{MSE}(y, \hat{y}) = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 MSE(y,y^)=n1i=1∑n(yi−y^i)2

其中, y i y_i yi 是真实值, y ^ i \hat{y}_i y^i 是预测值, n n n 是样本的数量。

典型应用场景

  • 预测连续数值,如房价预测、股票价格预测等。

实例代码

python 复制代码
import torch
import torch.nn as nn

loss_fn = nn.MSELoss()
input = torch.tensor([0.0, 0.5, 1.0])
target = torch.tensor([0.0, 1.0, 1.0])
loss = loss_fn(input, target)
print(f'MSE Loss: {loss.item()}')

2. 交叉熵损失(nn.CrossEntropyLoss)

用途

nn.CrossEntropyLoss 主要用于多分类任务,用于衡量预测类别分布与真实分布之间的差异。

公式

CrossEntropy ( y , y ^ ) = − 1 n ∑ i = 1 n ∑ c = 1 C y i , c log ⁡ ( y ^ i , c ) \text{CrossEntropy}(y, \hat{y}) = - \frac{1}{n} \sum_{i=1}^{n} \sum_{c=1}^{C} y_{i,c} \log(\hat{y}_{i,c}) CrossEntropy(y,y^)=−n1i=1∑nc=1∑Cyi,clog(y^i,c)

其中, y i , c y_{i,c} yi,c 是样本 i i i 属于类别 c c c 的真实概率, y ^ i , c \hat{y}_{i,c} y^i,c 是预测的概率, C C C 是类别数。

典型应用场景

  • 图像分类任务,如手写数字识别(MNIST)、图像分类(CIFAR-10)等。

实例代码

python 复制代码
import torch
import torch.nn as nn

loss_fn = nn.CrossEntropyLoss()
input = torch.tensor([[0.5, 1.0, 0.5], [0.3, 0.2, 0.5]])
target = torch.tensor([1, 2])
loss = loss_fn(input, target)
print(f'Cross Entropy Loss: {loss.item()}')

3. 二元交叉熵损失(nn.BCELoss)

用途

nn.BCELoss 主要用于二分类任务,用于衡量预测概率与真实标签之间的差异。

公式

BCE ( y , y ^ ) = − 1 n ∑ i = 1 n y i log ⁡ ( y \^ i ) + ( 1 − y i ) log ⁡ ( 1 − y \^ i ) \text{BCE}(y, \hat{y}) = - \frac{1}{n} \sum_{i=1}^{n} \left y_i \\log(\\hat{y}_i) + (1 - y_i) \\log(1 - \\hat{y}_i) \\right BCE(y,y^)=−n1i=1∑nyilog(y\^i)+(1−yi)log(1−y\^i)

其中, y i y_i yi 是二分类的真实标签(0 或 1), y ^ i \hat{y}_i y^i 是预测的概率值。

典型应用场景

  • 二分类任务,如垃圾邮件检测、二元图像分类等。

实例代码

python 复制代码
import torch
import torch.nn as nn

loss_fn = nn.BCELoss()
input = torch.tensor([0.1, 0.9])
target = torch.tensor([0.0, 1.0])
loss = loss_fn(input, target)
print(f'BCE Loss: {loss.item()}')

4. 带有Logits的二元交叉熵损失(nn.BCEWithLogitsLoss)

用途

nn.BCEWithLogitsLoss 用于二分类任务,但与 nn.BCELoss 不同的是,输入值未经过 Sigmoid 函数。该函数在内部结合了 SigmoidBCELoss,可以更稳定地计算损失。

公式

BCEWithLogits ( y , z ) = 1 n ∑ i = 1 n max ⁡ ( z i , 0 ) − z i ⋅ y i + log ⁡ ( 1 + e − ∣ z i ∣ ) \text{BCEWithLogits}(y, z) = \frac{1}{n} \sum_{i=1}^{n} \left \\max(z_i, 0) - z_i \\cdot y_i + \\log\\left(1 + e\^{-\|z_i\|}\\right) \\right BCEWithLogits(y,z)=n1i=1∑nmax(zi,0)−zi⋅yi+log(1+e−∣zi∣)

其中, z i z_i zi 是模型输出的未经过 Sigmoid 函数的值, y i y_i yi 是真实标签(0 或 1)。

典型应用场景

  • 二分类任务,特别是在需要直接处理 logits(未归一化预测值)时,如文本分类中的情感分析。

实例代码

python 复制代码
import torch
import torch.nn as nn

loss_fn = nn.BCEWithLogitsLoss()
input = torch.tensor([0.1, 0.9])
target = torch.tensor([0.0, 1.0])
loss = loss_fn(input, target)
print(f'BCE With Logits Loss: {loss.item()}')

5. 边际排名损失(nn.MarginRankingLoss)

用途

nn.MarginRankingLoss 主要用于学习排序任务,评估两个输入值的相对差异是否符合目标标签的排序。

公式

MarginRanking ( x 1 , x 2 , y ) = max ⁡ ( 0 , − y ⋅ ( x 1 − x 2 ) + margin ) \text{MarginRanking}(x_1, x_2, y) = \max(0, -y \cdot (x_1 - x_2) + \text{margin}) MarginRanking(x1,x2,y)=max(0,−y⋅(x1−x2)+margin)

其中, x 1 x_1 x1 和 x 2 x_2 x2 是两个输入值, y y y 是目标标签(+1 或 -1),margin 是定义的边距值。

典型应用场景

  • 信息检索和推荐系统,评估文档或物品的相关性排序。

实例代码

python 复制代码
import torch
import torch.nn as nn

loss_fn = nn.MarginRankingLoss(margin=1.0)
input1 = torch.tensor([0.8])
input2 = torch.tensor([0.3])
target = torch.tensor([1.0])  # 目标:input1应大于input2
loss = loss_fn(input1, input2, target)
print(f'Margin Ranking Loss: {loss.item()}')

6. 铰链嵌入损失(nn.HingeEmbeddingLoss)

用途

nn.HingeEmbeddingLoss 常用于支持向量机(SVM)以及其他涉及到嵌入学习的任务。

公式

HingeEmbedding ( y , y ^ ) = 1 n ∑ i = 1 n max ⁡ ( 0 , 1 − y i ⋅ y \^ i ) \text{HingeEmbedding}(y, \hat{y}) = \frac{1}{n} \sum_{i=1}^{n} \left \\max(0, 1 - y_i \\cdot \\hat{y}_i) \\right HingeEmbedding(y,y^)=n1i=1∑nmax(0,1−yi⋅y\^i)

其中, y i y_i yi 是目标标签(+1 或 -1), y ^ i \hat{y}_i y^i 是预测值。

典型应用场景

  • 用于二分类任务,特别是支持向量机模型中。

实例代码

python 复制代码
import torch
import torch.nn as nn

loss_fn = nn.HingeEmbeddingLoss()
input = torch.tensor([0.8, -0.5])
target = torch.tensor([1, -1])  # 标签是1或-1
loss = loss_fn(input, target)
print(f'Hinge Embedding Loss: {loss.item()}')

总结

损失函数在深度学习模型训练中起着至关重要的作用。通过选择合适的损失函数,可以有效引导模型优化目标,从而提高预测的准确性和可靠性。本文介绍了几种常见的 PyTorch 损失函数,包括它们的用途、公式、典型应用场景以及示例代码,希望对你在深度学习模型的构建和训练中有所帮助。

相关推荐
意图共鸣10 小时前
意图共鸣科技《AI记忆链商业化白皮书3.0》假设场景解析:从母亲到消防员,专属AI如何重塑记忆与传承
人工智能·科技·架构
ai产品老杨10 小时前
解耦安防碎片化:基于 Docker 与边缘计算的 AI 视频管理平台架构演进(附 GB28181/RTSP 统一接入与源码交付实践)
人工智能·docker·边缘计算
OpenAnolis小助手10 小时前
如何利用 AI Agent 实现热补丁的自动化生成
人工智能·安全·ai·操作系统·agent·龙蜥
米核AI易山10 小时前
扣子工作流项目交付全流程:从需求分析到上线维护的实战方法论
人工智能·需求分析·coze·扣子工作流·米核ai易山
沫儿笙10 小时前
弧焊机器人保护气智能节气阀
人工智能·机器人
DS随心转插件10 小时前
AI 导出鸭实操教程:Markdown 转 Word 高效协作与隐私交付实战指南
人工智能·ai·word·豆包·deepseek·ai导出鸭
腾讯云开发者10 小时前
探访香港科创高地,洞见 Agentic AI 时代的出海新范式
人工智能
产业家10 小时前
“绿算协同×Token工厂”新范式,润建股份探索出一个AI新样本
人工智能
暗夜猎手-大魔王10 小时前
hermes源码学习8-上下文压缩与缓存
人工智能·缓存
菜鸟‍11 小时前
【论文学习】Segment Anything 分割一切
深度学习·学习·计算机视觉