「Pytorch」CopyPaste 数据增强

数据增广是提升模型泛化能力重要的手段之一,CopyPaste 是一种新颖的数据增强技巧,已经在目标检测和实例分割任务中验证了有效性。利用 CopyPaste,可以合成文本实例来平衡训练图像中的正负样本之间的比例。相比而言,传统图像旋转、随机翻转和随机裁剪是无法做到的。

CopyPaste 主要步骤包括:

  1. 随机选择两幅训练图像;
  2. 随机尺度抖动缩放;
  3. 随机水平翻转;
  4. 随机选择一幅图像中的目标子集;
  5. 粘贴在另一幅图像中随机的位置。

这样就比较好地提升了样本丰富度,同时也增加了模型对环境的鲁棒性。如下图所示,通过在左下角的图中裁剪出来的文本,随机旋转缩放之后粘贴到左上角的图像中,进一步丰富了该文本在不同背景下的多样性。

参考代码:

python 复制代码
#  !/usr/bin/env  python
#  -*- coding:utf-8 -*-
# @Time   :  2024.07
# @Author :  绿色羽毛
# @Email  :  lvseyumao@foxmail.com
# @Blog   :  https://blog.csdn.net/ViatorSun
# @Note   :



import os
import cv2
import json
import logging
import random
import numpy as np

import matplotlib.pyplot as plt



def create_operators(op_param_list, global_config=None):
    """
    create operators based on the config

    Args:
        params(list): a dict list, used to create some operators
    """
    assert isinstance(op_param_list, list), ('operator config should be a list')
    ops = []
    for operator in op_param_list:
        assert isinstance(operator,
                          dict) and len(operator) == 1, "yaml format error"
        op_name = list(operator)[0]
        param = {} if operator[op_name] is None else operator[op_name]
        if global_config is not None:
            param.update(global_config)
        op = eval(op_name)(**param)
        ops.append(op)
    return ops


def transform(data, ops=None):
    """ transform """
    if ops is None:
        ops = []
    for op in ops:
        data = op(data)
        if data is None:
            return None
    return data



# CopyPaste示例的类
class CopyPasteDemo(object):
    def __init__(self, ):
        self.data_dir = "/media/sun/Data/Dataset/OCR_Data/det/train/"
        self.label_file_list = "/media/sun/Data/Dataset/OCR_Data/det/train.txt"
        self.data_lines = self.get_image_info_list(self.label_file_list)
        self.data_idx_order_list = list(range(len(self.data_lines)))
        transforms = [
            {"DecodeImage": {"img_mode": "BGR", "channel_first": False}},
            {"DetLabelEncode": {}},
            {"CopyPaste": {"objects_paste_ratio": 1.0}},
        ]
        self.ops = create_operators(transforms)

    # 选择一张图像,将其中的内容拷贝到当前图像中
    def get_ext_data(self, idx):
        ext_data_num = 1
        ext_data = []
        next_idx = idx
        load_data_ops = self.ops[:2]

        while len(ext_data) < ext_data_num:
            next_idx = (next_idx + 1) % len(self)
            file_idx = self.data_idx_order_list[next_idx]
            data_line = self.data_lines[file_idx]
            data_line = data_line.decode('utf-8')
            substr = data_line.strip("\n").split("\t")
            file_name = substr[0]
            label = substr[1]
            img_path = os.path.join(self.data_dir, file_name)
            data = {'img_path': img_path, 'label': label}
            if not os.path.exists(img_path):
                continue
            with open(data['img_path'], 'rb') as f:
                img = f.read()
                data['image'] = img
            data = transform(data, load_data_ops)
            if data is None:
                continue
            ext_data.append(data)
        return ext_data

    # 获取图像信息
    def get_image_info_list(self, file_list):
        if isinstance(file_list, str):
            file_list = [file_list]
        data_lines = []
        for idx, file in enumerate(file_list):
            with open(file, "rb") as f:
                lines = f.readlines()
                data_lines.extend(lines)
        return data_lines

    # 获取DataSet中的一条数据
    def __getitem__(self, idx):
        file_idx = self.data_idx_order_list[idx]
        data_line = self.data_lines[file_idx]
        try:
            data_line = data_line.decode('utf-8')
            substr = data_line.strip("\n").split("\t")
            file_name = substr[0]
            label = substr[1]
            img_path = os.path.join(self.data_dir, file_name)
            data = {'img_path': img_path, 'label': label}
            if not os.path.exists(img_path):
                raise Exception("{} does not exist!".format(img_path))
            with open(data['img_path'], 'rb') as f:
                img = f.read()
                data['image'] = img
            data['ext_data'] = self.get_ext_data(idx)
            outs = transform(data, self.ops)
        except Exception as e:
            print("When parsing line {}, error happened with msg: {}".format(data_line, e))
            outs = None
        if outs is None:
            return
        return outs

    def __len__(self):
        return len(self.data_idx_order_list)



if __name__ == '__main__':
    copy_paste_demo = CopyPasteDemo()

    idx = 1
    data1 = copy_paste_demo[idx]
    print(data1.keys())
    print(data1["img_path"])
    print(data1["ext_data"][0]["img_path"])
    infos = copy_paste_demo.data_lines[idx]
    infos = json.loads(infos.decode('utf-8').split("\t")[1])

    img3 = data1["image"].copy()
    plt.figure(figsize=(15, 10))
    plt.imshow(img3[:, :, ::-1])
    # 原始标注信息
    for info in infos:
        xs, ys = zip(*info["points"])
        xs = list(xs)
        ys = list(ys)
        xs.append(xs[0])
        ys.append(ys[0])
        plt.plot(xs, ys, "r")
    # 新增的标注信息
    for poly_idx in range(len(infos), len(data1["polys"])):
        poly = data1["polys"][poly_idx]
        xs, ys = zip(*poly)
        xs = list(xs)
        ys = list(ys)
        xs.append(xs[0])
        ys.append(ys[0])
        plt.plot(xs, ys, "b")
    plt.show()

生成后的图像

相关推荐
xingshanchang16 分钟前
PyTorch 不支持旧GPU的异常状态与解决方案:CUDNN_STATUS_NOT_SUPPORTED_ARCH_MISMATCH
人工智能·pytorch·python
昵称是6硬币3 小时前
YOLOv11: AN OVERVIEW OF THE KEY ARCHITECTURAL ENHANCEMENTS目标检测论文精读(逐段解析)
图像处理·人工智能·深度学习·yolo·目标检测·计算机视觉
费弗里3 小时前
Python全栈应用开发利器Dash 3.x新版本介绍(1)
python·dash
李少兄9 天前
解决OSS存储桶未创建导致的XML错误
xml·开发语言·python
就叫飞六吧9 天前
基于keepalived、vip实现高可用nginx (centos)
python·nginx·centos
Vertira9 天前
PyTorch中的permute, transpose, view, reshape和flatten函数详解(已解决)
人工智能·pytorch·python
heimeiyingwang9 天前
【深度学习加速探秘】Winograd 卷积算法:让计算效率 “飞” 起来
人工智能·深度学习·算法
学Linux的语莫9 天前
python基础语法
开发语言·python
匿名的魔术师9 天前
实验问题记录:PyTorch Tensor 也会出现 a = b 赋值后,修改 a 会影响 b 的情况
人工智能·pytorch·python
Ven%9 天前
PyTorch 张量(Tensors)全面指南:从基础到实战
人工智能·pytorch·python