PyTorch 基础学习(15)- 损失函数

系列文章:
《PyTorch 基础学习》文章索引

介绍

在深度学习中,损失函数(Loss Function)用于评估模型预测值与实际目标之间的差异,是模型训练的核心部分。在 PyTorch 中,损失函数通过 torch.nn 模块提供,有多种不同类型的损失函数可供选择,每种损失函数都有其特定的用途和应用场景。本文将详细介绍几种常见的 PyTorch 损失函数,包括它们的用途、公式、典型应用场景以及实例代码。

1. 均方误差损失(nn.MSELoss)

用途

nn.MSELoss 主要用于回归任务中,衡量模型预测值与真实值之间的差异。

公式

MSE ( y , y ^ ) = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \text{MSE}(y, \hat{y}) = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 MSE(y,y^)=n1i=1∑n(yi−y^i)2

其中, y i y_i yi 是真实值, y ^ i \hat{y}_i y^i 是预测值, n n n 是样本的数量。

典型应用场景

  • 预测连续数值,如房价预测、股票价格预测等。

实例代码

python 复制代码
import torch
import torch.nn as nn

loss_fn = nn.MSELoss()
input = torch.tensor([0.0, 0.5, 1.0])
target = torch.tensor([0.0, 1.0, 1.0])
loss = loss_fn(input, target)
print(f'MSE Loss: {loss.item()}')

2. 交叉熵损失(nn.CrossEntropyLoss)

用途

nn.CrossEntropyLoss 主要用于多分类任务,用于衡量预测类别分布与真实分布之间的差异。

公式

CrossEntropy ( y , y ^ ) = − 1 n ∑ i = 1 n ∑ c = 1 C y i , c log ⁡ ( y ^ i , c ) \text{CrossEntropy}(y, \hat{y}) = - \frac{1}{n} \sum_{i=1}^{n} \sum_{c=1}^{C} y_{i,c} \log(\hat{y}_{i,c}) CrossEntropy(y,y^)=−n1i=1∑nc=1∑Cyi,clog(y^i,c)

其中, y i , c y_{i,c} yi,c 是样本 i i i 属于类别 c c c 的真实概率, y ^ i , c \hat{y}_{i,c} y^i,c 是预测的概率, C C C 是类别数。

典型应用场景

  • 图像分类任务,如手写数字识别(MNIST)、图像分类(CIFAR-10)等。

实例代码

python 复制代码
import torch
import torch.nn as nn

loss_fn = nn.CrossEntropyLoss()
input = torch.tensor([[0.5, 1.0, 0.5], [0.3, 0.2, 0.5]])
target = torch.tensor([1, 2])
loss = loss_fn(input, target)
print(f'Cross Entropy Loss: {loss.item()}')

3. 二元交叉熵损失(nn.BCELoss)

用途

nn.BCELoss 主要用于二分类任务,用于衡量预测概率与真实标签之间的差异。

公式

BCE ( y , y ^ ) = − 1 n ∑ i = 1 n [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] \text{BCE}(y, \hat{y}) = - \frac{1}{n} \sum_{i=1}^{n} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right] BCE(y,y^)=−n1i=1∑n[yilog(y^i)+(1−yi)log(1−y^i)]

其中, y i y_i yi 是二分类的真实标签(0 或 1), y ^ i \hat{y}_i y^i 是预测的概率值。

典型应用场景

  • 二分类任务,如垃圾邮件检测、二元图像分类等。

实例代码

python 复制代码
import torch
import torch.nn as nn

loss_fn = nn.BCELoss()
input = torch.tensor([0.1, 0.9])
target = torch.tensor([0.0, 1.0])
loss = loss_fn(input, target)
print(f'BCE Loss: {loss.item()}')

4. 带有Logits的二元交叉熵损失(nn.BCEWithLogitsLoss)

用途

nn.BCEWithLogitsLoss 用于二分类任务,但与 nn.BCELoss 不同的是,输入值未经过 Sigmoid 函数。该函数在内部结合了 SigmoidBCELoss,可以更稳定地计算损失。

公式

BCEWithLogits ( y , z ) = 1 n ∑ i = 1 n [ max ⁡ ( z i , 0 ) − z i ⋅ y i + log ⁡ ( 1 + e − ∣ z i ∣ ) ] \text{BCEWithLogits}(y, z) = \frac{1}{n} \sum_{i=1}^{n} \left[ \max(z_i, 0) - z_i \cdot y_i + \log\left(1 + e^{-|z_i|}\right) \right] BCEWithLogits(y,z)=n1i=1∑n[max(zi,0)−zi⋅yi+log(1+e−∣zi∣)]

其中, z i z_i zi 是模型输出的未经过 Sigmoid 函数的值, y i y_i yi 是真实标签(0 或 1)。

典型应用场景

  • 二分类任务,特别是在需要直接处理 logits(未归一化预测值)时,如文本分类中的情感分析。

实例代码

python 复制代码
import torch
import torch.nn as nn

loss_fn = nn.BCEWithLogitsLoss()
input = torch.tensor([0.1, 0.9])
target = torch.tensor([0.0, 1.0])
loss = loss_fn(input, target)
print(f'BCE With Logits Loss: {loss.item()}')

5. 边际排名损失(nn.MarginRankingLoss)

用途

nn.MarginRankingLoss 主要用于学习排序任务,评估两个输入值的相对差异是否符合目标标签的排序。

公式

MarginRanking ( x 1 , x 2 , y ) = max ⁡ ( 0 , − y ⋅ ( x 1 − x 2 ) + margin ) \text{MarginRanking}(x_1, x_2, y) = \max(0, -y \cdot (x_1 - x_2) + \text{margin}) MarginRanking(x1,x2,y)=max(0,−y⋅(x1−x2)+margin)

其中, x 1 x_1 x1 和 x 2 x_2 x2 是两个输入值, y y y 是目标标签(+1 或 -1),margin 是定义的边距值。

典型应用场景

  • 信息检索和推荐系统,评估文档或物品的相关性排序。

实例代码

python 复制代码
import torch
import torch.nn as nn

loss_fn = nn.MarginRankingLoss(margin=1.0)
input1 = torch.tensor([0.8])
input2 = torch.tensor([0.3])
target = torch.tensor([1.0])  # 目标:input1应大于input2
loss = loss_fn(input1, input2, target)
print(f'Margin Ranking Loss: {loss.item()}')

6. 铰链嵌入损失(nn.HingeEmbeddingLoss)

用途

nn.HingeEmbeddingLoss 常用于支持向量机(SVM)以及其他涉及到嵌入学习的任务。

公式

HingeEmbedding ( y , y ^ ) = 1 n ∑ i = 1 n [ max ⁡ ( 0 , 1 − y i ⋅ y ^ i ) ] \text{HingeEmbedding}(y, \hat{y}) = \frac{1}{n} \sum_{i=1}^{n} \left[ \max(0, 1 - y_i \cdot \hat{y}_i) \right] HingeEmbedding(y,y^)=n1i=1∑n[max(0,1−yi⋅y^i)]

其中, y i y_i yi 是目标标签(+1 或 -1), y ^ i \hat{y}_i y^i 是预测值。

典型应用场景

  • 用于二分类任务,特别是支持向量机模型中。

实例代码

python 复制代码
import torch
import torch.nn as nn

loss_fn = nn.HingeEmbeddingLoss()
input = torch.tensor([0.8, -0.5])
target = torch.tensor([1, -1])  # 标签是1或-1
loss = loss_fn(input, target)
print(f'Hinge Embedding Loss: {loss.item()}')

总结

损失函数在深度学习模型训练中起着至关重要的作用。通过选择合适的损失函数,可以有效引导模型优化目标,从而提高预测的准确性和可靠性。本文介绍了几种常见的 PyTorch 损失函数,包括它们的用途、公式、典型应用场景以及示例代码,希望对你在深度学习模型的构建和训练中有所帮助。

相关推荐
Dovir多多7 分钟前
渗透测试入门学习——php文件上传与文件包含
前端·后端·学习·安全·web安全·php
专业白嫖怪8 分钟前
网络工程师学习笔记——网络互连与互联网
网络·笔记·学习·交换机·sdn可编程交换机
李小星同志25 分钟前
DroidBot-GPT: GPT-powered UI Automation for Android论文学习
gpt·学习
发光的小豆芽28 分钟前
TMStarget学习——T1 Segmentation数据处理及解bug
学习
shuxianshrng33 分钟前
鹰眼降尘模型
大数据·服务器·人工智能·经验分享·机器人
金智维科技官方44 分钟前
如何选择适合企业的高效财税自动化软件
大数据·人工智能·自动化
FL16238631291 小时前
[数据集][目标检测]高铁受电弓检测数据集VOC+YOLO格式1245张2类别
人工智能·yolo·目标检测
雅菲奥朗1 小时前
FinOps三人行:共话FinOps云成本管理与AI的未来在线分享(文字+视频)
人工智能·aigc·finops·云财务管理·云成本管理
GoppViper1 小时前
golang学习笔记24——golang微服务中配置管理问题的深度剖析
笔记·后端·学习·微服务·golang·配置管理
向往风的男子1 小时前
【从问题中去学习k8s】k8s中的常见面试题(夯实理论基础)(三十)
学习·容器·kubernetes