关于图像分类任务中划分数据集,并且生成分类类别的josn字典文件

1. 前言

在做图像分类任务的时候,数据格式是文件夹格式,相同文件夹下存放同一类型的类别

不少网上的数据,没有划分数据集,虽然代码简单,每次重新编写还是颇为麻烦,这里记录一下

如下,有的数据集这样摆放:

可以看出这是个三分类任务,不过没有划分测试集、验证集

代码存放位置:和数据集dataset 同一路径

2. 完整代码

如下:

python 复制代码
import random
import os
import shutil
from tqdm import tqdm
import json


def split_data(root, test_rate, flag=True):
    # 待分类数据的当前目录
    classes_directory = [i for i in os.listdir(root) if os.path.isdir(os.path.join(root, i))]

    # 建立生成后的目录,方便拷贝
    for i in classes_directory:
        os.makedirs(os.path.join('./data/train', i))  # 训练集
        os.makedirs(os.path.join('./data/test', i))  # 测试集

    # 是否生成类别的 json 字典文件,默认生成
    if flag:
        class_indices = dict((k, v) for v, k in enumerate(classes_directory))
        json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
        with open('class_indices.json', 'w') as json_file:
            json_file.write(json_str)

    # 遍历每个文件夹下的文件
    for cla in classes_directory:
        cla_path = os.path.join(root, cla)  # 每个文件夹的路径
        images_path = [os.path.join(root, cla, i) for i in os.listdir(cla_path)]

        # 按比例随机采样测试集样本
        test_split_path = random.sample(images_path, k=int(len(images_path) * test_rate))

        # 划分数据
        for i in tqdm(images_path, desc=cla):
            if i in test_split_path:
                shutil.copy(i, os.path.join('./data/test', cla))
            else:
                shutil.copy(i, os.path.join('./data/train', cla))


if __name__ == '__main__':
    rawDataSet = './dataset'  # 原始数据的路径

    if os.path.exists('./data'):  # 如果之前有,那么删除
        shutil.rmtree('./data')

    os.makedirs('./data/train')
    os.makedirs('./data/test')

    # 划分数据
    split_data(root=rawDataSet, test_rate=0.2)

运行代码过程:

运行结果:

生成的json文件:

3. 代码介绍

首先,rawDataSet 传入的是待划分的数据集根目录,这里会将之前划分的删掉,这样每次生成的结果不一样。训练集和测试集的比例为0.2

这里按照本人平时的习惯,划分好的目录结构如下

--data-train- 不同类别的文件夹

--data-test- 不同类别的文件夹

接下来这部分是读取每个子文件夹,或者说分类的classes(因为分类任务的文件夹就是class)

这里根据子文件夹名生成对应的json字典文件

划分数据,测试集会根据总数据的个数 * 划分比例 (test_rate)

遍历全部的数据,如果目标在测试集,那么就是测试集数据;否则为训练数据

如果是目标检测或者分割,数据和标签是分开的单独文件,划分的过程类似,后续会看着写写看

相关推荐
余炜yw13 分钟前
【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感
人工智能·rnn·深度学习
莫叫石榴姐29 分钟前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
如若1231 小时前
利用 `OpenCV` 和 `Matplotlib` 库进行图像读取、颜色空间转换、掩膜创建、颜色替换
人工智能·opencv·matplotlib
YRr YRr1 小时前
深度学习:神经网络中的损失函数的使用
人工智能·深度学习·神经网络
ChaseDreamRunner1 小时前
迁移学习理论与应用
人工智能·机器学习·迁移学习
Guofu_Liao1 小时前
大语言模型---梯度的简单介绍;梯度的定义;梯度计算的方法
人工智能·语言模型·矩阵·llama
我爱学Python!1 小时前
大语言模型与图结构的融合: 推荐系统中的新兴范式
人工智能·语言模型·自然语言处理·langchain·llm·大语言模型·推荐系统
果冻人工智能1 小时前
OpenAI 是怎么“压力测试”大型语言模型的?
人工智能·语言模型·压力测试
日出等日落1 小时前
Windows电脑本地部署llamafile并接入Qwen大语言模型远程AI对话实战
人工智能·语言模型·自然语言处理
麦麦大数据2 小时前
Python棉花病虫害图谱系统CNN识别+AI问答知识neo4j vue+flask深度学习神经网络可视化
人工智能·python·深度学习