卷积神经网络|制作自己的Dataset

在编写代码训练神经网络之前,导入数据是必不可少的。PyTorch提供了许多预加载的数据集(如FashionMNIST),这些数据集 子类并实现特定于特定数据的函数。

它们可用于对模型进行原型设计和基准测试,加载这些数据集是十分简单的。好吧,那如何加载自己制作的数据集呢?

简单来讲,自定义数据集类必须实现三个函数:initlen__和__getitem。下面代码就实现了一个Dataset

复制代码
import osimport torchfrom torch.utils.data import Datasetfrom torchvision import transformsfrom PIL import Imageimport numpy as np​class MyDataset(Dataset):    def __init__(self, path_file,transform=None,label_transform=None):        self.path_file=path_file        self.imgs=[name for name in os.listdir(path_file)]#获取path_file路径下所有文件名        self.transform = transform        self.label_transform = label_transform​    def __len__(self):        return len(self.imgs)​    def __getitem__(self, idx):        #get the image        img_path = os.path.join(self.path_file,self.imgs[idx])#获得图片完整路径        image=Image.open(img_path)        image=image.resize((28,28))#修改图片为默认大小        image = np.array(image)        image=torch.from_numpy(image)#将numpy数组转换为张量        image=image.permute(2,0,1)#将H,W,C转换为C,H,W​        if self.transform:            image = self.transform(image)​        #get the label        str1=self.imgs[idx].split('.')        label=torch.tensor(eval(str1[1]))​        if self.label_transform:            label=self.label_transform(label) ​        return image, label

注:上述代码从路径path_file读取文件,准确来讲应该是我们准备的训练图片,格式如下:

cat1.0.jpg

cat2.0.jpg

...

dog1.1.jpg

dog2.1.jpg

...

图片名重要含义:类别(0,1等)

而cat1,dog1这些并不重要,因为0,1,已经反映了图片的类别,这里仅仅是一个习惯,同样jpg也是如此。

实际上,在我们准备图片时,图片名往往不是这样,但直接写个简单的文件处理程序便很容易转变为上述格式

之所以这样命名,就是为容易获得图片和对应的类别,也就是实现自己的Dataset。当然,其它还有许多方法,但核心就是加载自己的数据时获得图片和对应的类别。

再次看一下实现自己的Dataset的架构:

复制代码
class CustomImageDataset(Dataset):    def __init__(self, path_file, transform=None, target_transform=None):        ...        ...        ...​    def __len__(self):        return len(...)​​​    def __getitem__(self, idx):        ...        ...        ...        if self.transform:            image = self.transform(image)        if self.label_transform:            label = self.label_transform(label)        return image, label

在训练模型时,我们通常希望 在"小批量"中传递样本,在每个时期重新洗牌数据以减少模型过度拟合,并使用 Python 的 加快数据检索速度。

**DataLoader是一个迭代对象,它在一个简单的 API 中为我们抽象了这种复杂性。**下面我们将Dataset带入DataLoader.

复制代码
path="E:\\3-10\\dogandcats\\train"#图片所在目录training_data=MyDataset(path)train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=2, shuffle=True)

让我们run一下:

复制代码
>>> trainimg,label=next(iter(train_dataloader))>>> trainimg.size()torch.Size([2, 3, 28, 28])>>> label.size()torch.Size([2])

结果符合预期,与在使用pytorch预加载的数据集格式一样!

点点点,赞和在看都在这儿!

相关推荐
机器学习之心2 小时前
PINN物理信息神经网络用于求解二阶常微分方程(ODE)的边值问题,Matlab实现
人工智能·神经网络·matlab·物理信息神经网络·二阶常微分方程
zandy10112 小时前
LLM与数据工程的融合:衡石Data Agent的语义层与Agent框架设计
大数据·人工智能·算法·ai·智能体
大千AI助手2 小时前
梯度消失问题:深度学习中的「记忆衰退」困境与解决方案
人工智能·深度学习·神经网络·梯度·梯度消失·链式法则·vanishing
计算机编程小央姐2 小时前
数据安全成焦点:基于Hadoop+Spark的信用卡诈骗分析系统实战教程
大数据·hadoop·python·spark·毕业设计·课程设计·dash
研梦非凡3 小时前
CVPR 2025|无类别词汇的视觉-语言模型少样本学习
人工智能·深度学习·学习·语言模型·自然语言处理
seegaler3 小时前
WrenAI:开源革命,重塑商业智能未来
人工智能·microsoft·ai
max5006003 小时前
本地部署开源数据生成器项目实战指南
开发语言·人工智能·python·深度学习·算法·开源
他们叫我技术总监3 小时前
【保姆级选型指南】2025年国产开源AI算力平台怎么选?覆盖企业级_制造业_国际化场景
人工智能·开源·算力调度·ai平台·gpu国产化
IT_陈寒3 小时前
🔥5个必学的JavaScript性能黑科技:让你的网页速度提升300%!
前端·人工智能·后端
czijin3 小时前
【论文阅读】Security of Language Models for Code: A Systematic Literature Review
论文阅读·人工智能·安全·语言模型·软件工程