PyTorch概述(二)---MNIST

NIST Special Database3

  • 具体指的是一个更大的特殊数据库3;
  • 该数据库的内容为手写数字黑白图片;
  • 该数据库由美国人口普查局的雇员手写

NIST Special Database1

  • 特殊数据库1;
  • 该数据库的内容为手写数字黑白图片;
  • 该数据库的图片由高中学生手写;

MNIST

  • MNIST 数据库:Modified National Institute of Standards and Technology 数据库
  • 是一个大的手写数字的集合;
  • 具有训练集60,000个;
  • 测试集10,000个;
  • 是NIST3和NIST1的子集;
  • 数字图片已经被居中,以固定的尺寸值标准化处理;
  • 原始的黑白两层图像被设置为20x20 像素大小,且保持宽高比;
  • 结果图像在标准化算法中的反走样技术的处理下包含灰度级图像;
  • 通过计算像素的质心,和平移操作,手写的数字被居中放置到尺寸为28X28的图片中;

MNIST 用法

python 复制代码
transform=transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize([0,],[1,])])
trainset=torchvision.datasets.MNIST(root='./data',
                                        train=True,
                                        download=True,
                                        transform=transform)
trainloader=torch.utils.data.DataLoader(trainset,
                                        batch_size=32,
                                        shuffle=True,
                                        num_workers=2)
testset=torchvision.datasets.MNIST(root='./data',
                                       train=False,
                                       download=True,
                                        transform=transform)
testloader=torch.utils.data.DataLoader(testset,
                                        batch_size=32,
                                        shuffle=True,
                                        num_workers=2)

MNIST 源码(python)

python 复制代码
import codecs
import os
import os.path
import shutil
import string
import sys
import warnings
from typing import Any,Callable,Dict,List,Optional,Tuple
from urllib.error import URLError

import numpy as np
import torch
from PIL import Image

from .utils import _flip_byte_order,check_integrity,download_and_extract_archive,extract_archive,verify_str_arg
from .vision import VisionDataset

class MNIST(VisionDataset):
    '''
    'MNIST <http://yann.lecun.com/exdb/mnist/>' _Dataset.
    '''
    mirrors=["http://yann.lecun.com/exdb/mnist/","https://ossci-datasets.s3.amazonaws.com/mnist/"]
    resource=[("train-images-idx3-ubyte.gz","f68b3c2dcbeaaa9fbdd348bbdeb94873"),
              ("train-labels-idx1-ubyte.gz","d53e105ee54ea40749a09fcbcd1e9432"),
              ("t10k-images-idx3-ubyte.gz","9fb629c4189551a2d022fa330f9573f3"),
              ("t10k-labels-idx1-ubyte.gz","ec29112dd5afa0611ce80d1b7f02629c")]
    training_file="training.pt"
    test_file="test.pt"
    classes=["0-zero",
             "1-one",
             "2-two",
             "3-three",
             "4-four",
             "5-five",
             "6-six",
             "7-seven",
             "8-eight",
             "9-nine"]
    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets
    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets
    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data
    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data
    def __init__(self,root:str,
                 train:bool=True,
                 transform:Optional[Callable]=None,
                 target_transform:Optional[Callable]=None,
                 download:bool=False)->None:
        '''
        Args
        :param root: string,root directory of dataset where 'MNIST/raw/train-images-idx3-ubyte' and 'MNIST/raw/t10k-images-idx3-ubyte' exist.
        :param train:(bool,optional),if true,creates dataset from 'train-images-idx3-utyte',otherwise from 't10k-images-idx3-utyte'.
        :param transform:(callable,optional),a function/transform that takes in an PIL image and returns a transformed version.E.g,'transform.RandomCrop'
        :param target_transform:(callable,optional),a function/transform that takes in the target and transform it.
        :param download:(bool,optional),if True,downloads the dataset from the internet and puts it in root directory.If dataset is already downloaded,it is not download again.
        '''
        super().__init__(root,transform,target_transform)
        self.train=train

        if self._check_legacy_exist():
            self.data,self.targets=self._load_legacy_data()
            return
        if download:
            self.download()
        if not self._check_exists():
            raise RuntimeError("Dataset not found.You can use download=True to download it")
        self.data,self.targets=self._load_data()

    def _check_legacy_exist(self):
        processed_folder_exists=os.path.exists(self.processed_folder)
        if not processed_folder_exists:
            return False
        return all(check_integrity(os.path.join(self.processed_folder,file)) for file in (self.training_file,self.test_file))
    def _load_legacy_data(self):
        #This is for BC only,We no longer cache the data in a custom binary,but simply read from the raw data directly.
        data_file=self.training_file if self.train else self.test_file
        return torch.load(os.path.join(self.processed_folder,data_file))
    def _load_data(self):
        image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
        data = read_image_file(os.path.join(self.raw_folder, image_file))

        label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
        targets = read_label_file(os.path.join(self.raw_folder, label_file))

        return data, targets

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode="L")

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return len(self.data)

    @property
    def raw_folder(self) -> str:
        return os.path.join(self.root, self.__class__.__name__, "raw")

    @property
    def processed_folder(self) -> str:
        return os.path.join(self.root, self.__class__.__name__, "processed")

    @property
    def class_to_idx(self) -> Dict[str, int]:
        return {_class: i for i, _class in enumerate(self.classes)}

    def _check_exists(self) -> bool:
        return all(
            check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
            for url, _ in self.resources
        )

    def download(self) -> None:
        """Download the MNIST data if it doesn't exist already."""

        if self._check_exists():
            return

        os.makedirs(self.raw_folder, exist_ok=True)

        # download files
        for filename, md5 in self.resources:
            for mirror in self.mirrors:
                url = f"{mirror}{filename}"
                try:
                    print(f"Downloading {url}")
                    download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
                except URLError as error:
                    print(f"Failed to download (trying next):\n{error}")
                    continue
                finally:
                    print()
                break
            else:
                raise RuntimeError(f"Error downloading {filename}")

    def extra_repr(self) -> str:
        split = "Train" if self.train is True else "Test"
        return f"Split: {split}"
相关推荐
nancy_princess4 小时前
clip实验
人工智能·深度学习
飞哥数智坊4 小时前
TRAE Friends@济南第4次活动:100+极客集结,2小时极限编程燃爆全场!
人工智能
AI自动化工坊4 小时前
ProofShot实战:给AI编码助手添加可视化验证,提升前端开发效率3倍
人工智能·ai·开源·github
飞哥数智坊4 小时前
一场直播涨粉 2 万的背后!OpenClaw + 飞书,正在重塑软件交付的方式
人工智能
飞哥数智坊5 小时前
养虾记第3期:安装、调教、落地,这场沙龙我们全聊了
人工智能
再不会python就不礼貌了5 小时前
从工具到个人助理——AI Agent的原理、演进与安全风险
人工智能·安全·ai·大模型·transformer·ai编程
AI医影跨模态组学5 小时前
Radiother Oncol 空军军医大学西京医院等团队:基于纵向CT的亚区域放射组学列线图预测食管鳞状细胞癌根治性放化疗后局部无复发生存期
人工智能·深度学习·医学影像·影像组学
A尘埃5 小时前
神经网络的激活函数+损失函数
人工智能·深度学习·神经网络·激活函数
没有不重的名么5 小时前
Pytorch深度学习快速入门教程
人工智能·pytorch·深度学习
有为少年5 小时前
告别“唯语料论”:用合成抽象数据为大模型开智
人工智能·深度学习·神经网络·算法·机器学习·大模型·预训练