【深度学习入门篇 ③】PyTorch的数据加载

【🍊 易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊 】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


掌握PyTorch数据通常的处理方法,是构建高效、可扩展模型的关键一步。今天,我们就利用PyTorch高效地处理数据,为模型训练打下坚实基础。

在前面的线性回归模型中,我们使用的数据很少,所以直接把全部数据放到模型中去使用。

但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的数据,不可能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成一个个的batch,同时还会对数据进行预处理。

所以,接下来我们来学习pytorch中的数据加载的方法~

Dataset基类介绍

dataset定义了这个数据集的总长度,以及会返回哪些参数,模板:

python 复制代码
from torch.utils.data import Dataset
 
class MyDataset(Dataset):
    def __init__(self, ):
        # 定义数据集包含的数据和标签
 
    def __len__(self):
        return len(...)
    def __getitem__(self, index):
        # 当数据集被读取时,返回一个包含数据和标签的元组

数据加载案例

数据来源:http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection

该数据集包含了5574条 短信,其中正常短信(标记为"ham")4831条 ,骚扰短信(标记为"spam")743条。

python 复制代码
from torch.utils.data import Dataset,DataLoader
import pandas as pd

data_path = r"data/SMSSpamCollection"    # 路径

class SMSDataset(Dataset):
    def __init__(self):
        lines = open(data_path,"r",encoding="utf-8")
        # 前4个为label,后面的为短信内容
        lines = [[i[:4].strip(),i[4:].strip()] for i in lines]
        # 转为dataFrame类型
        self.df = pd.DataFrame(lines,columns=["label","sms"])

    def __getitem__(self, index):
        single_item = self.df.iloc[index,:]
        return single_item.values[0],single_item.values[1]

    def __len__(self):
        return self.df.shape[0]

我们现在已经成功地构建了一个数据集类 SMSDataset,这个类能够加载SMS 垃圾短信数据集,并将每条短信及其对应的标签(hamspam)封装为可迭代的形式,以便于后续的数据加载和模型训练。

python 复制代码
d = SMSDataset()
for i in range(len(d)):
    print(i,d[i])

输出:

python 复制代码
...
5566 ('ham', "Why don't you wait 'til at least wednesday to see if you get your .")
5567 ('ham', 'Huh y lei...')
5568 ('spam', 'REMINDER FROM O2: To get 2.50 pounds free call credit and details of great offers pls reply 2 this text with your valid name, house no and postcode')
5569 ('spam', 'This is the 2nd time we have tried 2 contact u. U have won the £750 Pound prize. 2 claim is easy, call 087187272008 NOW1! Only 10p per minute. BT-national-rate.')
5570 ('ham', 'Will ü b going to esplanade fr home?')
5571 ('ham', 'Pity, * was in mood for that. So...any other suggestions?')
5572 ('ham', "The guy did some bitching but I acted like i'd be interested in buying something else next week and he gave it to us for free")
5573 ('ham', 'Rofl. Its true to its name')

DataLoader格式说明

python 复制代码
my_dataset = DataLoader(mydataset, batch_size=2, shuffle=True,num_workers=4)
 # num_workers:多进程读取数据

DataLoader的使用方法示例:

python 复制代码
from torch.utils.data import DataLoader

dataset = MyDataset()
data_loader = DataLoader(dataset=dataset,batch_size=10,shuffle=True,num_workers=2)

#遍历,获取其中的每个batch的结果
for index, (label, context) in enumerate(data_loader):
    print(index,label,context)
    print("*"*100)
  1. dataset:提前定义的dataset的实例

  2. batch_size:传入数据的batch的大小,常用128,256等等

  3. shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据

  4. num_workers:加载数据的线程数

导入两个列表到Dataset

python 复制代码
class MyDataset(Dataset):
    def __init__(self, ):
        # 定义数据集包含的数据和标签
        self.x_data = [i for i in range(10)]
        self.y_data = [2*i for i in range(10)]
 
    def __len__(self):
        return len(self.x_data)
    def __getitem__(self, index):
        # 当数据集被读取时,返回一个包含数据和标签的元组
        return self.x_data[index], self.y_data[index]
 
mydataset = MyDataset()
my_dataset = DataLoader(mydataset)
 
for x_i ,y_i in my_dataset:
    print(x_i,y_i)

💬输出:

python 复制代码
tensor([0]) tensor([0])
tensor([1]) tensor([2])
tensor([2]) tensor([4])
tensor([3]) tensor([6])
tensor([4]) tensor([8])
tensor([5]) tensor([10])
tensor([6]) tensor([12])
tensor([7]) tensor([14])
tensor([8]) tensor([16])
tensor([9]) tensor([18])

💬如果修改batch_size为2,则输出:

python 复制代码
tensor([0, 1]) tensor([0, 2])
tensor([2, 3]) tensor([4, 6])
tensor([4, 5]) tensor([ 8, 10])
tensor([6, 7]) tensor([12, 14])
tensor([8, 9]) tensor([16, 18])
  • 我们可以看出,这是管理每次输出的批次的
  • 还可以控制用多少个线程来加速读取数据(Num Workers),这参数和电脑cpu核心数有关系,尽量不超过电脑的核心数

我们看到可以不使用DataLoader,但这样就不能批次处理,只能for i in range(len(d))这样得到数据,也不能自动实现打乱逻辑,也不能串行加载。

python 复制代码
data_loader = DataLoader(dataset=Dataset,batch_size=10,shuffle=True,num_workers=2)
# 获取其中的每个batch的结果
for index, (label, context) in enumerate(data_loader):
    print(index,label,context)
    print("*"*100)

输出:

python 复制代码
555 ('ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'spam') ("I forgot 2 ask ü all smth.. There's a card on da present lei... How? Ü all want 2 write smth or sign on it?", 'Am i that much dirty fellow?', 'have got * few things to do. may be in * pub later.', 'Ok lor. Anyway i thk we cant get tickets now cos like quite late already. U wan 2 go look 4 ur frens a not? Darren is wif them now...', 'When you came to hostel.', 'Well i know Z will take care of me. So no worries.', 'I REALLY NEED 2 KISS U I MISS U MY BABY FROM UR BABY 4EVA', 'Booked ticket for pongal?', 'Awww dat is sweet! We can think of something to do he he! Have a nice time tonight ill probably txt u later cos im lonely :( xxx.', 'We tried to call you re your reply to our sms for a video mobile 750 mins UNLIMITED TEXT + free camcorder Reply of call 08000930705 Now')
****************************************************************************************************
556 ('ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'spam') (':-( sad puppy noise', 'G.W.R', 'Otherwise had part time job na-tuition..', 'They finally came to fix the ceiling.', 'The word "Checkmate" in chess comes from the Persian phrase "Shah Maat" which means; "the king is dead.." Goodmorning.. Have a good day..:)', 'Yup', 'I am real, baby! I want to bring out your inner tigress...', 'THANX4 TODAY CER IT WAS NICE 2 CATCH UP BUT WE AVE 2 FIND MORE TIME MORE OFTEN OH WELL TAKE CARE C U SOON.C', "She said,'' do u mind if I go into the bedroom for a minute ? '' ''OK'', I sed in a sexy mood. She came out 5 minuts latr wid a cake...n My Wife,", 'Ur cash-balance is currently 500 pounds - to maximize ur cash-in now send COLLECT to 83600 only 150p/msg. CC: 08718720201 PO BOX 114/14 TCR/W1')
****************************************************************************************************
557 ('ham', 'ham', 'ham', 'ham') ('It shall be fine. I have avalarr now. Will hollalater', "Nah it's straight, if you can just bring bud or drinks or something that's actually a little more useful than straight cash", 'U sleeping now.. Or you going to take? Haha.. I got spys wat.. Me online checking n replying mails lor..', 'In other news after hassling me to get him weed for a week andres has no money. HAUGHAIGHGTUJHYGUJ')
****************************************************************************************************

导入Excel数据到Dataset中

💥dataset只是一个类,因此数据可以从外部导入,我们也可以在dataset中规定数据在返回时进行更多的操作,数据在返回时也不一定是有两个。

python 复制代码
pip install pandas
pip install openpyxl
python 复制代码
class myDataset(Dataset):
    def __init__(self, data_loc):
        data = pd.read_ecl(data_loc)
        self.x1,self.x2,self.x3,self.x4,self.y = data['x1'],data['x2'],data['x3'] ,data['x4'],data['y']
 
    def __len__(self):
        return len(self.x1)
 
    def __getitem__(self, idx):
        return self.x1[idx],self.x2[idx],self.x3[idx],self.x4[idx],self.y[idx]
 
mydataset = myDataset(data_loc='e:\pythonProject Pytorch1\data.xls')
my_dataset = DataLoader(mydataset,batch_size=2)
for x1_i ,x2_i,x3_i,x4_i,y_i in my_dataset:
    print(x1_i,x2_i,x3_i,x4_i,y_i)

💯加载官方数据集

有一些数据集是PyTorch自带的,它被保存在TorchVisiontorchtext

  1. torchvision提供了对图片数据处理相关的api和数据

    • 数据位置:torchvision.datasets,例如:torchvision.datasets.MNIST(手写数字图片数据)
  2. torchtext提供了对文本数据处理相关的API和数据

    • 数据位置:torchtext.datasets,例如:torchtext.datasets.IMDB(电影评论文本数据)

我们以Mnist手写数字为例 ,看看pytorch如何加载其中自带的数据集

python 复制代码
torchvision.datasets.MNIST(root='/files/', train=True, download=True, transform=)`
  1. root参数表示数据存放的位置

  2. train:bool类型,表示是使用训练集的数据还是测试集的数据

  3. download:bool类型,表示是否需要下载数据到root目录

  4. transform:实现的对图片的处理函数

相关推荐
余炜yw28 分钟前
【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感
人工智能·rnn·深度学习
莫叫石榴姐1 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
96771 小时前
对抗样本存在的原因
深度学习
如若1231 小时前
利用 `OpenCV` 和 `Matplotlib` 库进行图像读取、颜色空间转换、掩膜创建、颜色替换
人工智能·opencv·matplotlib
YRr YRr1 小时前
深度学习:神经网络中的损失函数的使用
人工智能·深度学习·神经网络
ChaseDreamRunner2 小时前
迁移学习理论与应用
人工智能·机器学习·迁移学习
Guofu_Liao2 小时前
大语言模型---梯度的简单介绍;梯度的定义;梯度计算的方法
人工智能·语言模型·矩阵·llama
我爱学Python!2 小时前
大语言模型与图结构的融合: 推荐系统中的新兴范式
人工智能·语言模型·自然语言处理·langchain·llm·大语言模型·推荐系统
果冻人工智能2 小时前
OpenAI 是怎么“压力测试”大型语言模型的?
人工智能·语言模型·压力测试
日出等日落2 小时前
Windows电脑本地部署llamafile并接入Qwen大语言模型远程AI对话实战
人工智能·语言模型·自然语言处理