Focal Loss 原理详解及 PyTorch 代码实现
介绍
一、Focal Loss 背景
Focal Loss 是为解决类别不平衡问题设计的损失函数,通过引入 gamma 参数降低易分类样本的权重,使用 alpha 参数调节正负样本比例。在目标检测等类别不平衡场景中表现优异。
二、代码逐行解析
1. 类定义与初始化
python
class FocalLoss(nn.Module):
"""应用 Focal Loss 通过 gamma 和 alpha 参数改进 BCEWithLogitsLoss 以处理类别不平衡"""
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
super().__init__()
self.loss_fcn = loss_fcn # 必须使用 nn.BCEWithLogitsLoss()
self.gamma = gamma # 调节难易样本权重的指数参数
self.alpha = alpha # 平衡正负样本比例的权重系数
# 修改原损失函数的 reduction 为 'none' 进行逐元素计算
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = "none"
def forward(self, pred, true):
# 计算基础交叉熵损失
loss = self.loss_fcn(pred, true)
# 通过 sigmoid 获取概率预测值(范围0-1)
pred_prob = torch.sigmoid(pred)
# 计算 p_t(真实类别对应的预测概率)(正确分类的概率)
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
# 计算 alpha 因子:正样本乘 alpha,负样本乘 (1-alpha) (类别权重)
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
# 计算调制因子:难分类样本权重更大 (困难样本权重)
modulating_factor = (1.0 - p_t) ** self.gamma
# 组合得到最终的 Focal Loss
loss *= alpha_factor * modulating_factor
# 根据 reduction 设置返回结果
if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
else: # 'none'
return loss
三、核心参数作用
-
Gamma (γ):
- γ > 0 时,降低易分类样本(p_t 接近 1)的损失权重
- 典型取值范围:0.5-5.0
- 示例:当 p_t=0.9,γ=2 → 调制因子 = 0.01
-
Alpha (α):
- 调节正负样本权重比例
- α 接近 1 时强调正样本
- α 接近 0 时强调负样本
四、使用示例
python
# 初始化
criterion = FocalLoss(
loss_fcn=nn.BCEWithLogitsLoss(),
gamma=2.0,
alpha=0.75
)
# 计算损失
pred = model(inputs)
loss = criterion(pred, targets)
五、应用场景
- 目标检测(如 RetinaNet)
- 医学图像分析
- 任何存在严重类别不平衡的分类任务
六、总结
Focal Loss 通过两个关键参数实现了:
- 降低大量易分类样本的损失贡献
- 平衡正负样本的权重比例
- 改善模型对困难样本的学习能力