【PyTorch】数据集

文章目录

  • [1. 创建数据集](#1. 创建数据集)
    • [1.1. 直接继承Dataset类](#1.1. 直接继承Dataset类)
    • [1.2. 使用TensorDataset类](#1.2. 使用TensorDataset类)
  • [2. 数据集的划分](#2. 数据集的划分)
  • [3. 加载数据集](#3. 加载数据集)
  • [4. 将数据转移到GPU](#4. 将数据转移到GPU)

1. 创建数据集

主要是将数据集读入内存,并用Dataset类封装。

1.1. 直接继承Dataset类

必须要重写__getitem__方法,用于根据索引获得相应样本数据。必要时还可以重写__len__方法,用于返回数据集的大小。

python 复制代码
from torch.utils.data import Dataset

class BostonHousingDataset(Dataset):
    """定义波士顿房价数据集"""
    def __init__(self):
        self.data = np.load('../dataset/boston_housing/boston_housing.npz')

    def __getitem__(self, index):
        return self.data['x'][index], self.data['y'][index]

    def __len__(self):
        return self.data['x'].shape[0]

1.2. 使用TensorDataset类

将多个张量组合成一个数据集,要保证所有张量的第一个维度相等,保证每批样本数据格式相同。

python 复制代码
import torch
from torch.utils.data import TensorDataset

data = np.load('../dataset/boston_housing/boston_housing.npz')
X = torch.tensor(data['x'])
y = torch.tensor(data['y'])
dataset = TensorDataset(X, y)

2. 数据集的划分

数据集可以划分为训练集、验证集和测试集。

  • 训练集:用于模型拟合的数据样本集合。
  • 验证集:通常被用来调整模型的参数,以找出效果最佳的模型。
  • 测试集:用于训练好的模型性能评估的数据样本集合。
python 复制代码
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

3. 加载数据集

使用DataLoader类将Dataset封装的数据集分成批次并进行迭代,以便于模型训练。DataLoader常用参数如下:

  • dataset
    要加载的数据集。
  • batch_size
    每个数据批次中包含的样本数。默认为1。
  • shuffle
    是否打乱数据集。默认为False。
  • num_workers
    使用几个进程来加载数据。默认为0,即在主进程中加载数据。
  • drop_last
    当数据集样本数不能被batch_size整除时,是否舍弃最后一个不完整的batch。默认为False。
python 复制代码
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

4. 将数据转移到GPU

一般在要运算时才将数据转移到GPU,有以下两种方法:

  1. var.to(device)
  2. var.cuda()
python 复制代码
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for X,y in dataloader:
    # 将数据转移到GPU
    X = X.to(device)
    y = y.to(device)
    # 也可以
    X = X.cuda()
    y = y.cuda()
相关推荐
三无推导7 小时前
ComfyUI 安装部署教程:Windows 下快速搭建可视化 AI 绘图工作流,零基础也能跑通
人工智能·pytorch·windows·stable diffusion·aigc·ai绘画·持续部署
独隅9 小时前
PyTorch自动微分模块:从原理到实战一
人工智能·pytorch·python
不羁的木木16 小时前
HarmonyOS文件基础服务(Core File Kit)实战演练03-文件增删改查与目录操作
pytorch·华为·harmonyos
盼小辉丶17 小时前
PyTorch深度学习实战(55)——在Android上部署PyTorch模型
android·pytorch·python·模型部署
zhendianluli1 天前
PyTorch 复杂模型转 ONNX 踩坑纪实:从 diff 到 nan_to_num 的三关突破
人工智能·pytorch·python
weixin_468466852 天前
PyTorch 与 TensorFlow 实战选型与应用场景指南
人工智能·pytorch·深度学习·算法·机器学习·tensorflow·深度学习框架
独隅2 天前
PyTorch 新手从零搭建深度学习环境实战指南
人工智能·pytorch·深度学习
keineahnung23452 天前
在 Google Colab 中安裝 PyTorch 2.2.0
人工智能·pytorch·python·深度学习
AI算法沐枫2 天前
机器学习经典小项目1:鸢尾花分类
人工智能·pytorch·深度学习·神经网络·机器学习·分类·数据挖掘
weixin_468466852 天前
PyTorch 深度学习框架核心能力与实战评测
人工智能·pytorch·深度学习·神经网络·计算机视觉·动态图·模型训练