【Pytorch笔记】5.DataLoader、Dataset、自定义Dataset

参考
深度之眼官方账号 - 02-01 Dataloader与Dataset.mp4

torch.utils.data.DataLoader

功能:构建可迭代的数据装载器。

python 复制代码
data.DataLoader(dataset,
                batch_size=1,
                shuffle=False,
                sampler=None,
                batch_sampler=None,
                num_workers=0,
                collate_fn=None,
                pin_memory=False,
                drop_last=False,
                timeout=0,
                worker_init_fn=None,
                multiprocessing_context=None)

dataset:Dataset类,决定数据从哪里读取;

batch_size:批大小;

shuffle:每个epoch是否乱序;

sampler:定义从数据集中抽取数据的策略,指明sampler时不能指明shuffle。

batch_sampler:和sampler相似,但是一次返回一个批大小的下标。和batch_sizeshufflesamplerdrop_last互斥;

num_workers:使用该DataLoader加载数据的子进程 的数量。如果该值为0,意味着只有主进程使用该DataLoader加载数据。

collate_fn:合并样本列表,形成一个小型tensor批次。在使用map-style dataset批量加载时使用。

pin_memory:True时,DataLoader在返回tensor数据时会先将这些数据复制到device/CUDA的pinned memory中。

drop_last:当样本数不能被batch_size整除时,是否舍弃最后一批数据。

timeout:如果是正数,则表示获取批次的超时值。要求设置为非负数。

worker_init_fn:如果不是None,那么这个函数会调用所有带有worker_id的worker的子进程。

Epoch :所有训练样本都已输入到模型中,称为一个Epoch。
Iteration :一批样本输入到模型中,称为一个Iteration。
BatchSize:批大小,决定一个Epoch有多少个Iteration。

torch.utils.data.Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承他,并且复写__getitem__()

getitem():接收一个索引,返回一个样本。

torch.utils.data.TensorDataset

由tensor封装的数据集,传入一些在第1维长度相同的tensor,形成dataset。

DataLoader如何使用?

python 复制代码
import torch
from torch.utils import data


def data_gen(num_examples):
    X = torch.normal(0, 1, (num_examples, 2))
    y = torch.normal(0, 1, (num_examples, 1))
    return X, y


def load_array(data_arrays, batch_size, need_shuffle=True):
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=need_shuffle)


data_A, data_B = data_gen(30)
# 定义一个数据迭代器data_iter
data_iter = load_array((data_A, data_B), 10, False)

# 使用for循环,iter每次会指向下一组数据
# 每组数据返回我们封装进去的那些tensor,封装几个就返回几个
for X, y in iter(data_iter):
    print("X\n", X)
    print("y\n", y)

输出:

python 复制代码
X
 tensor([[-0.2338, -0.4072],
        [ 1.6080, -0.0880],
        [-1.0243, -0.6245],
        [ 0.6012,  1.9620],
        [ 1.1876,  0.9539],
        [-0.5972,  0.9251],
        [-0.4918, -0.1340],
        [ 2.3297,  0.1833],
        [-0.8487,  0.3370],
        [-0.0600,  2.3769]])
y
 tensor([[ 1.0454],
        [ 0.5992],
        [ 2.0075],
        [ 0.2727],
        [-1.6845],
        [ 0.0845],
        [ 1.0992],
        [ 0.5103],
        [-0.6727],
        [-0.1900]])
X
 tensor([[-0.1419, -1.5535],
        [-1.6436,  0.8680],
        [-0.4432, -0.7703],
        [ 0.3822,  0.4675],
        [ 0.4000,  1.3471],
        [ 0.9776,  2.0103],
        [ 0.1298,  2.7382],
        [ 0.2664, -0.6223],
        [-1.0774,  0.0734],
        [-0.1904, -1.3299]])
y
 tensor([[-0.5979],
        [-0.5432],
        [ 0.2951],
        [ 0.2811],
        [-0.5997],
        [ 0.8073],
        [ 1.4356],
        [ 1.1555],
        [-0.3368],
        [-0.0626]])
X
 tensor([[-1.4326, -0.3407],
        [-1.1878, -1.5619],
        [ 0.3498,  1.5307],
        [-0.8174,  0.6017],
        [ 0.8076,  0.8295],
        [ 2.6239,  1.1669],
        [-1.2598,  1.4309],
        [ 0.3365,  0.1765],
        [-0.4472, -0.6882],
        [ 0.6732, -0.0742]])
y
 tensor([[ 0.5114],
        [ 1.0669],
        [-1.5565],
        [ 0.4512],
        [ 3.2071],
        [ 0.4752],
        [-1.5981],
        [ 0.0035],
        [-0.2723],
        [-1.3634]])

自定义Dataset

python 复制代码
import torch
from torch.utils import data

# 自定义的Dataset
# 每一条数据包含X的一行数据(1x2的tensor)和y的一行数据(1x1的tensor)
class MyTensorDataset(data.Dataset):
    def __init__(self, tensor_list):
        # 初始化超类 (data.Dataset)
        super(MyTensorDataset, self).__init__()
        self.data = tensor_list

    
    def __getitem__(self, index):
        # 这个函数必须重写
        # 返回dataset中下标编号为index的数据
        sample = {'X': self.data[0][index], 'y': self.data[1][index]}
        return sample
    

    def __len__(self):
        # 返回这个dataset中的数据个数
        return len(self.data[0])


# 随机生成数据
def data_gen(num_examples):
    X = torch.normal(0, 1, (num_examples, 2))
    y = torch.normal(0, 1, (num_examples, 1))
    return X, y


# 使用自己的Dataset
def load_array(data_arrays, batch_size, need_shuffle=True):
    X, y = data_arrays  
    dataset = MyTensorDataset((X, y)) 
    return data.DataLoader(dataset, batch_size, shuffle=need_shuffle)


data_A, data_B = data_gen(30)
data_iter = load_array((data_A, data_B), 10, False)

for sample in iter(data_iter):
    print("sample:\n", sample)

输出:

python 复制代码
sample:
 {'X': tensor([[ 0.2149,  0.6216],
        [ 0.4691,  0.1862],
        [-1.7705, -1.5983],
        [-0.0196, -0.5903],
        [-0.1313,  0.3206],
        [ 0.1898, -0.2575],
        [ 1.5934, -0.4720],
        [-0.8343, -0.2181],
        [ 1.6159, -0.5473],
        [-1.2662, -0.0218]]), 'y': tensor([[ 0.8886],
        [ 0.1653],
        [ 0.8054],
        [-0.0725],
        [-0.4806],
        [-0.4661],
        [-0.4040],
        [ 0.6192],
        [-0.2522],
        [ 1.4091]])}
sample:
 {'X': tensor([[ 0.6441, -0.5759],
        [-0.7285, -1.0021],
        [ 0.1250, -0.2333],
        [ 0.3196,  0.7762],
        [ 0.1429,  0.4667],
        [ 1.0751, -0.4867],
        [ 0.1664,  0.3489],
        [ 0.1616, -0.1998],
        [ 0.6707, -0.4678],
        [-1.7778, -2.4658]]), 'y': tensor([[-0.2342],
        [ 0.1402],
        [ 0.5768],
        [ 1.7898],
        [-0.6802],
        [-2.3584],
        [ 0.7048],
        [ 0.1848],
        [ 0.1225],
        [ 0.5535]])}
sample:
 {'X': tensor([[ 0.1694,  0.2863],
        [ 0.3062, -0.7494],
        [ 0.6844,  1.9278],
        [ 0.9141,  0.3842],
        [-1.2314, -1.4933],
        [ 0.1568, -1.4182],
        [ 1.7723, -0.4890],
        [ 0.5734,  1.0614],
        [ 0.9536, -0.2866],
        [ 0.2510,  0.3375]]), 'y': tensor([[ 0.9802],
        [-0.5557],
        [ 0.7763],
        [ 1.1688],
        [-1.0067],
        [-0.4044],
        [-0.2745],
        [-0.3661],
        [-0.6058],
        [ 0.6905]])}
相关推荐
L.fountain14 小时前
强化学习2.2 MDP实践——Frozen lake
人工智能·强化学习
_dindong14 小时前
牛客101:链表
数据结构·c++·笔记·学习·算法·链表
JJJJ_iii14 小时前
【机器学习06】神经网络的实现、训练与向量化
人工智能·笔记·深度学习·神经网络·学习·机器学习·线性回归
倔强的石头10614 小时前
AI协作天花板!CherryStudio让多模型协同像搭积木
人工智能·cpolar
IT_陈寒14 小时前
Vite 3.0 性能优化实战:5个技巧让你的构建速度提升200% 🚀
前端·人工智能·后端
瑞士卷@14 小时前
spring从入门到精通(spring学习笔记,持续更新中)
笔记·学习·spring
说私域14 小时前
从工具理性到价值共生:开源链动2+1模式、AI智能名片与S2B2C商城系统的社会连接重构研究
人工智能·重构·开源
heisd_114 小时前
OpenCV计算机视觉库
人工智能·opencv·计算机视觉
我先去打把游戏先14 小时前
ESP32学习笔记(基于IDF):SmartConfig一键配网
笔记·嵌入式硬件·mcu·物联网·学习·esp32·硬件工程
摇滚侠18 小时前
Spring Boot 3零基础教程,WEB 开发 静态资源默认配置 笔记27
spring boot·笔记·后端