「Pytorch」CopyPaste 数据增强

数据增广是提升模型泛化能力重要的手段之一,CopyPaste 是一种新颖的数据增强技巧,已经在目标检测和实例分割任务中验证了有效性。利用 CopyPaste,可以合成文本实例来平衡训练图像中的正负样本之间的比例。相比而言,传统图像旋转、随机翻转和随机裁剪是无法做到的。

CopyPaste 主要步骤包括:

  1. 随机选择两幅训练图像;
  2. 随机尺度抖动缩放;
  3. 随机水平翻转;
  4. 随机选择一幅图像中的目标子集;
  5. 粘贴在另一幅图像中随机的位置。

这样就比较好地提升了样本丰富度,同时也增加了模型对环境的鲁棒性。如下图所示,通过在左下角的图中裁剪出来的文本,随机旋转缩放之后粘贴到左上角的图像中,进一步丰富了该文本在不同背景下的多样性。

参考代码:

python 复制代码
#  !/usr/bin/env  python
#  -*- coding:utf-8 -*-
# @Time   :  2024.07
# @Author :  绿色羽毛
# @Email  :  lvseyumao@foxmail.com
# @Blog   :  https://blog.csdn.net/ViatorSun
# @Note   :



import os
import cv2
import json
import logging
import random
import numpy as np

import matplotlib.pyplot as plt



def create_operators(op_param_list, global_config=None):
    """
    create operators based on the config

    Args:
        params(list): a dict list, used to create some operators
    """
    assert isinstance(op_param_list, list), ('operator config should be a list')
    ops = []
    for operator in op_param_list:
        assert isinstance(operator,
                          dict) and len(operator) == 1, "yaml format error"
        op_name = list(operator)[0]
        param = {} if operator[op_name] is None else operator[op_name]
        if global_config is not None:
            param.update(global_config)
        op = eval(op_name)(**param)
        ops.append(op)
    return ops


def transform(data, ops=None):
    """ transform """
    if ops is None:
        ops = []
    for op in ops:
        data = op(data)
        if data is None:
            return None
    return data



# CopyPaste示例的类
class CopyPasteDemo(object):
    def __init__(self, ):
        self.data_dir = "/media/sun/Data/Dataset/OCR_Data/det/train/"
        self.label_file_list = "/media/sun/Data/Dataset/OCR_Data/det/train.txt"
        self.data_lines = self.get_image_info_list(self.label_file_list)
        self.data_idx_order_list = list(range(len(self.data_lines)))
        transforms = [
            {"DecodeImage": {"img_mode": "BGR", "channel_first": False}},
            {"DetLabelEncode": {}},
            {"CopyPaste": {"objects_paste_ratio": 1.0}},
        ]
        self.ops = create_operators(transforms)

    # 选择一张图像,将其中的内容拷贝到当前图像中
    def get_ext_data(self, idx):
        ext_data_num = 1
        ext_data = []
        next_idx = idx
        load_data_ops = self.ops[:2]

        while len(ext_data) < ext_data_num:
            next_idx = (next_idx + 1) % len(self)
            file_idx = self.data_idx_order_list[next_idx]
            data_line = self.data_lines[file_idx]
            data_line = data_line.decode('utf-8')
            substr = data_line.strip("\n").split("\t")
            file_name = substr[0]
            label = substr[1]
            img_path = os.path.join(self.data_dir, file_name)
            data = {'img_path': img_path, 'label': label}
            if not os.path.exists(img_path):
                continue
            with open(data['img_path'], 'rb') as f:
                img = f.read()
                data['image'] = img
            data = transform(data, load_data_ops)
            if data is None:
                continue
            ext_data.append(data)
        return ext_data

    # 获取图像信息
    def get_image_info_list(self, file_list):
        if isinstance(file_list, str):
            file_list = [file_list]
        data_lines = []
        for idx, file in enumerate(file_list):
            with open(file, "rb") as f:
                lines = f.readlines()
                data_lines.extend(lines)
        return data_lines

    # 获取DataSet中的一条数据
    def __getitem__(self, idx):
        file_idx = self.data_idx_order_list[idx]
        data_line = self.data_lines[file_idx]
        try:
            data_line = data_line.decode('utf-8')
            substr = data_line.strip("\n").split("\t")
            file_name = substr[0]
            label = substr[1]
            img_path = os.path.join(self.data_dir, file_name)
            data = {'img_path': img_path, 'label': label}
            if not os.path.exists(img_path):
                raise Exception("{} does not exist!".format(img_path))
            with open(data['img_path'], 'rb') as f:
                img = f.read()
                data['image'] = img
            data['ext_data'] = self.get_ext_data(idx)
            outs = transform(data, self.ops)
        except Exception as e:
            print("When parsing line {}, error happened with msg: {}".format(data_line, e))
            outs = None
        if outs is None:
            return
        return outs

    def __len__(self):
        return len(self.data_idx_order_list)



if __name__ == '__main__':
    copy_paste_demo = CopyPasteDemo()

    idx = 1
    data1 = copy_paste_demo[idx]
    print(data1.keys())
    print(data1["img_path"])
    print(data1["ext_data"][0]["img_path"])
    infos = copy_paste_demo.data_lines[idx]
    infos = json.loads(infos.decode('utf-8').split("\t")[1])

    img3 = data1["image"].copy()
    plt.figure(figsize=(15, 10))
    plt.imshow(img3[:, :, ::-1])
    # 原始标注信息
    for info in infos:
        xs, ys = zip(*info["points"])
        xs = list(xs)
        ys = list(ys)
        xs.append(xs[0])
        ys.append(ys[0])
        plt.plot(xs, ys, "r")
    # 新增的标注信息
    for poly_idx in range(len(infos), len(data1["polys"])):
        poly = data1["polys"][poly_idx]
        xs, ys = zip(*poly)
        xs = list(xs)
        ys = list(ys)
        xs.append(xs[0])
        ys.append(ys[0])
        plt.plot(xs, ys, "b")
    plt.show()

生成后的图像

相关推荐
赛丽曼14 分钟前
Python中的TCP
python
小白~小黑15 分钟前
软件测试基础二十(接口测试 Postman)
python·自动化·postman
codists15 分钟前
《Django 5 By Example》阅读笔记:p76-p104
python·django·编程人
欧阳枫落23 分钟前
python 2小时学会八股文-数据结构
开发语言·数据结构·python
天天要nx27 分钟前
D64【python 接口自动化学习】- python基础之数据库
数据库·python
华清元宇宙实验中心36 分钟前
【每天学点AI】前向传播、损失函数、反向传播
深度学习·机器学习·ai人工智能
feifeikon1 小时前
Python Day5 进阶语法(列表表达式/三元/断言/with-as/异常捕获/字符串方法/lambda函数
开发语言·python
龙的爹23331 小时前
论文 | The Capacity for Moral Self-Correction in LargeLanguage Models
人工智能·深度学习·机器学习·语言模型·自然语言处理·prompt
杰仔正在努力2 小时前
python成长技能之枚举类
开发语言·python
Eiceblue2 小时前
通过Python 调整Excel行高、列宽
开发语言·vscode·python·pycharm·excel