pytorch学习笔记7

getitem在进行索引取值的时候自动调用,也是一个魔法方法,就像列表索引取值那样,一个意思

c 复制代码
import torchvision
from torch.utils.data import DataLoader

data_transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)
test_data=torchvision.datasets.CIFAR10('./dataset',train=False,transform=data_transform)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)#numwork,采取多少进程
# img,target=test_data[0]
# print(img.shape)
# print(target)
for i in test_loader:
    img,target=i
    print(img.shape)
    print(target)

如图所示的输出的选中部分中:

分别为4张图片,三通道,32*32

tensor([3, 2, 3, 2])

这是每张图片的target

c 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)
test_data=torchvision.datasets.CIFAR10('./dataset',train=False,transform=data_transform)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)#numwork,采取多少进程
# img,target=test_data[0]
# print(img.shape)
# print(target)
writer=SummaryWriter('dataloader')
step=0
for i in test_loader:
    img,target=i
    # print(img.shape)
    # print(target)
    writer.add_images('test_loader_data',img,step)
    step=step+1

writer.close()

debug设置断点可以查看当前断点的信息

add_images

功能:添加多张图像到TensorBoard。

用法:用于将多张图片添加到日志文件中,通常用于展示一批次的图像。

这里用的是dataloader,每批次4张图片因此用add_images而不是add_image

Epoch

定义:一个epoch表示使用整个训练数据集对模型进行一次完整的训练过程。换句话说,当所有的训练数据都被用来更新模型参数一次时,就完成了一个epoch。

用途:在训练神经网络时,单次遍历所有训练数据通常不足以使模型收敛。需要多次遍历数据集(即多个epoch)以逐渐优化模型参数,从而提高模型的性能。

Batch

定义:batch(也称为mini-batch)是指在一次参数更新过程中所使用的训练样本的一个子集。训练数据通常会被分成若干个batch,每个batch包含一定数量的样本。

用途:使用batch可以平衡训练速度和模型参数更新的稳定性。对于大型数据集,一次性使用全部数据进行参数更新可能会非常耗时且内存占用过高,而使用小的batch可以加速计算,同时还能使梯度估计更稳定。

两者的关系

在训练过程中,一个epoch通常会包含多个batch。每个batch会更新模型的参数一次,因此一个epoch会有多次参数更新。具体的关系可以用以下公式描述

c 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)
test_data=torchvision.datasets.CIFAR10('./dataset',train=False,transform=data_transform)
test_loader=DataLoader(dataset=test_data,batch_size=64,shuffle=False,num_workers=0,drop_last=False)#numwork,采取多少进程
img,target=test_data[0]
print(img.shape)
print(target)
writer=SummaryWriter('dataloader')

for epoch in range(2):
    step=0
    for i in test_loader:
        img,target=i
        # print(img.shape)
        # print(target)
        writer.add_images('epoch{}'.format(epoch),img,step)
        step=step+1

writer.close()

外层循环是epoch循环,共进行2个epoch。

内层循环是DataLoader的迭代器,它会遍历整个数据集。每次迭代会返回一个批量的数据,其中data是一个包含img和target的元组。

在每个批量数据上,使用SummaryWriter的add_images方法将图片数据img写入TensorBoard。这里将每个epoch的图片放在名为epoch{}的文件夹中,并使用step作为其次级目录,以便于在TensorBoard中查看不同批次的图片

相关推荐
deng-c-f2 小时前
Linux C/C++ 学习日记(29):IO密集型与CPU密集型、CPU的调度与线程切换
linux·学习·线程·cpu·io密集·cpu密集
四谎真好看4 小时前
Java 黑马程序员学习笔记(进阶篇18)
java·笔记·学习·学习笔记
洋洋的笔记4 小时前
银行测试学习计划
学习
IT_Octopus4 小时前
triton backend 模式docker 部署 pytorch gpu模型 镜像选择
pytorch·docker·triton·模型推理
Allan_20255 小时前
数据库学习
数据库·学习
报错小能手5 小时前
linux学习笔记(43)网络编程——HTTPS (补充)
linux·网络·学习
报错小能手5 小时前
linux学习笔记(45)git详解
linux·笔记·学习
百锦再6 小时前
Vue Scoped样式混淆问题详解与解决方案
java·前端·javascript·数据库·vue.js·学习·.net
Larry_Yanan6 小时前
QML学习笔记(四十四)QML与C++交互:对QML对象设置objectName
开发语言·c++·笔记·qt·学习·ui·交互
摇滚侠7 小时前
Spring Boot 3零基础教程,WEB 开发 默认页签图标 Favicon 笔记29
java·spring boot·笔记