深度学习:Pytorch常见损失函数Loss简介

深度学习:Pytorch常见损失函数Loss简介

  • [L1 Loss](#L1 Loss)
  • [MSE Loss](#MSE Loss)
  • [SmoothL1 Loss](#SmoothL1 Loss)
  • [CrossEntropy Loss](#CrossEntropy Loss)
  • [Focal Loss](#Focal Loss)

此篇博客主要对深度学习中常用的损失函数进行介绍,并结合Pytorch的函数进行分析,讲解其用法。

L1 Loss

L1 Loss计算预测值和真值的平均绝对误差。

L o s s ( y , y ^ ) = ∣ y − y ^ ∣ Loss(y,\hat{y}) = |y-\hat{y}| Loss(y,y^)=∣y−y^∣

Pytorch函数:

python 复制代码
torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')

参数:

  • size_average (bool, optional) -- 此参数已弃用;
  • reduce (bool, optional) -- 此参数已弃用;
  • reduction (str, optional) -- 由以下三个参数选其一:'none' | 'mean' | 'sum'. 'none':不对各个元素的误差处理, 'mean':输出是各个元素误差的平均值,'sum':输出是将各个元素的误差求和。 默认:'mean'。

MSE Loss

MSE Loss计算预测值和真值的均方误差。

L o s s ( y , y ^ ) = ( y − y ^ ) 2 Loss(y,\hat{y}) = (y-\hat{y})^2 Loss(y,y^)=(y−y^)2

Pytorch函数:

python 复制代码
torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

参数:

  • size_average (bool, optional) -- 此参数已弃用。
  • reduce (bool, optional) -- 此参数已弃用。
  • reduction (str, optional) -- 由以下三个参数选其一:'none' | 'mean' | 'sum'. 'none':不对各个元素的误差处理, 'mean':输出是各个元素误差的平均值,'sum':输出是将各个元素的误差求和。 默认:'mean'。

SmoothL1 Loss

在训练初期,当预测值和真值相差较大时,损失函数的值较大,容易导致训练不稳定,为了防止梯度爆炸(梯度值是指损失函数对输入的导数,梯度爆炸是指梯度值很大),同时当预测值和真值相差较小时,梯度值足够小,可以使用SmoothL1 Loss,它可以视作L1 Loss和L2 Loss(MSE Loss)的结合,计算公式如下:

KaTeX parse error: No such environment: equation at position 8: \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲ Loss(y,\hat{y}...

Pytorch函数:

python 复制代码
torch.nn.SmoothL1Loss(size_average=None, reduce=None, reduction='mean', beta=1.0)

参数:

  • size_average (bool, optional) -- 此参数已弃用。
  • reduce (bool, optional) -- 此参数已弃用。
  • reduction (str, optional) -- 由以下三个参数选其一:'none' | 'mean' | 'sum'. 'none':不对各个元素的误差处理, 'mean':输出是各个元素误差的平均值,'sum':输出是将各个元素的误差求和。 默认:'mean'。
  • beta ( float ,optional) -- 指定 L1 Loss和 L2 Loss之间变化的阈值。该值必须是非负数。默认值:1.0

CrossEntropy Loss

CrossEntropy Loss是在处理分类问题中常用的一种损失函数,如二分类和多分类。此函数来源于信息论中的交叉熵概念,用于衡量两个预估概率分布和真实概率分布之间的差异。交叉熵损失函数公式如下:

(1)对于二分类问题:
L o s s ( y , y ^ ) = − ∑ i = 1 n ( y i l o g ( y i ^ ) + ( 1 − y i ) l o g ( 1 − y i ^ ) ) Loss(y,\hat{y}) = -\sum_{i=1}^{n}(y_ilog(\hat{y_i})+(1-y_i)log(1-\hat{y_i})) Loss(y,y^)=−i=1∑n(yilog(yi^)+(1−yi)log(1−yi^))

其中, y y y是真值, y ^ \hat{y} y^是预测值,n是样本的数量,每个样本都会计算一个损失,如果reduction是'mean',那么会对所有样本的损失求平均;如果reduction是'sum',那么会对所有样本的损失求和。

(2)对于多分类问题:
L o s s ( y , y ^ ) = − ∑ i = 1 n ∑ j = 1 m y i j l o g ( y i j ^ ) Loss(y,\hat{y}) = - \sum_{i=1}^{n}\sum_{j=1}^{m}y_{ij}log(\hat{y_{ij}}) Loss(y,y^)=−i=1∑nj=1∑myijlog(yij^)

其中, y i j y_{ij} yij是第i个样本的真实标签在第j类的概率, y i j ^ \hat{y_{ij}} yij^是第i个样本预测为第j类的概率,n是样本数量,m是类别的数量,每个样本都会计算一个损失,如果reduction是'mean',那么会对所有样本的损失求平均;如果reduction是'sum',那么会对所有样本的损失求和。

Pytorch函数:

python 复制代码
torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0)

参数:

  • weight (Tensor, optional) -- 为每个类指定的手动缩放权重。如果给定,则必须是大小为C的张量。
  • size_average (bool, optional) -- 此参数已弃用。
  • ignore_index (int, optional) -- 指定被忽略且不会对输入梯度产生影响的目标值。
  • reduce (bool, optional) -- 此参数已弃用。
  • reduction (str, optional) -- 由以下三个参数选其一:'none' | 'mean' | 'sum'. 'none':不对各个元素的误差处理, 'mean':输出是各个元素误差的平均值,'sum':输出是将各个元素的误差求和。 默认:'mean'。
  • label_smoothing (float, optional) -- [0.0, 1.0] 中的浮点数。指定计算损失时的平滑量,其中 0.0 表示不平滑。默认值: 0.0.

Focal Loss

Focal Loss主要用来处理正负样本(特别是前景和背景样本的分类)不均衡的问题。样本不均衡会导致训练效率低,甚至可能会导致模型退化。Focal Loss可以视为对CrossENtropy Loss增加权重加以平衡(增加预测概率小的样本权重,其对应的损失函数值变大;反而降低预测概率大的样本权重,其对应的损失函数值变小)。参考公式如下:
L o s s ( y , y ^ ) = − ∑ i = 1 n ∑ j = 1 m ( 1 − y i j ^ ) γ y i j l o g ( y i j ^ ) Loss(y,\hat{y}) = - \sum_{i=1}^{n}\sum_{j=1}^{m}(1-\hat{y_{ij}})^{\gamma}y_{ij}log(\hat{y_{ij}}) Loss(y,y^)=−i=1∑nj=1∑m(1−yij^)γyijlog(yij^)

其中, γ \gamma γ常取2.

相关推荐
zhangfeng113311 小时前
小龙虾 wordbuddy 安装浏览器控制器 agent-browser npm install -g agent-browse
前端·人工智能·npm·node.js
阿里云大数据AI技术11 小时前
一条 SQL 生成广告:Hologres 如何实现素材生成到投放分析一体化
人工智能·sql
liudanzhengxi11 小时前
GitSubmodule避坑全攻略
人工智能·新人首发
用户4252108006011 小时前
Claude Code Linux 服务器部署与配置
人工智能
OJAC11111 小时前
学过Python却不敢投AI岗,他最后拿下12K offer
人工智能
Bigger11 小时前
因为看不懂小棉袄的画,我写了个 AI 程序帮我“翻译”她的世界
前端·人工智能·ai编程
CeshirenTester11 小时前
LangChain的工具调用 vs 原生Skill API:性能差在哪儿?
java·人工智能·langchain
爱问的艾文11 小时前
八周带你手搓AI应用-第二周-让AI更像人-第1天-流式输出改造
人工智能
多年小白12 小时前
【周末消息面汇总】2026年5月10日(周日)
人工智能·科技·机器学习·ai·金融
码农小白AI12 小时前
宠物用品耐磨检测走向标准化新阶段:IACheck让AI报告审核更无忧更稳定
人工智能