PyTorch 创建数据集

图片数据和标签数据准备

1.本文所用图片数据在同级文件夹中 ,文件路径为'train/'

2.标签数据在同级文件,文件路径为'train.csv'

3。将标签数据提取

python 复制代码
train_csv=pd.read_csv('train.csv')

创建继承类

第一步,首先创建数据类对象 此时可以想象为单个数据单元的创建 { 图像,标签}

继承的是Dataset类 (数据集类)

python 复制代码
from torch.utils.data import Dataset
from PIL import Image          //从文件路径中提取图片所需要的函数

class Imagedata(Dataset):        //继承Dataset类
	def __init__(self,df,dir,transform=None):     //往类里传输需要的数据必须在这定义,后面初始化函数才能使用传入的数据,
	                                              //df表示传入的标签数据,dir表示图像数据文件地址,transform是图像增强的处理操作
	      super().__init__()                      //声明后面操作需要用的数据
	      self.df=df                           
	      self.dir=dir
	      self.transform=transform
    def __len__(self):                     //模板函数,没什么卵用
        return len(self.df)
    def __getitem__(self, idex):           //将单个数据和标签整合到一块的初始化函数
        img_id=self.df.iloc[idex,0]        //图片的名称在df文件中,标签也在df的文件中,如下图,为的就是提出图像数据文件中的图片,否则从图片数据文件中一张一张提取出来很难,名称太长
        img=Image.open(self.dir+img_id)   //拿到了图片的整个完整地址  
        img=np.array(img)                //Image提取出来的为image类型,需要转换为numpy数组,才能存储到数据集中
                                         //上面两行也可以换为cv2.imread(dir),直接读取的数据就可以往里面存,避免了数据转换
        label=self.df.iloc[idex,1]       //从df中提取对应的标签,就是同一张图像的标签,由idex固定
        return img,label                 //返回整理好的单个数据单元(图像+标签)
		

第二步,创造好了单个数据单元对象,那么需要将多个数据单元整合起来构成一个完整的数据集

先将单个数据单元实现,因为上面的代码为类对象代码,并没有实现

python 复制代码
train_dataset=ImageDataset(df=train_csv,dir='train/')  //df为标签文件,dir表示你图像存储的文件地址

得到了单个数据单元,那么开始将数据整合,先调用数据整合函数:

python 复制代码
from torch.utils.data import DataLoader

通过数据流来整合

python 复制代码
train_data=DataLoader(train_dataset,batch_size=32)    //train_dataset 为单个对象     batch_size为设置几个为一小组,为后面的分组训练做准备

那么最后得到的train_data就是带有图像和标签的数据集,可以验证一下:

python 复制代码
for img,label in train_data:
    print(img,label)
图像增强技术(降噪,标准化)

上面没有加入图像增强代码,创建数据集时候,可以先将图像增强后再存入数据集,增强的主要目的就是提高训练准确率,标准化可以使图像在神经网络训练的更快,因为图像的数据明显变小,举个例子,由像素[233,221,222]可以直接变为[2.33,2.21,2.22]

如下使图像增强代码,用的使torchvision,每行代码都有注释

python 复制代码
from torchvision import transforms

transform_train = transforms.Compose([transforms.ToTensor(),        //将图像变为Tensor张量,并将图像像素由255-0变为1-0,压缩,并将图像的维度从 (H x W x C) 转换为 (C x H x W)
                                      transforms.Pad(32, padding_mode='symmetric')   //表示在图像的四周各填充 32 个像素。
                                      transforms.RandomHorizontalFlip(),    //以一定的概率对图像进行随机水平翻转。这有助于增加数据的多样性,提高模型的泛化能力。防止拟合
                                      transforms.RandomVerticalFlip(),      //以一定的概率对图像进行随机垂直翻转。同样是为了增加数据多样性
                                      transforms.RandomRotation(10),       //以一定的概率对图像进行随机旋转,旋转角度在 -10 到 10 度之间。增加数据的多样性
                                      transforms.Normalize((0.485, 0.456, 0.406),     //指定每个通道的均值。通常是在 ImageNet 数据集上计算得到的均值。
                                                           (0.229, 0.224, 0.225))])   //指定每个通道的标准差。也是在 ImageNet 数据集上计算得到的标准差。
                                                           

那么在数据单元创建的时候加入,以下是完整代码:

python 复制代码
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, df, dir, transform=None): 
        super().__init__()
        
        self.df = df
        self.dir = dir
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_id = self.df.iloc[idx,0]
        img_path = self.dir + img_id
        image = cv2.imread(img_path)            //这里用了cv2直接读取图片,避免了转换numpy
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)   //opencv里的数据增强
        label = self.df.iloc[idx,1]
        
        if self.transform is not None:
            image = self.transform(image)
        return image, label


-----------------------图像增强技术------------------------
from torchvision import transforms
transform_train = transforms.Compose([transforms.ToTensor(),
                                      transforms.Pad(32, padding_mode='symmetric'),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomVerticalFlip(),
                                      transforms.RandomRotation(10),
                                      transforms.Normalize((0.485, 0.456, 0.406),
                                                           (0.229, 0.224, 0.225))])
transform_test = transforms.Compose([transforms.ToTensor(),
                                     transforms.Pad(32, padding_mode='symmetric'),
                                        transforms.Normalize((0.485, 0.456, 0.406),
                                                           (0.229, 0.224, 0.225))])


from torch.utils.data import DataLoader
dataset_train = ImageDataset(df=train_df, img_dir='train/',transform=transform_train)
loader_train = DataLoader(dataset=dataset_train, batch_size=32, shuffle=True)
相关推荐
LCG元23 分钟前
垂直Agent才是未来:详解让大模型"专业对口"的三大核心技术
人工智能
我不是QI43 分钟前
周志华《机器学习—西瓜书》二
人工智能·安全·机器学习
BBB努力学习程序设计1 小时前
Python面向对象编程:从代码搬运工到架构师
python·pycharm
操练起来1 小时前
【昇腾CANN训练营·第八期】Ascend C生态兼容:基于PyTorch Adapter的自定义算子注册与自动微分实现
人工智能·pytorch·acl·昇腾·cann
rising start1 小时前
五、python正则表达式
python·正则表达式
KG_LLM图谱增强大模型1 小时前
[500页电子书]构建自主AI Agent系统的蓝图:谷歌重磅发布智能体设计模式指南
人工智能·大模型·知识图谱·智能体·知识图谱增强大模型·agenticai
声网1 小时前
活动推荐丨「实时互动 × 对话式 AI」主题有奖征文
大数据·人工智能·实时互动
caiyueloveclamp1 小时前
【功能介绍03】ChatPPT好不好用?如何用?用户操作手册来啦!——【AI溯源篇】
人工智能·信息可视化·powerpoint·ai生成ppt·aippt
BBB努力学习程序设计1 小时前
Python错误处理艺术:从崩溃到优雅恢复的蜕变
python·pycharm
q***48411 小时前
Vanna AI:告别代码,用自然语言轻松查询数据库,领先的RAG2SQL技术让结果更智能、更精准!
人工智能·microsoft