《深度学习》—— 神经网络中的数据增强

文章目录

一、为什么要进行数据增强?

  • 神经网络中的数据增强是一种通过增加训练数据的多样性和数量来提高模型泛化能力的技术 。在神经网络训练过程中,尤其是在数据集较小的情况下,模型容易出现过拟合问题,即模型在训练数据上表现良好,但在未见过的数据上表现不佳。数据增强技术通过对原始数据进行一系列变换,生成新的数据样本,从而有效缓解这一问题。
  • 下面是一张猫的图片经过一系列变换,生成新的数据样本

二、常见的数据增强方法

1. 几何变换

  • 裁剪(Cropping)

    • 中心裁剪(CenterCrop):从图像中心裁剪出指定大小的区域。
    • 随机裁剪(RandomCrop):从图像中随机裁剪出指定大小的区域。
    • 随机大小裁剪(RandomResizedCrop):随机裁剪一个区域,并将其调整为指定大小。
  • 翻转(Flipping)

    • 水平翻转(RandomHorizontalFlip):以一定概率水平翻转图像。
    • 垂直翻转(RandomVerticalFlip):以一定概率垂直翻转图像。
  • 旋转(Rotation)

    • 随机旋转(RandomRotation):以随机角度旋转图像。
  • 仿射变换(Affine):包括旋转、缩放、平移、倾斜等多种变换的组合。

2. 颜色变换

  • 色彩抖动(ColorJitter):随机改变图像的亮度、对比度、饱和度和色调。
  • 灰度化(Grayscale):将彩色图像转换为灰度图像。
  • 随机灰度化(RandomGrayscale):以一定概率将图像转换为灰度图像。

3. 尺寸变换

  • 缩放(Resize/Rescale):将图像缩放到指定的大小。

4. 填充

  • 在图像的周围添加边框,以便进行进一步的裁剪或保持图像大小。

5. 噪声添加

  • 可以通过添加随机噪声(如高斯噪声)来增加数据的多样性。

6. 组合变换

  • 可以使用 transforms.Compose 将多个变换组合起来,一次性应用到图像上。此外,还可以通过 transforms.RandomChoicetransforms.RandomApplytransforms.RandomOrder 等方法,使得数据增强的过程更加随机和灵活。

三、代码实现

  • 在PyTorch中,torchvision库的transforms模块提供了丰富的数据增强方法

    python 复制代码
    from torchvision import transforms 
    
    """ 中心裁剪(CenterCrop) """
    transforms.CenterCrop(256),  # 从中心开始裁剪,裁剪大小为 256x256
    
    """ 随机裁剪(RandomCrop) """
    transform = transforms.RandomCrop(size=224)  # 裁剪为224x224大小
    
    """ 随机大小裁剪(RandomResizedCrop) """
    transform = transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0))  # 随机裁剪并缩放至224x224
    
    """" 水平翻转(RandomHorizontalFlip) """
    transform = transforms.RandomHorizontalFlip(p=0.5)  # 以0.5的概率水平翻转
    
    """ 垂直翻转(RandomVerticalFlip) """
    transform = transforms.RandomVerticalFlip(p=0.5)  # 以0.5的概率垂直翻转
    
    """ 随机旋转(RandomRotation) """
    transform = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)
    # 图像的亮度、对比度、饱和度和色调
    # 0.2 --> 可以被调整至原图的80%到120%之间
    
    """ 灰度化(Grayscale) """
    transform = transforms.Grayscale(num_output_channels=1)  # 转换为灰度图
    
    """ 缩放(Resize) """
    transform = transforms.Resize(size=(256, 256))  # 缩放至256x256
    
    """ 填充(Pad) """
    transform = transforms.Pad(padding=10, fill=0, padding_mode='constant')  # 四周填充10个像素,填充值为0
    
    """ 组合变换(Compose) """
    """ 可以将多个变换组合起来,一次性应用到图像上 """
    from torchvision import transforms  
    
    transform = transforms.Compose([  
        transforms.RandomResizedCrop(224),  
        transforms.RandomHorizontalFlip(),  
        transforms.ToTensor(),  
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  
    ])
    # 图像首先被随机裁剪并缩放至224x224大小,然后以0.5的概率进行水平翻转
    # 接着转换为Tensor类型,最后进行标准化处理

四、注意事项

  • 在使用数据增强技术时,应根据具体任务和数据集的特点选择合适的增强方法。
  • 过度使用数据增强可能会导致模型学习到不必要的噪声或变形特征,从而降低模型的性能。
  • 数据增强应与其他训练策略相结合,如正则化、早停等,以进一步提高模型的泛化能力。

五、总结

  • 神经网络中的数据增强是一种重要的技术手段,通过增加训练数据的多样性和数量来提高模型的泛化能力。
  • 在实际应用中,应根据具体需求和数据集特点选择合适的数据增强方法。
  • PyTorch 中的数据增强方法非常丰富,可以通过 torchvision 库中的 transforms 模块实现多种几何变换、颜色变换、尺寸变换等
相关推荐
TTGGGFF16 小时前
控制系统建模仿真(四):线性控制系统的数学模型
人工智能·算法
UXbot16 小时前
UI设计工具推荐合集
前端·人工智能·ui
kicikng16 小时前
智能体来了(西南总部)实战指南:AI调度官与AI Agent指挥官的Prompt核心逻辑
人工智能·prompt·多智能体系统
抓个马尾女孩16 小时前
为什么self-attention除以根号dk而不是其他值
人工智能·深度学习·机器学习·transformer
叫我辉哥e116 小时前
新手进阶Python:办公看板集成ERP跨系统同步+自动备份+AI异常复盘
开发语言·人工智能·python
Loo国昌16 小时前
【LangChain1.0】第五阶段:RAG高级篇(高级检索与优化)
人工智能·后端·语言模型·架构
伊克罗德信息科技16 小时前
技术分享 | 用Dify搭建个人AI知识助手
人工智能
TOPGUS16 小时前
谷歌发布三大AI购物新功能:从对话式搜索到AI代你下单
大数据·人工智能·搜索引擎·chatgpt·谷歌·seo·数字营销
Godspeed Zhao16 小时前
从零开始学AI4——背景知识3
人工智能