【深度学习】动态交叉熵损失函数Focal Loss

FocalLoss,由何恺明等人在2017年的论文《Focal Loss for Dense Object Detection》中提出。它是一种专门为解决,目标检测中类别不平衡和难易样本不平衡问题而设计的损失函数。当然,也是可以适用于图像分割。

本文将详细介绍FocalLoss的基本概念、思想原理,以及如何设置 FocalLoss中的关键参数,并提供PyTorch的实现代码。

交叉熵损失函数

Focalloss是基于二分类交叉熵,可以理解成一个动态缩放的交叉熵损失。即:

  • 通过一个动态放缩因子,可以在不同类别的样本数量极不平衡情况下,动态降低训练过程中数量极多的样本的权重,避免数量极少的样本被淹没

  • 通过一个动态缩放因子,可以动态降低训练过程中易区分样本的权重,从而将重心快速聚焦在那些难区分的样本

那就先回顾下交叉熵损失的定义。

在二分类任务中,一般使用 Sigmoid 作为最后的激活函数,输出代表样本为正的概率值 y ^ \hat y y^。二分类非正即负,所以样本为负的概率值为 1 − y ^ 1-\hat y 1−y^ 。二分类交叉熵损失的计算公式为:

B C E = − y ⋅ l n ( y ^ ) − ( 1 − y ) ⋅ l n ( 1 − y ^ ) BCE=-y\cdot ln(\hat y)-(1-y)\cdot ln(1-\hat y) BCE=−y⋅ln(y^)−(1−y)⋅ln(1−y^)

其中:

  • y y y:target标签,正样本为 1,负样本为 0

  • y ^ \hat y y^: Sigmoid 激活函数的输出值,表示模型预测为正样本的概率

在多分类情况下,一般使用 Softmax 作为最后的激活函数,输出有多个值,对应每个类别的概率值,且这些概率之和为 1。多分类交叉熵损失的计算公式为:

C E = − ∑ i = 0 N y i ⋅ l n ( y ^ i ) CE=-\sum_{i=0}^Ny_i\cdot ln(\hat y_i) CE=−i=0∑Nyi⋅ln(y^i)

其中:

  • N N N:类别的总数

  • y i y_i yi:target标签的 one-hot 编码,若样本属于第 i i i 类,则 y i = 1 y_i =1 yi=1,否则 y i = 0 y_i=0 yi=0

  • y ^ i \hat y_i y^i : Softmax 激活函数输出结果中第 i i i 类的概率

直观解释:

无论是二分类问题,还是多分类问题,如果目标标签是类别 k k k,那么此时的交叉熵可以简化成:

C E = − l n ( y ^ k ) = − l n ( 模型预测类别 k 的概率 ) CE=-ln(\hat y_k)=-ln(模型预测类别k的概率) CE=−ln(y^k)=−ln(模型预测类别k的概率)

更根本的就是,如果目标标签是什么,那么 C E CE CE就是模型预测成该目标标签的概率,经过 − l n -ln −ln的计算之后的结果。

关于交叉熵损失函数的具体讲解,这里就不赘述了。可以参考博文:

FocalLoss基本概念

样本数量不平衡

我们在做实际模型训练的时候,经常会遇到各类样本数量比例不平衡的情况。比如:对于二分类任务,有可能负样本的数量远远多于正样本的数量,导致模型更多关注在负样本上,从而忽略正样本。

因此,在使用交叉熵损失的时候,通常会增加一个平衡参数用来调节正负样本的比重。

引入样本平衡参数 α k ∈ [ 0 , 1 ] \alpha_k \in[0,1] αk∈[0,1]:

C E = − α k ⋅ l n ( y ^ k ) CE=-\alpha_k \cdot ln(\hat y_k) CE=−αk⋅ln(y^k)

如果每个 α k \alpha_k αk都相等,那么就可以理解为:退化成交叉熵了。

一般情况下,如果是二分类问题,我们保证 α 正样本 + α 负样本 = 1 \alpha_{正样本}+\alpha_{负样本}=1 α正样本+α负样本=1,因此只需要指定正样本的平衡参数 α \alpha α就行了,此时负样本的平衡参数就是 1 − α 1-\alpha 1−α。如果是多分类问题,则保证每个 α k ∈ [ 0 , 1 ] \alpha_k \in[0,1] αk∈[0,1]就行了。即, α \alpha α的类别总数就是target标签的类别总数。

也就是说,如果有 N N N种类别,那么 α \alpha α就有 N N N种值。每个样本的target标签是 k k k,那么它的平衡参数就是 α k \alpha_k αk。二分类, N = 2 N=2 N=2,就是 α \alpha α、 1 − α 1-\alpha 1−α;多分类, N N N,就是 α 1 \alpha_1 α1、 α 2 \alpha_2 α2、...、 α N \alpha_N αN。

如果类别 k k k的样本数量较多,那么可以适当的降低 α k \alpha_k αk的值,降低loss;如果类别 k k k的样本数量较少,那么可以适当的提高 α k \alpha_k αk的值,提高loss。也就是说, α k \alpha_k αk的值应该是样本数量成反比。

例如,对于一个二分类问题,假如正样本数量占0.2,负样本数量占0.8,那么 α 正样本 = 0.8 \alpha_{正样本}=0.8 α正样本=0.8、 α 负样本 = 0.2 \alpha_{负样本}=0.2 α负样本=0.2(当然也可以微调),以此来达到平衡正负样本的目的。这样理解,看来是没有问题的。

难易分类样本

我们在做实际模型训练的时候,也会遇到各种难分类样本、易分类样本。比如:当易分类样本超级多时,整个训练过程将会围绕着易分类样本进行,进而淹没难分类样本,造成大损失。

因此,在使用交叉熵损失的时候,通常会增加一个平衡参数用来调节难易分类样本的比重。

引入样本平衡参数 γ ∈ [ 0 , 5 ] \gamma \in [0,5] γ∈[0,5]:

C E = − ( 1 − y ^ k ) γ ⋅ l n ( y ^ k ) CE=-(1-\hat y_k)^{\gamma}\cdot ln(\hat y_k) CE=−(1−y^k)γ⋅ln(y^k)

如果 γ = 0 \gamma=0 γ=0,那么就退化成交叉熵了。

当 γ ≠ 0 \gamma\neq 0 γ=0时,如果类别 k k k的模型预测概率 y ^ k \hat y_k y^k较高,即易分类样本;如果类别 k k k的模型预测概率 y ^ k \hat y_k y^k较低,即难分类样本;那么该平衡参数对易分类的loss降低幅度,比对难分类的loss降低幅度大很多。如果 γ \gamma γ越大,两者的幅度差距会更大。这样理解,看来是没有问题的。

FocalLoss

通过以上针对样本数量不平衡以及难易分类样本,可以得到应该最终的FocalLoss形式:

F L = − α k ( 1 − y ^ k ) γ ⋅ l n ( y ^ k ) = − α k ( 1 − p k ) γ ⋅ l n ( p k ) \begin{aligned}FL&=-\alpha_k(1-\hat y_k)^{\gamma}\cdot ln(\hat y_k)\\&=-\alpha_k(1-p_k)^\gamma\cdot ln(p_k)\end{aligned} FL=−αk(1−y^k)γ⋅ln(y^k)=−αk(1−pk)γ⋅ln(pk)

即:即通过 α k \alpha_k αk可以抑制样本的数量失衡,通过 γ \gamma γ可以控制难易分类样本失衡。

FocalLoss关键点解释

接下来,会对FocalLoss形式中的计算结果进行关键的的分析和解释:

F L = − α k ( 1 − p k ) γ ⋅ l n ( p k ) FL=-\alpha_k(1-p_k)^\gamma\cdot ln(p_k) FL=−αk(1−pk)γ⋅ln(pk)

p k p_k pk

p k p_k pk就是交叉熵里面的 y ^ k \hat y_k y^k,即模型预测类别k的概率。

对于二分类的情形,如果target标签是正样本,那 p k p_k pk就是模型预测正样本的概率;如果target标签是负样本,那 p k p_k pk就是模型预测负样本的概率。

对于多分类的情形, p k p_k pk就是模型预测target标签的概率。

也就是说, p k p_k pk的值代表了模型对样本预测正确的概率。因此, p k p_k pk 的大小实际能反映出样本难易分类的程度。

如果 p k p_k pk的值越逼近1,那么该样本就比较容易分类;如果 p k p_k pk的值越逼近0,那么该样本就不容易分类。

甚至可以这么粗暴的认为:

  • 易分类样本:模型预测正确的概率较高,即 p k p_k pk 较大(通常 p k > 0.5 p_k>0.5 pk>0.5)

  • 难分类样本:模型预测正确的概率较低,即 p k p_k pk 较小(通常 p k < 0.5 p_k<0.5 pk<0.5)

超参数 γ \gamma γ

上面说过, p k p_k pk代表了样本难易分类的程度。在训练模型的时候,我们希望模型更加关注难分类样本,所以会考虑将难分类样本在损失函数中的比重加大。

作者在原始的二分类交叉函数中增加了一项 ( 1 − p k ) γ (1-p_k)^{\gamma} (1−pk)γ,对原始交叉熵损失做了衰减。

经过对 p k p_k pk的分析可知,难分类样本的 p k p_k pk 值小, 1 − p k 1-p_k 1−pk 大;易分类样本的 p k p_k pk 值大, 1 − p k 1-p_k 1−pk 值小。尽管,无论是难分类还是易分类样本,FocalLoss相对于原始的交叉熵都做了衰减,但是难分类样本相对于易分类样本衰减得更轻微。

也就是说,从相对的角度上看,对于易分类样本是衰减的,对于难分类样本是增强的。

而这里的超参数 γ \gamma γ,则决定了两者相对的衰减/增强的属性, γ \gamma γ越大,相对性越明显。即,当 γ \gamma γ增大时,对易分类样本的损失衰减更加明显,从而使模型更加关注难分类样本。

如果不是很好理解,我们用一个例子说明:

对于一个二分类情形,我们取 {\\gamma}=2,当target样本为正样本时:

  • 如果模型预测正样本的 p = 0.3 p = 0.3 p=0.3,表明是难分类样本,则 p k = 0.3 p_k = 0.3 pk=0.3, ( 1 − p k ) γ = ( 1 − 0.3 ) 2 = 0.49 ( 1 − p_k )^{\gamma} = ( 1 − 0.3 )^ 2 = 0.49 (1−pk)γ=(1−0.3)2=0.49,相当于原始的CE Loss的0.49倍

  • 如果模型预测正样本的 p = 0.7 p = 0.7 p=0.7,表明是易分类样本,则 p k = 0.7 p_k = 0.7 pk=0.7, ( 1 − p k ) γ = ( 1 − 0.7 ) 2 = 0.09 ( 1 − p_k )^{\gamma} = ( 1 − 0.7 )^ 2 = 0.09 (1−pk)γ=(1−0.7)2=0.09,相当于原始的CE Loss的0.09倍

也就是说,尽管都是衰减,但是难分类样本衰减成了0.49,但是易分类样本衰减成了0.09。可以看出,这样会导致模型更关注难分类样本。

超参数 α \alpha α

上面说过, α \alpha α用于缓解各类样本数量比例不平衡的情况。这个比较容易理解,如果某类的样本数量很多,那么就给它更小的平衡参数,让模型更关注样本数量较少的样本。

论文作者指出:加入 α \alpha α平衡参数比不加时精度有所提升。并且给出了实验参数,在作者的实验中当 {\\alpha}=0.25 , , , {\\gamma}=2时精度最高。

这时就有一个问题了, α \alpha α代表计算损失时对应正样本的调节权重,而正样本数量一般要小于负样本的数量,所以理论上正样本的权重应该大于负样本的权重的。但是作者实验中最佳的正样本权重( {\\alpha}=0.25 )为啥比负样本权重( )为啥比负样本权重( )为啥比负样本权重( 1-{\\alpha}=0.75)还要低呢?

明明负样本的数量已经远远大于正样本的数量了,为啥还要增加损失函数中负样本的比重呢?这不是矛盾吗?

其实,作者在论文里给出了解释。

即: α \alpha α和 γ \gamma γ如果是单独使用,自然是符合我们分析的结果的。但是在FocalLoss中,两者是混合使用的,它们之间也是有影响的主次关系的。一般来说, γ \gamma γ是占主导地位的,随着 {\\gamma} 的增大, 的增大, 的增大,{\\alpha}要相应的减小。

FocalLoss梯度分析

根据FocalLoss的公式:

F L = − α k ( 1 − p k ) γ ⋅ l n ( p k ) FL=-\alpha_k(1-p_k)^\gamma\cdot ln(p_k) FL=−αk(1−pk)γ⋅ln(pk)

对FocalLoss进行求导:

∂ ( F L ) ∂ p k = − α k ⋅ γ ⋅ ( 1 − p k ) γ − 1 l n ( p k ) − α k ⋅ ( 1 − p k ) γ ⋅ 1 p k \frac{\partial (FL)}{\partial p_k}=-\alpha_k\cdot\gamma\cdot(1-p_k)^{\gamma-1}ln(p_k)-\alpha_k\cdot(1-p_k)^{\gamma}\cdot\frac{1}{p_k} ∂pk∂(FL)=−αk⋅γ⋅(1−pk)γ−1ln(pk)−αk⋅(1−pk)γ⋅pk1

可以看出, p t p_{t} pt接近1时,FocalLoss的梯度趋于0, p t p_{t} pt靠近0,FocalLoss的梯度越来越大。

也就是说,那么预测值 p t p_t pt和真实值target非常接近的时候,梯度极小,网络参数几乎不变,当预测值 p t p_t pt和真实值target差距较大时,梯度变大,网络参数开始调整。

小Tips

  1. 在使用FocalLoss时, γ \gamma γ占主导因素,同时 γ \gamma γ是用来控制难易分类样本的,所以当数据中难分类样本较多时, γ \gamma γ可以设置的大一些

  2. 如果要分类的目标特征比较明显(建筑、道路),最好不要用FocalLoss

代码实战

二分类问题,此时 α \alpha α指定正样本的平衡系数即可,此时负样本的平衡系数就是 1 − α 1-\alpha 1−α, γ \gamma γ本身就是一个标量。

python 复制代码
class BinaryFocalLoss(torch.nn.Module):
  """
  二分类BinaryFocalLoss
  """

  def __init__(self, alpha=0.25, gamma=2):
    super(BinaryFocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma

  def forward(self, preds, labels):
    """
    preds: sigmoid的输出结果
    labels: 标签, 1为正样本, 0为负样本
    """
    eps = 1e-7
    loss_1 = -1 * self.alpha * torch.pow((1 - preds), self.gamma) * torch.log(preds + eps) * labels
    loss_0 = -1 * (1 - self.alpha) * torch.pow(preds, self.gamma) * torch.log(1 - preds + eps) * (1 - labels)
    loss = loss_0 + loss_1
    return torch.mean(loss)

多分类问题,此时 α \alpha α是一个向量,大小为所有样本的长度,表示每个样本所属的target类别的平衡系数, γ \gamma γ本身就是一个标量。

python 复制代码
class FocalLoss(torch.nn.Module):
  """
  多分类FocalLoss
  """

  def __init__(self, alpha, gamma=2):
    super(FocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma

  def forward(self, preds, labels):
    """
    preds: 模型的输出, 未经过 Softmax, 形状为 [N, C]
    labels: 标签, 形状为 [N], 取值为类别的索引
    """
    eps = 1e-7
    # 计算 Log-Softmax
    preds_logsoft = F.log_softmax(preds, dim=1)
    # 计算 softmax
    preds_softmax = torch.exp(preds_logsoft)
    # unsqueeze:将 labels 的形状从 [N] 转换为 [N, 1],便于与 preds_softmax 对应
    # gather:根据 labels 提取每个样本真实类别对应的概率值,结果形状为 [N, 1]
    preds_softmax = preds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1)
    preds_logsoft = preds_logsoft.gather(1, labels.unsqueeze(1)).squeeze(1)

    # 计算 focal loss
    loss = -self.alpha * torch.pow(1 - preds_softmax, self.gamma) * preds_logsoft
    return loss.mean()

这边对代码preds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1)分析:

python 复制代码
import torch

# 假设 preds_softmax 是模型的 Softmax 输出,形状为 [N, C]
preds_softmax = torch.tensor([
    [0.1, 0.7, 0.2],  # 样本 1 的类别概率分布
    [0.3, 0.4, 0.3],  # 样本 2 的类别概率分布
    [0.2, 0.5, 0.3]   # 样本 3 的类别概率分布
])

# 假设 labels 是目标标签,形状为 [N]
labels = torch.tensor([1, 2, 0])  # 样本 1 属于类别 1,样本 2 属于类别 2,样本 3 属于类别 0

# 使用 gather 提取每个样本真实类别的概率
labels_unsqueezed = labels.unsqueeze(1)  # 将 labels 的形状从 [N] 转换为 [N, 1]
print("labels.unsqueeze(1):")
print(labels_unsqueezed)

# 根据 labels 提取 preds_softmax 中对应的概率
gathered_probs = preds_softmax.gather(1, labels_unsqueezed)  # 结果形状为 [N, 1]
print("\npreds_softmax.gather(1, labels.unsqueeze(1)):")
print(gathered_probs)

# 使用 squeeze(1) 去掉第 1 维,结果形状为 [N]
final_probs = gathered_probs.squeeze(1)
print("\npreds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1):")
print(final_probs)

输出结果:

复制代码
labels.unsqueeze(1):
tensor([[1],
        [2],
        [0]])

preds_softmax.gather(1, labels.unsqueeze(1)):
tensor([[0.7000],
        [0.3000],
        [0.2000]])

preds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1):
tensor([0.7000, 0.3000, 0.2000])

需要注意的是,这边的代码写法应该与输入的shape、参数的shape有关系的,需要针对于具体的情形进行适配和修改。

相关阅读

相关推荐
老蒋每日coding2 小时前
AI Agent 设计模式系列(十四)—— 知识检索(RAG)模式
人工智能·设计模式·langchain
劈星斩月2 小时前
3Blue1Brown-深度学习之神经网络
人工智能·深度学习·神经网络
顾北122 小时前
RAG 入门到实战:Spring AI 搭建旅游问答知识库(本地 + 阿里云百炼双方案)
java·人工智能·阿里云
云雾J视界2 小时前
AI服务器供电革命:为何交错并联Buck成为算力时代的必然选择
服务器·人工智能·nvidia·算力·buck·dgx·交错并联
阳艳讲ai2 小时前
九尾狐AI:重构企业AI生产力的实战革命
大数据·人工智能
大势智慧2 小时前
大势智慧与土耳其合作发展中心、蕾奥规划签署土耳其智慧城市项目战略合作协议
人工智能·ai·智慧城市·三维建模·实景三维·发展趋势·创新
爱看科技2 小时前
苹果Siri或升级机器人“CAMPOS”亮相,微美全息加速AI与机器人结合培育动能
人工智能·microsoft·机器人
Nowl2 小时前
基于langchain的个人情感陪伴agent
人工智能·机器学习·langchain
UI设计兰亭妙微2 小时前
零售门店选址评估小程序界面设计
人工智能·小程序·零售