损失函数 的 硬截断 和 平滑衰减

损失函数 的 硬截断 和 平滑衰减

flyfish

在逐样本损失计算完成、取平均之前,对损失过高的样本做权重压制,不删除样本,只削弱它们对梯度的贡献,属于软降权------既保留了样本的监督信号,又避免极端难样本/疑似错标样本带偏整个模型。

损失硬截断

损失硬截断是给单样本损失设置一个上限,超过这个阈值的损失,直接按阈值计算。相当于一刀切,超过上限的样本梯度不再放大。

代码实现

python 复制代码
class FocalLossWithSmoothing(nn.Module):
    def __init__(self, gamma=2, alpha=None, smoothing=0.0, num_classes=2, max_loss=None):
        """
        :param max_loss: 单样本损失上限,None表示不开启截断;设置数值后,单样本损失不会超过该值
        """
        super().__init__()
        self.gamma = gamma
        self.alpha = torch.tensor(alpha).to(DEVICE) if alpha else None
        self.smoothing = smoothing
        self.num_classes = num_classes
        self.max_loss = max_loss  # 损失截断阈值

    def forward(self, inputs, targets):
        targets_one_hot = torch.zeros_like(inputs).scatter_(1, targets.unsqueeze(1), 1)
        soft_targets = targets_one_hot * (1 - self.smoothing) + self.smoothing / self.num_classes
        
        log_probs = torch.nn.functional.log_softmax(inputs, dim=1)
        probs = torch.exp(log_probs)
        
        p_t = (probs * targets_one_hot).sum(dim=1, keepdim=True)
        focal_weight = (1 - p_t) ** self.gamma
        
        ce_loss = (-soft_targets * log_probs).sum(dim=1)
        loss = focal_weight.squeeze() * ce_loss
        
        if self.alpha is not None:
            alpha_t = (self.alpha.unsqueeze(0) * targets_one_hot).sum(dim=1)
            loss = loss * alpha_t

        # ========== 损失截断 ==========
        if self.max_loss is not None:
            loss = torch.clamp(loss, max=self.max_loss)
        
        return loss.mean()

使用方式

在训练函数里初始化损失时,多加一个 max_loss 参数即可:

python 复制代码
# 示例:单样本损失最高不超过2.0,超过的全部按2.0计算
criterion = FocalLossWithSmoothing(
    gamma=FOCAL_GAMMA, 
    alpha=FOCAL_ALPHA, 
    smoothing=LABEL_SMOOTHING, 
    num_classes=NUM_CLASSES,
    max_loss=2.0  # 开启截断,阈值可按需调整
)

平滑衰减降权

硬截断是一刀切:损失超过阈值,直接砍平,损失值瞬间不再增长,像台阶一样突变;

平滑衰减是越涨越慢:损失低于阈值时正常计算,超过阈值后还能继续涨,但增长速度会越来越慢,过渡是顺滑的曲线,没有突变台阶。

它的目的:既保留损失越高、权重越大的相对顺序,又不让极端高损失样本无限放大梯度带偏模型,同时保证训练过程梯度平稳,不会出现跳变

代码实现 只需要把截断部分替换成平滑衰减逻辑即可:

python 复制代码
class FocalLossWithSmoothing(nn.Module):
    def __init__(self, gamma=2, alpha=None, smoothing=0.0, num_classes=3, loss_threshold=1.8):
        super().__init__()
        self.gamma = gamma
        self.alpha = torch.tensor(alpha).to(DEVICE) if alpha else None
        self.smoothing = smoothing
        self.num_classes = num_classes
        self.loss_threshold = loss_threshold  # 平滑衰减阈值

    def forward(self, inputs, targets):
        targets_one_hot = torch.zeros_like(inputs).scatter_(1, targets.unsqueeze(1), 1)
        soft_targets = targets_one_hot * (1 - self.smoothing) + self.smoothing / self.num_classes
        
        log_probs = torch.nn.functional.log_softmax(inputs, dim=1)
        probs = torch.exp(log_probs)
        
        p_t = (probs * targets_one_hot).sum(dim=1, keepdim=True)
        focal_weight = (1 - p_t) ** self.gamma
        
        ce_loss = (-soft_targets * log_probs).sum(dim=1)
        loss = focal_weight.squeeze() * ce_loss
        
        if self.alpha is not None:
            alpha_t = (self.alpha.unsqueeze(0) * targets_one_hot).sum(dim=1)
            loss = loss * alpha_t

        # 平滑衰减降权:压制极端高损失样本
        if self.loss_threshold is not None:
            high_loss_mask = loss > self.loss_threshold
            loss[high_loss_mask] = self.loss_threshold + torch.log(1 + loss[high_loss_mask] - self.loss_threshold)
        
        return loss.mean()

假设设置阈值 = 1.5,看不同原始损失对应的处理结果:

原始单样本损失 硬截断后损失 变化特点
1.0(正常样本) 1.0 低于阈值,完全不变
1.4(较难样本) 1.4 低于阈值,完全不变
1.5(阈值点) 1.5 刚好等于阈值
1.6(难样本) 1.5 超过一点点,直接被砍成1.5,瞬间停止增长
3.0(极难/错标样本) 1.5 不管多高,全砍成1.5,和1.6的样本权重完全一样

硬截断的问题

  1. 阈值点处损失突变,梯度也会突变,训练过程容易出现震荡;
  2. 所有超过阈值的样本,损失都一样,丢失了难分程度的差异信息------3.0的极难样本和1.6的轻微难样本,对模型的贡献变得完全相同,有点矫枉过正。

平滑衰减的逻辑:两段式 + 对数压缩

代码里用的是阈值以下正常计算,阈值以上对数压缩的两段式策略,公式是:

处理后损失={原始损失原始损失≤阈值阈值+log⁡(1+原始损失−阈值)原始损失>阈值 \text{处理后损失} = \begin{cases} \text{原始损失} & \text{原始损失} \le 阈值 \\ 阈值 + \log(1 + \text{原始损失} - 阈值) & \text{原始损失} > 阈值 \end{cases} 处理后损失={原始损失阈值+log(1+原始损失−阈值)原始损失≤阈值原始损失>阈值

为什么用 log(对数)函数?

对数函数有两个完美匹配需求的特性:

  1. 单调递增:原始损失越大,处理后的损失也一定越大,不会改变谁更难、谁损失更高的排序,样本的相对权重关系保留了;
  2. 增速递减:x 越大,log(x) 涨得越慢。原始损失越高,压缩力度越强,正好符合极端样本降权更多的需求。

直观对比效果

还是设阈值 = 1.5,算一组真实数值,一眼就能看出区别:

原始单样本损失 硬截断后 平滑衰减后 直观感受
1.0 1.0 1.00 低于阈值,两者完全一样
1.4 1.4 1.40 低于阈值,两者完全一样
1.5 1.5 1.50 阈值点,两者对齐
1.6 1.5 1.595 只超了一点点,压缩很轻微,几乎和原值差不多
2.0 1.5 1.693 超了0.5,增长明显放缓,不再是直线涨
3.0 1.5 1.946 超了1.5,涨幅被大幅压缩,不会涨到3.0
5.0 1.5 2.208 超了3.5,增速进一步变慢,和3.0的差距被缩小

可以明显看到:

刚超过阈值时,损失几乎不受影响,过渡非常顺滑;

损失越高,被压缩得越厉害,但始终保持越高越重的排序;

不会像硬截断那样,所有高损失全变成同一个值。

对应代码

python 复制代码
loss[high_loss_mask] = self.loss_threshold + torch.log(1 + loss[high_loss_mask] - self.loss_threshold)

拆解开:

  1. loss[high_loss_mask] - self.loss_threshold:算出损失超出阈值的部分(增量);
  2. 1 + 增量:加1保证对数的输入大于0,避免出现负数报错;
  3. torch.log(...):对超出的增量做对数压缩,让增量涨得变慢;
  4. self.loss_threshold + 压缩后的增量:把基准阈值加回来,保证阈值点处数值连续、没有台阶。

什么时候用硬截断,什么时候用平滑衰减?

方案 场景 特点
硬截断 确定有大量标注错误,想直接屏蔽极端错标的影响 简单粗暴,可控性强,调试方便
平滑衰减 样本大多是标注正确的难样本(比如小目标、低对比度),只想削弱、不想完全屏蔽 更温和,梯度平稳,训练更稳定,保留难样本的相对差异信息