pytorch导入数据集

1、概念:

Dataset:一种数据结构,存储数据及其标签

Dataloader:一种工具,可以将Dataset里的数据分批、打乱、批量加载并进行迭代等

(方便模型训练和验证)

Dataset就像一个大书架,存放着带有标签的数据书籍,并且这些书有编号(0,1,2...);

而Dataloader就像一个图书管理员,负责从书架上按需取出书籍并分批提供给读者。

2、Dataset的组织形式

train:训练集 val:验证集

一种方式是label作为数据文件夹的名字,

另一种方式是label和数据本身分开成两个文件夹(label文件夹里装的是和每个数据对应的.txt)

3、处理图像:PIL(Python Imaging Library).Image

|--------------------------------------|--------------------------------------|
| pip install Pillow | 安装PIL |
| from PIL import Image | 引入Image类(代表图像对象, 可以通过创建Image实例来操作图像) |
| img=Image.open('图像路径') 打开图像 | img.show() 显示图像 |
| print(img.size) 输出(宽度,高度) | print(img.format) 输出图像格式(JPEG、PNG等) |
| resized_img=img.resize((宽度,高度)) 调整大小 | |
| resized_img=img.save('新路径') 保存为新文件 | |

4、处理目录和文件:os

|---------------------------------------------|--------------------------------------|
| import os | |
| cur_dir=os.getcwd() | 获取当前工作目录 |
| files=os.listdir(cur_dir) | 列举当前目录下的所有子目录(文件和文件夹) |
| os.makedirs('new_folder') | 创建新文件夹(如果不存在) |
| os.remove('file.txt') | 删除文件(os.rmdir('empty_folder')删除空文件夹) |
| os.path.exists('some_path') | 检查路径是否存在 |
| file_path=os.path.join('folder','file.txt') | 拼接路径 |
| abs_path=os.path.abspath('file.txt) | 获取文件的绝对路径 |

5、代码

python 复制代码
from torch.utils.data import Dataset #从torch的常用工具箱utils中拿data工具,然后引入Dataset类
from PIL import Image #处理图片要用到
import os #访问目录、获取图片的地址要用到

class MyData(Dataset): #让MyData类继承Dataset类
    def __init__(self,root_dir,label_dir): #数据集的初始化:要用到根目录和标签目录(这里把label作为数据文件夹的名字了)
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir) #根目录+标签目录=数据集的路径
        self.img_dir_list=os.listdir(self.path) #列举数据集目录下的每个数据(文件)

    def __getitem__(self,idx): #获取索引对应的数据
        img_dir=self.img_dir_list[idx] #得到索引对应的数据文件
        img_path=os.path.join(self.root_dir,self.label_dir,img_dir) #数据集路径+数据文件=数据文件路径
        img=Image.open(img_path)
        label=self.label_dir
        return img,label

    def __len__(self):
        return len(self.img_dir_list) #数据长度=数据集目录下的子文件数量

root_dir=r"dataset/hymenoptera_data/train"
ants_label_dir="ants"
ants_dataset=MyData(root_dir,ants_label_dir)
bees_label_dir="bees"
bees_dataset=MyData(root_dir,bees_label_dir)

train_dataset=ants_dataset+bees_dataset
相关推荐
企业架构师老王1 分钟前
药品生产环节:用实在Agent自动生成批记录与打印领料单的合规设计与架构落地
大数据·人工智能·ai·架构
m0_588758485 分钟前
高效实现分组内跨行时间戳匹配:为每组生成布尔标记列 user_rejects
jvm·数据库·python
黎阳之光5 分钟前
视频孪生重构轨交数字孪生新范式|黎阳之光以自主核心技术破解落地难题
大数据·人工智能·算法·安全·数字孪生
ai产品老杨5 分钟前
告别重复造轮子:深度解析支持源码交付的 AI 视频平台架构,实现 X86/ARM 与 GPU/NPU 异构算力融合
人工智能·架构·音视频
好运的阿财7 分钟前
OpenClaw工具拆解之 web_fetch+image_generate
前端·python·机器学习·ai·ai编程·openclaw·openclaw工具
写代码的小阿帆9 分钟前
AI工具使用——外挂AI插件、AI原生IDE与AI终端
ide·人工智能·ai-native
谢谢 啊sir9 分钟前
L2-060 大语言模型的推理 - java
java·人工智能·语言模型
阿杰学AI9 分钟前
AI核心知识140—大语言模型之 推理期算力(简洁且通俗易懂版)
人工智能·语言模型·自然语言处理·思维链·思维树·慢思考·推理期算力
云淡风轻~窗明几净9 分钟前
关于TSP的sealine算法与角谷猜想(2026-04-25)
数据结构·人工智能·算法·动态规划·模拟退火算法
wayz1110 分钟前
Day 13:朴素贝叶斯分类器
人工智能·算法·机器学习·朴素贝叶斯