图片数据增强

数据增强

数据增强脚本

  1. 随机上下镜像
  2. 随机左右镜像
  3. 随机左右旋转45度以内
  4. 随机裁剪
  5. 随机透视变换,拉伸(未实现)
  6. 随机平移
python 复制代码
import os, cv2, shutil
from glob import glob
import random
import sys
from tqdm import tqdm
import random
import numpy as np


# 1. 加载图片路径
def load_files(path):
    files = glob("{}/*".format(path))
    # files= os.listdir(path)
    random.shuffle(files)
    return files


# 3. 检查增强目录是否存在,存在就删除,然后重新生成
def mkdir(path):
    if os.path.exists(path):
        shutil.rmtree(path)
    os.mkdir(path)


class Image:
    def __init__(self, image_path):
        self.src = image_path  # 原始图像
        self.cv2_image = None
        self.filename = os.path.basename(image_path)
        self.__init()
        self.generate_aug_name()
        self.is_aug = False

    def __init(self):
        """ load image"""
        if not os.path.exists(self.src):
            print("image_src = {}".format(self.src))
            print("-----------------------------------------------------")
            print("------------ [IMG] file_dot't exist, exit -----------")
            print("-----------------------------------------------------")
            sys.exit(0)
        try:
            self.cv2_image = cv2.imread(self.src)
        except:
            print("image error", self.src)
            os.remove(self.src)

    def generate_aug_name(self):
        """ 生成增强后保存图片的名称 """
        self.aug_name = self.filename.split(".")[0] + "_aug." + self.filename.split(".")[-1]
        # print(self.aug_name)


""" 数据增强 """


# 1. 随机上下镜像
# 2. 随机左右镜像
# 3. 随机左右旋转45度以内
# 4. 随机裁剪
# 5. 随机透视变换,拉伸
# 6. 随机平移
class ImageAugmentation:
    def __init__(self, image_path, flip_prob=0.5, revolve=None, crop=None, translate_prob=0.5):
        """
        数据增强参数
        :param image_path: 图片路径
        :param flip_prob: 图片镜像概率
        :param revolve: 图片旋转参数,旋转方向随机 [旋转概率,旋转最大角度]
        :param crop: 图片裁剪参数,[裁剪概率,裁剪比率]
        :param translate_prob: 图片平移参数概率,方向随机,左右和上下
        """
        if revolve is None:
            revolve = [0.5, 15]
        if crop is None:
            crop = [0.5, 0.75]
        self.image_path = image_path
        self.flip_prob = flip_prob
        self.revolve_prob = revolve[0]
        self.revolve_angle = revolve[1]
        self.crop_prob = crop[0]
        self.crop_rate = crop[1]
        self.translate_prob = translate_prob
        self.__init()
        self.file_list = load_files(self.image_path)

    def __init(self):
        self.aug_path = self.image_path + "_aug"
        mkdir(self.aug_path)

    def flip(self, image):
        """
        随机镜像图片
        :param image:
        :return:
        """
        image.is_aug = True
        flip_type = random.randint(1, 3)
        if flip_type == 1:
            image.cv2_image = cv2.flip(image.cv2_image, 0)
        elif flip_type == 2:
            image.cv2_image = cv2.flip(image.cv2_image, 1)
        else:
            image.cv2_image = cv2.flip(image.cv2_image, -1)

    def revolve(self, image):
        image.is_aug = True
        revolve_type = random.randint(1, 2)
        revolve_angle = random.randint(1, self.revolve_angle)
        if revolve_type == 1:
            revolve_angle = -revolve_angle
        # dividing height and width by 2 to get the center of the image
        height, width = image.cv2_image.shape[:2]
        # get the center coordinates of the image to create the 2D rotation matrix
        center = (width / 2, height / 2)
        # using cv2.getRotationMatrix2D() to get the rotation matrix
        rotate_matrix = cv2.getRotationMatrix2D(center=center, angle=revolve_angle, scale=1)
        image.cv2_image = cv2.warpAffine(src=image.cv2_image, M=rotate_matrix, dsize=(width, height))

    def crop(self, image):
        image.is_aug = True
        min_rate = int(self.crop_rate * 100)
        rate = random.randint(min_rate, 100) * 0.01
        height, width = image.cv2_image.shape[:2]
        center = (width / 2, height / 2)
        crop_height = int(height * rate)
        crop_width = int(width * rate)
        left = int((width - crop_width) / 2)
        top = int((height - crop_height) / 2)
        right = left + crop_width
        bottom = top + crop_height

        image.cv2_image = image.cv2_image[left:right, top:bottom]

    def translate(self, image):
        image.is_aug = True
        height, width = image.cv2_image.shape[:2]
        """ 随机平移类型(上下左右) """
        translate_type = random.randint(1, 4)
        translate_x = 0
        translate_y = 0
        translate_length_radio = random.randint(1, 33) * 0.01

        # print(translate_type)
        if translate_type == 1:
            """ 图片右移 """
            translate_x = width * translate_length_radio
        elif translate_type == 2:
            """ 图片左移 """
            translate_x = - (width * translate_length_radio)
        elif translate_type == 3:
            """ 图片下移 """
            translate_y = height * translate_length_radio
        elif translate_type == 4:
            """ 图片上移 """
            translate_y = - (height * translate_length_radio)
        else:
            print("[error] 不符合要求的随机数")
            raise TypeError

        M = np.float32([[1, 0, translate_x], [0, 1, translate_y]])
        image.cv2_image = cv2.warpAffine(image.cv2_image, M, (width, height))

    def run(self):
        for file in tqdm(self.file_list):
            img = Image(file)
            # 随机镜像图片
            flip_prob = random.random()
            if flip_prob >= self.flip_prob:
                self.flip(img)
            # 随机旋转图片
            revolve_prob = random.random()
            if revolve_prob >= self.revolve_prob:
                self.revolve(img)
            # 随机裁剪图片
            crop_prob = random.random()
            if crop_prob >= self.crop_prob:
                self.crop(img)
            translate_prob = random.random()
            if translate_prob >= self.translate_prob:
                self.translate(img)

            # save image
            if img.is_aug:
                cv2.imwrite(os.path.join(self.aug_path, img.aug_name), img.cv2_image)


test_path = r"D:\user\code\python\data_process\aug"
if __name__ == '__main__':
    """
   数据增强参数
   :param image_path: 图片路径
   :param flip_prob: 图片镜像概率
   :param revolve: 图片旋转参数,旋转方向随机 [旋转概率,旋转最大角度]
   :param crop: 图片裁剪参数,[裁剪概率,裁剪比率]
   :param translate_prob: 图片平移参数概率,方向随机,左右和上下
   """
    aug = ImageAugmentation(test_path, flip_prob=0.4, revolve=[0.4, 30], crop=[0.3, 0.85], translate_prob=0.4)
    aug.run()
相关推荐
成功人chen某2 小时前
配置VScodePython环境Python was not found;
开发语言·python
2301_786964362 小时前
EXCEL Python 实现绘制柱状线型组合图和树状图(包含数据透视表)
python·microsoft·excel
skd89993 小时前
小蜗牛拨号助手用户使用手册
python
「QT(C++)开发工程师」3 小时前
STM32 | FreeRTOS 递归信号量
python·stm32·嵌入式硬件
史迪仔01123 小时前
[python] Python单例模式:__new__与线程安全解析
开发语言·python·单例模式
胡耀超3 小时前
18.自动化生成知识图谱的多维度质量评估方法论
人工智能·python·自动化·知识图谱·数据科学·逻辑学·质量评估
whoarethenext3 小时前
c/c++的opencv的轮廓匹配初识
c语言·c++·opencv
新手村领路人3 小时前
qt5.14.2 opencv调用摄像头显示在label
qt·opencv·命令模式
三块钱07943 小时前
【原创】基于视觉大模型gemma-3-4b实现短视频自动识别内容并生成解说文案
开发语言·python·音视频
神码小Z4 小时前
Ubuntu快速安装Python3.11及多版本管理
python