FocalLoss,由何恺明等人在2017年的论文《Focal Loss for Dense Object Detection》中提出。它是一种专门为解决,目标检测中类别不平衡和难易样本不平衡问题而设计的损失函数。当然,也是可以适用于图像分割。
本文将详细介绍FocalLoss的基本概念、思想原理,以及如何设置 FocalLoss中的关键参数,并提供PyTorch的实现代码。
交叉熵损失函数
Focalloss是基于二分类交叉熵,可以理解成一个动态缩放的交叉熵损失。即:
-
通过一个动态放缩因子,可以在不同类别的样本数量极不平衡情况下,动态降低训练过程中数量极多的样本的权重,避免数量极少的样本被淹没
-
通过一个动态缩放因子,可以动态降低训练过程中易区分样本的权重,从而将重心快速聚焦在那些难区分的样本
那就先回顾下交叉熵损失的定义。
在二分类任务中,一般使用 Sigmoid 作为最后的激活函数,输出代表样本为正的概率值 y ^ \hat y y^。二分类非正即负,所以样本为负的概率值为 1 − y ^ 1-\hat y 1−y^ 。二分类交叉熵损失的计算公式为:
B C E = − y ⋅ l n ( y ^ ) − ( 1 − y ) ⋅ l n ( 1 − y ^ ) BCE=-y\cdot ln(\hat y)-(1-y)\cdot ln(1-\hat y) BCE=−y⋅ln(y^)−(1−y)⋅ln(1−y^)
其中:
-
y y y:target标签,正样本为 1,负样本为 0
-
y ^ \hat y y^: Sigmoid 激活函数的输出值,表示模型预测为正样本的概率
在多分类情况下,一般使用 Softmax 作为最后的激活函数,输出有多个值,对应每个类别的概率值,且这些概率之和为 1。多分类交叉熵损失的计算公式为:
C E = − ∑ i = 0 N y i ⋅ l n ( y ^ i ) CE=-\sum_{i=0}^Ny_i\cdot ln(\hat y_i) CE=−i=0∑Nyi⋅ln(y^i)
其中:
-
N N N:类别的总数
-
y i y_i yi:target标签的 one-hot 编码,若样本属于第 i i i 类,则 y i = 1 y_i =1 yi=1,否则 y i = 0 y_i=0 yi=0
-
y ^ i \hat y_i y^i : Softmax 激活函数输出结果中第 i i i 类的概率
直观解释:
无论是二分类问题,还是多分类问题,如果目标标签是类别 k k k,那么此时的交叉熵可以简化成:
C E = − l n ( y ^ k ) = − l n ( 模型预测类别 k 的概率 ) CE=-ln(\hat y_k)=-ln(模型预测类别k的概率) CE=−ln(y^k)=−ln(模型预测类别k的概率)
更根本的就是,如果目标标签是什么,那么 C E CE CE就是模型预测成该目标标签的概率,经过 − l n -ln −ln的计算之后的结果。
关于交叉熵损失函数的具体讲解,这里就不赘述了。可以参考博文:
FocalLoss基本概念
样本数量不平衡
我们在做实际模型训练的时候,经常会遇到各类样本数量比例不平衡的情况。比如:对于二分类任务,有可能负样本的数量远远多于正样本的数量,导致模型更多关注在负样本上,从而忽略正样本。
因此,在使用交叉熵损失的时候,通常会增加一个平衡参数用来调节正负样本的比重。
引入样本平衡参数 α k ∈ [ 0 , 1 ] \alpha_k \in[0,1] αk∈[0,1]:
C E = − α k ⋅ l n ( y ^ k ) CE=-\alpha_k \cdot ln(\hat y_k) CE=−αk⋅ln(y^k)
如果每个 α k \alpha_k αk都相等,那么就可以理解为:退化成交叉熵了。
一般情况下,如果是二分类问题,我们保证 α 正样本 + α 负样本 = 1 \alpha_{正样本}+\alpha_{负样本}=1 α正样本+α负样本=1,因此只需要指定正样本的平衡参数 α \alpha α就行了,此时负样本的平衡参数就是 1 − α 1-\alpha 1−α。如果是多分类问题,则保证每个 α k ∈ [ 0 , 1 ] \alpha_k \in[0,1] αk∈[0,1]就行了。即, α \alpha α的类别总数就是target标签的类别总数。
也就是说,如果有 N N N种类别,那么 α \alpha α就有 N N N种值。每个样本的target标签是 k k k,那么它的平衡参数就是 α k \alpha_k αk。二分类, N = 2 N=2 N=2,就是 α \alpha α、 1 − α 1-\alpha 1−α;多分类, N N N,就是 α 1 \alpha_1 α1、 α 2 \alpha_2 α2、...、 α N \alpha_N αN。
如果类别 k k k的样本数量较多,那么可以适当的降低 α k \alpha_k αk的值,降低loss;如果类别 k k k的样本数量较少,那么可以适当的提高 α k \alpha_k αk的值,提高loss。也就是说, α k \alpha_k αk的值应该是样本数量成反比。
例如,对于一个二分类问题,假如正样本数量占0.2,负样本数量占0.8,那么 α 正样本 = 0.8 \alpha_{正样本}=0.8 α正样本=0.8、 α 负样本 = 0.2 \alpha_{负样本}=0.2 α负样本=0.2(当然也可以微调),以此来达到平衡正负样本的目的。这样理解,看来是没有问题的。
难易分类样本
我们在做实际模型训练的时候,也会遇到各种难分类样本、易分类样本。比如:当易分类样本超级多时,整个训练过程将会围绕着易分类样本进行,进而淹没难分类样本,造成大损失。
因此,在使用交叉熵损失的时候,通常会增加一个平衡参数用来调节难易分类样本的比重。
引入样本平衡参数 γ ∈ [ 0 , 5 ] \gamma \in [0,5] γ∈[0,5]:
C E = − ( 1 − y ^ k ) γ ⋅ l n ( y ^ k ) CE=-(1-\hat y_k)^{\gamma}\cdot ln(\hat y_k) CE=−(1−y^k)γ⋅ln(y^k)
如果 γ = 0 \gamma=0 γ=0,那么就退化成交叉熵了。
当 γ ≠ 0 \gamma\neq 0 γ=0时,如果类别 k k k的模型预测概率 y ^ k \hat y_k y^k较高,即易分类样本;如果类别 k k k的模型预测概率 y ^ k \hat y_k y^k较低,即难分类样本;那么该平衡参数对易分类的loss降低幅度,比对难分类的loss降低幅度大很多。如果 γ \gamma γ越大,两者的幅度差距会更大。这样理解,看来是没有问题的。
FocalLoss
通过以上针对样本数量不平衡以及难易分类样本,可以得到应该最终的FocalLoss形式:
F L = − α k ( 1 − y ^ k ) γ ⋅ l n ( y ^ k ) = − α k ( 1 − p k ) γ ⋅ l n ( p k ) \begin{aligned}FL&=-\alpha_k(1-\hat y_k)^{\gamma}\cdot ln(\hat y_k)\\&=-\alpha_k(1-p_k)^\gamma\cdot ln(p_k)\end{aligned} FL=−αk(1−y^k)γ⋅ln(y^k)=−αk(1−pk)γ⋅ln(pk)
即:即通过 α k \alpha_k αk可以抑制样本的数量失衡,通过 γ \gamma γ可以控制难易分类样本失衡。
FocalLoss关键点解释
接下来,会对FocalLoss形式中的计算结果进行关键的的分析和解释:
F L = − α k ( 1 − p k ) γ ⋅ l n ( p k ) FL=-\alpha_k(1-p_k)^\gamma\cdot ln(p_k) FL=−αk(1−pk)γ⋅ln(pk)
p k p_k pk
p k p_k pk就是交叉熵里面的 y ^ k \hat y_k y^k,即模型预测类别k的概率。
对于二分类的情形,如果target标签是正样本,那 p k p_k pk就是模型预测正样本的概率;如果target标签是负样本,那 p k p_k pk就是模型预测负样本的概率。
对于多分类的情形, p k p_k pk就是模型预测target标签的概率。
也就是说, p k p_k pk的值代表了模型对样本预测正确的概率。因此, p k p_k pk 的大小实际能反映出样本难易分类的程度。
如果 p k p_k pk的值越逼近1,那么该样本就比较容易分类;如果 p k p_k pk的值越逼近0,那么该样本就不容易分类。
甚至可以这么粗暴的认为:
-
易分类样本:模型预测正确的概率较高,即 p k p_k pk 较大(通常 p k > 0.5 p_k>0.5 pk>0.5)
-
难分类样本:模型预测正确的概率较低,即 p k p_k pk 较小(通常 p k < 0.5 p_k<0.5 pk<0.5)
超参数 γ \gamma γ
上面说过, p k p_k pk代表了样本难易分类的程度。在训练模型的时候,我们希望模型更加关注难分类样本,所以会考虑将难分类样本在损失函数中的比重加大。
作者在原始的二分类交叉函数中增加了一项 ( 1 − p k ) γ (1-p_k)^{\gamma} (1−pk)γ,对原始交叉熵损失做了衰减。
经过对 p k p_k pk的分析可知,难分类样本的 p k p_k pk 值小, 1 − p k 1-p_k 1−pk 大;易分类样本的 p k p_k pk 值大, 1 − p k 1-p_k 1−pk 值小。尽管,无论是难分类还是易分类样本,FocalLoss相对于原始的交叉熵都做了衰减,但是难分类样本相对于易分类样本衰减得更轻微。
也就是说,从相对的角度上看,对于易分类样本是衰减的,对于难分类样本是增强的。
而这里的超参数 γ \gamma γ,则决定了两者相对的衰减/增强的属性, γ \gamma γ越大,相对性越明显。即,当 γ \gamma γ增大时,对易分类样本的损失衰减更加明显,从而使模型更加关注难分类样本。
如果不是很好理解,我们用一个例子说明:
对于一个二分类情形,我们取 {\\gamma}=2,当target样本为正样本时:
-
如果模型预测正样本的 p = 0.3 p = 0.3 p=0.3,表明是难分类样本,则 p k = 0.3 p_k = 0.3 pk=0.3, ( 1 − p k ) γ = ( 1 − 0.3 ) 2 = 0.49 ( 1 − p_k )^{\gamma} = ( 1 − 0.3 )^ 2 = 0.49 (1−pk)γ=(1−0.3)2=0.49,相当于原始的CE Loss的0.49倍
-
如果模型预测正样本的 p = 0.7 p = 0.7 p=0.7,表明是易分类样本,则 p k = 0.7 p_k = 0.7 pk=0.7, ( 1 − p k ) γ = ( 1 − 0.7 ) 2 = 0.09 ( 1 − p_k )^{\gamma} = ( 1 − 0.7 )^ 2 = 0.09 (1−pk)γ=(1−0.7)2=0.09,相当于原始的CE Loss的0.09倍
也就是说,尽管都是衰减,但是难分类样本衰减成了0.49,但是易分类样本衰减成了0.09。可以看出,这样会导致模型更关注难分类样本。
超参数 α \alpha α
上面说过, α \alpha α用于缓解各类样本数量比例不平衡的情况。这个比较容易理解,如果某类的样本数量很多,那么就给它更小的平衡参数,让模型更关注样本数量较少的样本。
论文作者指出:加入 α \alpha α平衡参数比不加时精度有所提升。并且给出了实验参数,在作者的实验中当 {\\alpha}=0.25 , , , {\\gamma}=2时精度最高。
这时就有一个问题了, α \alpha α代表计算损失时对应正样本的调节权重,而正样本数量一般要小于负样本的数量,所以理论上正样本的权重应该大于负样本的权重的。但是作者实验中最佳的正样本权重( {\\alpha}=0.25 )为啥比负样本权重( )为啥比负样本权重( )为啥比负样本权重( 1-{\\alpha}=0.75)还要低呢?
明明负样本的数量已经远远大于正样本的数量了,为啥还要增加损失函数中负样本的比重呢?这不是矛盾吗?
其实,作者在论文里给出了解释。
即: α \alpha α和 γ \gamma γ如果是单独使用,自然是符合我们分析的结果的。但是在FocalLoss中,两者是混合使用的,它们之间也是有影响的主次关系的。一般来说, γ \gamma γ是占主导地位的,随着 {\\gamma} 的增大, 的增大, 的增大,{\\alpha}要相应的减小。
FocalLoss梯度分析
根据FocalLoss的公式:
F L = − α k ( 1 − p k ) γ ⋅ l n ( p k ) FL=-\alpha_k(1-p_k)^\gamma\cdot ln(p_k) FL=−αk(1−pk)γ⋅ln(pk)
对FocalLoss进行求导:
∂ ( F L ) ∂ p k = − α k ⋅ γ ⋅ ( 1 − p k ) γ − 1 l n ( p k ) − α k ⋅ ( 1 − p k ) γ ⋅ 1 p k \frac{\partial (FL)}{\partial p_k}=-\alpha_k\cdot\gamma\cdot(1-p_k)^{\gamma-1}ln(p_k)-\alpha_k\cdot(1-p_k)^{\gamma}\cdot\frac{1}{p_k} ∂pk∂(FL)=−αk⋅γ⋅(1−pk)γ−1ln(pk)−αk⋅(1−pk)γ⋅pk1
可以看出, p t p_{t} pt接近1时,FocalLoss的梯度趋于0, p t p_{t} pt靠近0,FocalLoss的梯度越来越大。
也就是说,那么预测值 p t p_t pt和真实值target非常接近的时候,梯度极小,网络参数几乎不变,当预测值 p t p_t pt和真实值target差距较大时,梯度变大,网络参数开始调整。
小Tips
-
在使用FocalLoss时, γ \gamma γ占主导因素,同时 γ \gamma γ是用来控制难易分类样本的,所以当数据中难分类样本较多时, γ \gamma γ可以设置的大一些
-
如果要分类的目标特征比较明显(建筑、道路),最好不要用FocalLoss
代码实战
二分类问题,此时 α \alpha α指定正样本的平衡系数即可,此时负样本的平衡系数就是 1 − α 1-\alpha 1−α, γ \gamma γ本身就是一个标量。
python
class BinaryFocalLoss(torch.nn.Module):
"""
二分类BinaryFocalLoss
"""
def __init__(self, alpha=0.25, gamma=2):
super(BinaryFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, preds, labels):
"""
preds: sigmoid的输出结果
labels: 标签, 1为正样本, 0为负样本
"""
eps = 1e-7
loss_1 = -1 * self.alpha * torch.pow((1 - preds), self.gamma) * torch.log(preds + eps) * labels
loss_0 = -1 * (1 - self.alpha) * torch.pow(preds, self.gamma) * torch.log(1 - preds + eps) * (1 - labels)
loss = loss_0 + loss_1
return torch.mean(loss)
多分类问题,此时 α \alpha α是一个向量,大小为所有样本的长度,表示每个样本所属的target类别的平衡系数, γ \gamma γ本身就是一个标量。
python
class FocalLoss(torch.nn.Module):
"""
多分类FocalLoss
"""
def __init__(self, alpha, gamma=2):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, preds, labels):
"""
preds: 模型的输出, 未经过 Softmax, 形状为 [N, C]
labels: 标签, 形状为 [N], 取值为类别的索引
"""
eps = 1e-7
# 计算 Log-Softmax
preds_logsoft = F.log_softmax(preds, dim=1)
# 计算 softmax
preds_softmax = torch.exp(preds_logsoft)
# unsqueeze:将 labels 的形状从 [N] 转换为 [N, 1],便于与 preds_softmax 对应
# gather:根据 labels 提取每个样本真实类别对应的概率值,结果形状为 [N, 1]
preds_softmax = preds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1)
preds_logsoft = preds_logsoft.gather(1, labels.unsqueeze(1)).squeeze(1)
# 计算 focal loss
loss = -self.alpha * torch.pow(1 - preds_softmax, self.gamma) * preds_logsoft
return loss.mean()
这边对代码preds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1)分析:
python
import torch
# 假设 preds_softmax 是模型的 Softmax 输出,形状为 [N, C]
preds_softmax = torch.tensor([
[0.1, 0.7, 0.2], # 样本 1 的类别概率分布
[0.3, 0.4, 0.3], # 样本 2 的类别概率分布
[0.2, 0.5, 0.3] # 样本 3 的类别概率分布
])
# 假设 labels 是目标标签,形状为 [N]
labels = torch.tensor([1, 2, 0]) # 样本 1 属于类别 1,样本 2 属于类别 2,样本 3 属于类别 0
# 使用 gather 提取每个样本真实类别的概率
labels_unsqueezed = labels.unsqueeze(1) # 将 labels 的形状从 [N] 转换为 [N, 1]
print("labels.unsqueeze(1):")
print(labels_unsqueezed)
# 根据 labels 提取 preds_softmax 中对应的概率
gathered_probs = preds_softmax.gather(1, labels_unsqueezed) # 结果形状为 [N, 1]
print("\npreds_softmax.gather(1, labels.unsqueeze(1)):")
print(gathered_probs)
# 使用 squeeze(1) 去掉第 1 维,结果形状为 [N]
final_probs = gathered_probs.squeeze(1)
print("\npreds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1):")
print(final_probs)
输出结果:
labels.unsqueeze(1):
tensor([[1],
[2],
[0]])
preds_softmax.gather(1, labels.unsqueeze(1)):
tensor([[0.7000],
[0.3000],
[0.2000]])
preds_softmax.gather(1, labels.unsqueeze(1)).squeeze(1):
tensor([0.7000, 0.3000, 0.2000])
需要注意的是,这边的代码写法应该与输入的shape、参数的shape有关系的,需要针对于具体的情形进行适配和修改。