PyTorch数据加载与预处理

torch.utils.data

torch.utils.data是PyTorch数据处理核心模块,包含Dataset、DataLoader等关键组件。

torch.utils.data模块为PyTorch提供了完整的数据处理流水线,主要包含以下核心组件

核心组件

‌Dataset‌ - 数据容器抽象类

  • 必须实现 len() 返回数据集大小
  • 必须实现 getitem() 按索引获取样本
  • 支持自定义数据源和预处理逻辑

DataLoader‌ - 数据加载迭代器‌

  • batch_size: 批次大小
  • shuffle: 是否随机打乱
  • num_workers: 并行加载进程数
  • drop_last: 是否舍弃最后不足批次的数据

Sampler类‌ - 数据采样策略

SequentialSampler: 顺序采样

RandomSampler: 随机采样

WeightedRandomSampler: 加权随机采样

实用工具函数

‌随机分割‌ - random_split()

clike 复制代码
# 将数据集按比例分割
train_set, val_set = random_split(dataset, [800, 200])

‌数据拼接‌ - ConcatDataset()

clike 复制代码
# 合并多个数据集
combined = ConcatDataset([set1, set2])

📊 实战示例

clike 复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.df = pd.read_csv(csv_file)
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        sample = self.df.iloc[idx].SalePrice
        if self.transform:
            sample = self.transform(sample)
        return sample

多进程加速‌

clike 复制代码
# 启用4个worker进程并行加载
dataloader = DataLoader(dataset, batch_size=32, 
                      shuffle=True, num_workers=4)

内存优化‌‌

  1. 使用惰性加载处理大数据集
  2. 启用 pin_memory=True 加速GPU数据传输
  3. 合理设置 prefetch_factor 预取数据

数据预处理

如何实现自定义数据集的预处理?

PyTorch中自定义数据集的预处理主要通过继承Dataset类和组合transforms模块来实现。

核心实现框架

自定义数据集预处理需要实现三个关键部分:‌Dataset类定义、transforms预处理管道、DataLoader数据加载‌。

基础Dataset类实现
clike 复制代码
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np

class CustomDataset(Dataset):
    def __init__(self, csv_file, feature_cols, label_col, transform=None):
        self.data = pd.read_csv(csv_file)
        self.features = self.data[feature_cols]
        self.labels = self.data[label_col]
        self.transform = transform
        self._preprocess_data()
    
    def _preprocess_data(self):
        """数据预处理:处理缺失值和类型转换"""
        # 数值型特征:填充均值
        numeric_features = self.features.select_dtypes(include=[np.number])
        self.features[numeric_features.columns] = numeric_features.fillna(numeric_features.mean())
        
        # 类别型特征:转为one-hot编码
        categorical_features = self.features.select_dtypes(include=['object'])
        if not categorical_features.empty:
            self.features = pd.get_dummies(self.features, columns=categorical_features.columns)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = {
            'features': torch.tensor(self.features.iloc[idx].values, dtype=torch.float32),
            'label': torch.tensor(self.labels.iloc[idx], dtype=torch.long)
        }
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample
自定义预处理转换类
clike 复制代码
import torchvision.transforms as transforms

class Normalization:
    """数据归一化处理"""
    def __call__(self, sample):
        features = sample['features']
        # 最大最小归一化
        features = (features - features.min()) / (features.max() - features.min())
        sample['features'] = features
        return sample

class FeatureScaling:
    """特征标准化"""
    def __call__(self, sample):
        features = sample['features']
        features = (features - features.mean()) / features.std()
        sample['features'] = features
        return sample

class ToTensor:
    """确保数据转为Tensor格式"""
    def __call__(self, sample):
        if not isinstance(sample['features'], torch.Tensor):
            sample['features'] = torch.tensor(sample['features'], dtype=torch.float32)
        return sample
完整的预处理管道
clike 复制代码
# 构建预处理管道
from torchvision.transforms import Compose

# 训练集预处理(包含数据增强)
train_transform = Compose([
    Normalization(),
    FeatureScaling(),
    ToTensor()
])

# 测试集预处理(不包含数据增强)
test_transform = Compose([
    Normalization(),
    ToTensor()
])

# 创建数据集实例
train_dataset = CustomDataset(
    csv_file='train.csv',
    feature_cols=['feature1', 'feature2', 'feature3'],
    label_col='label',
    transform=train_transform
)

# 创建DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)

图像数据预处理示例

对于图像数据,PyTorch提供了更丰富的预处理工具:

clike 复制代码
from torchvision import transforms

# 图像预处理管道
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整尺寸
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    transforms.RandomRotation(10),  # 随机旋转
    transforms.ToTensor(),  # 转为Tensor并归一化到[0,1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
])

常用的数据集

常用数据集分类与介绍

  1. 图像分类

    • CIFAR-10:6万张32×32彩色图像,涵盖飞机、汽车等10类,每类6000张。
    • ImageNet:超1400万标注图像,含2.2万类别,是计算机视觉基准数据集。
    • COCO:30万+图像,含200万+对象实例,支持80类目标检测任务。
  2. 自然语言处理

    • IMDB评论:5万条电影评论,用于情感分析。
    • Wikipedia:开放文本数据,支持NLP研究。
  3. 其他领域

    • COVID-19 X-Ray Dataset:6500张胸部X光片,含COVID-19病例标注。
    • MovieLens:2000万条电影评分数据,用于推荐系统。

MNIST数据集介绍

MNIST是包含6万张训练图像和1万张测试图像的手写数字数据集,每张为28×28像素的灰度图,标签为0-9的整数‌。数据需归一化到[0,1]范围,标签转为one-hot向量格式‌。作为深度学习入门基准,其简单性使其成为验证模型效果的理想选择‌。

相关推荐
凌杰11 小时前
AI 学习笔记:LLM 的部署与测试
人工智能
心易行者11 小时前
在 Claude 4.6 发布的当下,一个不懂编程的人聊聊 Claude Code:当 AI 终于学会自己动手干活
人工智能
子榆.11 小时前
CANN 性能分析与调优实战:使用 msprof 定位瓶颈,榨干硬件每一分算力
大数据·网络·人工智能
爱喝白开水a11 小时前
前端AI自动化测试:brower-use调研让大模型帮你做网页交互与测试
前端·人工智能·大模型·prompt·交互·agent·rag
学易11 小时前
第十五节.别人的工作流,如何使用和调试(上)?(2类必现报错/缺失节点/缺失模型/思路/实操/通用调试步骤)
人工智能·ai作画·stable diffusion·报错·comfyui·缺失节点
空白诗11 小时前
CANN ops-nn 算子解读:大语言模型推理中的 MatMul 矩阵乘实现
人工智能·语言模型·矩阵
空白诗11 小时前
CANN ops-nn 算子解读:AIGC 风格迁移中的 BatchNorm 与 InstanceNorm 实现
人工智能·ai
新芒11 小时前
暖通行业两位数下滑,未来靠什么赢?
大数据·人工智能
B站_计算机毕业设计之家11 小时前
豆瓣电影数据采集分析推荐系统 | Python Vue Flask框架 LSTM Echarts多技术融合开发 毕业设计源码 计算机
vue.js·python·机器学习·flask·echarts·lstm·推荐算法
weixin_4462608511 小时前
掌握 Claude Code Hooks:让 AI 变得更聪明!
人工智能