【深度学习实战(22)】解决分类不均衡问题之Focal Loss

一、Focal Loss公式介绍

Focal loss是何恺明大神提出的一种新的loss计算方案。其具有两个重要的特点。

1、控制正负样本的权重

2、控制容易分类和难分类样本的权重

论文:

二分类问题交叉熵损失

公式:

我们可以利用如下Pt简化交叉熵loss。

此时:

代码:

cpp 复制代码
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')

正负样本平衡项

-想要降低负样本的影响,可以在常规的损失函数前增加一个系数αt。与Pt类似,当label=1的时候,αt=α;当label=otherwise的时候,αt=1 - α,a的范围也是0到1。此时我们便可以通过设置α实现控制正负样本对loss的贡献。

公式:

其中:

分解开就是:

难易样本平衡项

样本属于某个类,且预测结果中该类的概率越大,其越容易分类 ,在二分类问题中,正样本的标签为1,负样本的标签为0,p代表样本为1类的概率。

对于正样本而言,1-p的值越大,样本越难分类。

对于负样本而言,p的值越大,样本越难分类。

Pt的定义如下

所以利用1-Pt就可以计算出每个样本属于容易分类或者难分类。

具体实现方式如下。

两种权重控制方法合并,就得到了Focal Loss

通过如下公式就可以实现控制正负样本的权重和控制容易分类和难分类样本的权重。

分解开就是:

二、Focal Loss代码实现

cpp 复制代码
import torch
import torch.nn as nn
import torch.functional as F

class WeightedFocalLoss(nn.Module):
    "Non weighted version of Focal Loss"    
    def __init__(self, alpha=.25, gamma=2):
            super(WeightedFocalLoss, self).__init__()  
            # --------------#
            #   平衡正负样本系数
            # --------------#      
            self.alpha = torch.tensor([alpha, 1-alpha]).cuda()      
            # --------------#
            #   平衡难易样本系数
            # --------------#   
            self.gamma = gamma
            
    def forward(self, inputs, targets):
            # --------------#
            #   分类交叉熵损失
            # --------------# 
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')   
            # --------------#
            #   标签GT
            # --------------#      
            targets = targets.type(torch.long)     
            # --------------#
            #   计算at
            # --------------#    
            at = self.alpha.gather(0, targets.data.view(-1))   
            # --------------#
            #   计算pt: BEC_loss = -log(pt)  --> pt = torch.exp(-BCE_loss)   
            # --------------#       
            pt = torch.exp(-BCE_loss)   
            # --------------#
            #   计算Focal Loss
            # --------------#       
            F_loss = at*(1-pt)**self.gamma * BCE_loss        
            return F_loss.mean()
相关推荐
@不误正业2 分钟前
鸿蒙小艺智能体开放平台实战-接入系统级AI-Agent能力
人工智能·华为·harmonyos
月诸清酒3 分钟前
47-260429 AI 科技日报 (HappyHorse 1.0 登顶文本转视频模型排行榜)
人工智能
byoass3 分钟前
智巢AI知识库深度解析:企业文档管理从大海捞针到精准狙击的进化之路
开发语言·网络·人工智能·安全·c#·云计算
掘金一周13 分钟前
你们觉得房贷多少,没有压力 | 沸点周刊 4.30
前端·人工智能·后端
美狐美颜SDK开放平台13 分钟前
多场景美颜SDK解决方案:直播APP(iOS/安卓)开发接入详解
android·人工智能·ios·音视频·美颜sdk·第三方美颜sdk·短视频美颜sdk
桜吹雪28 分钟前
Langchain.js官方文档:构建具备按需加载技能的 SQL 助手
javascript·人工智能·node.js
ting945200032 分钟前
深入解析 Social Fetch 机制:原理、架构、应用场景、实战落地与性能优化全攻略
人工智能·性能优化·架构
阿瑞说项目管理32 分钟前
2026 实战入门指南:企业 Agent 到底能解决哪些工作问题?
大数据·人工智能·agent·智能体·企业级ai
ZOOOOOOU33 分钟前
云边端协同架构下,门禁权限引擎的离线决策与策略续存实现
大数据·人工智能·架构
han_34 分钟前
一篇看懂国内外主流大模型:GPT、Claude、Gemini、DeepSeek、通义千问有什么区别?
前端·人工智能·llm