图片数据增强

数据增强

数据增强脚本

  1. 随机上下镜像
  2. 随机左右镜像
  3. 随机左右旋转45度以内
  4. 随机裁剪
  5. 随机透视变换,拉伸(未实现)
  6. 随机平移
python 复制代码
import os, cv2, shutil
from glob import glob
import random
import sys
from tqdm import tqdm
import random
import numpy as np


# 1. 加载图片路径
def load_files(path):
    files = glob("{}/*".format(path))
    # files= os.listdir(path)
    random.shuffle(files)
    return files


# 3. 检查增强目录是否存在,存在就删除,然后重新生成
def mkdir(path):
    if os.path.exists(path):
        shutil.rmtree(path)
    os.mkdir(path)


class Image:
    def __init__(self, image_path):
        self.src = image_path  # 原始图像
        self.cv2_image = None
        self.filename = os.path.basename(image_path)
        self.__init()
        self.generate_aug_name()
        self.is_aug = False

    def __init(self):
        """ load image"""
        if not os.path.exists(self.src):
            print("image_src = {}".format(self.src))
            print("-----------------------------------------------------")
            print("------------ [IMG] file_dot't exist, exit -----------")
            print("-----------------------------------------------------")
            sys.exit(0)
        try:
            self.cv2_image = cv2.imread(self.src)
        except:
            print("image error", self.src)
            os.remove(self.src)

    def generate_aug_name(self):
        """ 生成增强后保存图片的名称 """
        self.aug_name = self.filename.split(".")[0] + "_aug." + self.filename.split(".")[-1]
        # print(self.aug_name)


""" 数据增强 """


# 1. 随机上下镜像
# 2. 随机左右镜像
# 3. 随机左右旋转45度以内
# 4. 随机裁剪
# 5. 随机透视变换,拉伸
# 6. 随机平移
class ImageAugmentation:
    def __init__(self, image_path, flip_prob=0.5, revolve=None, crop=None, translate_prob=0.5):
        """
        数据增强参数
        :param image_path: 图片路径
        :param flip_prob: 图片镜像概率
        :param revolve: 图片旋转参数,旋转方向随机 [旋转概率,旋转最大角度]
        :param crop: 图片裁剪参数,[裁剪概率,裁剪比率]
        :param translate_prob: 图片平移参数概率,方向随机,左右和上下
        """
        if revolve is None:
            revolve = [0.5, 15]
        if crop is None:
            crop = [0.5, 0.75]
        self.image_path = image_path
        self.flip_prob = flip_prob
        self.revolve_prob = revolve[0]
        self.revolve_angle = revolve[1]
        self.crop_prob = crop[0]
        self.crop_rate = crop[1]
        self.translate_prob = translate_prob
        self.__init()
        self.file_list = load_files(self.image_path)

    def __init(self):
        self.aug_path = self.image_path + "_aug"
        mkdir(self.aug_path)

    def flip(self, image):
        """
        随机镜像图片
        :param image:
        :return:
        """
        image.is_aug = True
        flip_type = random.randint(1, 3)
        if flip_type == 1:
            image.cv2_image = cv2.flip(image.cv2_image, 0)
        elif flip_type == 2:
            image.cv2_image = cv2.flip(image.cv2_image, 1)
        else:
            image.cv2_image = cv2.flip(image.cv2_image, -1)

    def revolve(self, image):
        image.is_aug = True
        revolve_type = random.randint(1, 2)
        revolve_angle = random.randint(1, self.revolve_angle)
        if revolve_type == 1:
            revolve_angle = -revolve_angle
        # dividing height and width by 2 to get the center of the image
        height, width = image.cv2_image.shape[:2]
        # get the center coordinates of the image to create the 2D rotation matrix
        center = (width / 2, height / 2)
        # using cv2.getRotationMatrix2D() to get the rotation matrix
        rotate_matrix = cv2.getRotationMatrix2D(center=center, angle=revolve_angle, scale=1)
        image.cv2_image = cv2.warpAffine(src=image.cv2_image, M=rotate_matrix, dsize=(width, height))

    def crop(self, image):
        image.is_aug = True
        min_rate = int(self.crop_rate * 100)
        rate = random.randint(min_rate, 100) * 0.01
        height, width = image.cv2_image.shape[:2]
        center = (width / 2, height / 2)
        crop_height = int(height * rate)
        crop_width = int(width * rate)
        left = int((width - crop_width) / 2)
        top = int((height - crop_height) / 2)
        right = left + crop_width
        bottom = top + crop_height

        image.cv2_image = image.cv2_image[left:right, top:bottom]

    def translate(self, image):
        image.is_aug = True
        height, width = image.cv2_image.shape[:2]
        """ 随机平移类型(上下左右) """
        translate_type = random.randint(1, 4)
        translate_x = 0
        translate_y = 0
        translate_length_radio = random.randint(1, 33) * 0.01

        # print(translate_type)
        if translate_type == 1:
            """ 图片右移 """
            translate_x = width * translate_length_radio
        elif translate_type == 2:
            """ 图片左移 """
            translate_x = - (width * translate_length_radio)
        elif translate_type == 3:
            """ 图片下移 """
            translate_y = height * translate_length_radio
        elif translate_type == 4:
            """ 图片上移 """
            translate_y = - (height * translate_length_radio)
        else:
            print("[error] 不符合要求的随机数")
            raise TypeError

        M = np.float32([[1, 0, translate_x], [0, 1, translate_y]])
        image.cv2_image = cv2.warpAffine(image.cv2_image, M, (width, height))

    def run(self):
        for file in tqdm(self.file_list):
            img = Image(file)
            # 随机镜像图片
            flip_prob = random.random()
            if flip_prob >= self.flip_prob:
                self.flip(img)
            # 随机旋转图片
            revolve_prob = random.random()
            if revolve_prob >= self.revolve_prob:
                self.revolve(img)
            # 随机裁剪图片
            crop_prob = random.random()
            if crop_prob >= self.crop_prob:
                self.crop(img)
            translate_prob = random.random()
            if translate_prob >= self.translate_prob:
                self.translate(img)

            # save image
            if img.is_aug:
                cv2.imwrite(os.path.join(self.aug_path, img.aug_name), img.cv2_image)


test_path = r"D:\user\code\python\data_process\aug"
if __name__ == '__main__':
    """
   数据增强参数
   :param image_path: 图片路径
   :param flip_prob: 图片镜像概率
   :param revolve: 图片旋转参数,旋转方向随机 [旋转概率,旋转最大角度]
   :param crop: 图片裁剪参数,[裁剪概率,裁剪比率]
   :param translate_prob: 图片平移参数概率,方向随机,左右和上下
   """
    aug = ImageAugmentation(test_path, flip_prob=0.4, revolve=[0.4, 30], crop=[0.3, 0.85], translate_prob=0.4)
    aug.run()
相关推荐
凯歌的博客3 小时前
python虚拟环境应用
linux·开发语言·python
西柚小萌新3 小时前
【深入浅出PyTorch】--8.1.PyTorch生态--torchvision
人工智能·pytorch·python
MonkeyKing_sunyuhua3 小时前
can‘t read /etc/apt/sources.list: No such file or directory
python
多喝开水少熬夜4 小时前
损失函数系列:focal-Dice-vgg
图像处理·python·算法·大模型·llm
有为少年5 小时前
告别乱码:OpenCV 中文路径(Unicode)读写的解决方案
人工智能·opencv·计算机视觉
清风与日月5 小时前
halcon分类器使用标准流程
深度学习·目标检测·计算机视觉
初学小刘5 小时前
基于 U-Net 的医学图像分割
python·opencv·计算机视觉
B站计算机毕业设计之家5 小时前
Python招聘数据分析可视化系统 Boss直聘数据 selenium爬虫 Flask框架 数据清洗(附源码)✅
爬虫·python·selenium·机器学习·数据分析·flask
雪碧聊技术5 小时前
爬虫是什么?
大数据·爬虫·python·数据分析
FL16238631295 小时前
[yolov11改进系列]基于yolov11使用fasternet_t0替换backbone用于轻量化网络的python源码+训练源码
python·yolo·php