「Pytorch」CopyPaste 数据增强

数据增广是提升模型泛化能力重要的手段之一,CopyPaste 是一种新颖的数据增强技巧,已经在目标检测和实例分割任务中验证了有效性。利用 CopyPaste,可以合成文本实例来平衡训练图像中的正负样本之间的比例。相比而言,传统图像旋转、随机翻转和随机裁剪是无法做到的。

CopyPaste 主要步骤包括:

  1. 随机选择两幅训练图像;
  2. 随机尺度抖动缩放;
  3. 随机水平翻转;
  4. 随机选择一幅图像中的目标子集;
  5. 粘贴在另一幅图像中随机的位置。

这样就比较好地提升了样本丰富度,同时也增加了模型对环境的鲁棒性。如下图所示,通过在左下角的图中裁剪出来的文本,随机旋转缩放之后粘贴到左上角的图像中,进一步丰富了该文本在不同背景下的多样性。

参考代码:

python 复制代码
#  !/usr/bin/env  python
#  -*- coding:utf-8 -*-
# @Time   :  2024.07
# @Author :  绿色羽毛
# @Email  :  lvseyumao@foxmail.com
# @Blog   :  https://blog.csdn.net/ViatorSun
# @Note   :



import os
import cv2
import json
import logging
import random
import numpy as np

import matplotlib.pyplot as plt



def create_operators(op_param_list, global_config=None):
    """
    create operators based on the config

    Args:
        params(list): a dict list, used to create some operators
    """
    assert isinstance(op_param_list, list), ('operator config should be a list')
    ops = []
    for operator in op_param_list:
        assert isinstance(operator,
                          dict) and len(operator) == 1, "yaml format error"
        op_name = list(operator)[0]
        param = {} if operator[op_name] is None else operator[op_name]
        if global_config is not None:
            param.update(global_config)
        op = eval(op_name)(**param)
        ops.append(op)
    return ops


def transform(data, ops=None):
    """ transform """
    if ops is None:
        ops = []
    for op in ops:
        data = op(data)
        if data is None:
            return None
    return data



# CopyPaste示例的类
class CopyPasteDemo(object):
    def __init__(self, ):
        self.data_dir = "/media/sun/Data/Dataset/OCR_Data/det/train/"
        self.label_file_list = "/media/sun/Data/Dataset/OCR_Data/det/train.txt"
        self.data_lines = self.get_image_info_list(self.label_file_list)
        self.data_idx_order_list = list(range(len(self.data_lines)))
        transforms = [
            {"DecodeImage": {"img_mode": "BGR", "channel_first": False}},
            {"DetLabelEncode": {}},
            {"CopyPaste": {"objects_paste_ratio": 1.0}},
        ]
        self.ops = create_operators(transforms)

    # 选择一张图像,将其中的内容拷贝到当前图像中
    def get_ext_data(self, idx):
        ext_data_num = 1
        ext_data = []
        next_idx = idx
        load_data_ops = self.ops[:2]

        while len(ext_data) < ext_data_num:
            next_idx = (next_idx + 1) % len(self)
            file_idx = self.data_idx_order_list[next_idx]
            data_line = self.data_lines[file_idx]
            data_line = data_line.decode('utf-8')
            substr = data_line.strip("\n").split("\t")
            file_name = substr[0]
            label = substr[1]
            img_path = os.path.join(self.data_dir, file_name)
            data = {'img_path': img_path, 'label': label}
            if not os.path.exists(img_path):
                continue
            with open(data['img_path'], 'rb') as f:
                img = f.read()
                data['image'] = img
            data = transform(data, load_data_ops)
            if data is None:
                continue
            ext_data.append(data)
        return ext_data

    # 获取图像信息
    def get_image_info_list(self, file_list):
        if isinstance(file_list, str):
            file_list = [file_list]
        data_lines = []
        for idx, file in enumerate(file_list):
            with open(file, "rb") as f:
                lines = f.readlines()
                data_lines.extend(lines)
        return data_lines

    # 获取DataSet中的一条数据
    def __getitem__(self, idx):
        file_idx = self.data_idx_order_list[idx]
        data_line = self.data_lines[file_idx]
        try:
            data_line = data_line.decode('utf-8')
            substr = data_line.strip("\n").split("\t")
            file_name = substr[0]
            label = substr[1]
            img_path = os.path.join(self.data_dir, file_name)
            data = {'img_path': img_path, 'label': label}
            if not os.path.exists(img_path):
                raise Exception("{} does not exist!".format(img_path))
            with open(data['img_path'], 'rb') as f:
                img = f.read()
                data['image'] = img
            data['ext_data'] = self.get_ext_data(idx)
            outs = transform(data, self.ops)
        except Exception as e:
            print("When parsing line {}, error happened with msg: {}".format(data_line, e))
            outs = None
        if outs is None:
            return
        return outs

    def __len__(self):
        return len(self.data_idx_order_list)



if __name__ == '__main__':
    copy_paste_demo = CopyPasteDemo()

    idx = 1
    data1 = copy_paste_demo[idx]
    print(data1.keys())
    print(data1["img_path"])
    print(data1["ext_data"][0]["img_path"])
    infos = copy_paste_demo.data_lines[idx]
    infos = json.loads(infos.decode('utf-8').split("\t")[1])

    img3 = data1["image"].copy()
    plt.figure(figsize=(15, 10))
    plt.imshow(img3[:, :, ::-1])
    # 原始标注信息
    for info in infos:
        xs, ys = zip(*info["points"])
        xs = list(xs)
        ys = list(ys)
        xs.append(xs[0])
        ys.append(ys[0])
        plt.plot(xs, ys, "r")
    # 新增的标注信息
    for poly_idx in range(len(infos), len(data1["polys"])):
        poly = data1["polys"][poly_idx]
        xs, ys = zip(*poly)
        xs = list(xs)
        ys = list(ys)
        xs.append(xs[0])
        ys.append(ys[0])
        plt.plot(xs, ys, "b")
    plt.show()

生成后的图像

相关推荐
人类群星闪耀时21 分钟前
使用Python实现基因组数据分析:探索生命的奥秘
开发语言·python·数据分析
中杯可乐多加冰1 小时前
【玩转OCR | 腾讯云智能结构化OCR应用探索和场景实践】
人工智能·深度学习·信息可视化·云计算·ocr·腾讯云·玩转腾讯云ocr
Allen_LVyingbo2 小时前
Python 青铜宝剑十六维,破医疗数智化难关(上)
开发语言·笔记·python·健康医疗·集成学习
重整旗鼓~5 小时前
1.flask介绍、入门、基本用法
python·flask
杜小白也想的美5 小时前
FlaskAPI-交互式文档与includ_router
python·fastapi
2401_887406575 小时前
搭建一个高效且安全的APP分发平台
python
deephub6 小时前
SCOPE:面向大语言模型长序列生成的双阶段KV缓存优化框架
人工智能·深度学习·transformer·大语言模型·kv缓存
dangdanding7 小时前
udp分片报文发送和接收
linux·网络·python·网络协议·udp
herogus丶8 小时前
【LLM】Langflow 的简单使用
人工智能·python·langchain