图片数据增强

数据增强

数据增强脚本

  1. 随机上下镜像
  2. 随机左右镜像
  3. 随机左右旋转45度以内
  4. 随机裁剪
  5. 随机透视变换,拉伸(未实现)
  6. 随机平移
python 复制代码
import os, cv2, shutil
from glob import glob
import random
import sys
from tqdm import tqdm
import random
import numpy as np


# 1. 加载图片路径
def load_files(path):
    files = glob("{}/*".format(path))
    # files= os.listdir(path)
    random.shuffle(files)
    return files


# 3. 检查增强目录是否存在,存在就删除,然后重新生成
def mkdir(path):
    if os.path.exists(path):
        shutil.rmtree(path)
    os.mkdir(path)


class Image:
    def __init__(self, image_path):
        self.src = image_path  # 原始图像
        self.cv2_image = None
        self.filename = os.path.basename(image_path)
        self.__init()
        self.generate_aug_name()
        self.is_aug = False

    def __init(self):
        """ load image"""
        if not os.path.exists(self.src):
            print("image_src = {}".format(self.src))
            print("-----------------------------------------------------")
            print("------------ [IMG] file_dot't exist, exit -----------")
            print("-----------------------------------------------------")
            sys.exit(0)
        try:
            self.cv2_image = cv2.imread(self.src)
        except:
            print("image error", self.src)
            os.remove(self.src)

    def generate_aug_name(self):
        """ 生成增强后保存图片的名称 """
        self.aug_name = self.filename.split(".")[0] + "_aug." + self.filename.split(".")[-1]
        # print(self.aug_name)


""" 数据增强 """


# 1. 随机上下镜像
# 2. 随机左右镜像
# 3. 随机左右旋转45度以内
# 4. 随机裁剪
# 5. 随机透视变换,拉伸
# 6. 随机平移
class ImageAugmentation:
    def __init__(self, image_path, flip_prob=0.5, revolve=None, crop=None, translate_prob=0.5):
        """
        数据增强参数
        :param image_path: 图片路径
        :param flip_prob: 图片镜像概率
        :param revolve: 图片旋转参数,旋转方向随机 [旋转概率,旋转最大角度]
        :param crop: 图片裁剪参数,[裁剪概率,裁剪比率]
        :param translate_prob: 图片平移参数概率,方向随机,左右和上下
        """
        if revolve is None:
            revolve = [0.5, 15]
        if crop is None:
            crop = [0.5, 0.75]
        self.image_path = image_path
        self.flip_prob = flip_prob
        self.revolve_prob = revolve[0]
        self.revolve_angle = revolve[1]
        self.crop_prob = crop[0]
        self.crop_rate = crop[1]
        self.translate_prob = translate_prob
        self.__init()
        self.file_list = load_files(self.image_path)

    def __init(self):
        self.aug_path = self.image_path + "_aug"
        mkdir(self.aug_path)

    def flip(self, image):
        """
        随机镜像图片
        :param image:
        :return:
        """
        image.is_aug = True
        flip_type = random.randint(1, 3)
        if flip_type == 1:
            image.cv2_image = cv2.flip(image.cv2_image, 0)
        elif flip_type == 2:
            image.cv2_image = cv2.flip(image.cv2_image, 1)
        else:
            image.cv2_image = cv2.flip(image.cv2_image, -1)

    def revolve(self, image):
        image.is_aug = True
        revolve_type = random.randint(1, 2)
        revolve_angle = random.randint(1, self.revolve_angle)
        if revolve_type == 1:
            revolve_angle = -revolve_angle
        # dividing height and width by 2 to get the center of the image
        height, width = image.cv2_image.shape[:2]
        # get the center coordinates of the image to create the 2D rotation matrix
        center = (width / 2, height / 2)
        # using cv2.getRotationMatrix2D() to get the rotation matrix
        rotate_matrix = cv2.getRotationMatrix2D(center=center, angle=revolve_angle, scale=1)
        image.cv2_image = cv2.warpAffine(src=image.cv2_image, M=rotate_matrix, dsize=(width, height))

    def crop(self, image):
        image.is_aug = True
        min_rate = int(self.crop_rate * 100)
        rate = random.randint(min_rate, 100) * 0.01
        height, width = image.cv2_image.shape[:2]
        center = (width / 2, height / 2)
        crop_height = int(height * rate)
        crop_width = int(width * rate)
        left = int((width - crop_width) / 2)
        top = int((height - crop_height) / 2)
        right = left + crop_width
        bottom = top + crop_height

        image.cv2_image = image.cv2_image[left:right, top:bottom]

    def translate(self, image):
        image.is_aug = True
        height, width = image.cv2_image.shape[:2]
        """ 随机平移类型(上下左右) """
        translate_type = random.randint(1, 4)
        translate_x = 0
        translate_y = 0
        translate_length_radio = random.randint(1, 33) * 0.01

        # print(translate_type)
        if translate_type == 1:
            """ 图片右移 """
            translate_x = width * translate_length_radio
        elif translate_type == 2:
            """ 图片左移 """
            translate_x = - (width * translate_length_radio)
        elif translate_type == 3:
            """ 图片下移 """
            translate_y = height * translate_length_radio
        elif translate_type == 4:
            """ 图片上移 """
            translate_y = - (height * translate_length_radio)
        else:
            print("[error] 不符合要求的随机数")
            raise TypeError

        M = np.float32([[1, 0, translate_x], [0, 1, translate_y]])
        image.cv2_image = cv2.warpAffine(image.cv2_image, M, (width, height))

    def run(self):
        for file in tqdm(self.file_list):
            img = Image(file)
            # 随机镜像图片
            flip_prob = random.random()
            if flip_prob >= self.flip_prob:
                self.flip(img)
            # 随机旋转图片
            revolve_prob = random.random()
            if revolve_prob >= self.revolve_prob:
                self.revolve(img)
            # 随机裁剪图片
            crop_prob = random.random()
            if crop_prob >= self.crop_prob:
                self.crop(img)
            translate_prob = random.random()
            if translate_prob >= self.translate_prob:
                self.translate(img)

            # save image
            if img.is_aug:
                cv2.imwrite(os.path.join(self.aug_path, img.aug_name), img.cv2_image)


test_path = r"D:\user\code\python\data_process\aug"
if __name__ == '__main__':
    """
   数据增强参数
   :param image_path: 图片路径
   :param flip_prob: 图片镜像概率
   :param revolve: 图片旋转参数,旋转方向随机 [旋转概率,旋转最大角度]
   :param crop: 图片裁剪参数,[裁剪概率,裁剪比率]
   :param translate_prob: 图片平移参数概率,方向随机,左右和上下
   """
    aug = ImageAugmentation(test_path, flip_prob=0.4, revolve=[0.4, 30], crop=[0.3, 0.85], translate_prob=0.4)
    aug.run()
相关推荐
Java后端的Ai之路11 小时前
【Python 教程15】-Python和Web
python
冬奇Lab13 小时前
一天一个开源项目(第15篇):MapToPoster - 用代码将城市地图转换为精美的海报设计
python·开源
二十雨辰15 小时前
[python]-AI大模型
开发语言·人工智能·python
Yvonne爱编码15 小时前
JAVA数据结构 DAY6-栈和队列
java·开发语言·数据结构·python
晚霞的不甘15 小时前
CANN 在工业质检中的亚像素级视觉检测系统设计
人工智能·计算机视觉·架构·开源·视觉检测
前端摸鱼匠16 小时前
YOLOv8 环境配置全攻略:Python、PyTorch 与 CUDA 的和谐共生
人工智能·pytorch·python·yolo·目标检测
WangYaolove131416 小时前
基于python的在线水果销售系统(源码+文档)
python·mysql·django·毕业设计·源码
AALoveTouch16 小时前
大麦网协议分析
javascript·python
ZH154558913116 小时前
Flutter for OpenHarmony Python学习助手实战:自动化脚本开发的实现
python·学习·flutter
xcLeigh17 小时前
Python入门:Python3 requests模块全面学习教程
开发语言·python·学习·模块·python3·requests