python
复制代码
import os
import cv2
import albumentations as A
from tqdm import tqdm
from glob import glob
import numpy as np
# 方法1:加入噪声
trans2 = A.Compose([
A.RandomBrightnessContrast(p=0.5),
A.HueSaturationValue(p=0.5),
A.OneOf([ A.AdvancedBlur(p=0.5),
A.Blur(p=0.5),
A.Defocus(p=0.5),
A.GaussianBlur(p=0.5),
A.GlassBlur(p=0.5),
A.MedianBlur(p=0.5),
A.MotionBlur(p=0.5),],p=1.0),
A.GaussNoise(p=0.5),
])
# 方法2:直方图均衡化:对比度受限自适应直方图均衡
trans3 = A.Compose([
A.CLAHE(p=1), # 对比度受限自适应直方图均衡
A.RandomGamma(p=0.5),])
def create_dir(path):
if not os.path.exists(path):
os.makedirs(path)
def load_data_aug(path):
# todo:train imgs
train_x = sorted(glob(os.path.join(path, "train/images", "*.png")))
train_y = sorted(glob(os.path.join(path, "train/masks/0", "*.png")))
# todo:val imgs
val_x = sorted(glob(os.path.join(path, "val/images", "*.png")))
val_y = sorted(glob(os.path.join(path, "val/masks/0", "*.png")))
# test_x = sorted(glob(os.path.join(path, "test/images", "*.png")))
# test_y = sorted(glob(os.path.join(path, "test/masks/0", "*.png")))
return (train_x, train_y), (val_x, val_y)
def augment_data(images, masks, save_path, augment=False):
for idx, (x, y) in tqdm(enumerate(zip(images, masks)), total=len(images)):
name = x.split("/")[-1].split(".")[0]
img = cv2.imread(x)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = img[:, 520:3368, :] # 对图片进行裁剪
# mask = cv2.imread(y, 0) # todo:将图像读取为单通道的灰度图像
# mask = mask[:, 520:3368] # 对图片进行裁剪
mask = cv2.imread(y)
# for i, m in zip(img, mask):
# print(i.shape)
# print(m.shape)
transformed = trans3(image=img, mask=mask) # todo:select methods
img= transformed['image']
mask = transformed['mask']
# img = trans3(image=img)['image']
# i1 = cv2.resize(img, (512, 512))
# m1 = cv2.resize(mask, (512, 512))
tmp_image_name = f"{name}_trans3.png" # todo:rename
tmp_mask_name = f"{name}_trans3.png" # todo:rename
image_path = os.path.join(save_path, "images", tmp_image_name)
mask_path = os.path.join(save_path, "masks", "0", tmp_mask_name)
cv2.imwrite(image_path, img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
cv2.imwrite(mask_path, mask, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
if __name__ == '__main__':
root_path = 'path/to/数据集根目录'
out_path ='path/to/保存路径(数据集根目录)'
(train_x, train_y), (val_x, val_y) = load_data_aug(root_path)
# print(train_x)
# print(train_y)
# print(test_x)
# print(test_y)
# create_dir(out_path + "/train/images/")
# create_dir(out_path + "/train/masks/0")
# create_dir(out_path + "/val/images/")
# create_dir(out_path + "/val/masks/0")
# create_dir(out_path + "/test/images/")
# create_dir(out_path + "/test/masks/0")
augment_data(train_x, train_y, out_path + "/train/", augment=False)
augment_data(val_x, val_y, out_path + "/val/", augment=False)
# augment_data(test_x, test_y, out_path + "/test/", augment=False)