数据增广是提升模型泛化能力重要的手段之一,CopyPaste 是一种新颖的数据增强技巧,已经在目标检测和实例分割任务中验证了有效性。利用 CopyPaste,可以合成文本实例来平衡训练图像中的正负样本之间的比例。相比而言,传统图像旋转、随机翻转和随机裁剪是无法做到的。
CopyPaste 主要步骤包括:
- 随机选择两幅训练图像;
- 随机尺度抖动缩放;
- 随机水平翻转;
- 随机选择一幅图像中的目标子集;
- 粘贴在另一幅图像中随机的位置。
这样就比较好地提升了样本丰富度,同时也增加了模型对环境的鲁棒性。如下图所示,通过在左下角的图中裁剪出来的文本,随机旋转缩放之后粘贴到左上角的图像中,进一步丰富了该文本在不同背景下的多样性。
参考代码:
python
# !/usr/bin/env python
# -*- coding:utf-8 -*-
# @Time : 2024.07
# @Author : 绿色羽毛
# @Email : lvseyumao@foxmail.com
# @Blog : https://blog.csdn.net/ViatorSun
# @Note :
import os
import cv2
import json
import logging
import random
import numpy as np
import matplotlib.pyplot as plt
def create_operators(op_param_list, global_config=None):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(op_param_list, list), ('operator config should be a list')
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = eval(op_name)(**param)
ops.append(op)
return ops
def transform(data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
# CopyPaste示例的类
class CopyPasteDemo(object):
def __init__(self, ):
self.data_dir = "/media/sun/Data/Dataset/OCR_Data/det/train/"
self.label_file_list = "/media/sun/Data/Dataset/OCR_Data/det/train.txt"
self.data_lines = self.get_image_info_list(self.label_file_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
transforms = [
{"DecodeImage": {"img_mode": "BGR", "channel_first": False}},
{"DetLabelEncode": {}},
{"CopyPaste": {"objects_paste_ratio": 1.0}},
]
self.ops = create_operators(transforms)
# 选择一张图像,将其中的内容拷贝到当前图像中
def get_ext_data(self, idx):
ext_data_num = 1
ext_data = []
next_idx = idx
load_data_ops = self.ops[:2]
while len(ext_data) < ext_data_num:
next_idx = (next_idx + 1) % len(self)
file_idx = self.data_idx_order_list[next_idx]
data_line = self.data_lines[file_idx]
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split("\t")
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
continue
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
data = transform(data, load_data_ops)
if data is None:
continue
ext_data.append(data)
return ext_data
# 获取图像信息
def get_image_info_list(self, file_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
data_lines.extend(lines)
return data_lines
# 获取DataSet中的一条数据
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx]
try:
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split("\t")
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
data['ext_data'] = self.get_ext_data(idx)
outs = transform(data, self.ops)
except Exception as e:
print("When parsing line {}, error happened with msg: {}".format(data_line, e))
outs = None
if outs is None:
return
return outs
def __len__(self):
return len(self.data_idx_order_list)
if __name__ == '__main__':
copy_paste_demo = CopyPasteDemo()
idx = 1
data1 = copy_paste_demo[idx]
print(data1.keys())
print(data1["img_path"])
print(data1["ext_data"][0]["img_path"])
infos = copy_paste_demo.data_lines[idx]
infos = json.loads(infos.decode('utf-8').split("\t")[1])
img3 = data1["image"].copy()
plt.figure(figsize=(15, 10))
plt.imshow(img3[:, :, ::-1])
# 原始标注信息
for info in infos:
xs, ys = zip(*info["points"])
xs = list(xs)
ys = list(ys)
xs.append(xs[0])
ys.append(ys[0])
plt.plot(xs, ys, "r")
# 新增的标注信息
for poly_idx in range(len(infos), len(data1["polys"])):
poly = data1["polys"][poly_idx]
xs, ys = zip(*poly)
xs = list(xs)
ys = list(ys)
xs.append(xs[0])
ys.append(ys[0])
plt.plot(xs, ys, "b")
plt.show()
生成后的图像