自定义MyDataSet获取数据及对应label

自定义MyDataSet获取数据及对应label

实例化数据集需要用到 DataSet 类,我们可以自定义来实现对数据集的处理

MyDataSet类代码如下:

python 复制代码
from PIL import Image
import torch
from torch.utils.data import Dataset

class MyDataSet(Dataset):
    """自定义数据集"""

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)
        
	# 获取item对象图像和类别,只对img进行预处理,label不处理
    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        # zip(*batch):处理一个batch内的图片,图片为一组,标签为一组
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)	# 增加batch维度
        labels = torch.as_tensor(labels)	#将labels转化为tensor,images在__getitem__方法的transform已经转化为tensor
        return images, labels

定义好MyDataSet后,就可以在train类中引用了,具体代码如下:

python 复制代码
from my_dataset import MyDataSet

# 省略其他代码......

# 这里定义了train和val两种预处理方法
data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
                                   
# MyDataSet实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

# 省略其他代码......

从而实现了在train类中获取一个batch的数据,且该数据为图像一组和label一组,同时经过预处理的数据

相关推荐
惯导马工21 小时前
【论文导读】ORB-SLAM3:An Accurate Open-Source Library for Visual, Visual-Inertial and
深度学习·算法
隐语SecretFlow2 天前
国人自研开源隐私计算框架SecretFlow,深度拆解框架及使用【开发者必看】
深度学习
Billy_Zuo2 天前
人工智能深度学习——卷积神经网络(CNN)
人工智能·深度学习·cnn
羊羊小栈2 天前
基于「YOLO目标检测 + 多模态AI分析」的遥感影像目标检测分析系统(vue+flask+数据集+模型训练)
人工智能·深度学习·yolo·目标检测·毕业设计·大作业
l12345sy2 天前
Day24_【深度学习—广播机制】
人工智能·pytorch·深度学习·广播机制
九章云极AladdinEdu2 天前
超参数自动化调优指南:Optuna vs. Ray Tune 对比评测
运维·人工智能·深度学习·ai·自动化·gpu算力
研梦非凡2 天前
ICCV 2025|从粗到细:用于高效3D高斯溅射的可学习离散小波变换
人工智能·深度学习·学习·3d
通街市密人有2 天前
IDF: Iterative Dynamic Filtering Networks for Generalizable Image Denoising
人工智能·深度学习·计算机视觉
智数研析社2 天前
9120 部 TMDb 高分电影数据集 | 7 列全维度指标 (评分 / 热度 / 剧情)+API 权威源 | 电影趋势分析 / 推荐系统 / NLP 建模用
大数据·人工智能·python·深度学习·数据分析·数据集·数据清洗
七元权3 天前
论文阅读-Correlate and Excite
论文阅读·深度学习·注意力机制·双目深度估计