损失函数 的 硬截断 和 平滑衰减
flyfish
在逐样本损失计算完成、取平均之前,对损失过高的样本做权重压制,不删除样本,只削弱它们对梯度的贡献,属于软降权------既保留了样本的监督信号,又避免极端难样本/疑似错标样本带偏整个模型。
损失硬截断
损失硬截断是给单样本损失设置一个上限,超过这个阈值的损失,直接按阈值计算。相当于一刀切,超过上限的样本梯度不再放大。
代码实现
python
class FocalLossWithSmoothing(nn.Module):
def __init__(self, gamma=2, alpha=None, smoothing=0.0, num_classes=2, max_loss=None):
"""
:param max_loss: 单样本损失上限,None表示不开启截断;设置数值后,单样本损失不会超过该值
"""
super().__init__()
self.gamma = gamma
self.alpha = torch.tensor(alpha).to(DEVICE) if alpha else None
self.smoothing = smoothing
self.num_classes = num_classes
self.max_loss = max_loss # 损失截断阈值
def forward(self, inputs, targets):
targets_one_hot = torch.zeros_like(inputs).scatter_(1, targets.unsqueeze(1), 1)
soft_targets = targets_one_hot * (1 - self.smoothing) + self.smoothing / self.num_classes
log_probs = torch.nn.functional.log_softmax(inputs, dim=1)
probs = torch.exp(log_probs)
p_t = (probs * targets_one_hot).sum(dim=1, keepdim=True)
focal_weight = (1 - p_t) ** self.gamma
ce_loss = (-soft_targets * log_probs).sum(dim=1)
loss = focal_weight.squeeze() * ce_loss
if self.alpha is not None:
alpha_t = (self.alpha.unsqueeze(0) * targets_one_hot).sum(dim=1)
loss = loss * alpha_t
# ========== 损失截断 ==========
if self.max_loss is not None:
loss = torch.clamp(loss, max=self.max_loss)
return loss.mean()
使用方式
在训练函数里初始化损失时,多加一个 max_loss 参数即可:
python
# 示例:单样本损失最高不超过2.0,超过的全部按2.0计算
criterion = FocalLossWithSmoothing(
gamma=FOCAL_GAMMA,
alpha=FOCAL_ALPHA,
smoothing=LABEL_SMOOTHING,
num_classes=NUM_CLASSES,
max_loss=2.0 # 开启截断,阈值可按需调整
)
平滑衰减降权
硬截断是一刀切:损失超过阈值,直接砍平,损失值瞬间不再增长,像台阶一样突变;
平滑衰减是越涨越慢:损失低于阈值时正常计算,超过阈值后还能继续涨,但增长速度会越来越慢,过渡是顺滑的曲线,没有突变台阶。
它的目的:既保留损失越高、权重越大的相对顺序,又不让极端高损失样本无限放大梯度带偏模型,同时保证训练过程梯度平稳,不会出现跳变。
代码实现 只需要把截断部分替换成平滑衰减逻辑即可:
python
class FocalLossWithSmoothing(nn.Module):
def __init__(self, gamma=2, alpha=None, smoothing=0.0, num_classes=3, loss_threshold=1.8):
super().__init__()
self.gamma = gamma
self.alpha = torch.tensor(alpha).to(DEVICE) if alpha else None
self.smoothing = smoothing
self.num_classes = num_classes
self.loss_threshold = loss_threshold # 平滑衰减阈值
def forward(self, inputs, targets):
targets_one_hot = torch.zeros_like(inputs).scatter_(1, targets.unsqueeze(1), 1)
soft_targets = targets_one_hot * (1 - self.smoothing) + self.smoothing / self.num_classes
log_probs = torch.nn.functional.log_softmax(inputs, dim=1)
probs = torch.exp(log_probs)
p_t = (probs * targets_one_hot).sum(dim=1, keepdim=True)
focal_weight = (1 - p_t) ** self.gamma
ce_loss = (-soft_targets * log_probs).sum(dim=1)
loss = focal_weight.squeeze() * ce_loss
if self.alpha is not None:
alpha_t = (self.alpha.unsqueeze(0) * targets_one_hot).sum(dim=1)
loss = loss * alpha_t
# 平滑衰减降权:压制极端高损失样本
if self.loss_threshold is not None:
high_loss_mask = loss > self.loss_threshold
loss[high_loss_mask] = self.loss_threshold + torch.log(1 + loss[high_loss_mask] - self.loss_threshold)
return loss.mean()
假设设置阈值 = 1.5,看不同原始损失对应的处理结果:
| 原始单样本损失 | 硬截断后损失 | 变化特点 |
|---|---|---|
| 1.0(正常样本) | 1.0 | 低于阈值,完全不变 |
| 1.4(较难样本) | 1.4 | 低于阈值,完全不变 |
| 1.5(阈值点) | 1.5 | 刚好等于阈值 |
| 1.6(难样本) | 1.5 | 超过一点点,直接被砍成1.5,瞬间停止增长 |
| 3.0(极难/错标样本) | 1.5 | 不管多高,全砍成1.5,和1.6的样本权重完全一样 |
硬截断的问题:
- 阈值点处损失突变,梯度也会突变,训练过程容易出现震荡;
- 所有超过阈值的样本,损失都一样,丢失了难分程度的差异信息------3.0的极难样本和1.6的轻微难样本,对模型的贡献变得完全相同,有点矫枉过正。
平滑衰减的逻辑:两段式 + 对数压缩
代码里用的是阈值以下正常计算,阈值以上对数压缩的两段式策略,公式是:
处理后损失={原始损失原始损失≤阈值阈值+log(1+原始损失−阈值)原始损失>阈值 \text{处理后损失} = \begin{cases} \text{原始损失} & \text{原始损失} \le 阈值 \\ 阈值 + \log(1 + \text{原始损失} - 阈值) & \text{原始损失} > 阈值 \end{cases} 处理后损失={原始损失阈值+log(1+原始损失−阈值)原始损失≤阈值原始损失>阈值
为什么用 log(对数)函数?
对数函数有两个完美匹配需求的特性:
- 单调递增:原始损失越大,处理后的损失也一定越大,不会改变谁更难、谁损失更高的排序,样本的相对权重关系保留了;
- 增速递减:x 越大,log(x) 涨得越慢。原始损失越高,压缩力度越强,正好符合极端样本降权更多的需求。
直观对比效果
还是设阈值 = 1.5,算一组真实数值,一眼就能看出区别:
| 原始单样本损失 | 硬截断后 | 平滑衰减后 | 直观感受 |
|---|---|---|---|
| 1.0 | 1.0 | 1.00 | 低于阈值,两者完全一样 |
| 1.4 | 1.4 | 1.40 | 低于阈值,两者完全一样 |
| 1.5 | 1.5 | 1.50 | 阈值点,两者对齐 |
| 1.6 | 1.5 | 1.595 | 只超了一点点,压缩很轻微,几乎和原值差不多 |
| 2.0 | 1.5 | 1.693 | 超了0.5,增长明显放缓,不再是直线涨 |
| 3.0 | 1.5 | 1.946 | 超了1.5,涨幅被大幅压缩,不会涨到3.0 |
| 5.0 | 1.5 | 2.208 | 超了3.5,增速进一步变慢,和3.0的差距被缩小 |
可以明显看到:
刚超过阈值时,损失几乎不受影响,过渡非常顺滑;
损失越高,被压缩得越厉害,但始终保持越高越重的排序;
不会像硬截断那样,所有高损失全变成同一个值。
对应代码
python
loss[high_loss_mask] = self.loss_threshold + torch.log(1 + loss[high_loss_mask] - self.loss_threshold)
拆解开:
loss[high_loss_mask] - self.loss_threshold:算出损失超出阈值的部分(增量);1 + 增量:加1保证对数的输入大于0,避免出现负数报错;torch.log(...):对超出的增量做对数压缩,让增量涨得变慢;self.loss_threshold + 压缩后的增量:把基准阈值加回来,保证阈值点处数值连续、没有台阶。
什么时候用硬截断,什么时候用平滑衰减?
| 方案 | 场景 | 特点 |
|---|---|---|
| 硬截断 | 确定有大量标注错误,想直接屏蔽极端错标的影响 | 简单粗暴,可控性强,调试方便 |
| 平滑衰减 | 样本大多是标注正确的难样本(比如小目标、低对比度),只想削弱、不想完全屏蔽 | 更温和,梯度平稳,训练更稳定,保留难样本的相对差异信息 |