《深度学习》—— 神经网络中的数据增强

文章目录

一、为什么要进行数据增强?

  • 神经网络中的数据增强是一种通过增加训练数据的多样性和数量来提高模型泛化能力的技术 。在神经网络训练过程中,尤其是在数据集较小的情况下,模型容易出现过拟合问题,即模型在训练数据上表现良好,但在未见过的数据上表现不佳。数据增强技术通过对原始数据进行一系列变换,生成新的数据样本,从而有效缓解这一问题。
  • 下面是一张猫的图片经过一系列变换,生成新的数据样本

二、常见的数据增强方法

1. 几何变换

  • 裁剪(Cropping)

    • 中心裁剪(CenterCrop):从图像中心裁剪出指定大小的区域。
    • 随机裁剪(RandomCrop):从图像中随机裁剪出指定大小的区域。
    • 随机大小裁剪(RandomResizedCrop):随机裁剪一个区域,并将其调整为指定大小。
  • 翻转(Flipping)

    • 水平翻转(RandomHorizontalFlip):以一定概率水平翻转图像。
    • 垂直翻转(RandomVerticalFlip):以一定概率垂直翻转图像。
  • 旋转(Rotation)

    • 随机旋转(RandomRotation):以随机角度旋转图像。
  • 仿射变换(Affine):包括旋转、缩放、平移、倾斜等多种变换的组合。

2. 颜色变换

  • 色彩抖动(ColorJitter):随机改变图像的亮度、对比度、饱和度和色调。
  • 灰度化(Grayscale):将彩色图像转换为灰度图像。
  • 随机灰度化(RandomGrayscale):以一定概率将图像转换为灰度图像。

3. 尺寸变换

  • 缩放(Resize/Rescale):将图像缩放到指定的大小。

4. 填充

  • 在图像的周围添加边框,以便进行进一步的裁剪或保持图像大小。

5. 噪声添加

  • 可以通过添加随机噪声(如高斯噪声)来增加数据的多样性。

6. 组合变换

  • 可以使用 transforms.Compose 将多个变换组合起来,一次性应用到图像上。此外,还可以通过 transforms.RandomChoicetransforms.RandomApplytransforms.RandomOrder 等方法,使得数据增强的过程更加随机和灵活。

三、代码实现

  • 在PyTorch中,torchvision库的transforms模块提供了丰富的数据增强方法

    python 复制代码
    from torchvision import transforms 
    
    """ 中心裁剪(CenterCrop) """
    transforms.CenterCrop(256),  # 从中心开始裁剪,裁剪大小为 256x256
    
    """ 随机裁剪(RandomCrop) """
    transform = transforms.RandomCrop(size=224)  # 裁剪为224x224大小
    
    """ 随机大小裁剪(RandomResizedCrop) """
    transform = transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0))  # 随机裁剪并缩放至224x224
    
    """" 水平翻转(RandomHorizontalFlip) """
    transform = transforms.RandomHorizontalFlip(p=0.5)  # 以0.5的概率水平翻转
    
    """ 垂直翻转(RandomVerticalFlip) """
    transform = transforms.RandomVerticalFlip(p=0.5)  # 以0.5的概率垂直翻转
    
    """ 随机旋转(RandomRotation) """
    transform = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)
    # 图像的亮度、对比度、饱和度和色调
    # 0.2 --> 可以被调整至原图的80%到120%之间
    
    """ 灰度化(Grayscale) """
    transform = transforms.Grayscale(num_output_channels=1)  # 转换为灰度图
    
    """ 缩放(Resize) """
    transform = transforms.Resize(size=(256, 256))  # 缩放至256x256
    
    """ 填充(Pad) """
    transform = transforms.Pad(padding=10, fill=0, padding_mode='constant')  # 四周填充10个像素,填充值为0
    
    """ 组合变换(Compose) """
    """ 可以将多个变换组合起来,一次性应用到图像上 """
    from torchvision import transforms  
    
    transform = transforms.Compose([  
        transforms.RandomResizedCrop(224),  
        transforms.RandomHorizontalFlip(),  
        transforms.ToTensor(),  
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  
    ])
    # 图像首先被随机裁剪并缩放至224x224大小,然后以0.5的概率进行水平翻转
    # 接着转换为Tensor类型,最后进行标准化处理

四、注意事项

  • 在使用数据增强技术时,应根据具体任务和数据集的特点选择合适的增强方法。
  • 过度使用数据增强可能会导致模型学习到不必要的噪声或变形特征,从而降低模型的性能。
  • 数据增强应与其他训练策略相结合,如正则化、早停等,以进一步提高模型的泛化能力。

五、总结

  • 神经网络中的数据增强是一种重要的技术手段,通过增加训练数据的多样性和数量来提高模型的泛化能力。
  • 在实际应用中,应根据具体需求和数据集特点选择合适的数据增强方法。
  • PyTorch 中的数据增强方法非常丰富,可以通过 torchvision 库中的 transforms 模块实现多种几何变换、颜色变换、尺寸变换等
相关推荐
ZHOU_WUYI5 分钟前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若1237 分钟前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界35 分钟前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK2215136 分钟前
机器学习系列----关联分析
人工智能·机器学习
Robot2511 小时前
Figure 02迎重大升级!!人形机器人独角兽[Figure AI]商业化加速
人工智能·机器人·微信公众平台
浊酒南街2 小时前
Statsmodels之OLS回归
人工智能·数据挖掘·回归
畅联云平台2 小时前
美畅物联丨智能分析,安全管控:视频汇聚平台助力智慧工地建设
人工智能·物联网
加密新世界2 小时前
优化 Solana 程序
人工智能·算法·计算机视觉
hunteritself2 小时前
ChatGPT高级语音模式正在向Web网页端推出!
人工智能·gpt·chatgpt·openai·语音识别
Che_Che_3 小时前
Cross-Inlining Binary Function Similarity Detection
人工智能·网络安全·gnn·二进制相似度检测