深度学习——划分自定义数据集

深度学习------划分自定义数据集

以人脸表情数据集raf_db为例,初始目录如下:

需要经过处理后返回

train_images, train_label, val_images, val_label

定义 read_split_data(root: str, val_rate: float = 0.2) 方法来解决,代码如下:

python 复制代码
# root:数据集所在路径
# val_rate:划分测试集的比例

def read_split_data(root: str, val_rate: float = 0.2):

    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍历文件夹,一个文件夹对应一个类别
    file_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证各平台顺序一致
    file_class.sort()
    # 生成类别名称以及对应的数字索引
    class_indices = dict((k, v) for v, k in enumerate(file_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images = []  # 存储训练集的所有图片路径
    train_label = []  # 存储训练集图片对应索引信息
    val_images = []  # 存储验证集的所有图片路径
    val_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型

    # 遍历每个文件夹下的文件
    for cla in file_class:
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有文件路径
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 排序,保证各平台顺序一致
        images.sort()
        # 获取该类别对应的索引
        image_class = class_indices[cla]
        # 记录该类别的样本数量
        every_class_num.append(len(images))
        # 按比例随机采样验证样本
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images.append(img_path)
                val_label.append(image_class)
            else:  # 否则存入训练集
                train_images.append(img_path)
                train_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images)))
    print("{} images for validation.".format(len(val_images)))
    assert len(train_images) > 0, "number of training images must greater than 0."
    assert len(val_images) > 0, "number of validation images must greater than 0."

    return train_images, train_label, val_images, val_label

此时可通过以下代码获得训练集和测试集数据:

python 复制代码
train_images, train_label, val_images, val_label = read_split_data(data_path)

完结撒花。

相关推荐
兰亭妙微3 小时前
用户体验的真正边界在哪里?对的 “认知负荷” 设计思考
人工智能·ux
13631676419侯3 小时前
智慧物流与供应链追踪
人工智能·物联网
TomCode先生3 小时前
MES 离散制造核心流程详解(含关键动作、角色与异常处理)
人工智能·制造·mes
zd2005724 小时前
AI辅助数据分析和学习了没?
人工智能·学习
johnny2334 小时前
强化学习RL
人工智能
乌恩大侠4 小时前
无线网络规划与优化方式的根本性变革
人工智能·usrp
放羊郎4 小时前
基于萤火虫+Gmapping、分层+A*优化的导航方案
人工智能·slam·建图·激光slam
王哈哈^_^4 小时前
【数据集+完整源码】水稻病害数据集,yolov8水稻病害检测数据集 6715 张,目标检测水稻识别算法实战训推教程
人工智能·算法·yolo·目标检测·计算机视觉·视觉检测·毕业设计
SEOETC4 小时前
数字人技术:虚实交融的未来图景正在展开
人工智能
boonya4 小时前
从阿里云大模型服务平台百炼看AI应用集成与实践
人工智能·阿里云·云计算