PyTorch数据加载与预处理

torch.utils.data

torch.utils.data是PyTorch数据处理核心模块,包含Dataset、DataLoader等关键组件。

torch.utils.data模块为PyTorch提供了完整的数据处理流水线,主要包含以下核心组件

核心组件

‌Dataset‌ - 数据容器抽象类

  • 必须实现 len() 返回数据集大小
  • 必须实现 getitem() 按索引获取样本
  • 支持自定义数据源和预处理逻辑

DataLoader‌ - 数据加载迭代器‌

  • batch_size: 批次大小
  • shuffle: 是否随机打乱
  • num_workers: 并行加载进程数
  • drop_last: 是否舍弃最后不足批次的数据

Sampler类‌ - 数据采样策略

SequentialSampler: 顺序采样

RandomSampler: 随机采样

WeightedRandomSampler: 加权随机采样

实用工具函数

‌随机分割‌ - random_split()

clike 复制代码
# 将数据集按比例分割
train_set, val_set = random_split(dataset, [800, 200])

‌数据拼接‌ - ConcatDataset()

clike 复制代码
# 合并多个数据集
combined = ConcatDataset([set1, set2])

📊 实战示例

clike 复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.df = pd.read_csv(csv_file)
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        sample = self.df.iloc[idx].SalePrice
        if self.transform:
            sample = self.transform(sample)
        return sample

多进程加速‌

clike 复制代码
# 启用4个worker进程并行加载
dataloader = DataLoader(dataset, batch_size=32, 
                      shuffle=True, num_workers=4)

内存优化‌‌

  1. 使用惰性加载处理大数据集
  2. 启用 pin_memory=True 加速GPU数据传输
  3. 合理设置 prefetch_factor 预取数据

数据预处理

如何实现自定义数据集的预处理?

PyTorch中自定义数据集的预处理主要通过继承Dataset类和组合transforms模块来实现。

核心实现框架

自定义数据集预处理需要实现三个关键部分:‌Dataset类定义、transforms预处理管道、DataLoader数据加载‌。

基础Dataset类实现
clike 复制代码
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np

class CustomDataset(Dataset):
    def __init__(self, csv_file, feature_cols, label_col, transform=None):
        self.data = pd.read_csv(csv_file)
        self.features = self.data[feature_cols]
        self.labels = self.data[label_col]
        self.transform = transform
        self._preprocess_data()
    
    def _preprocess_data(self):
        """数据预处理:处理缺失值和类型转换"""
        # 数值型特征:填充均值
        numeric_features = self.features.select_dtypes(include=[np.number])
        self.features[numeric_features.columns] = numeric_features.fillna(numeric_features.mean())
        
        # 类别型特征:转为one-hot编码
        categorical_features = self.features.select_dtypes(include=['object'])
        if not categorical_features.empty:
            self.features = pd.get_dummies(self.features, columns=categorical_features.columns)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = {
            'features': torch.tensor(self.features.iloc[idx].values, dtype=torch.float32),
            'label': torch.tensor(self.labels.iloc[idx], dtype=torch.long)
        }
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample
自定义预处理转换类
clike 复制代码
import torchvision.transforms as transforms

class Normalization:
    """数据归一化处理"""
    def __call__(self, sample):
        features = sample['features']
        # 最大最小归一化
        features = (features - features.min()) / (features.max() - features.min())
        sample['features'] = features
        return sample

class FeatureScaling:
    """特征标准化"""
    def __call__(self, sample):
        features = sample['features']
        features = (features - features.mean()) / features.std()
        sample['features'] = features
        return sample

class ToTensor:
    """确保数据转为Tensor格式"""
    def __call__(self, sample):
        if not isinstance(sample['features'], torch.Tensor):
            sample['features'] = torch.tensor(sample['features'], dtype=torch.float32)
        return sample
完整的预处理管道
clike 复制代码
# 构建预处理管道
from torchvision.transforms import Compose

# 训练集预处理(包含数据增强)
train_transform = Compose([
    Normalization(),
    FeatureScaling(),
    ToTensor()
])

# 测试集预处理(不包含数据增强)
test_transform = Compose([
    Normalization(),
    ToTensor()
])

# 创建数据集实例
train_dataset = CustomDataset(
    csv_file='train.csv',
    feature_cols=['feature1', 'feature2', 'feature3'],
    label_col='label',
    transform=train_transform
)

# 创建DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)

图像数据预处理示例

对于图像数据,PyTorch提供了更丰富的预处理工具:

clike 复制代码
from torchvision import transforms

# 图像预处理管道
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整尺寸
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    transforms.RandomRotation(10),  # 随机旋转
    transforms.ToTensor(),  # 转为Tensor并归一化到[0,1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
])

常用的数据集

常用数据集分类与介绍

  1. 图像分类

    • CIFAR-10:6万张32×32彩色图像,涵盖飞机、汽车等10类,每类6000张。
    • ImageNet:超1400万标注图像,含2.2万类别,是计算机视觉基准数据集。
    • COCO:30万+图像,含200万+对象实例,支持80类目标检测任务。
  2. 自然语言处理

    • IMDB评论:5万条电影评论,用于情感分析。
    • Wikipedia:开放文本数据,支持NLP研究。
  3. 其他领域

    • COVID-19 X-Ray Dataset:6500张胸部X光片,含COVID-19病例标注。
    • MovieLens:2000万条电影评分数据,用于推荐系统。

MNIST数据集介绍

MNIST是包含6万张训练图像和1万张测试图像的手写数字数据集,每张为28×28像素的灰度图,标签为0-9的整数‌。数据需归一化到[0,1]范围,标签转为one-hot向量格式‌。作为深度学习入门基准,其简单性使其成为验证模型效果的理想选择‌。

相关推荐
极客BIM工作室1 小时前
从GAN到Sora:生成式AI在图像与视频领域的技术演进全景
人工智能·生成对抗网络·计算机视觉
xxxxxmy1 小时前
相向双指针—接雨水
python·相向双指针
skywalk81631 小时前
用Trae的sole模式来模拟文心快码comate的Spec Mode模式来做一个esp32操作系统的项目
人工智能·comate·trae·esp32c3
*星星之火*1 小时前
【大白话 AI 答疑】第5篇 从 “窄域专精” 到 “广谱通用”:传统机器学习与大模型的 6 大核心区别
人工智能·机器学习
roman_日积跬步-终至千里1 小时前
【模式识别与机器学习(7)】主要算法与技术(下篇:高级模型与集成方法)之 扩展线性模型(Extending Linear Models)
人工智能·算法·机器学习
张飞签名上架1 小时前
苹果TF签名:革新应用分发的解决方案
人工智能·安全·ios·苹果签名·企业签名·苹果超级签名
Sindy_he1 小时前
2025最新版微软GraphRAG 2.0.0本地部署教程:基于Ollama快速构建知识图谱
python·microsoft·大模型·知识图谱·rag
xcLeigh1 小时前
AI 绘制图表专栏:用豆包轻松实现 HTML 柱状图、折线图与饼图
前端·人工智能·html·折线图·柱状图·图表·豆包
玖日大大1 小时前
LongCat-Flash-Omni:5600 亿参数开源全模态模型的技术革命与产业实践
人工智能·microsoft·语言模型