分类中的样本不平衡问题——Asymmetric Loss

收到采集条件,真实情况的限制,不管是二分类的正负样本还是多分类中,都会普遍存在样本不平衡的问题。

Asymmetric Loss(非对称损失函数,简称 ASL)

. 概率计算与非对称裁剪 (Asymmetric Clipping)

复制代码
xs_pos = torch.sigmoid(x)
xs_neg = 1 - xs_pos

# Asymmetric Clipping
if self.clip is not None and self.clip > 0:
    xs_neg = (xs_neg + self.clip).clamp(max=1)
  • 原理解读 :首先通过 Sigmoid 将模型输出的 logits (x) 转换为正样本的概率 xs_pos,那么负样本的概率就是 xs_neg = 1 - xs_pos
  • 代码亮点 :这里实现了 ASL 独有的概率位移(Probability Margin Shifting) 。代码对负样本概率进行了 (xs_neg + clip) 的操作,然后截断在 1 以内。这相当于给简单负样本设定了一个硬阈值边界,当负样本的原始预测概率极低时,加上 clip 后的对数损失会被大幅削弱甚至归零,从而直接丢弃这些无用的梯度。注意,这个操作只针对负样本,体现了非对称性。

基础交叉熵计算与掩码分离 (Basic CE & Masking)

复制代码
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
loss = los_pos + los_neg
  • 原理解读:这是标准的二元交叉熵(BCE)计算过程。
  • 代码亮点:交叉指的是权重和log部分分别是真实标签和预测值;
  • 掩码则指的是利用标签 y(0或1矩阵)作为天然的掩码(Mask) 。因为当 y=1 时,loss=los_pos;当 y=0时,loss=los_neg。所以 los_pos 实际上只保留了真实为正类的样本的损失,而 los_neg 只保留了真实为负类的样本的损失。两者相加后,就得到了一个包含所有样本基础损失的矩阵。

非对称聚焦权重 (Asymmetric Focusing)

复制代码
pt = xs_pos * y + xs_neg * (1 - y)
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
loss *= one_sided_w
  • 原理解读:这是 Focal Loss 核心调制因子 (1−pt)γ的非对称升级版。
  • 代码亮点
    • pt 巧妙地提取了每个样本对应的正确类别概率(正样本取 xs_pos,负样本取 xs_neg)。含义是分对的概率/自信程度。
    • one_sided_gamma 同样利用标签 y 作为掩码,为每一个样本动态分配 γ 值:正样本获得 gamma_pos,负样本获得 gamma_neg
    • pt是自信程度,那么1-pt就是犹豫程度,使用pow可以使得犹豫程度低的压低权重,犹豫程度高的接近1,则不被抑制。
    • 针对负样本(gamma_neg :在多标签分类中,一张图通常只有少数几个正标签,剩下的绝大多数都是负标签。因为负样本数量极多,且其中大部分是毫无难度的"简单负样本",所以必须设置一个较大的 gamma_neg(如默认值 4),以强力抑制它们,防止模型被负样本主导。
    • 针对正样本(gamma_pos :正样本本身就是稀缺资源。如果给正样本设置过大的 γγ ,会导致大量正样本的损失被削弱,模型反而学不到东西。因此,gamma_pos 通常设置得较小,甚至为 0(即不对正样本做 Focal 降权,保持标准交叉熵的学习强度)。
    • 抑制程度就是gamma控制,正/负样本分别对应一个gamma,gamma越大,抑制越厉害。
    • 最后将这个非对称权重 one_sided_w 乘到基础 loss 上,完成了对难易样本的动态降权
    • 正负失衡 vs 类间失衡。ASL 的原始设计并不直接处理这种类间的不平衡,它默认所有正样本(无论是人还是长颈鹿)在训练初期都同样稀缺。
    • 对于4分类,pt = [0.1, 0.8, 0.05, 0.05],那么 1 - pt = [0.9, 0.2, 0.95, 0.95],真实标签y=[0,1,0,0],one_sided_gamma = 4, 1, 4, 4,指数之后的权重张量 one_sided_w = [0.6561, 0.2, 0.8145, 0.8145]。可以看到,索引0,2,3处本来错误率很高,加权之后权重相对索引1被降低了,所以模型可以更专注于正样本索引1.

梯度控制优化 (Gradient Control)

复制代码
  if self.disable_torch_grad_focal_loss:
      torch.set_grad_enabled(False)
  # ... 计算 one_sided_w ...
  if self.disable_torch_grad_focal_loss:
      torch.set_grad_enabled(True)

工程细节:这是一个非常高级的工程优化。在计算聚焦权重 one_sided_w 时,如果开启了此选项,会临时关闭 PyTorch 的自动求导功能,计算出one_sided_w之后,乘到loss之前,再打开梯度。

  • 因为one_sided_w 是一个常数权重,但是 PyTorch 并不知道你的"数学意图"。它会忠实地把 torch.pow* 这两个操作都加入计算图。它不仅会计算基础 BCE Loss 对模型预测概率 pt 的梯度,还会额外计算 one_sided_w 对 pt 的梯度,然后利用链式法则把它们加起来。这样反向传播时需要多做一次复杂的链式求导,拖慢了训练速度,还容易显存爆炸。

  • 所以临时设置torch.set_grad_enabled(False),我们通常只需要让基础 BCE Loss 参与梯度回传,而不需要对这个复杂的权重项再求一次高阶导数。这不仅能节省显存,还能加快训练速度。

相关推荐
-山中问答-1 小时前
【智能体工具使用实战04】构建执行沙盒与安全边界
人工智能·安全·智能体·沙盒
AI客栈1 小时前
云原生 AI 平台架构设计:从模型服务到弹性调度的全链路工程实践
人工智能
AI原来如此1 小时前
阿里云百炼上线DeepSeek,OpenAI发布GPT-5.5,模型服务战升级
人工智能·gpt·阿里云·ai·大模型·ai编程
果丁智能1 小时前
物联网智能锁在网约房、民宿场景的落地实践:身份核验与远程授权的全链路解决方案
人工智能·物联网·智能家居
jinxindeep1 小时前
ω-EVA:基于隐变量交互式世界模型的机器人动作生成新范式(星源智)
人工智能·机器人
hnult1 小时前
2026在线笔试平台选型指南:考试云九重防作弊与六大AI能力解析
人工智能·笔记·microsoft·课程设计
Mr. zhihao1 小时前
SDD(规范驱动开发):AI 编程时代的范式革命——因果链视角
人工智能·ai编程
大腾智能1 小时前
华为开发者大会2026观察:鸿蒙底座成型,大腾智能锚定工业AI路径
人工智能·华为·harmonyos
rising start1 小时前
ReAct Agent:让 AI 学会思考与行动
人工智能·agent