深度学习——划分自定义数据集

深度学习------划分自定义数据集

以人脸表情数据集raf_db为例,初始目录如下:

需要经过处理后返回

train_images, train_label, val_images, val_label

定义 read_split_data(root: str, val_rate: float = 0.2) 方法来解决,代码如下:

python 复制代码
# root:数据集所在路径
# val_rate:划分测试集的比例

def read_split_data(root: str, val_rate: float = 0.2):

    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍历文件夹,一个文件夹对应一个类别
    file_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证各平台顺序一致
    file_class.sort()
    # 生成类别名称以及对应的数字索引
    class_indices = dict((k, v) for v, k in enumerate(file_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images = []  # 存储训练集的所有图片路径
    train_label = []  # 存储训练集图片对应索引信息
    val_images = []  # 存储验证集的所有图片路径
    val_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型

    # 遍历每个文件夹下的文件
    for cla in file_class:
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有文件路径
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 排序,保证各平台顺序一致
        images.sort()
        # 获取该类别对应的索引
        image_class = class_indices[cla]
        # 记录该类别的样本数量
        every_class_num.append(len(images))
        # 按比例随机采样验证样本
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images.append(img_path)
                val_label.append(image_class)
            else:  # 否则存入训练集
                train_images.append(img_path)
                train_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images)))
    print("{} images for validation.".format(len(val_images)))
    assert len(train_images) > 0, "number of training images must greater than 0."
    assert len(val_images) > 0, "number of validation images must greater than 0."

    return train_images, train_label, val_images, val_label

此时可通过以下代码获得训练集和测试集数据:

python 复制代码
train_images, train_label, val_images, val_label = read_split_data(data_path)

完结撒花。

相关推荐
运器1236 分钟前
【一起来学AI大模型】PyTorch DataLoader 实战指南
大数据·人工智能·pytorch·python·深度学习·ai·ai编程
超龄超能程序猿20 分钟前
(5)机器学习小白入门 YOLOv:数据需求与图像不足应对策略
人工智能·python·机器学习·numpy·pandas·scipy
卷福同学21 分钟前
【AI编程】AI+高德MCP不到10分钟搞定上海三日游
人工智能·算法·程序员
帅次29 分钟前
系统分析师-计算机系统-输入输出系统
人工智能·分布式·深度学习·神经网络·架构·系统架构·硬件架构
AndrewHZ1 小时前
【图像处理基石】如何入门大规模三维重建?
人工智能·深度学习·大模型·llm·三维重建·立体视觉·大规模三维重建
5G行业应用1 小时前
【赠书福利,回馈公号读者】《智慧城市与智能网联汽车,融合创新发展之路》
人工智能·汽车·智慧城市
悟空胆好小1 小时前
分音塔科技(BABEL Technology) 的公司背景、股权构成、产品类型及技术能力的全方位解读
网络·人工智能·科技·嵌入式硬件
探讨探讨AGV1 小时前
以科技赋能未来,科聪持续支持青年创新实践 —— 第七届“科聪杯”浙江省大学生智能机器人创意竞赛圆满落幕
人工智能·科技·机器人
cwn_1 小时前
回归(多项式回归)
人工智能·机器学习·数据挖掘·回归
聚客AI2 小时前
🔥 大模型开发进阶:基于LangChain的异步流式响应与性能优化
人工智能·langchain·agent