目标检测之数据增强

数据翻转,需要把bbox相应的坐标值也进行交换

代码:

python 复制代码
import random
from torchvision.transforms import functional as F


class Compose(object):
    """组合多个transform函数"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class ToTensor(object):
    """将PIL图像转为Tensor"""
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target


class RandomHorizontalFlip(object):
    """随机水平翻转图像以及bboxes"""
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)  # 水平翻转图片
            bbox = target["boxes"]
            # bbox: xmin, ymin, xmax, ymax
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息
            target["boxes"] = bbox
        return image, target

对图像及其对应的标注文件(XML格式)进行数据增强,并将增强后的图像和标注文件保存到指定的目录中

  • root:XML文件所在的目录路径。

  • image_id:XML文件的名称(不包含扩展名)。

代码:

python 复制代码
import xml.etree.ElementTree as ET
import pickle
import os
from os import getcwd
import numpy as np
from PIL import Image
import shutil
import matplotlib.pyplot as plt

import imgaug as ia
from imgaug import augmenters as iaa


ia.seed(1)


def read_xml_annotation(root, image_id):
    in_file = open(os.path.join(root, image_id))
    tree = ET.parse(in_file)
    root = tree.getroot()
    bndboxlist = []

    for object in root.findall('object'):  # 找到root节点下的所有country节点
        bndbox = object.find('bndbox')  # 子节点下节点rank的值

        xmin = int(bndbox.find('xmin').text)
        xmax = int(bndbox.find('xmax').text)
        ymin = int(bndbox.find('ymin').text)
        ymax = int(bndbox.find('ymax').text)
        # print(xmin,ymin,xmax,ymax)
        bndboxlist.append([xmin, ymin, xmax, ymax])
        # print(bndboxlist)

    bndbox = root.find('object').find('bndbox')
    return bndboxlist


# (506.0000, 330.0000, 528.0000, 348.0000) -> (520.4747, 381.5080, 540.5596, 398.6603)
def change_xml_annotation(root, image_id, new_target):
    new_xmin = new_target[0]
    new_ymin = new_target[1]
    new_xmax = new_target[2]
    new_ymax = new_target[3]

    in_file = open(os.path.join(root, str(image_id) + '.xml'))  # 这里root分别由两个意思
    tree = ET.parse(in_file)
    xmlroot = tree.getroot()
    object = xmlroot.find('object')
    bndbox = object.find('bndbox')
    xmin = bndbox.find('xmin')
    xmin.text = str(new_xmin)
    ymin = bndbox.find('ymin')
    ymin.text = str(new_ymin)
    xmax = bndbox.find('xmax')
    xmax.text = str(new_xmax)
    ymax = bndbox.find('ymax')
    ymax.text = str(new_ymax)
    tree.write(os.path.join(root, str("%06d" % (str(id) + '.xml'))))


def change_xml_list_annotation(root, image_id, new_target, saveroot, id):
    in_file = open(os.path.join(root, str(image_id) + '.xml'))  # 这里root分别由两个意思
    tree = ET.parse(in_file)
    elem = tree.find('filename')
    elem.text = (id + '.jpg')
    xmlroot = tree.getroot()
    index = 0

    for object in xmlroot.findall('object'):  # 找到root节点下的所有country节点
        bndbox = object.find('bndbox')  # 子节点下节点rank的值

        # xmin = int(bndbox.find('xmin').text)
        # xmax = int(bndbox.find('xmax').text)
        # ymin = int(bndbox.find('ymin').text)
        # ymax = int(bndbox.find('ymax').text)

        new_xmin = new_target[index][0]
        new_ymin = new_target[index][1]
        new_xmax = new_target[index][2]
        new_ymax = new_target[index][3]

        xmin = bndbox.find('xmin')
        xmin.text = str(new_xmin)
        ymin = bndbox.find('ymin')
        ymin.text = str(new_ymin)
        xmax = bndbox.find('xmax')
        xmax.text = str(new_xmax)
        ymax = bndbox.find('ymax')
        ymax.text = str(new_ymax)

        index = index + 1

    tree.write(os.path.join(saveroot, id + '.xml'))


def mkdir(path):
    # 去除首位空格
    path = path.strip()
    # 去除尾部 \ 符号
    path = path.rstrip("\\")
    # 判断路径是否存在
    # 存在     True
    # 不存在   False
    isExists = os.path.exists(path)
    # 判断结果
    if not isExists:
        # 如果不存在则创建目录
        # 创建目录操作函数
        os.makedirs(path)
        print(path + ' 创建成功')
        return True
    else:
        # 如果目录存在则不创建,并提示目录已存在
        print(path + ' 目录已存在')
        return False


if __name__ == "__main__":

    IMG_DIR = "VOCdevkit/VOC2007/JPEGImages3"
    XML_DIR = "VOCdevkit/VOC2007/Annotations3"

    AUG_XML_DIR = "VOCdevkit/VOC2007/Annotations"  # 存储增强后的XML文件夹路径
    try:
        shutil.rmtree(AUG_XML_DIR)
    except FileNotFoundError as e:
        a = 1
    mkdir(AUG_XML_DIR)
    
    AUG_IMG_DIR = "VOCdevkit/VOC2007/JPEGImages"  # 存储增强后的影像文件夹路径
    try:
        shutil.rmtree(AUG_IMG_DIR)
    except FileNotFoundError as e:
        a = 1
    mkdir(AUG_IMG_DIR)

    AUGLOOP = 8  # 每张影像增强的数量

    boxes_img_aug_list = []
    new_bndbox = []
    new_bndbox_list = []

    # 影像增强
    seq = iaa.Sequential([
        iaa.Flipud(0.5),  # vertically flip 20% of all images
        iaa.Fliplr(0.5),  # 镜像
        iaa.Multiply((1.2, 1.5)),  # change brightness, doesn't affect BBs
        iaa.GaussianBlur(sigma=(0, 3.0)),  # iaa.GaussianBlur(0.5),
        iaa.Affine(
            translate_px={"x": 15, "y": 15},
            scale=(0.8, 0.95),
            rotate=(-30, 30)
        )  # translate by 40/60px on x/y axis, and scale to 50-70%, affects BBs
    ])

    for root, sub_folders, files in os.walk(XML_DIR):

        for name in files:

            bndbox = read_xml_annotation(XML_DIR, name)
            shutil.copy(os.path.join(XML_DIR, name), AUG_XML_DIR)
            shutil.copy(os.path.join(IMG_DIR, name[:-4] + '.jpg'), AUG_IMG_DIR)

            for epoch in range(AUGLOOP):
                seq_det = seq.to_deterministic()  # 保持坐标和图像同步改变,而不是随机
                # 读取图片
                img = Image.open(os.path.join(IMG_DIR, name[:-4] + '.jpg'))
                # sp = img.size
                img = np.asarray(img)
                # bndbox 坐标增强
                for i in range(len(bndbox)):
                    bbs = ia.BoundingBoxesOnImage([
                        ia.BoundingBox(x1=bndbox[i][0], y1=bndbox[i][1], x2=bndbox[i][2], y2=bndbox[i][3]),
                    ], shape=img.shape)

                    bbs_aug = seq_det.augment_bounding_boxes([bbs])[0]
                    boxes_img_aug_list.append(bbs_aug)

                    # new_bndbox_list:[[x1,y1,x2,y2],...[],[]]
                    n_x1 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x1)))
                    n_y1 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y1)))
                    n_x2 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x2)))
                    n_y2 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y2)))
                    if n_x1 == 1 and n_x1 == n_x2:
                        n_x2 += 1
                    if n_y1 == 1 and n_y2 == n_y1:
                        n_y2 += 1
                    if n_x1 >= n_x2 or n_y1 >= n_y2:
                        print('error', name)
                    new_bndbox_list.append([n_x1, n_y1, n_x2, n_y2])
                # 存储变化后的图片
                image_aug = seq_det.augment_images([img])[0]
                path = os.path.join(AUG_IMG_DIR,
                                    str("%06d" % (len(files)*epoch))+ name[:-4] + '.jpg')
                image_auged = bbs.draw_on_image(image_aug, thickness=0)
                Image.fromarray(image_auged).save(path)

                # 存储变化后的XML
                change_xml_list_annotation(XML_DIR, name[:-4], new_bndbox_list, AUG_XML_DIR,
                                           str("%06d" % (len(files)*epoch))+ name[:-4])
                print(str("%06d" % (len(files)*epoch))+ name[:-4] + '.jpg')
                new_bndbox_list = []

代码结构解读:

1. 导入模块
python 复制代码
import xml.etree.ElementTree as ET
import pickle
import os
from os import getcwd
import numpy as np
from PIL import Image
import shutil
import matplotlib.pyplot as plt

import imgaug as ia
from imgaug import augmenters as iaa
  • xml.etree.ElementTree:用于解析和操作XML文件。

  • numpyPIL:用于图像处理。

  • imgaug:用于图像增强。

  • 其他模块用于文件操作和路径管理。

2. 数据增强的随机种子
  • 设置随机种子,确保每次运行代码时增强操作的一致性。
python 复制代码
ia.seed(1)
3. 读取XML标注文件
python 复制代码
def read_xml_annotation(root, image_id):
    in_file = open(os.path.join(root, image_id))
    tree = ET.parse(in_file)
    root = tree.getroot()
    bndboxlist = []

    for object in root.findall('object'):
        bndbox = object.find('bndbox')
        xmin = int(bndbox.find('xmin').text)
        xmax = int(bndbox.find('xmax').text)
        ymin = int(bndbox.find('ymin').text)
        ymax = int(bndbox.find('ymax').text)
        bndboxlist.append([xmin, ymin, xmax, ymax])

    return bndboxlist
  • 输入:XML文件所在的目录和文件名。

  • 功能:解析XML文件,提取所有目标对象的边界框坐标。

  • 输出:边界框列表,每个边界框用 [xmin, ymin, xmax, ymax] 表示。

4. 更新单个XML标注文件
python 复制代码
def change_xml_annotation(root, image_id, new_target):
    new_xmin, new_ymin, new_xmax, new_ymax = new_target
    in_file = open(os.path.join(root, str(image_id) + '.xml'))
    tree = ET.parse(in_file)
    xmlroot = tree.getroot()
    object = xmlroot.find('object')
    bndbox = object.find('bndbox')
    xmin = bndbox.find('xmin')
    xmin.text = str(new_xmin)
    ymin = bndbox.find('ymin')
    ymin.text = str(new_ymin)
    xmax = bndbox.find('xmax')
    xmax.text = str(new_xmax)
    ymax = bndbox.find('ymax')
    ymax.text = str(new_ymax)
    tree.write(os.path.join(root, str("%06d" % (str(id) + '.xml'))))
  • 输入:XML文件所在的目录、文件名和新的边界框坐标。

  • 功能:更新XML文件中第一个目标对象的边界框坐标。

  • 输出:保存更新后的XML文件。

5. 更新多个XML标注文件
python 复制代码
def change_xml_list_annotation(root, image_id, new_target, saveroot, id):
    in_file = open(os.path.join(root, str(image_id) + '.xml'))
    tree = ET.parse(in_file)
    elem = tree.find('filename')
    elem.text = (id + '.jpg')
    xmlroot = tree.getroot()
    index = 0

    for object in xmlroot.findall('object'):
        bndbox = object.find('bndbox')
        new_xmin = new_target[index][0]
        new_ymin = new_target[index][1]
        new_xmax = new_target[index][2]
        new_ymax = new_target[index][3]

        xmin = bndbox.find('xmin')
        xmin.text = str(new_xmin)
        ymin = bndbox.find('ymin')
        ymin.text = str(new_ymin)
        xmax = bndbox.find('xmax')
        xmax.text = str(new_xmax)
        ymax = bndbox.find('ymax')
        ymax.text = str(new_ymax)

        index += 1

    tree.write(os.path.join(saveroot, id + '.xml'))
  • 输入:原始XML目录、文件名、新的边界框列表、保存目录和新的文件名。

  • 功能:更新XML文件中所有目标对象的边界框坐标。

  • 输出:保存更新后的XML文件。

6. 创建目录
python 复制代码
def mkdir(path):
    path = path.strip()
    path = path.rstrip("\\")
    isExists = os.path.exists(path)
    if not isExists:
        os.makedirs(path)
        print(path + ' 创建成功')
        return True
    else:
        print(path + ' 目录已存在')
        return False
  • 输入:目标目录路径。

  • 功能:创建目录,如果目录已存在,则提示。

7. 主程序
python 复制代码
if __name__ == "__main__":
    IMG_DIR = "VOCdevkit/VOC2007/JPEGImages3"
    XML_DIR = "VOCdevkit/VOC2007/Annotations3"

    AUG_XML_DIR = "VOCdevkit/VOC2007/Annotations"
    try:
        shutil.rmtree(AUG_XML_DIR)
    except FileNotFoundError as e:
        pass
    mkdir(AUG_XML_DIR)

    AUG_IMG_DIR = "VOCdevkit/VOC2007/JPEGImages"
    try:
        shutil.rmtree(AUG_IMG_DIR)
    except FileNotFoundError as e:
        pass
    mkdir(AUG_IMG_DIR)

    AUGLOOP = 8  # 每张影像增强的数量

    seq = iaa.Sequential([
        iaa.Flipud(0.5),  # 垂直翻转
        iaa.Fliplr(0.5),  # 水平翻转
        iaa.Multiply((1.2, 1.5)),  # 调整亮度
        iaa.GaussianBlur(sigma=(0, 3.0)),  # 高斯模糊
        iaa.Affine(
            translate_px={"x": 15, "y": 15},
            scale=(0.8, 0.95),
            rotate=(-30, 30)
        )  # 平移、缩放、旋转
    ])

    for root, sub_folders, files in os.walk(XML_DIR):
        for name in files:
            bndbox = read_xml_annotation(XML_DIR, name)
            shutil.copy(os.path.join(XML_DIR, name), AUG_XML_DIR)
            shutil.copy(os.path.join(IMG_DIR, name[:-4] + '.jpg'), AUG_IMG_DIR)

            for epoch in range(AUGLOOP):
                seq_det = seq.to_deterministic()
                img = Image.open(os.path.join(IMG_DIR, name[:-4] + '.jpg'))
                img = np.asarray(img)

                for i in range(len(bndbox)):
                    bbs = ia.BoundingBoxesOnImage([
                        ia.BoundingBox(x1=bndbox[i][0], y1=bndbox[i][1], x2=bndbox[i][2], y2=bndbox[i][3]),
                    ], shape=img.shape)
                    bbs_aug = seq_det.augment_bounding_boxes([bbs])[0]
                    n_x1 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x1)))
                    n_y1 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y1)))
                    n_x2 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x2)))
                    n_y2 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y2)))
                    if n_x1 == 1 and n_x1 == n_x2:
                        n_x2 += 1
                    if n_y1 == 1 and n_y2 == n_y1:
                        n_y2 += 1
                    if n_x1 >= n_x2 or n_y1 >= n_y2:
                        print('error', name)
                    new_bndbox_list.append([n_x1, n_y1, n_x2, n_y2])

                image_aug = seq_det.augment_images([img])[0]
                path = os.path.join(AUG_IMG_DIR, str("%06d" % (len(files) * epoch)) + name
相关推荐
思绪无限1 小时前
YOLOv5至YOLOv12升级:车型识别与计数系统的设计与实现(完整代码+界面+数据集项目)
深度学习·yolo·目标检测·yolov12·yolo全家桶·车型识别与计数
思绪无限1 小时前
YOLOv5至YOLOv12升级:田间杂草检测系统的设计与实现(完整代码+界面+数据集项目)
深度学习·yolo·目标检测·田间杂草检测·yolov12·yolo全家桶
思绪无限2 小时前
YOLOv5至YOLOv12升级:常见车型识别系统的设计与实现(完整代码+界面+数据集项目)
人工智能·深度学习·yolo·目标检测·目标跟踪·yolov12·yolo全家桶
jay神2 小时前
鸟类识别数据集 - CUB_200
人工智能·深度学习·目标检测·计算机视觉·目标跟踪·毕业设计
QQ676580083 小时前
智慧工地物料堆积识别 工地钢筋木材图像识别 工地砖块目标检测 建筑物大理石图像识别 建筑物工地材料识别 物料堆积识别10349期
人工智能·目标检测·计算机视觉·工地物料堆积·工地钢筋木材图像识别·工地砖块目标检测·建筑物大理石图像
探物 AI3 小时前
【感知实战·数据增强篇】深度解析目标检测中的图片数据增强算法,多图演示效果
人工智能·算法·目标检测
思绪无限4 小时前
YOLOv5至YOLOv12升级:交通信号灯识别系统的设计与实现(完整代码+界面+数据集项目)
深度学习·yolo·目标检测·交通信号灯识别·yolov12·yolo全家桶
学习论之费曼学习法4 小时前
AI 入门 30 天挑战 - Day 15 费曼学习法版 - 目标检测基础
人工智能·学习·目标检测
流年残碎念5 小时前
用TensorFlow Lite在树莓派上部署目标检测
人工智能·目标检测·tensorflow
思绪无限6 小时前
YOLOv5至YOLOv12升级:水下目标检测系统的设计与实现(完整代码+界面+数据集项目)
人工智能·深度学习·yolo·目标检测·水下目标检测·yolov12·yolo全家桶