Focal Loss 原理详解及 PyTorch 代码实现

Focal Loss 原理详解及 PyTorch 代码实现

介绍

一、Focal Loss 背景

Focal Loss 是为解决类别不平衡问题设计的损失函数,通过引入 gamma 参数降低易分类样本的权重,使用 alpha 参数调节正负样本比例。在目标检测等类别不平衡场景中表现优异。

二、代码逐行解析

1. 类定义与初始化

python 复制代码
class FocalLoss(nn.Module):
    """应用 Focal Loss 通过 gamma 和 alpha 参数改进 BCEWithLogitsLoss 以处理类别不平衡"""
    
    def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
        super().__init__()
        self.loss_fcn = loss_fcn  # 必须使用 nn.BCEWithLogitsLoss()
        self.gamma = gamma        # 调节难易样本权重的指数参数
        self.alpha = alpha        # 平衡正负样本比例的权重系数
        
        # 修改原损失函数的 reduction 为 'none' 进行逐元素计算
        self.reduction = loss_fcn.reduction
        self.loss_fcn.reduction = "none"

	def forward(self, pred, true):
	    # 计算基础交叉熵损失
	    loss = self.loss_fcn(pred, true)
	    
	    # 通过 sigmoid 获取概率预测值(范围0-1)
	    pred_prob = torch.sigmoid(pred)
	    
	    # 计算 p_t(真实类别对应的预测概率)(正确分类的概率)
	    p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
	    
	    # 计算 alpha 因子:正样本乘 alpha,负样本乘 (1-alpha) (类别权重)
	    alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
	    
	    # 计算调制因子:难分类样本权重更大 (困难样本权重)
	    modulating_factor = (1.0 - p_t) ** self.gamma
	    
	    # 组合得到最终的 Focal Loss
	    loss *= alpha_factor * modulating_factor
	    
	    # 根据 reduction 设置返回结果
	    if self.reduction == "mean":
	        return loss.mean()
	    elif self.reduction == "sum":
	        return loss.sum()
	    else:  # 'none'
	        return loss

三、核心参数作用

  1. Gamma (γ)

    • γ > 0 时,降低易分类样本(p_t 接近 1)的损失权重
    • 典型取值范围:0.5-5.0
    • 示例:当 p_t=0.9,γ=2 → 调制因子 = 0.01
  2. Alpha (α)

    • 调节正负样本权重比例
    • α 接近 1 时强调正样本
    • α 接近 0 时强调负样本

四、使用示例

python 复制代码
# 初始化
criterion = FocalLoss(
    loss_fcn=nn.BCEWithLogitsLoss(),
    gamma=2.0,
    alpha=0.75
)

# 计算损失
pred = model(inputs)
loss = criterion(pred, targets)

五、应用场景

  • 目标检测(如 RetinaNet)
  • 医学图像分析
  • 任何存在严重类别不平衡的分类任务

六、总结

Focal Loss 通过两个关键参数实现了:

  1. 降低大量易分类样本的损失贡献
  2. 平衡正负样本的权重比例
  3. 改善模型对困难样本的学习能力
相关推荐
东风西巷1 天前
Balabolka:免费高效的文字转语音软件
前端·人工智能·学习·语音识别·软件需求
非门由也1 天前
《sklearn机器学习——管道和复合估计器》联合特征(FeatureUnion)
人工智能·机器学习·sklearn
l12345sy1 天前
Day21_【机器学习—决策树(1)—信息增益、信息增益率、基尼系数】
人工智能·决策树·机器学习·信息增益·信息增益率·基尼指数
非门由也1 天前
《sklearn机器学习——管道和复合估算器》异构数据的列转换器
人工智能·机器学习·sklearn
计算机毕业设计指导1 天前
基于ResNet50的智能垃圾分类系统
人工智能·分类·数据挖掘
飞哥数智坊1 天前
终端里用 Claude Code 太难受?我把它接进 TRAE,真香!
人工智能·claude·trae
java1234_小锋1 天前
Scikit-learn Python机器学习 - 特征降维 压缩数据 - 特征提取 - 主成分分析 (PCA)
python·机器学习·scikit-learn
java1234_小锋1 天前
Scikit-learn Python机器学习 - 特征降维 压缩数据 - 特征提取 - 线性判别分析 (LDA)
python·机器学习·scikit-learn
小王爱学人工智能1 天前
OpenCV的阈值处理
人工智能·opencv·计算机视觉
新智元1 天前
刚刚,光刻机巨头 ASML 杀入 AI!豪掷 15 亿押注「欧版 OpenAI」,成最大股东
人工智能·openai