当前手语识别训练用的是多分类交叉熵损失函数:
python
criterion = nn.CrossEntropyLoss()
位置在:
训练时是这样算的:
python
logits = model(keypoints, valid_mask)
loss = criterion(logits, labels)
也就是:
text
模型输出 logits: [B, 500]
真实标签 labels: [B]
交叉熵损失 loss
为什么用 CrossEntropyLoss
你的任务是:
text
CSL2018 500 类孤立手语识别
每个样本只属于一个类别,比如:
text
label = 0
label = 128
label = 499
这是标准的单标签多分类任务,所以用交叉熵是合理的。
logits 是什么
模型最后输出:
text
[B, 500]
例如一个 batch 有 32 个样本:
text
[32, 500]
每一行是一个样本对 500 个手语类别的原始分数:
text
[类别0分数, 类别1分数, ..., 类别499分数]
注意这里不是概率,还没经过 softmax。
CrossEntropyLoss 内部会自动做:
text
softmax + negative log likelihood
所以代码里不需要手动写 softmax。
labels 是什么
labels 形状是:
text
[B]
例如:
python
labels = [0, 12, 88, 499]
表示每个样本的真实类别编号。
损失函数想让模型学什么
交叉熵的目标是:
text
让真实类别的预测分数尽量高
让其他类别的预测分数尽量低
例如某个样本真实类别是 128,模型输出 500 个类别分数,交叉熵会推动:
text
第 128 类分数升高
其他类别分数降低
你的训练输出里的 loss 是什么
例如:
text
train_loss=0.2942
val_loss=0.8619
含义是:
text
train_loss:训练集平均交叉熵损失
val_loss:验证集平均交叉熵损失
一般来说:
text
loss 越低越好
top1 越高越好
但论文里通常主要报告:
text
Top-1 Accuracy
Top-5 Accuracy
loss 更多用于观察训练是否收敛、是否过拟合。
当前损失函数有没有改进空间
有。现在是最基础的:
python
nn.CrossEntropyLoss()
后续可以考虑:
python
nn.CrossEntropyLoss(label_smoothing=0.1)
也就是 label smoothing,标签平滑。
它的作用是缓解模型过度自信,对你这种 500 类任务可能有帮助,尤其你之前出现过:
text
train_top1 很高
val_top1 明显低一些
说明有一定过拟合。
不过现在论文实验建议先保持基础交叉熵作为主设置,后续可以把 label smoothing 作为优化实验或正则化实验。
Label Smoothing Cross Entropy 本质上是在防止模型对训练标签过度自信。
普通交叉熵里,标签是 one-hot:
text
真实类别 = 3
目标分布:
类别0: 0
类别1: 0
类别2: 0
类别3: 1
类别4: 0
也就是说,模型被要求把真实类别概率推到 1.0,其他类别全部推到 0.0。
但在手语识别里,很多类别动作很像,关键点也可能有噪声。如果模型被迫对训练样本"绝对自信",就容易记住训练集细节,验证集泛化变差。
Label Smoothing 会把硬标签变软。比如 label_smoothing=0.1,500 类任务里大概变成:
text
真实类别: 0.9 左右
其他类别: 分到很小的一点概率
它表达的意思是:
text
这个样本主要属于正确类别,但不要认为其他类别概率必须绝对为 0。
对你的项目来说,它主要解决三个问题:
-
缓解过拟合
你之前 FULL 模型训练集 top1 到 98% 左右,验证集停在 82% 左右,说明模型已经很会记训练集了。Label smoothing 会降低这种"死记硬背"。
-
降低模型过度自信
普通 CE 容易让 logits 拉得很大,模型输出非常尖锐。Label smoothing 会让输出分布更平滑。
-
适合细粒度相似类别
手语类别之间可能只有手形、方向、轨迹的一点差异。软标签允许模型保留一点不确定性,更符合这个任务特性。
一句话概括:
普通 Cross Entropy 是"必须百分百相信这个标签";Label Smoothing Cross Entropy 是"相信这个标签,但别过度自信"。
论文里可以这样写:
为缓解孤立手语类别间细粒度相似性和关键点噪声导致的过拟合问题,本文在交叉熵损失中引入标签平滑策略,将硬标签分布转化为软标签分布,从而抑制模型对训练样本的过度置信,提高模型泛化能力。