在国产芯片上实现YOLOv5/v8图像AI识别-【4.1】RK3588训练数据时进行图像增强更多内容见视频

本专栏主要是提供一种国产化图像识别的解决方案,专栏中实现了YOLOv5/v8在国产化芯片上的使用部署,并可以实现网页端实时查看。根据自己的具体需求可以直接产品化部署使用。

B站配套视频:https://www.bilibili.com/video/BV1or421T74f

图像增强的必要性

在我们日常进行训练的时候经常会遇到数据集不足的情况,比如对特定物品进行识别。我们很难收到满足训练数量的有效数据集,这时候我们就可以考虑采用图像增强的方式增加我们的数据量。

不废话直接上代码

python 复制代码
from albumentations import *
import os
import cv2
from tqdm import tqdm

class enhancement:
    def __init__(self, picture_path, label_path, save_img_path, save_lable_path):
        image_files = []
        for file_name in os.listdir(picture_path):
            if file_name.endswith(('.jpg','.jepg','.png','.gif')):
                image_files.append(file_name)

        self.picture_name = sorted(image_files)

        label_files = []
        for file_name in os.listdir(label_path):
            if file_name.endswith(('.txt')):
                label_files.append(file_name)
        self.label_name = sorted(label_files)

        self.picture_path = [picture_path + i for i in self.picture_name]
        self.label_path = [label_path + i for i in self.label_name]
        self.save_img_path = save_img_path
        self.save_lable_path = save_lable_path
 
    def iter(self):
        batch_size = 10
        for index_bin in tqdm(range(0, len(self.picture_path), batch_size), desc='批次进度'):
            # print(index_bin)
            picture_batch = self.picture_path[index_bin:index_bin + batch_size]
            label_batch = self.label_path[index_bin:index_bin + batch_size]
            yield picture_batch, label_batch, [index_bin, index_bin + batch_size]
 
    def get_transform(self):
        transform = Compose([
            # 图像均值平滑滤波。
            # Blur(blur_limit=7, always_apply=False, p=0.5),
            # VerticalFlip 水平翻转
            # VerticalFlip(always_apply=False, p=0.5),
            # HorizontalFlip 垂直翻转
            # HorizontalFlip(always_apply=False, p=1),
            # 中心裁剪
            # CenterCrop(200, 200, always_apply=False, p=1.0),
            # RandomFog(fog_coef_lower=0.3, fog_coef_upper=0.7, alpha_coef=0.08, always_apply=False, p=1),
            # RandomCrop(width=200, height=200)
            # Downscale(always_apply=False,p=1)
            # 添加其他增强技术
            # 参数:随机色调、饱和度、值变化。
            # HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=True, p=0.5),
            # 使用具有随机核大小的高斯滤波器对输入图像进行模糊处理
            # GaussianBlur(blur_limit=7, always_apply=False, p=0.5),
            # 随机填充遮挡区域
            # CoarseDropout(max_holes=30, max_height=110, max_width=110, min_holes=10, min_height=50, min_width=50,
            #               fill_value=0, mask_fill_value=None, always_apply=True, p=0.5),
            GridDropout(always_apply=True),
        ], bbox_params=BboxParams(format='yolo', label_fields=['class_labels']))
        return transform
 
    def augmentations(self, image, bboxes, class_labels):
        transform = self.get_transform()
        transformed = transform(image=image, bboxes=bboxes, class_labels=class_labels)
        augmented_image = transformed['image']
        augmented_bboxes = transformed['bboxes']
        augmented_labels = transformed['class_labels']
        return augmented_image, augmented_bboxes, augmented_labels
 
    def augmented_image_bboxes(self, img_path, l_path):
        # 打印l_path
        print(l_path)
        with open(l_path, 'r') as f:
            values = f.read()
            f.close()
        class_labels, original_bboxes = [], []
        values = [i.split(' ') for i in values.split('\n')[:-1]]
        for i in values:
            class_labels.append(int(i[0]))
            original_bboxes.append([float(i) for i in i[1:]])
        original_image = cv2.imread(img_path)
        augmented_image, augmented_bboxes, augmented_labels = self.augmentations(original_image, original_bboxes,
                                                                                 class_labels)
        return augmented_image, augmented_bboxes, augmented_labels, original_image
 
    def parsing_data(self, p_l_i):
        img_path, l_path, index = p_l_i[0], p_l_i[1], p_l_i[2]
        self.augmented_image, self.augmented_bboxes, augmented_labels, original_image = self.augmented_image_bboxes(
            img_path, l_path)
        data = []
        for l, d in zip(augmented_labels, self.augmented_bboxes):
            s = ' '.join(map(str, [l] + list(d)))
            data.append(s)
        data = '\n'.join(data)
        if augmented_labels:
            self.show_img()
            self.save_img_lable(data, self.augmented_image, self.save_img_path, self.save_lable_path, index)
        else:
            print(f'{self.picture_name[index]}该图片没有标签,不做保存')
 
    def save_img_lable(self, data, img, save_img_path, save_lable_path, index):
        cv2.imwrite(save_img_path + 'aug2_camber_' + self.picture_name[index], img)
        with open(save_lable_path + 'aug2_camber_' + self.label_name[index], 'w') as f:
            f.write(data)
            f.close()
 
    def __call__(self):
        for picture_batch, label_batch, index_bin in self.iter():
            list(map(self.parsing_data,
                     [(p, l, i) for p, l, i in zip(picture_batch, label_batch, range(index_bin[0], index_bin[1]))]))
 
    def show_img(self, boxe=False):
        if boxe:
            for j in self.augmented_bboxes:
                x, y, w, h = j
                x1 = int((x - w / 2) * self.augmented_image.shape[1])
                y1 = int((y - h / 2) * self.augmented_image.shape[0])
                x2 = int((x + w / 2) * self.augmented_image.shape[1])
                y2 = int((y + h / 2) * self.augmented_image.shape[0])
                cv2.rectangle(self.augmented_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
                cv2.rectangle(self.augmented_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
        else:
            pass
        # cv2.imshow('Augmented Image', self.augmented_image)
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()
 
 
if __name__ == '__main__':
    # 原图片,标签的路径
    picture_path = 'images/train/'
    label_path = 'labels/train/'
    # 增强后的图片跟标签
    save_img_path = 'images/train-aug-point/'
    save_lable_path = 'labels/train-aug-point/'
    c = enhancement(picture_path=picture_path,
                    label_path=label_path,
                    save_img_path=save_img_path,
                    save_lable_path=save_lable_path)
    c()

代码说明

除了常规的opencv之外我们需要安装albumentations

bash 复制代码
pip install albumentations

代码讲解查看视频 https://www.bilibili.com/video/BV1or421T74f](https://www.bilibili.com/video/BV1or421T74f

相关推荐
默子昂2 小时前
yolo自动化项目实例解析(一)日志格式输出、并发异步多线程、websocket、循环截图、yolo推理、3d寻路
运维·yolo·自动化
说私域2 小时前
开源 AI 智能名片 S2B2C 商城小程序相关角色的探索
人工智能·搜索引擎·小程序
大数据AI人工智能培训专家培训讲师叶梓2 小时前
大模型从失败中学习 —— 微调大模型以提升Agent性能
人工智能·学习·性能优化·微调·agent·代理·大模型微调
PlumCarefree3 小时前
基于鸿蒙API10的RTSP播放器(十:USB视频流转H.265测试)
音视频·harmonyos·h.265
youcans_4 小时前
OpenAI全新发布o1模型:开启 AGI 的新时代
人工智能·chatgpt·agi
黑色叉腰丶大魔王4 小时前
《自然语言处理 Transformer 模型详解》
人工智能·自然语言处理·transformer
ersaijun7 小时前
【Obsidian】当笔记接入AI,Copilot插件推荐
人工智能·笔记·copilot
lgbisha7 小时前
828华为云征文|华为云Flexus X实例docker部署Jitsi构建属于自己的音视频会议系统
docker·华为云·音视频
格林威8 小时前
Baumer工业相机堡盟工业相机如何通过BGAPISDK使用短曝光功能(曝光可设置1微秒)(C语言)
c语言·开发语言·人工智能·数码相机·计算机视觉
仰望大佬0078 小时前
HalconDotNet中的图像视频采集
数码相机·计算机视觉·c#·音视频·halcon