python
复制代码
import os, cv2, shutil
from glob import glob
import random
import sys
from tqdm import tqdm
import random
import numpy as np
# 1. 加载图片路径
def load_files(path):
files = glob("{}/*".format(path))
# files= os.listdir(path)
random.shuffle(files)
return files
# 3. 检查增强目录是否存在,存在就删除,然后重新生成
def mkdir(path):
if os.path.exists(path):
shutil.rmtree(path)
os.mkdir(path)
class Image:
def __init__(self, image_path):
self.src = image_path # 原始图像
self.cv2_image = None
self.filename = os.path.basename(image_path)
self.__init()
self.generate_aug_name()
self.is_aug = False
def __init(self):
""" load image"""
if not os.path.exists(self.src):
print("image_src = {}".format(self.src))
print("-----------------------------------------------------")
print("------------ [IMG] file_dot't exist, exit -----------")
print("-----------------------------------------------------")
sys.exit(0)
try:
self.cv2_image = cv2.imread(self.src)
except:
print("image error", self.src)
os.remove(self.src)
def generate_aug_name(self):
""" 生成增强后保存图片的名称 """
self.aug_name = self.filename.split(".")[0] + "_aug." + self.filename.split(".")[-1]
# print(self.aug_name)
""" 数据增强 """
# 1. 随机上下镜像
# 2. 随机左右镜像
# 3. 随机左右旋转45度以内
# 4. 随机裁剪
# 5. 随机透视变换,拉伸
# 6. 随机平移
class ImageAugmentation:
def __init__(self, image_path, flip_prob=0.5, revolve=None, crop=None, translate_prob=0.5):
"""
数据增强参数
:param image_path: 图片路径
:param flip_prob: 图片镜像概率
:param revolve: 图片旋转参数,旋转方向随机 [旋转概率,旋转最大角度]
:param crop: 图片裁剪参数,[裁剪概率,裁剪比率]
:param translate_prob: 图片平移参数概率,方向随机,左右和上下
"""
if revolve is None:
revolve = [0.5, 15]
if crop is None:
crop = [0.5, 0.75]
self.image_path = image_path
self.flip_prob = flip_prob
self.revolve_prob = revolve[0]
self.revolve_angle = revolve[1]
self.crop_prob = crop[0]
self.crop_rate = crop[1]
self.translate_prob = translate_prob
self.__init()
self.file_list = load_files(self.image_path)
def __init(self):
self.aug_path = self.image_path + "_aug"
mkdir(self.aug_path)
def flip(self, image):
"""
随机镜像图片
:param image:
:return:
"""
image.is_aug = True
flip_type = random.randint(1, 3)
if flip_type == 1:
image.cv2_image = cv2.flip(image.cv2_image, 0)
elif flip_type == 2:
image.cv2_image = cv2.flip(image.cv2_image, 1)
else:
image.cv2_image = cv2.flip(image.cv2_image, -1)
def revolve(self, image):
image.is_aug = True
revolve_type = random.randint(1, 2)
revolve_angle = random.randint(1, self.revolve_angle)
if revolve_type == 1:
revolve_angle = -revolve_angle
# dividing height and width by 2 to get the center of the image
height, width = image.cv2_image.shape[:2]
# get the center coordinates of the image to create the 2D rotation matrix
center = (width / 2, height / 2)
# using cv2.getRotationMatrix2D() to get the rotation matrix
rotate_matrix = cv2.getRotationMatrix2D(center=center, angle=revolve_angle, scale=1)
image.cv2_image = cv2.warpAffine(src=image.cv2_image, M=rotate_matrix, dsize=(width, height))
def crop(self, image):
image.is_aug = True
min_rate = int(self.crop_rate * 100)
rate = random.randint(min_rate, 100) * 0.01
height, width = image.cv2_image.shape[:2]
center = (width / 2, height / 2)
crop_height = int(height * rate)
crop_width = int(width * rate)
left = int((width - crop_width) / 2)
top = int((height - crop_height) / 2)
right = left + crop_width
bottom = top + crop_height
image.cv2_image = image.cv2_image[left:right, top:bottom]
def translate(self, image):
image.is_aug = True
height, width = image.cv2_image.shape[:2]
""" 随机平移类型(上下左右) """
translate_type = random.randint(1, 4)
translate_x = 0
translate_y = 0
translate_length_radio = random.randint(1, 33) * 0.01
# print(translate_type)
if translate_type == 1:
""" 图片右移 """
translate_x = width * translate_length_radio
elif translate_type == 2:
""" 图片左移 """
translate_x = - (width * translate_length_radio)
elif translate_type == 3:
""" 图片下移 """
translate_y = height * translate_length_radio
elif translate_type == 4:
""" 图片上移 """
translate_y = - (height * translate_length_radio)
else:
print("[error] 不符合要求的随机数")
raise TypeError
M = np.float32([[1, 0, translate_x], [0, 1, translate_y]])
image.cv2_image = cv2.warpAffine(image.cv2_image, M, (width, height))
def run(self):
for file in tqdm(self.file_list):
img = Image(file)
# 随机镜像图片
flip_prob = random.random()
if flip_prob >= self.flip_prob:
self.flip(img)
# 随机旋转图片
revolve_prob = random.random()
if revolve_prob >= self.revolve_prob:
self.revolve(img)
# 随机裁剪图片
crop_prob = random.random()
if crop_prob >= self.crop_prob:
self.crop(img)
translate_prob = random.random()
if translate_prob >= self.translate_prob:
self.translate(img)
# save image
if img.is_aug:
cv2.imwrite(os.path.join(self.aug_path, img.aug_name), img.cv2_image)
test_path = r"D:\user\code\python\data_process\aug"
if __name__ == '__main__':
"""
数据增强参数
:param image_path: 图片路径
:param flip_prob: 图片镜像概率
:param revolve: 图片旋转参数,旋转方向随机 [旋转概率,旋转最大角度]
:param crop: 图片裁剪参数,[裁剪概率,裁剪比率]
:param translate_prob: 图片平移参数概率,方向随机,左右和上下
"""
aug = ImageAugmentation(test_path, flip_prob=0.4, revolve=[0.4, 30], crop=[0.3, 0.85], translate_prob=0.4)
aug.run()