【深度学习】动态交叉熵损失函数Focal Loss

FocalLoss,由何恺明等人在2017年的论文《Focal Loss for Dense Object Detection》中提出。它是一种专门为解决,目标检测中类别不平衡和难易样本不平衡问题而设计的损失函数。当然,也是可以适用于图像分割。

本文将详细介绍FocalLoss的基本概念、思想原理,以及如何设置 FocalLoss中的关键参数,并提供PyTorch的实现代码。

交叉熵损失函数

Focalloss是基于二分类交叉熵,可以理解成一个动态缩放的交叉熵损失。即:

  • 通过一个动态放缩因子,可以在不同类别的样本数量极不平衡情况下,动态降低训练过程中数量极多的样本的权重,避免数量极少的样本被淹没

  • 通过一个动态缩放因子,可以动态降低训练过程中易区分样本的权重,从而将重心快速聚焦在那些难区分的样本

那就先回顾下交叉熵损失的定义。

在二分类任务中,一般使用 Sigmoid 作为最后的激活函数,输出代表样本为正的概率值 y ^ \hat y y^。二分类非正即负,所以样本为负的概率值为 1 − y ^ 1-\hat y 1−y^ 。二分类交叉熵损失的计算公式为:

B C E = − y ⋅ l n ( y ^ ) − ( 1 − y ) ⋅ l n ( 1 − y ^ ) BCE=-y\cdot ln(\hat y)-(1-y)\cdot ln(1-\hat y) BCE=−y⋅ln(y^)−(1−y)⋅ln(1−y^)

其中:

  • y y y:target标签,正样本为 1,负样本为 0

  • y ^ \hat y y^: Sigmoid 激活函数的输出值,表示模型预测为正样本的概率

在多分类情况下,一般使用 Softmax 作为最后的激活函数,输出有多个值,对应每个类别的概率值,且这些概率之和为 1。多分类交叉熵损失的计算公式为:

C E = − ∑ i = 0 N y i ⋅ l n ( y ^ i ) CE=-\sum_{i=0}^Ny_i\cdot ln(\hat y_i) CE=−i=0∑Nyi⋅ln(y^i)

其中:

  • N N N:类别的总数

  • y i y_i yi:target标签的 one-hot 编码,若样本属于第 i i i 类,则 y i = 1 y_i =1 yi=1,否则 y i = 0 y_i=0 yi=0

  • y ^ i \hat y_i y^i : Softmax 激活函数输出结果中第 i i i 类的概率

直观解释:

无论是二分类问题,还是多分类问题,如果目标标签是类别 k k k,那么此时的交叉熵可以简化成:

C E = − l n ( y ^ k ) = − l n ( 模型预测类别 k 的概率 ) CE=-ln(\hat y_k)=-ln(模型预测类别k的概率) CE=−ln(y^k)=−ln(模型预测类别k的概率)

更根本的就是,如果目标标签是什么,那么 C E CE CE就是模型预测成该目标标签的概率,经过 − l n -ln −ln的计算之后的结果。

关于交叉熵损失函数的具体讲解,这里就不赘述了。可以参考博文:

FocalLoss基本概念

样本数量不平衡

我们在做实际模型训练的时候,经常会遇到各类样本数量比例不平衡的情况。比如:对于二分类任务,有可能负样本的数量远远多于正样本的数量,导致模型更多关注在负样本上,从而忽略正样本。

因此,在使用交叉熵损失的时候,通常会增加一个平衡参数用来调节正负样本的比重。

引入样本平衡参数 α k ∈ [ 0 , 1 ] \alpha_k \in[0,1] αk∈[0,1]:

C E = − α k ⋅ l n ( y ^ k ) CE=-\alpha_k \cdot ln(\hat y_k) CE=−αk⋅ln(y^k)

如果每个 α k \alpha_k αk都相等,那么就可以理解为:退化成交叉熵了。

一般情况下,如果是二分类问题,我们保证 α 正样本 + α 负样本 = 1 \alpha_{正样本}+\alpha_{负样本}=1 α正样本+α负样本=1,因此只需要指定正样本的平衡参数 α \alpha α就行了,此时负样本的平衡参数就是 1 − α 1-\alpha 1−α。如果是多分类问题,则保证每个 α k ∈ [ 0 , 1 ] \alpha_k \in[0,1] αk∈[0,1]就行了。即, α \alpha α的类别总数就是target标签的类别总数。

也就是说,如果有 N N N种类别,那么 α \alpha α就有 N N N种值。每个样本的target标签是 k k k,那么它的平衡参数就是 α k \alpha_k αk。二分类, N = 2 N=2 N=2,就是 α \alpha α、 1 − α 1-\alpha 1−α;多分类, N N N,就是 α 1 \alpha_1 α1、 α 2 \alpha_2 α2、...、 α N \alpha_N αN。

如果类别 k k k的样本数量较多,那么可以适当的降低 α k \alpha_k αk的值,降低loss;如果类别 k k k的样本数量较少,那么可以适当的提高 α k \alpha_k αk的值,提高loss。也就是说, α k \alpha_k αk的值应该是样本数量成反比。

例如,对于一个二分类问题,假如正样本数量占0.2,负样本数量占0.8,那么 α 正样本 = 0.8 \alpha_{正样本}=0.8 α正样本=0.8、 α 负样本 = 0.2 \alpha_{负样本}=0.2 α负样本=0.2(当然也可以微调),以此来达到平衡正负样本的目的。这样理解,看来是没有问题的。

难易分类样本

我们在做实际模型训练的时候,也会遇到各种难分类样本、易分类样本。比如:当易分类样本超级多时,整个训练过程将会围绕着易分类样本进行,进而淹没难分类样本,造成大损失。

因此,在使用交叉熵损失的时候,通常会增加一个平衡参数用来调节难易分类样本的比重。

引入样本平衡参数 γ ∈ [ 0 , 5 ] \gamma \in [0,5] γ∈[0,5]:

C E = − ( 1 − y ^ k ) γ ⋅ l n ( y ^ k ) CE=-(1-\hat y_k)^{\gamma}\cdot ln(\hat y_k) CE=−(1−y^k)γ⋅ln(y^k)

如果 γ = 0 \gamma=0 γ=0,那么就退化成交叉熵了。

当 γ ≠ 0 \gamma\neq 0 γ=0时,如果类别 k k k的模型预测概率 y ^ k \hat y_k y^k较高,即易分类样本;如果类别 k k k的模型预测概率 y ^ k \hat y_k y^k较低,即难分类样本;那么该平衡参数对易分类的loss降低幅度,比对难分类的loss降低幅度大很多。如果 γ \gamma γ越大,两者的幅度差距会更大。这样理解,看来是没有问题的。

FocalLoss

通过以上针对样本数量不平衡以及难易分类样本,可以得到应该最终的FocalLoss形式:

F L = − α k ( 1 − y ^ k ) γ ⋅ l n ( y ^ k ) = − α k ( 1 − p k ) γ ⋅ l n ( p k ) \begin{aligned}FL&=-\alpha_k(1-\hat y_k)^{\gamma}\cdot ln(\hat y_k)\\&=-\alpha_k(1-p_k)^\gamma\cdot ln(p_k)\end{aligned} FL=−αk(1−y^k)γ⋅ln(y^k)=−αk(1−pk)γ⋅ln(pk)

即:即通过 α k \alpha_k αk可以抑制样本的数量失衡,通过 γ \gamma γ可以控制难易分类样本失衡。

FocalLoss关键点解释

接下来,会对FocalLoss形式中的计算结果进行关键的的分析和解释:

F L = − α k ( 1 − p k ) γ ⋅ l n ( p k ) FL=-\alpha_k(1-p_k)^\gamma\cdot ln(p_k) FL=−αk(1−pk)γ⋅ln(pk)

p k p_k pk

p k p_k pk就是交叉熵里面的 y ^ k \hat y_k y^k,即模型预测类别k的概率。

对于二分类的情形,如果target标签是正样本,那 p k p_k pk就是模型预测正样本的概率;如果target标签是负样本,那 p k p_k pk就是模型预测负样本的概率。

对于多分类的情形, p k p_k pk就是模型预测target标签的概率。

也就是说, p k p_k pk的值代表了模型对样本预测正确的概率。因此, p k p_k pk 的大小实际能反映出样本难易分类的程度。

如果 p k p_k pk的值越逼近1,那么该样本就比较容易分类;如果 p k p_k pk的值越逼近0,那么该样本就不容易分类。

甚至可以这么粗暴的认为:

  • 易分类样本:模型预测正确的概率较高,即 p k p_k pk 较大(通常 p k > 0.5 p_k>0.5 pk>0.5)

  • 难分类样本:模型预测正确的概率较低,即 p k p_k pk 较小(通常 p k < 0.5 p_k<0.5 pk<0.5)

超参数 γ \gamma γ

上面说过, p k p_k pk代表了样本难易分类的程度。在训练模型的时候,我们希望模型更加关注难分类样本,所以会考虑将难分类样本在损失函数中的比重加大。

作者在原始的二分类交叉函数中增加了一项 ( 1 − p k ) γ (1-p_k)^{\gamma} (1−pk)γ,对原始交叉熵损失做了衰减。

经过对 p k p_k pk的分析可知,难分类样本的 p k p_k pk 值小, 1 − p k 1-p_k 1−pk 大;易分类样本的 p k p_k pk 值大, 1 − p k 1-p_k 1−pk 值小。尽管,无论是难分类还是易分类样本,FocalLoss相对于原始的交叉熵都做了衰减,但是难分类样本相对于易分类样本衰减得更轻微。

也就是说,从相对的角度上看,对于易分类样本是衰减的,对于难分类样本是增强的。

而这里的超参数 γ \gamma γ,则决定了两者相对的衰减/增强的属性, γ \gamma γ越大,相对性越明显。即,当 γ \gamma γ增大时,对易分类样本的损失衰减更加明显,从而使模型更加关注难分类样本。

如果不是很好理解,我们用一个例子说明:

对于一个二分类情形,我们取 {\\gamma}=2,当target样本为正样本时:

  • 如果模型预测正样本的 p = 0.3 p = 0.3 p=0.3,表明是难分类样本,则 p k = 0.3 p_k = 0.3 pk=0.3, ( 1 − p k ) γ = ( 1 − 0.3 ) 2 = 0.49 ( 1 − p_k )^{\gamma} = ( 1 − 0.3 )^ 2 = 0.49 (1−pk)γ=(1−0.3)2=0.49,相当于原始的CE Loss的0.49倍

  • 如果模型预测正样本的 p = 0.7 p = 0.7 p=0.7,表明是易分类样本,则 p k = 0.7 p_k = 0.7 pk=0.7, ( 1 − p k ) γ = ( 1 − 0.7 ) 2 = 0.09 ( 1 − p_k )^{\gamma} = ( 1 − 0.7 )^ 2 = 0.09 (1−pk)γ=(1−0.7)2=0.09,相当于原始的CE Loss的0.09倍

也就是说,尽管都是衰减,但是难分类样本衰减成了0.49,但是易分类样本衰减成了0.09。可以看出,这样会导致模型更关注难分类样本。

超参数 α \alpha α

上面说过, α \alpha α用于缓解各类样本数量比例不平衡的情况。这个比较容易理解,如果某类的样本数量很多,那么就给它更小的平衡参数,让模型更关注样本数量较少的样本。

论文作者指出:加入 α \alpha α平衡参数比不加时精度有所提升。并且给出了实验参数,在作者的实验中当 {\\alpha}=0.25 , , , {\\gamma}=2时精度最高。

这时就有一个问题了, α \alpha α代表计算损失时对应正样本的调节权重,而正样本数量一般要小于负样本的数量,所以理论上正样本的权重应该大于负样本的权重的。但是作者实验中最佳的正样本权重( {\\alpha}=0.25 )为啥比负样本权重( )为啥比负样本权重( )为啥比负样本权重( 1-{\\alpha}=0.75)还要低呢?

明明负样本的数量已经远远大于正样本的数量了,为啥还要增加损失函数中负样本的比重呢?这不是矛盾吗?

其实,作者在论文里给出了解释。

即: α \alpha α和 γ \gamma γ如果是单独使用,自然是符合我们分析的结果的。但是在FocalLoss中,两者是混合使用的,它们之间也是有影响的主次关系的。一般来说, γ \gamma γ是占主导地位的,随着 {\\gamma} 的增大, 的增大, 的增大,{\\alpha}要相应的减小。

FocalLoss梯度分析

根据FocalLoss的公式:

F L = − α k ( 1 − p k ) γ ⋅ l n ( p k ) FL=-\alpha_k(1-p_k)^\gamma\cdot ln(p_k) FL=−αk(1−pk)γ⋅ln(pk)

对FocalLoss进行求导:

∂ ( F L ) ∂ p k = − α k ⋅ γ ⋅ ( 1 − p k ) γ − 1 l n ( p k ) − α k ⋅ ( 1 − p k ) γ ⋅ 1 p k \frac{\partial (FL)}{\partial p_k}=-\alpha_k\cdot\gamma\cdot(1-p_k)^{\gamma-1}ln(p_k)-\alpha_k\cdot(1-p_k)^{\gamma}\cdot\frac{1}{p_k} ∂pk∂(FL)=−αk⋅γ⋅(1−pk)γ−1ln(pk)−αk⋅(1−pk)γ⋅pk1

可以看出, p t p_{t} pt接近1时,FocalLoss的梯度趋于0, p t p_{t} pt靠近0,FocalLoss的梯度越来越大。

也就是说,那么预测值 p t p_t pt和真实值target非常接近的时候,梯度极小,网络参数几乎不变,当预测值 p t p_t pt和真实值target差距较大时,梯度变大,网络参数开始调整。

小Tips

  1. 在使用FocalLoss时, γ \gamma γ占主导因素,同时 γ \gamma γ是用来控制难易分类样本的,所以当数据中难分类样本较多时, γ \gamma γ可以设置的大一些

  2. 如果要分类的目标特征比较明显(建筑、道路),最好不要用FocalLoss

代码实战

二分类问题,此时 α \alpha α指定正样本的平衡系数即可,此时负样本的平衡系数就是 1 − α 1-\alpha 1−α, γ \gamma γ本身就是一个标量。

python 复制代码
class BinaryFocalLoss(torch.nn.Module):
  """
  二分类BinaryFocalLoss
  """

  def __init__(self, alpha=0.25, gamma=2):
    super(BinaryFocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma

  def forward(self, preds, labels):
    """
    preds: sigmoid的输出结果
    labels: 标签, 1为正样本, 0为负样本
    """
    eps = 1e-7
    loss_1 = -1 * self.alpha * torch.pow((1 - preds), self.gamma) * torch.log(preds + eps) * labels
    loss_0 = -1 * (1 - self.alpha) * torch.pow(preds, self.gamma) * torch.log(1 - preds + eps) * (1 - labels)
    loss = loss_0 + loss_1
    return torch.mean(loss)

多分类问题,此时 α \alpha α是一个向量,大小为所有样本的长度,表示每个样本所属的target类别的平衡系数, γ \gamma γ本身就是一个标量。

python 复制代码
class FocalLoss(torch.nn.Module):
  """
  多分类FocalLoss
  """

  def __init__(self, alpha, gamma=2):
    super(FocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma

  def forward(self, preds, labels):
    """
    preds: 模型的输出, 未经过 Softmax, 形状为 [N, C]
    labels: 标签, 形状为 [N], 取值为类别的索引
    """
    eps = 1e-7
    # 计算 Log-Softmax
    preds_logsoft = F.log_softmax(preds, dim=1)
    # 计算 softmax
    preds_softmax = torch.exp(preds_logsoft)
    # unsqueeze:将 labels 的形状从 [N] 转换为 [N, 1],便于与 preds_softmax 对应
    # gather:根据 labels 提取每个样本真实类别对应的概率值,结果形状为 [N, 1]
    preds_softmax = preds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1)
    preds_logsoft = preds_logsoft.gather(1, labels.unsqueeze(1)).squeeze(1)

    # 计算 focal loss
    loss = -self.alpha * torch.pow(1 - preds_softmax, self.gamma) * preds_logsoft
    return loss.mean()

这边对代码preds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1)分析:

python 复制代码
import torch

# 假设 preds_softmax 是模型的 Softmax 输出,形状为 [N, C]
preds_softmax = torch.tensor([
    [0.1, 0.7, 0.2],  # 样本 1 的类别概率分布
    [0.3, 0.4, 0.3],  # 样本 2 的类别概率分布
    [0.2, 0.5, 0.3]   # 样本 3 的类别概率分布
])

# 假设 labels 是目标标签,形状为 [N]
labels = torch.tensor([1, 2, 0])  # 样本 1 属于类别 1,样本 2 属于类别 2,样本 3 属于类别 0

# 使用 gather 提取每个样本真实类别的概率
labels_unsqueezed = labels.unsqueeze(1)  # 将 labels 的形状从 [N] 转换为 [N, 1]
print("labels.unsqueeze(1):")
print(labels_unsqueezed)

# 根据 labels 提取 preds_softmax 中对应的概率
gathered_probs = preds_softmax.gather(1, labels_unsqueezed)  # 结果形状为 [N, 1]
print("\npreds_softmax.gather(1, labels.unsqueeze(1)):")
print(gathered_probs)

# 使用 squeeze(1) 去掉第 1 维,结果形状为 [N]
final_probs = gathered_probs.squeeze(1)
print("\npreds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1):")
print(final_probs)

输出结果:

复制代码
labels.unsqueeze(1):
tensor([[1],
        [2],
        [0]])

preds_softmax.gather(1, labels.unsqueeze(1)):
tensor([[0.7000],
        [0.3000],
        [0.2000]])

preds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1):
tensor([0.7000, 0.3000, 0.2000])

需要注意的是,这边的代码写法应该与输入的shape、参数的shape有关系的,需要针对于具体的情形进行适配和修改。

相关阅读

相关推荐
AngelPP8 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年8 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼8 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS8 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区9 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈10 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang10 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx
shengjk111 小时前
NanoClaw 深度剖析:一个"AI 原生"架构的个人助手是如何运转的?
人工智能
西门老铁13 小时前
🦞OpenClaw 让 MacMini 脱销了,而我拿出了6年陈的安卓机
人工智能