图片数据增强

数据增强

数据增强脚本

  1. 随机上下镜像
  2. 随机左右镜像
  3. 随机左右旋转45度以内
  4. 随机裁剪
  5. 随机透视变换,拉伸(未实现)
  6. 随机平移
python 复制代码
import os, cv2, shutil
from glob import glob
import random
import sys
from tqdm import tqdm
import random
import numpy as np


# 1. 加载图片路径
def load_files(path):
    files = glob("{}/*".format(path))
    # files= os.listdir(path)
    random.shuffle(files)
    return files


# 3. 检查增强目录是否存在,存在就删除,然后重新生成
def mkdir(path):
    if os.path.exists(path):
        shutil.rmtree(path)
    os.mkdir(path)


class Image:
    def __init__(self, image_path):
        self.src = image_path  # 原始图像
        self.cv2_image = None
        self.filename = os.path.basename(image_path)
        self.__init()
        self.generate_aug_name()
        self.is_aug = False

    def __init(self):
        """ load image"""
        if not os.path.exists(self.src):
            print("image_src = {}".format(self.src))
            print("-----------------------------------------------------")
            print("------------ [IMG] file_dot't exist, exit -----------")
            print("-----------------------------------------------------")
            sys.exit(0)
        try:
            self.cv2_image = cv2.imread(self.src)
        except:
            print("image error", self.src)
            os.remove(self.src)

    def generate_aug_name(self):
        """ 生成增强后保存图片的名称 """
        self.aug_name = self.filename.split(".")[0] + "_aug." + self.filename.split(".")[-1]
        # print(self.aug_name)


""" 数据增强 """


# 1. 随机上下镜像
# 2. 随机左右镜像
# 3. 随机左右旋转45度以内
# 4. 随机裁剪
# 5. 随机透视变换,拉伸
# 6. 随机平移
class ImageAugmentation:
    def __init__(self, image_path, flip_prob=0.5, revolve=None, crop=None, translate_prob=0.5):
        """
        数据增强参数
        :param image_path: 图片路径
        :param flip_prob: 图片镜像概率
        :param revolve: 图片旋转参数,旋转方向随机 [旋转概率,旋转最大角度]
        :param crop: 图片裁剪参数,[裁剪概率,裁剪比率]
        :param translate_prob: 图片平移参数概率,方向随机,左右和上下
        """
        if revolve is None:
            revolve = [0.5, 15]
        if crop is None:
            crop = [0.5, 0.75]
        self.image_path = image_path
        self.flip_prob = flip_prob
        self.revolve_prob = revolve[0]
        self.revolve_angle = revolve[1]
        self.crop_prob = crop[0]
        self.crop_rate = crop[1]
        self.translate_prob = translate_prob
        self.__init()
        self.file_list = load_files(self.image_path)

    def __init(self):
        self.aug_path = self.image_path + "_aug"
        mkdir(self.aug_path)

    def flip(self, image):
        """
        随机镜像图片
        :param image:
        :return:
        """
        image.is_aug = True
        flip_type = random.randint(1, 3)
        if flip_type == 1:
            image.cv2_image = cv2.flip(image.cv2_image, 0)
        elif flip_type == 2:
            image.cv2_image = cv2.flip(image.cv2_image, 1)
        else:
            image.cv2_image = cv2.flip(image.cv2_image, -1)

    def revolve(self, image):
        image.is_aug = True
        revolve_type = random.randint(1, 2)
        revolve_angle = random.randint(1, self.revolve_angle)
        if revolve_type == 1:
            revolve_angle = -revolve_angle
        # dividing height and width by 2 to get the center of the image
        height, width = image.cv2_image.shape[:2]
        # get the center coordinates of the image to create the 2D rotation matrix
        center = (width / 2, height / 2)
        # using cv2.getRotationMatrix2D() to get the rotation matrix
        rotate_matrix = cv2.getRotationMatrix2D(center=center, angle=revolve_angle, scale=1)
        image.cv2_image = cv2.warpAffine(src=image.cv2_image, M=rotate_matrix, dsize=(width, height))

    def crop(self, image):
        image.is_aug = True
        min_rate = int(self.crop_rate * 100)
        rate = random.randint(min_rate, 100) * 0.01
        height, width = image.cv2_image.shape[:2]
        center = (width / 2, height / 2)
        crop_height = int(height * rate)
        crop_width = int(width * rate)
        left = int((width - crop_width) / 2)
        top = int((height - crop_height) / 2)
        right = left + crop_width
        bottom = top + crop_height

        image.cv2_image = image.cv2_image[left:right, top:bottom]

    def translate(self, image):
        image.is_aug = True
        height, width = image.cv2_image.shape[:2]
        """ 随机平移类型(上下左右) """
        translate_type = random.randint(1, 4)
        translate_x = 0
        translate_y = 0
        translate_length_radio = random.randint(1, 33) * 0.01

        # print(translate_type)
        if translate_type == 1:
            """ 图片右移 """
            translate_x = width * translate_length_radio
        elif translate_type == 2:
            """ 图片左移 """
            translate_x = - (width * translate_length_radio)
        elif translate_type == 3:
            """ 图片下移 """
            translate_y = height * translate_length_radio
        elif translate_type == 4:
            """ 图片上移 """
            translate_y = - (height * translate_length_radio)
        else:
            print("[error] 不符合要求的随机数")
            raise TypeError

        M = np.float32([[1, 0, translate_x], [0, 1, translate_y]])
        image.cv2_image = cv2.warpAffine(image.cv2_image, M, (width, height))

    def run(self):
        for file in tqdm(self.file_list):
            img = Image(file)
            # 随机镜像图片
            flip_prob = random.random()
            if flip_prob >= self.flip_prob:
                self.flip(img)
            # 随机旋转图片
            revolve_prob = random.random()
            if revolve_prob >= self.revolve_prob:
                self.revolve(img)
            # 随机裁剪图片
            crop_prob = random.random()
            if crop_prob >= self.crop_prob:
                self.crop(img)
            translate_prob = random.random()
            if translate_prob >= self.translate_prob:
                self.translate(img)

            # save image
            if img.is_aug:
                cv2.imwrite(os.path.join(self.aug_path, img.aug_name), img.cv2_image)


test_path = r"D:\user\code\python\data_process\aug"
if __name__ == '__main__':
    """
   数据增强参数
   :param image_path: 图片路径
   :param flip_prob: 图片镜像概率
   :param revolve: 图片旋转参数,旋转方向随机 [旋转概率,旋转最大角度]
   :param crop: 图片裁剪参数,[裁剪概率,裁剪比率]
   :param translate_prob: 图片平移参数概率,方向随机,左右和上下
   """
    aug = ImageAugmentation(test_path, flip_prob=0.4, revolve=[0.4, 30], crop=[0.3, 0.85], translate_prob=0.4)
    aug.run()
相关推荐
梧桐树04292 小时前
python常用内建模块:collections
python
Dream_Snowar2 小时前
速通Python 第三节
开发语言·python
蓝天星空3 小时前
Python调用open ai接口
人工智能·python
jasmine s3 小时前
Pandas
开发语言·python
郭wes代码3 小时前
Cmd命令大全(万字详细版)
python·算法·小程序
leaf_leaves_leaf4 小时前
win11用一条命令给anaconda环境安装GPU版本pytorch,并检查是否为GPU版本
人工智能·pytorch·python
夜雨飘零14 小时前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志
404NooFound4 小时前
Python轻量级NoSQL数据库TinyDB
开发语言·python·nosql
天天要nx4 小时前
D102【python 接口自动化学习】- pytest进阶之fixture用法
python·pytest
minstbe4 小时前
AI开发:使用支持向量机(SVM)进行文本情感分析训练 - Python
人工智能·python·支持向量机