PyTorch重写DataSet类

PyTorch重写DataSet类


文章目录


前言

在之前沐神的Cifar-10分类 课程学习中,沐神是用的将每一类创建一个文件夹去完成图片的导入。此外我们还可以通过重写DataSet类来完成!

一、如何重写?

通过查看官方文档我们可知。

需要去重写__getitem__这个方法,去以一种特定的方法拿到一个数据。并且选择性的重写__len__这个方法,去返回整个数据集的大小。

二、具体代码

1.数据集格式

这个数据集是沐神课程上讲过的cifar-10数据集。

train和test文件夹分别为要进行训练和测试的图片。而训练数据的标签以csv文件存在trainLabels.csv文件中。

2.获取标签

python 复制代码
def read_csv_labels(fname):
    with open(fname,'r') as f:
        lines = f.readlines()[1:]
    tokens = [l.rstrip().split(',') for l in lines]
    return dict(((name,label) for name,label in tokens))

这里通过一个read_csv_labels的方法 将图片名字和标签以一个字典的方式返回

3.重写dataset

python 复制代码
class MyDateset(Dataset):
    def __init__(self,root_dir,state,label_dict=None):
        self.root_dir = root_dir
        self.state = state
        if label_dict is not None:
            self.label_dict = label_dict
        self.img_path = os.listdir(os.path.join(root_dir,state))
        # os.listdir 将当前文件夹下的图片名称按列表返回

    def __getitem__(self, idx):
        img = Image.open(os.path.join(self.root_dir,self.state,self.img_path[idx]))
        if self.state == 'train':
            img_num =self.img_path[idx].split('.')[0]
            # 这个取出来是数字.jpg 所以需要将.jpg舍去
            label = self.label_dict[img_num]
            return img,label
        else:
            return img

    def __len__(self):
        return len(self.img_path)

state参数表示此时是训练数据集还是测试数据集。

4.调用

python 复制代码
root_dir = "D:\\PytorchLearn\\cifar-10"
label_dict = read_csv_labels(os.path.join(root_dir,"trainLabels.csv"))

train_dataset = MyDateset(root_dir,'train',label_dict)

test_dataset = MyDateset(root_dir,'test')

train_iter = torch.utils.data.DataLoader(train_dataset,batch_size=8,shuffle=True)

总结

以上就是重写DataSet的方法,有不足之处还望各位指出。

相关推荐
千澜空15 分钟前
celery在django项目中实现并发任务和定时任务
python·django·celery·定时任务·异步任务
学习前端的小z18 分钟前
【AIGC】如何通过ChatGPT轻松制作个性化GPTs应用
人工智能·chatgpt·aigc
斯凯利.瑞恩23 分钟前
Python决策树、随机森林、朴素贝叶斯、KNN(K-最近邻居)分类分析银行拉新活动挖掘潜在贷款客户附数据代码
python·决策树·随机森林
yannan2019031344 分钟前
【算法】(Python)动态规划
python·算法·动态规划
埃菲尔铁塔_CV算法1 小时前
人工智能图像算法:开启视觉新时代的钥匙
人工智能·算法
EasyCVR1 小时前
EHOME视频平台EasyCVR视频融合平台使用OBS进行RTMP推流,WebRTC播放出现抖动、卡顿如何解决?
人工智能·算法·ffmpeg·音视频·webrtc·监控视频接入
打羽毛球吗️1 小时前
机器学习中的两种主要思路:数据驱动与模型驱动
人工智能·机器学习
蒙娜丽宁1 小时前
《Python OpenCV从菜鸟到高手》——零基础进阶,开启图像处理与计算机视觉的大门!
python·opencv·计算机视觉
光芒再现dev1 小时前
已解决,部署GPTSoVITS报错‘AsyncRequest‘ object has no attribute ‘_json_response_data‘
运维·python·gpt·语言模型·自然语言处理
好喜欢吃红柚子1 小时前
万字长文解读空间、通道注意力机制机制和超详细代码逐行分析(SE,CBAM,SGE,CA,ECA,TA)
人工智能·pytorch·python·计算机视觉·cnn