图片数据增强

数据增强

数据增强脚本

  1. 随机上下镜像
  2. 随机左右镜像
  3. 随机左右旋转45度以内
  4. 随机裁剪
  5. 随机透视变换,拉伸(未实现)
  6. 随机平移
python 复制代码
import os, cv2, shutil
from glob import glob
import random
import sys
from tqdm import tqdm
import random
import numpy as np


# 1. 加载图片路径
def load_files(path):
    files = glob("{}/*".format(path))
    # files= os.listdir(path)
    random.shuffle(files)
    return files


# 3. 检查增强目录是否存在,存在就删除,然后重新生成
def mkdir(path):
    if os.path.exists(path):
        shutil.rmtree(path)
    os.mkdir(path)


class Image:
    def __init__(self, image_path):
        self.src = image_path  # 原始图像
        self.cv2_image = None
        self.filename = os.path.basename(image_path)
        self.__init()
        self.generate_aug_name()
        self.is_aug = False

    def __init(self):
        """ load image"""
        if not os.path.exists(self.src):
            print("image_src = {}".format(self.src))
            print("-----------------------------------------------------")
            print("------------ [IMG] file_dot't exist, exit -----------")
            print("-----------------------------------------------------")
            sys.exit(0)
        try:
            self.cv2_image = cv2.imread(self.src)
        except:
            print("image error", self.src)
            os.remove(self.src)

    def generate_aug_name(self):
        """ 生成增强后保存图片的名称 """
        self.aug_name = self.filename.split(".")[0] + "_aug." + self.filename.split(".")[-1]
        # print(self.aug_name)


""" 数据增强 """


# 1. 随机上下镜像
# 2. 随机左右镜像
# 3. 随机左右旋转45度以内
# 4. 随机裁剪
# 5. 随机透视变换,拉伸
# 6. 随机平移
class ImageAugmentation:
    def __init__(self, image_path, flip_prob=0.5, revolve=None, crop=None, translate_prob=0.5):
        """
        数据增强参数
        :param image_path: 图片路径
        :param flip_prob: 图片镜像概率
        :param revolve: 图片旋转参数,旋转方向随机 [旋转概率,旋转最大角度]
        :param crop: 图片裁剪参数,[裁剪概率,裁剪比率]
        :param translate_prob: 图片平移参数概率,方向随机,左右和上下
        """
        if revolve is None:
            revolve = [0.5, 15]
        if crop is None:
            crop = [0.5, 0.75]
        self.image_path = image_path
        self.flip_prob = flip_prob
        self.revolve_prob = revolve[0]
        self.revolve_angle = revolve[1]
        self.crop_prob = crop[0]
        self.crop_rate = crop[1]
        self.translate_prob = translate_prob
        self.__init()
        self.file_list = load_files(self.image_path)

    def __init(self):
        self.aug_path = self.image_path + "_aug"
        mkdir(self.aug_path)

    def flip(self, image):
        """
        随机镜像图片
        :param image:
        :return:
        """
        image.is_aug = True
        flip_type = random.randint(1, 3)
        if flip_type == 1:
            image.cv2_image = cv2.flip(image.cv2_image, 0)
        elif flip_type == 2:
            image.cv2_image = cv2.flip(image.cv2_image, 1)
        else:
            image.cv2_image = cv2.flip(image.cv2_image, -1)

    def revolve(self, image):
        image.is_aug = True
        revolve_type = random.randint(1, 2)
        revolve_angle = random.randint(1, self.revolve_angle)
        if revolve_type == 1:
            revolve_angle = -revolve_angle
        # dividing height and width by 2 to get the center of the image
        height, width = image.cv2_image.shape[:2]
        # get the center coordinates of the image to create the 2D rotation matrix
        center = (width / 2, height / 2)
        # using cv2.getRotationMatrix2D() to get the rotation matrix
        rotate_matrix = cv2.getRotationMatrix2D(center=center, angle=revolve_angle, scale=1)
        image.cv2_image = cv2.warpAffine(src=image.cv2_image, M=rotate_matrix, dsize=(width, height))

    def crop(self, image):
        image.is_aug = True
        min_rate = int(self.crop_rate * 100)
        rate = random.randint(min_rate, 100) * 0.01
        height, width = image.cv2_image.shape[:2]
        center = (width / 2, height / 2)
        crop_height = int(height * rate)
        crop_width = int(width * rate)
        left = int((width - crop_width) / 2)
        top = int((height - crop_height) / 2)
        right = left + crop_width
        bottom = top + crop_height

        image.cv2_image = image.cv2_image[left:right, top:bottom]

    def translate(self, image):
        image.is_aug = True
        height, width = image.cv2_image.shape[:2]
        """ 随机平移类型(上下左右) """
        translate_type = random.randint(1, 4)
        translate_x = 0
        translate_y = 0
        translate_length_radio = random.randint(1, 33) * 0.01

        # print(translate_type)
        if translate_type == 1:
            """ 图片右移 """
            translate_x = width * translate_length_radio
        elif translate_type == 2:
            """ 图片左移 """
            translate_x = - (width * translate_length_radio)
        elif translate_type == 3:
            """ 图片下移 """
            translate_y = height * translate_length_radio
        elif translate_type == 4:
            """ 图片上移 """
            translate_y = - (height * translate_length_radio)
        else:
            print("[error] 不符合要求的随机数")
            raise TypeError

        M = np.float32([[1, 0, translate_x], [0, 1, translate_y]])
        image.cv2_image = cv2.warpAffine(image.cv2_image, M, (width, height))

    def run(self):
        for file in tqdm(self.file_list):
            img = Image(file)
            # 随机镜像图片
            flip_prob = random.random()
            if flip_prob >= self.flip_prob:
                self.flip(img)
            # 随机旋转图片
            revolve_prob = random.random()
            if revolve_prob >= self.revolve_prob:
                self.revolve(img)
            # 随机裁剪图片
            crop_prob = random.random()
            if crop_prob >= self.crop_prob:
                self.crop(img)
            translate_prob = random.random()
            if translate_prob >= self.translate_prob:
                self.translate(img)

            # save image
            if img.is_aug:
                cv2.imwrite(os.path.join(self.aug_path, img.aug_name), img.cv2_image)


test_path = r"D:\user\code\python\data_process\aug"
if __name__ == '__main__':
    """
   数据增强参数
   :param image_path: 图片路径
   :param flip_prob: 图片镜像概率
   :param revolve: 图片旋转参数,旋转方向随机 [旋转概率,旋转最大角度]
   :param crop: 图片裁剪参数,[裁剪概率,裁剪比率]
   :param translate_prob: 图片平移参数概率,方向随机,左右和上下
   """
    aug = ImageAugmentation(test_path, flip_prob=0.4, revolve=[0.4, 30], crop=[0.3, 0.85], translate_prob=0.4)
    aug.run()
相关推荐
愤豆17 分钟前
05-Java语言核心-语法特性--模块化系统详解
java·开发语言·python
AI-Ming29 分钟前
程序员转行学习 AI 大模型: 踩坑记录:服务器内存不够,程序被killed
服务器·人工智能·python·gpt·深度学习·学习·agi
我材不敲代码42 分钟前
OpenCV 背景建模实战:三种方法实现运动目标检测
人工智能·opencv·目标检测
2401_873544921 小时前
使用Python处理计算机图形学(PIL/Pillow)
jvm·数据库·python
njidf1 小时前
自动化机器学习(AutoML)库TPOT使用指南
jvm·数据库·python
只与明月听1 小时前
RAG深入学习之向量数据库
前端·人工智能·python
极光代码工作室1 小时前
基于Hadoop的日志数据分析系统设计
大数据·hadoop·python·数据分析·数据可视化
AAI机器之心2 小时前
这个RAG框架绝了:无论多少跳,LLM只调用两次,成本暴降
人工智能·python·ai·llm·agent·产品经理·rag
Fairy要carry2 小时前
项目01-手搓Agent之loop
前端·javascript·python
郝学胜-神的一滴2 小时前
【技术实战】500G单行大文件读取难题破解!生成器+自定义函数最优方案解析
开发语言·python·程序人生·面试