【深度学习入门篇 ③】PyTorch的数据加载

【🍊 易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊 】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


掌握PyTorch数据通常的处理方法,是构建高效、可扩展模型的关键一步。今天,我们就利用PyTorch高效地处理数据,为模型训练打下坚实基础。

在前面的线性回归模型中,我们使用的数据很少,所以直接把全部数据放到模型中去使用。

但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的数据,不可能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成一个个的batch,同时还会对数据进行预处理。

所以,接下来我们来学习pytorch中的数据加载的方法~

Dataset基类介绍

dataset定义了这个数据集的总长度,以及会返回哪些参数,模板:

python 复制代码
from torch.utils.data import Dataset
 
class MyDataset(Dataset):
    def __init__(self, ):
        # 定义数据集包含的数据和标签
 
    def __len__(self):
        return len(...)
    def __getitem__(self, index):
        # 当数据集被读取时,返回一个包含数据和标签的元组

数据加载案例

数据来源:http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection

该数据集包含了5574条 短信,其中正常短信(标记为"ham")4831条 ,骚扰短信(标记为"spam")743条。

python 复制代码
from torch.utils.data import Dataset,DataLoader
import pandas as pd

data_path = r"data/SMSSpamCollection"    # 路径

class SMSDataset(Dataset):
    def __init__(self):
        lines = open(data_path,"r",encoding="utf-8")
        # 前4个为label,后面的为短信内容
        lines = [[i[:4].strip(),i[4:].strip()] for i in lines]
        # 转为dataFrame类型
        self.df = pd.DataFrame(lines,columns=["label","sms"])

    def __getitem__(self, index):
        single_item = self.df.iloc[index,:]
        return single_item.values[0],single_item.values[1]

    def __len__(self):
        return self.df.shape[0]

我们现在已经成功地构建了一个数据集类 SMSDataset,这个类能够加载SMS 垃圾短信数据集,并将每条短信及其对应的标签(hamspam)封装为可迭代的形式,以便于后续的数据加载和模型训练。

python 复制代码
d = SMSDataset()
for i in range(len(d)):
    print(i,d[i])

输出:

python 复制代码
...
5566 ('ham', "Why don't you wait 'til at least wednesday to see if you get your .")
5567 ('ham', 'Huh y lei...')
5568 ('spam', 'REMINDER FROM O2: To get 2.50 pounds free call credit and details of great offers pls reply 2 this text with your valid name, house no and postcode')
5569 ('spam', 'This is the 2nd time we have tried 2 contact u. U have won the £750 Pound prize. 2 claim is easy, call 087187272008 NOW1! Only 10p per minute. BT-national-rate.')
5570 ('ham', 'Will ü b going to esplanade fr home?')
5571 ('ham', 'Pity, * was in mood for that. So...any other suggestions?')
5572 ('ham', "The guy did some bitching but I acted like i'd be interested in buying something else next week and he gave it to us for free")
5573 ('ham', 'Rofl. Its true to its name')

DataLoader格式说明

python 复制代码
my_dataset = DataLoader(mydataset, batch_size=2, shuffle=True,num_workers=4)
 # num_workers:多进程读取数据

DataLoader的使用方法示例:

python 复制代码
from torch.utils.data import DataLoader

dataset = MyDataset()
data_loader = DataLoader(dataset=dataset,batch_size=10,shuffle=True,num_workers=2)

#遍历,获取其中的每个batch的结果
for index, (label, context) in enumerate(data_loader):
    print(index,label,context)
    print("*"*100)
  1. dataset:提前定义的dataset的实例

  2. batch_size:传入数据的batch的大小,常用128,256等等

  3. shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据

  4. num_workers:加载数据的线程数

导入两个列表到Dataset

python 复制代码
class MyDataset(Dataset):
    def __init__(self, ):
        # 定义数据集包含的数据和标签
        self.x_data = [i for i in range(10)]
        self.y_data = [2*i for i in range(10)]
 
    def __len__(self):
        return len(self.x_data)
    def __getitem__(self, index):
        # 当数据集被读取时,返回一个包含数据和标签的元组
        return self.x_data[index], self.y_data[index]
 
mydataset = MyDataset()
my_dataset = DataLoader(mydataset)
 
for x_i ,y_i in my_dataset:
    print(x_i,y_i)

💬输出:

python 复制代码
tensor([0]) tensor([0])
tensor([1]) tensor([2])
tensor([2]) tensor([4])
tensor([3]) tensor([6])
tensor([4]) tensor([8])
tensor([5]) tensor([10])
tensor([6]) tensor([12])
tensor([7]) tensor([14])
tensor([8]) tensor([16])
tensor([9]) tensor([18])

💬如果修改batch_size为2,则输出:

python 复制代码
tensor([0, 1]) tensor([0, 2])
tensor([2, 3]) tensor([4, 6])
tensor([4, 5]) tensor([ 8, 10])
tensor([6, 7]) tensor([12, 14])
tensor([8, 9]) tensor([16, 18])
  • 我们可以看出,这是管理每次输出的批次的
  • 还可以控制用多少个线程来加速读取数据(Num Workers),这参数和电脑cpu核心数有关系,尽量不超过电脑的核心数

我们看到可以不使用DataLoader,但这样就不能批次处理,只能for i in range(len(d))这样得到数据,也不能自动实现打乱逻辑,也不能串行加载。

python 复制代码
data_loader = DataLoader(dataset=Dataset,batch_size=10,shuffle=True,num_workers=2)
# 获取其中的每个batch的结果
for index, (label, context) in enumerate(data_loader):
    print(index,label,context)
    print("*"*100)

输出:

python 复制代码
555 ('ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'spam') ("I forgot 2 ask ü all smth.. There's a card on da present lei... How? Ü all want 2 write smth or sign on it?", 'Am i that much dirty fellow?', 'have got * few things to do. may be in * pub later.', 'Ok lor. Anyway i thk we cant get tickets now cos like quite late already. U wan 2 go look 4 ur frens a not? Darren is wif them now...', 'When you came to hostel.', 'Well i know Z will take care of me. So no worries.', 'I REALLY NEED 2 KISS U I MISS U MY BABY FROM UR BABY 4EVA', 'Booked ticket for pongal?', 'Awww dat is sweet! We can think of something to do he he! Have a nice time tonight ill probably txt u later cos im lonely :( xxx.', 'We tried to call you re your reply to our sms for a video mobile 750 mins UNLIMITED TEXT + free camcorder Reply of call 08000930705 Now')
****************************************************************************************************
556 ('ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'spam') (':-( sad puppy noise', 'G.W.R', 'Otherwise had part time job na-tuition..', 'They finally came to fix the ceiling.', 'The word "Checkmate" in chess comes from the Persian phrase "Shah Maat" which means; "the king is dead.." Goodmorning.. Have a good day..:)', 'Yup', 'I am real, baby! I want to bring out your inner tigress...', 'THANX4 TODAY CER IT WAS NICE 2 CATCH UP BUT WE AVE 2 FIND MORE TIME MORE OFTEN OH WELL TAKE CARE C U SOON.C', "She said,'' do u mind if I go into the bedroom for a minute ? '' ''OK'', I sed in a sexy mood. She came out 5 minuts latr wid a cake...n My Wife,", 'Ur cash-balance is currently 500 pounds - to maximize ur cash-in now send COLLECT to 83600 only 150p/msg. CC: 08718720201 PO BOX 114/14 TCR/W1')
****************************************************************************************************
557 ('ham', 'ham', 'ham', 'ham') ('It shall be fine. I have avalarr now. Will hollalater', "Nah it's straight, if you can just bring bud or drinks or something that's actually a little more useful than straight cash", 'U sleeping now.. Or you going to take? Haha.. I got spys wat.. Me online checking n replying mails lor..', 'In other news after hassling me to get him weed for a week andres has no money. HAUGHAIGHGTUJHYGUJ')
****************************************************************************************************

导入Excel数据到Dataset中

💥dataset只是一个类,因此数据可以从外部导入,我们也可以在dataset中规定数据在返回时进行更多的操作,数据在返回时也不一定是有两个。

python 复制代码
pip install pandas
pip install openpyxl
python 复制代码
class myDataset(Dataset):
    def __init__(self, data_loc):
        data = pd.read_ecl(data_loc)
        self.x1,self.x2,self.x3,self.x4,self.y = data['x1'],data['x2'],data['x3'] ,data['x4'],data['y']
 
    def __len__(self):
        return len(self.x1)
 
    def __getitem__(self, idx):
        return self.x1[idx],self.x2[idx],self.x3[idx],self.x4[idx],self.y[idx]
 
mydataset = myDataset(data_loc='e:\pythonProject Pytorch1\data.xls')
my_dataset = DataLoader(mydataset,batch_size=2)
for x1_i ,x2_i,x3_i,x4_i,y_i in my_dataset:
    print(x1_i,x2_i,x3_i,x4_i,y_i)

💯加载官方数据集

有一些数据集是PyTorch自带的,它被保存在TorchVisiontorchtext

  1. torchvision提供了对图片数据处理相关的api和数据

    • 数据位置:torchvision.datasets,例如:torchvision.datasets.MNIST(手写数字图片数据)
  2. torchtext提供了对文本数据处理相关的API和数据

    • 数据位置:torchtext.datasets,例如:torchtext.datasets.IMDB(电影评论文本数据)

我们以Mnist手写数字为例 ,看看pytorch如何加载其中自带的数据集

python 复制代码
torchvision.datasets.MNIST(root='/files/', train=True, download=True, transform=)`
  1. root参数表示数据存放的位置

  2. train:bool类型,表示是使用训练集的数据还是测试集的数据

  3. download:bool类型,表示是否需要下载数据到root目录

  4. transform:实现的对图片的处理函数

相关推荐
新加坡内哥谈技术8 分钟前
口哨声、歌声、boing声和biotwang声:用AI识别鲸鱼叫声
人工智能·自然语言处理
wx74085132619 分钟前
小琳AI课堂:机器学习
人工智能·机器学习
FL162386312926 分钟前
[数据集][目标检测]车油口挡板开关闭合检测数据集VOC+YOLO格式138张2类别
人工智能·yolo·目标检测
YesPMP平台官方29 分钟前
AI+教育|拥抱AI智能科技,让课堂更生动高效
人工智能·科技·ai·数据分析·软件开发·教育
FL16238631291 小时前
AI健身体能测试之基于paddlehub实现引体向上计数个数统计
人工智能
黑客-雨1 小时前
构建你的AI职业生涯:从基础知识到专业实践的路线图
人工智能·产品经理·ai大模型·ai产品经理·大模型学习·大模型入门·大模型教程
子午1 小时前
动物识别系统Python+卷积神经网络算法+TensorFlow+人工智能+图像识别+计算机毕业设计项目
人工智能·python·cnn
大耳朵爱学习1 小时前
掌握Transformer之注意力为什么有效
人工智能·深度学习·自然语言处理·大模型·llm·transformer·大语言模型
TAICHIFEI1 小时前
目标检测-数据集
人工智能·目标检测·目标跟踪
qq_15321452641 小时前
【2023工业异常检测文献】SimpleNet
图像处理·人工智能·深度学习·神经网络·机器学习·计算机视觉·视觉检测