Torchvision
torchvision
是pytorch
的一个图形库,它服务于PyTorch
深度学习框架的,主要用来构建计算机视觉模型。
- torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
- torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
- torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
- torchvision.utils: 其他的一些有用的方法。
数据集准备 - CIFAR10
此次代码中要用到的数据集,见附件有 介绍与中文的参数。
下载数据集-CIFAR10
通过py代码
python
# 使用CIFAR10数据集
# 训练集
# 如果下载比较慢,可以将控制台打印的下载链接放到专门的下载工具中下载
# 首先下载的是一个压缩包,会自动解压
train_set = torchvision.datasets.CIFAR10(root="./torchvision_dataset", train=True, download=True)
# 测试集
test_set = torchvision.datasets.CIFAR10(root="./torchvision_dataset", train=False, download=True)
运行代码,控制台显示如下信息
50000 -- 说明有5w张训练数据
10000 --说明有1w张测试数据
会自动下载数据集到
torchvision_dataset
文件夹已下载就不会继续下载,控制台会出输
Files already downloaded and verified
字样
代码
目标: torchvision和transform的联合使用
python
import torchvision.datasets
from torch.utils.tensorboard import SummaryWriter
# 将图片数据都转为tensor类型
# 可以对数据集做任何transforms范围内的操作,该例子只针对数据做toTensor
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor
])
# 使用CIFAR10数据集
train_set = torchvision.datasets.CIFAR10(root="./torchvision_dataset", train=True, transform=dataset_transform, download=True)
# 测试集
test_set = torchvision.datasets.CIFAR10(root="./torchvision_dataset", train=False, transform=dataset_transform, download=True)
# 用tensorboard显示前10张图片
# 运行tensorboard --logdir=p10
writer = SummaryWriter('p10')
for i in range(0):
img, target = test_set[i]
writer.add_image("test_set", img,i)
writer.close()
附:Torchvision介绍
torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
torchvision.utils: 其他的一些有用的方法。
pytorch官网
Pytorch的官网:pytorch.org/
上图是官方的首页,图中同时标出了Docs菜单下常用的子菜单项目,其中torchvision与视觉相关,transform就出自于出此。
torchvison.dataset
进入 torchvision-> datasets 菜单可以找到相当的数据集链接,如下图
其中:
- COCO: 用于目标检测,语义分割
MNIST
MNIST 常用的入门级数据集,手写文字数据集
包名 torchvision.datasets.MNIST()
Fashion MNIST
该数据集与 MNIST 类似,但该数据集不是手写数字,而是 T 恤、裤子、包等服装项目。
包名 torchvision.datasets.FashionMNIST()
CIFAR10
CIFAR10由10个不同标签的图像组成。其中包括卡车、青蛙、船、汽车、鹿等常见图像。还有一个CIFAR100版本,有 100 个不同的类别
CIFAR10/CIFAR100一般用于物价识别,其广泛用于机器学习领域的计算机视觉算法基准测试。详情 官网地址 包名 torchvision.datasets.FashionMNIST()
包名 torchvision.datasets.CIFAR10()
参数说明:
- root: 数据集根路径,可以是相对路径
- train: = ture 训练集,否则为测试集
- transform: 对数据集进行的transform操作
- target_transform: 训练后的目标数据集执行指定的transform操作
- download:=true 自动下载数据集,false不会下载
COOC
目前有超过 100,000 个日常物品,如人、瓶子、文具、书籍等。这个图像数据集广泛用于对象检测和图像描述. (来自知乎...)
torchvision.models
提供神经网络常见的神经网络,有一些神经网络已经预训练好了。
torchvision.transform
图像处理与变形等
torchvision.utils
提供一些常用的工具,比如tensorboard
等
javascript
from torch.utils.tensorboard import SummaryWriter
创建于:2012/11/10
[参考]