快手面试题:样本不均衡问题

题目

在深度学习训练过程中,如何解决样本不均衡问题?

解答

在深度学习的训练过程中,样本不均衡是一个非常常见且棘手的问题。它指的是数据集中不同类别的样本数量差异巨大,会导致模型偏向于多数类,忽视少数类。

解决样本不均衡问题通常需要从数据层面模型层面损失函数层面三管齐下。以下是系统性的解决方案:

一、 数据层面

这是最直接的方法,核心是改变数据的分布,让模型在训练时看到更均衡的类别。

1. 重采样
  • 过采样:随机复制少数类的样本,使其数量与多数类持平。

    • 优点:简单,不会丢失信息。

    • 缺点 :容易导致过拟合,因为模型记住了复制样本的细节,缺乏泛化性。

    • 进阶 :使用 SMOTE 及其变种。SMOTE 通过在少数类样本之间进行插值来生成新的合成样本,而不是简单复制,能有效缓解过拟合。

  • 欠采样:随机丢弃多数类的样本,使其数量与少数类持平。

    • 优点:减少了训练数据量,加快了训练速度。

    • 缺点:可能会丢失多数类中的重要信息。

    • 进阶 :使用 Tomek LinksNearMiss 等方法,只删除那些对分类边界贡献不大的冗余多数类样本。

  • 混合采样:先过采样,再欠采样,结合两者优点。

2. 数据增强
  • 针对少数类:如果少数类是图像、文本或语音,可以对其进行特定的增强。

    • 图像:对少数类图片进行旋转、裁剪、翻转、颜色抖动、CutMix 等,生成更多样化的训练样本。

    • 文本:进行回译、同义词替换等。

    • 这种方法往往比简单的过采样效果更好,因为它引入了数据多样性。

二、 损失函数层面

不改变数据,而是让模型在计算误差时,意识到少数类样本分错的代价更大。

1. 加权损失函数
  • 原理:在计算损失(如 CrossEntropyLoss)时,为不同类别赋予不同的权重。多数类的权重低,少数类的权重高。

  • 权重设置

    • 通常设置为样本数的反比:权重 = 总样本数 / (类别数 * 该类样本数)

    • 或者使用中位数频率平衡等方法。

  • 实现 :PyTorch 的 torch.nn.CrossEntropyLoss 中自带 weight 参数;TensorFlow 的 class_weight 参数。

2. Focal Loss
  • 适用场景单阶段目标检测(如 RetinaNet)或极难样本分类。

  • 原理 :Focal Loss 在交叉熵损失的基础上增加了一个调制因子。它会降低易分类样本 (通常是多数类)的损失贡献,从而迫使模型关注难分类的样本(通常是少数类)。

  • 公式理念:当样本被正确分类且概率很高时,它的损失被大幅降低;当样本被错分或概率不高时,它的损失保留。

3. OHEM
  • 原理 :在每个训练批次中,不是对所有样本计算损失,而是选出损失最高(最难分)的那一批样本进行反向传播。这些难分样本往往包含了少数类。

三、 模型/训练策略层面

1. 难例挖掘
  • 在训练过程中,先用模型对样本进行预测,找出那些预测错误的样本(尤其是被预测为多数类的少数类样本),将它们加入训练集,或者加大这些样本在下一轮训练中的权重。
2. 异常检测视角
  • 如果少数类占比极低(如千分之一以下),可以考虑放弃分类方法,转而使用异常检测/单分类方法。

  • 方法:训练一个自动编码器只学习多数类的特征,当遇到少数类时,其重构误差会很大,据此来判断。

3. 两阶段训练
  • Stage 1:在原始不均衡数据上正常训练,让模型学到基本的特征分布。

  • Stage 2:对少数类进行过采样或重加权,微调模型,让模型调整决策边界,更关注少数类。

4. 集成学习
  • 平衡随机森林:每个子模型在训练时,都从多数类中随机抽取一部分,并结合全部少数类进行训练。最后综合多个子模型的结果。

四、 评估指标的调整

在解决不均衡问题时,千万不要只盯着 Accuracy(准确率)看。因为如果正负比是 99:1,模型把所有样本判为负类,准确率也有 99%,但这个模型毫无价值。

建议使用的指标:

  • 混淆矩阵

  • 精确率

  • 召回率

  • F1-score(精确率和召回率的调和平均)

  • AUC-ROC

总结建议

在实际操作中,通常按以下顺序尝试:

  1. 如果数据量较大 :尝试欠采样 + 集成学习 ,或者直接使用加权损失函数(这是最简单且效果不错的起点)。

  2. 如果数据量中等 :尝试SMOTE 过采样数据增强

  3. 如果极度不均衡 :尝试Focal Loss ,或者考虑改用异常检测思路。

  4. 评估时 :务必关注 F1-score召回率

相关推荐
V搜xhliang02462 小时前
自然语言理解与语音识别(ASR)
大数据·人工智能·机器学习·自然语言处理·机器人·语音识别·xcode
人工智能培训2 小时前
数字孪生在航空领域的应用方法及案例
人工智能·机器学习·知识图谱·数字孪生·企业ai培训
Dfreedom.2 小时前
归一化技术全景指南
深度学习·算法·机器学习·归一化
南滑散修3 小时前
机器学习(三):SVM支持向量机算法
算法·机器学习·支持向量机
浩哥依然3 小时前
【论文笔记之 ULCNET】Ultra Low Complexity Deep Learning Based Noise Suppression
论文阅读·深度学习·神经网络·语音增强·语音降噪小模型
renhongxia13 小时前
人工智能代理能生成微服务吗?我们离多远了?
人工智能·深度学习·学习·微服务·云原生·架构·机器人
盼小辉丶4 小时前
视觉Transformer实战 | Cross-Attention Multi-Scale Vision Transformer(CrossViT)详解与实现
深度学习·计算机视觉·transformer
智星云算力4 小时前
实验室无GPU如何深度学习
人工智能·深度学习·阿里云·智星云·gpu算力租用
zh路西法4 小时前
【宇树机器人强化学习】(四):Go2基础训练以及参数调节与解析
python·深度学习·ubuntu·机器学习·机器人