深度学习数据预处理:Dataset类的全面解析与实战指南

前言

在深度学习项目中,数据预处理是模型训练前至关重要的一环。一个高效、灵活的数据预处理流程不仅能提升模型性能,还能大大加快开发效率。本文将深入探讨PyTorch中的Dataset类,介绍数据预处理的常见技巧,并通过实战示例展示如何构建自己的数据预处理流程。

一、Dataset作用

在深度学习项目中,原始数据通常需要经过一系列处理才能输入模型。Dataset类的主要作用包括:

  1. 数据统一接口:为不同类型的数据提供统一的访问接口

  2. 内存高效利用:实现按需加载,避免一次性加载所有数据

  3. 数据增强:方便集成各种数据增强技术

  4. 代码可维护性:使数据处理逻辑模块化,便于维护和复用

二、Dataset基础

PyTorch提供了两个核心类来处理数据:

  • torch.utils.data.Dataset :抽象类,所有自定义数据集应继承此类

  • torch.utils.data.DataLoader:数据加载器,负责批量生成数据

基本Dataset实现:

python 复制代码
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample, label

三、常见数据预处理技术

1. 图像数据预处理

python 复制代码
from torchvision import transforms

# 常见的图像预处理流程
image_transform = transforms.Compose([
    transforms.Resize(256),          # 调整大小
    transforms.CenterCrop(224),      # 中心裁剪
    transforms.ToTensor(),           # 转为Tensor
    transforms.Normalize(            # 标准化
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

2. 文本数据预处理

python 复制代码
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

# 分词器
tokenizer = get_tokenizer('basic_english')

# 构建词汇表
def yield_tokens(data_iter):
    for text, _ in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"])

# 文本转tensor
def text_pipeline(text):
    return torch.tensor([vocab[token] for token in tokenizer(text)], dtype=torch.long)

3. 数值数据预处理

python 复制代码
from sklearn.preprocessing import StandardScaler

# 标准化数值特征
scaler = StandardScaler()
train_data = scaler.fit_transform(train_data)
test_data = scaler.transform(test_data)  # 使用相同的scaler

四、高级Dataset技巧

1. 懒加载大数据集

对于大型数据集(如图像数据集),我们通常不希望一次性加载所有数据:

python 复制代码
class LazyImageDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
    
    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        image = Image.open(img_path).convert('RGB')  # 按需加载
        
        if self.transform:
            image = self.transform(image)
            
        return image, self.labels[idx]

2. 多模态数据集处理

处理同时包含图像和文本的数据:

python 复制代码
class MultiModalDataset(Dataset):
    def __init__(self, image_paths, texts, labels, image_transform, text_transform):
        self.image_paths = image_paths
        self.texts = texts
        self.labels = labels
        self.image_transform = image_transform
        self.text_transform = text_transform
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        text = self.texts[idx]
        label = self.labels[idx]
        
        if self.image_transform:
            image = self.image_transform(image)
        
        if self.text_transform:
            text = self.text_transform(text)
            
        return {"image": image, "text": text}, label

3. 数据增强技巧

python 复制代码
# 训练和验证时使用不同的预处理
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

五、实战:构建图像分类Dataset

让我们实现一个完整的图像分类数据集:

python 复制代码
import os

import numpy as np
from PIL import Image

def train_test_file(root,dir):
    file_txt=open(dir+'.txt','w')
    path=os.path.join(root,dir)
    for roots,directories,files in os.walk(path):
        if len(directories) !=0:
            dirs=directories
        else:
            now_dir=roots.split('\\')
            for file in files:
                path_1=os.path.join(roots,file)
                print(path_1)
                file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')
    file_txt.close()
root=r'.\食物分类\food_dataset'
train_dir='train'
test_dir='test'
train_test_file(root,train_dir)
train_test_file(root,test_dir)

import torch
from torch import nn   #导入神经网络模块,
from torch.utils.data import DataLoader   #数据包管理工具,打包数据,
from torchvision import transforms
from torch.utils.data import Dataset

data_transforms={
'train':
transforms.Compose([
    transforms.Resize([300, 300]),
    transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机选
    transforms.CenterCrop(256),  # 从中心开始裁剪[256,256]
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 选择一个概率概率
    transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转
    transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomGrayscale(p=0.1),  # 概率转换成灰度率,3通道就是R=G=B
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'valid':
transforms.Compose([
    transforms.Resize([256, 256]),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

}

food_type={0:"八宝粥",1:"巴旦木",2:"白萝卜",3:"板栗",4:"菠萝",5:"草莓",6:"蛋",7:"蛋挞",8:"骨肉相连",
           9:"瓜子",10:"哈密瓜",11:"汉堡",12:"胡萝卜",13:"火龙果",14:"鸡翅",15:"青菜",16:"生肉",17:"圣女果",18:"薯条",19:"炸鸡"}


class food_dataset(Dataset):
    def __init__(self,file_path,transform=None):
        self.file_path=file_path
        self.imgs=[]
        self.labels=[]
        self.transform=transform
        with open(self.file_path) as f:
            samples=[x.strip().split(' ') for x in f.readlines()]
            for img_path,label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image=Image.open(self.imgs[idx])
        if self.transform:
            image=self.transform(image)

        label = self.labels[idx]
        label = torch.from_numpy(np.array(label,dtype=np.int64))
        return image,label

training_data=food_dataset(file_path='train.txt', transform=data_transforms['train'])
test_data=food_dataset(file_path='test.txt', transform=data_transforms['valid'])


train_dataloader=DataLoader(training_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)

'''断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU。'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")  #字符串的格式化

'''定义神经网络 类的继承'''
class CNN(nn.Module):  # 通过调用类的形式来使用神经网络,神经网络的模型nn.moudle
    def __init__(self):
        super().__init__()  # 继承父类的初始化
        self.conv1=nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.ReLU(),      #(16,28,28)
            nn.MaxPool2d(kernel_size=2) #(16,14,14)
        )
        self.conv2=nn.Sequential(
            nn.Conv2d(16,32,5,1,2),  #32,14,14
            nn.ReLU(),

        )
        self.conv3=nn.Sequential(
            nn.Conv2d(32,64,5,1,2),    #128,7,7
            nn.ReLU()
        )
        self.out=nn.Linear(64*128*128,20)


    def forward(self, x):  # 前向传播,指明数据的流向,使神经网络连接起来,函数名称不能修改
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        x=x.view(x.size(0),-1)
        out=self.out(x)
        return out

model = CNN().to(device)
print(model)

def train(dataloader,model,loss_fn,optimizer):
    model.train()   #告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w。在训练过程中,w会被修改的
#pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。
#一般用法是:在训练开始之前写上model.trian(),在测试时写上 model.eval()
    batch_size_num=1
    for X,y in dataloader:       #其中batch为每一个数据的编号
        X,y=X.to(device),y.to(device)    #把训练数据集和标签传入cpu或GPU
        pred=model.forward(X)    #.forward可以被省略,父类中已经对次功能进行了设置。自动初始化
        loss=loss_fn(pred,y)     #通过交叉熵损失函数计算损失值loss
        # Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络
        optimizer.zero_grad()    #梯度值清零
        loss.backward()          #反向传播计算得到每个参数的梯度值w
        optimizer.step()         #根据梯度更新网络w参数

        loss_value=loss.item()   #从tensor数据中提取数据出来,tensor获取损失值
        if batch_size_num %1 ==0:
            print(f'loss:{loss:>7f} [number:{batch_size_num}]')
        batch_size_num+=1

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()

    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model.forward(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)
            b = (pred.argmax(1) == y).type(torch.float)

    test_loss /= num_batches
    correct /= size
    # print(food_type)
    # print(pred.argmax(1).tolist())
    # print(y.tolist())

    result=zip(pred.argmax(1).tolist(),y.tolist())
    for i in result:
        print(f"当前测试的结果为:{food_type[i[0]]},当前真实的结果为:{food_type[i[1]]}")


    print(f"Test result:\n Accurracy:{(100 * correct)}%,AVG loss:{test_loss}")

    test_loss /=num_batches
    correct /=size
    print(f'Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}')

loss_fn=nn.CrossEntropyLoss()   #创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)   #创建一个优化器,SGD为随机梯度下降算法
# #params:要训练的参数,一般我们传入的都是model.parameters()#
# lr:learning_rate学习率,也就是步长

#loss表示模型训练后的输出结果与,样本标签的差距。如果差距越小,就表示模型训练越好,越逼近干真实的模型。

# train(train_dataloader,model,loss_fn,optimizer)
# test(test_dataloader,model,loss_fn)

epoch=10
for i in range(epoch):
    print(i + 1)
    train(train_dataloader, model, loss_fn, optimizer)

test(test_dataloader, model, loss_fn)

总结

数据预处理是深度学习项目成功的关键因素之一。通过合理设计Dataset类,我们可以:

  1. 实现高效的数据加载和预处理

  2. 方便地应用各种数据增强技术

  3. 保持代码的整洁和可维护性

  4. 轻松处理不同类型的数据(图像、文本、音频等)

相关推荐
yt948322 分钟前
基于GMM的语音识别
人工智能·语音识别
carpell8 分钟前
小白也能行【手撕ResNet代码篇(附代码)】:详解可复现
人工智能·深度学习·计算机视觉
yangmf204013 分钟前
私有知识库 Coco AI 实战(三):摄入 Elasticsearch 官方文档
人工智能·elasticsearch·搜索引擎·全文检索·coco ai
亚图跨际30 分钟前
从物理到预测:数据驱动的深度学习的结构化探索及AI推理
人工智能·深度学习
一RTOS一33 分钟前
鸿道操作系统Type 1虚拟化:破局AI机器人与智能汽车的“安全”与“算力”双刃剑
人工智能·机器人·汽车·鸿道intewell操作系统·工业os
搬砖的阿wei42 分钟前
Transformer:引领深度学习新时代的架构
人工智能·深度学习·transformer
lilye661 小时前
精益数据分析(6/126):深入理解精益分析的核心要点
前端·人工智能·数据分析
果冻人工智能1 小时前
直观讲解生成对抗网络背后的数学原理
人工智能
新智元1 小时前
刚刚,OpenAI 最强图像生成 API 上线,一张图 1 毛 5!
人工智能·openai
梓羽玩Python1 小时前
开源TTS领域迎来重磅新星!Dia-1.6B:超逼真对话生成,开源2天斩获6.5K Star!
人工智能·python·github