PyTorch数据加载与预处理

torch.utils.data

torch.utils.data是PyTorch数据处理核心模块,包含Dataset、DataLoader等关键组件。

torch.utils.data模块为PyTorch提供了完整的数据处理流水线,主要包含以下核心组件

核心组件

‌Dataset‌ - 数据容器抽象类

  • 必须实现 len() 返回数据集大小
  • 必须实现 getitem() 按索引获取样本
  • 支持自定义数据源和预处理逻辑

DataLoader‌ - 数据加载迭代器‌

  • batch_size: 批次大小
  • shuffle: 是否随机打乱
  • num_workers: 并行加载进程数
  • drop_last: 是否舍弃最后不足批次的数据

Sampler类‌ - 数据采样策略

SequentialSampler: 顺序采样

RandomSampler: 随机采样

WeightedRandomSampler: 加权随机采样

实用工具函数

‌随机分割‌ - random_split()

clike 复制代码
# 将数据集按比例分割
train_set, val_set = random_split(dataset, [800, 200])

‌数据拼接‌ - ConcatDataset()

clike 复制代码
# 合并多个数据集
combined = ConcatDataset([set1, set2])

📊 实战示例

clike 复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.df = pd.read_csv(csv_file)
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        sample = self.df.iloc[idx].SalePrice
        if self.transform:
            sample = self.transform(sample)
        return sample

多进程加速‌

clike 复制代码
# 启用4个worker进程并行加载
dataloader = DataLoader(dataset, batch_size=32, 
                      shuffle=True, num_workers=4)

内存优化‌‌

  1. 使用惰性加载处理大数据集
  2. 启用 pin_memory=True 加速GPU数据传输
  3. 合理设置 prefetch_factor 预取数据

数据预处理

如何实现自定义数据集的预处理?

PyTorch中自定义数据集的预处理主要通过继承Dataset类和组合transforms模块来实现。

核心实现框架

自定义数据集预处理需要实现三个关键部分:‌Dataset类定义、transforms预处理管道、DataLoader数据加载‌。

基础Dataset类实现
clike 复制代码
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np

class CustomDataset(Dataset):
    def __init__(self, csv_file, feature_cols, label_col, transform=None):
        self.data = pd.read_csv(csv_file)
        self.features = self.data[feature_cols]
        self.labels = self.data[label_col]
        self.transform = transform
        self._preprocess_data()
    
    def _preprocess_data(self):
        """数据预处理:处理缺失值和类型转换"""
        # 数值型特征:填充均值
        numeric_features = self.features.select_dtypes(include=[np.number])
        self.features[numeric_features.columns] = numeric_features.fillna(numeric_features.mean())
        
        # 类别型特征:转为one-hot编码
        categorical_features = self.features.select_dtypes(include=['object'])
        if not categorical_features.empty:
            self.features = pd.get_dummies(self.features, columns=categorical_features.columns)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = {
            'features': torch.tensor(self.features.iloc[idx].values, dtype=torch.float32),
            'label': torch.tensor(self.labels.iloc[idx], dtype=torch.long)
        }
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample
自定义预处理转换类
clike 复制代码
import torchvision.transforms as transforms

class Normalization:
    """数据归一化处理"""
    def __call__(self, sample):
        features = sample['features']
        # 最大最小归一化
        features = (features - features.min()) / (features.max() - features.min())
        sample['features'] = features
        return sample

class FeatureScaling:
    """特征标准化"""
    def __call__(self, sample):
        features = sample['features']
        features = (features - features.mean()) / features.std()
        sample['features'] = features
        return sample

class ToTensor:
    """确保数据转为Tensor格式"""
    def __call__(self, sample):
        if not isinstance(sample['features'], torch.Tensor):
            sample['features'] = torch.tensor(sample['features'], dtype=torch.float32)
        return sample
完整的预处理管道
clike 复制代码
# 构建预处理管道
from torchvision.transforms import Compose

# 训练集预处理(包含数据增强)
train_transform = Compose([
    Normalization(),
    FeatureScaling(),
    ToTensor()
])

# 测试集预处理(不包含数据增强)
test_transform = Compose([
    Normalization(),
    ToTensor()
])

# 创建数据集实例
train_dataset = CustomDataset(
    csv_file='train.csv',
    feature_cols=['feature1', 'feature2', 'feature3'],
    label_col='label',
    transform=train_transform
)

# 创建DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)

图像数据预处理示例

对于图像数据,PyTorch提供了更丰富的预处理工具:

clike 复制代码
from torchvision import transforms

# 图像预处理管道
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整尺寸
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    transforms.RandomRotation(10),  # 随机旋转
    transforms.ToTensor(),  # 转为Tensor并归一化到[0,1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
])

常用的数据集

常用数据集分类与介绍

  1. 图像分类

    • CIFAR-10:6万张32×32彩色图像,涵盖飞机、汽车等10类,每类6000张。
    • ImageNet:超1400万标注图像,含2.2万类别,是计算机视觉基准数据集。
    • COCO:30万+图像,含200万+对象实例,支持80类目标检测任务。
  2. 自然语言处理

    • IMDB评论:5万条电影评论,用于情感分析。
    • Wikipedia:开放文本数据,支持NLP研究。
  3. 其他领域

    • COVID-19 X-Ray Dataset:6500张胸部X光片,含COVID-19病例标注。
    • MovieLens:2000万条电影评分数据,用于推荐系统。

MNIST数据集介绍

MNIST是包含6万张训练图像和1万张测试图像的手写数字数据集,每张为28×28像素的灰度图,标签为0-9的整数‌。数据需归一化到[0,1]范围,标签转为one-hot向量格式‌。作为深度学习入门基准,其简单性使其成为验证模型效果的理想选择‌。

相关推荐
智驱力人工智能3 分钟前
仓库园区无人机烟雾识别:构建立体化、智能化的早期火灾预警体系 无人机烟雾检测 无人机动态烟雾分析AI系统 无人机辅助火灾救援系统
人工智能·opencv·算法·目标检测·架构·无人机·边缘计算
未来之窗软件服务3 分钟前
幽冥大陆(六十) SmolVLM 本地部署 轻量 AI 方案—东方仙盟筑基期
人工智能·本地部署·轻量模型·东方仙盟·东方仙盟自动化
世界唯一最大变量4 分钟前
自创的机械臂新算法,因为是AI写的,暂时,并不智能,但目前支持任何段数
python·排序算法
今天也要学习吖5 分钟前
【开源客服系统推荐】AI-CS:一个开源的智能客服系统
人工智能·开源·客服系统·ai大模型·ai客服·开源客服系统
Christo38 分钟前
2022-《Deep Clustering: A Comprehensive Survey》
人工智能·算法·机器学习·数据挖掘
jqpwxt11 分钟前
启点创新智慧景区服务平台,智慧景区数字驾驶舱建设
大数据·人工智能
C+++Python12 分钟前
如何选择合适的锁机制来提高 Java 程序的性能?
java·前端·python
weisian15112 分钟前
入门篇--人工智能发展史-2-什么是深度学习,深度学习的前世今生?
人工智能·深度学习
阿里云大数据AI技术13 分钟前
Hologres Dynamic Table:高效增量刷新,构建实时统一数仓的核心利器
大数据·人工智能·阿里云·实时数仓·hologres