【深度学习】重采样(Resampling)

在深度学习的背景下,重采样主要涉及两个方面:

  1. 数据层面的重采样:处理不平衡数据集。
  2. 模型层面的重采样:在神经网络内部进行上采样(UpSampling)或下采样(DownSampling),常见于架构如编码器-解码器(Encoder-Decoder)或生成对抗网络(GAN)。

1. 数据层面的重采样:处理类别不平衡

在许多现实世界的数据集中(如医疗诊断、欺诈检测),不同类别的样本数量可能差异巨大。例如,99%的样本是正常交易,只有1%是欺诈交易。如果直接用这样的数据训练模型,模型会倾向于预测多数类,导致对少数类的识别率极差。

解决方法就是通过重采样来调整训练集的分布。

A. 过采样(Oversampling)

增加少数类的样本数量,使其与多数类相当。

  • 随机过采样:随机复制少数类样本。

    • 优点:简单。
    • 缺点:容易导致过拟合,因为模型会多次看到完全相同的样本。
  • SMOTE(Synthetic Minority Over-sampling Technique)创建新的合成样本,而不是简单复制。

    • 原理:对每一个少数类样本,从其K个最近邻中随机选择一个样本,然后在这两个样本的连线上随机插值一点,作为新样本。
    • 优点:有效增加了样本多样性,缓解了过拟合问题。
    • 缺点:可能会在多数类样本密集的区域创造一些"模糊"的样本,增加类别间的重叠。
B. 欠采样(Undersampling)

减少多数类的样本数量,使其与少数类相当。

  • 随机欠采样 :随机地从多数类中删除一些样本。
    • 优点:简单,减少训练时间。
    • 缺点:可能会丢失多数类中包含的重要信息,导致模型欠拟合。

通常,SMOTE(或其变体,如ADASYN)与随机欠采样结合使用被认为是效果更好的策略。

在代码中的实现(以imbalanced-learn库为例)
python 复制代码
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
from imblearn.pipeline import Pipeline
from collections import Counter

# 假设 X_train, y_train 是你的训练数据和标签
print(f'Original dataset shape {Counter(y_train)}')

# 定义一个采样管道:先SMOTE过采样,再随机欠采样
over = SMOTE(sampling_strategy=0.5)  # 使少数类达到多数类的一半数量
under = RandomUnderSampler(sampling_strategy=0.5) # 使多数类降到少数类的两倍数量
steps = [('o', over), ('u', under)]
pipeline = Pipeline(steps=steps)

# 应用重采样
X_resampled, y_resampled = pipeline.fit_resample(X_train, y_train)

print(f'Resampled dataset shape {Counter(y_resampled)}')

现代替代方案 :除了重采样,还可以使用加权损失函数 (如class_weight in PyTorch's CrossEntropyLoss or TensorFlow/Keras)。这种方法在计算损失时,给少数类的错误预测赋予更高的权重,从而让模型更关注少数类。这通常是更受欢迎的方法,因为它不改变数据分布,计算高效。


2. 模型层面的重采样:特征图的空间变换

在卷积神经网络(CNN)架构中,特别是用于分割(如U-Net)、检测(如SSD)或生成(如GAN)的模型中,网络需要在不同分辨率的特征图之间进行转换。这就用到了上采样和下采样操作。

A. 下采样(DownSampling)

目的:增大感受野 ,提取更抽象、更全局的特征,同时减少计算量

  • 池化层(Pooling Layers)

    • 最大池化(Max Pooling):取窗口内的最大值。能更好地保留纹理特征。
    • 平均池化(Average Pooling):取窗口内的平均值。能更好地保留整体数据的特征。
    • 目前,最大池化更为常用
  • 带步长的卷积(Strided Convolution)

    • 使用stride > 1的卷积层,在计算卷积的同时直接实现下采样。
    • 例如,一个3x3卷积核,stride=2,输出特征图的高和宽会减半。
    • 这是现代架构(如ResNet)的首选,因为卷积层可以学习到最优的下采样方式,而池化层是确定性的、不可学习的。
B. 上采样(UpSampling)

目的:恢复空间分辨率,将压缩的、抽象的特征图映射回高分辨率的空间,用于像素级预测(如图像分割)或生成图像。

  • 转置卷积(Transposed Convolution / Deconvolution)

    • 虽然不是真正的反卷积,但它是可学习的上采样方法
    • 它通过插入零值或进行插值来扩展输入特征图的大小,然后进行常规卷积操作。
    • 缺点:如果核大小和步长参数设置不当,容易产生"棋盘效应"(checkerboard artifacts)。
  • 上采样 + 卷积(Upsampling + Convolution)

    • 先使用最近邻插值(Nearest Neighbor)双线性插值(Bilinear) 等不可学习的插值方法将特征图尺寸放大。
    • 然后跟一个普通的1x13x3卷积来平滑和细化特征。
    • 这种方法可以有效避免棋盘效应,是许多现代架构(如SRGAN)的选择。
  • Unpooling

    • 通常与Max Pooling配对使用。记录下Max Pooling时最大值的位置,在Unpooling时,将值放回原位置,其他位置填0。
    • 在U-Net等网络中有所应用,但不如前两种方法普遍。
在代码中的实现(以PyTorch为例)
python 复制代码
import torch
import torch.nn as nn

# 下采样
downsample_by_pool = nn.MaxPool2d(kernel_size=2, stride=2) # 使用池化
downsample_by_conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1) # 使用步长卷积

# 上采样
upsample_by_transpose = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=2, stride=2) # 转置卷积
upsample_by_interpolation = nn.Sequential(
    nn.Upsample(scale_factor=2, mode='nearest'),  # 或 'bilinear'
    nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1)
)

总结

方面 类型 方法 目的 关键点
数据重采样 过采样 随机过采样、SMOTE 平衡类别分布,解决不平衡问题 SMOTE创建合成样本,优于随机复制
欠采样 随机欠采样 平衡类别分布,解决不平衡问题 可能丢失信息,常与过采样结合使用
模型重采样 下采样 池化层、步长卷积 扩大感受野,减少计算量 步长卷积是现代首选
上采样 转置卷积、插值+卷积 恢复空间分辨率,用于密集预测 插值+卷积可避免棋盘效应

选择哪种重采样技术完全取决于你的具体任务:

  • 如果你的数据标签不平衡 ,优先考虑加权损失函数 ,如果效果不佳再尝试SMOTE结合欠采样
  • 如果你在设计网络结构 (如图像分割),步长卷积 用于下采样,最近邻/双线性上采样 + 卷积是稳健且高效的上采样选择。
相关推荐
草莓熊Lotso41 分钟前
Linux 文件描述符与重定向实战:从原理到 minishell 实现
android·linux·运维·服务器·数据库·c++·人工智能
Coder_Boy_2 小时前
技术发展的核心规律是「加法打底,减法优化,重构平衡」
人工智能·spring boot·spring·重构
会飞的老朱4 小时前
医药集团数智化转型,智能综合管理平台激活集团管理新效能
大数据·人工智能·oa协同办公
聆风吟º5 小时前
CANN runtime 实战指南:异构计算场景中运行时组件的部署、调优与扩展技巧
人工智能·神经网络·cann·异构计算
Codebee7 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º8 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys8 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_56788 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子8 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
智驱力人工智能9 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算