pytorch学习笔记7

getitem在进行索引取值的时候自动调用,也是一个魔法方法,就像列表索引取值那样,一个意思

c 复制代码
import torchvision
from torch.utils.data import DataLoader

data_transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)
test_data=torchvision.datasets.CIFAR10('./dataset',train=False,transform=data_transform)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)#numwork,采取多少进程
# img,target=test_data[0]
# print(img.shape)
# print(target)
for i in test_loader:
    img,target=i
    print(img.shape)
    print(target)

如图所示的输出的选中部分中:

分别为4张图片,三通道,32*32

tensor([3, 2, 3, 2])

这是每张图片的target

c 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)
test_data=torchvision.datasets.CIFAR10('./dataset',train=False,transform=data_transform)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)#numwork,采取多少进程
# img,target=test_data[0]
# print(img.shape)
# print(target)
writer=SummaryWriter('dataloader')
step=0
for i in test_loader:
    img,target=i
    # print(img.shape)
    # print(target)
    writer.add_images('test_loader_data',img,step)
    step=step+1

writer.close()

debug设置断点可以查看当前断点的信息

add_images

功能:添加多张图像到TensorBoard。

用法:用于将多张图片添加到日志文件中,通常用于展示一批次的图像。

这里用的是dataloader,每批次4张图片因此用add_images而不是add_image

Epoch

定义:一个epoch表示使用整个训练数据集对模型进行一次完整的训练过程。换句话说,当所有的训练数据都被用来更新模型参数一次时,就完成了一个epoch。

用途:在训练神经网络时,单次遍历所有训练数据通常不足以使模型收敛。需要多次遍历数据集(即多个epoch)以逐渐优化模型参数,从而提高模型的性能。

Batch

定义:batch(也称为mini-batch)是指在一次参数更新过程中所使用的训练样本的一个子集。训练数据通常会被分成若干个batch,每个batch包含一定数量的样本。

用途:使用batch可以平衡训练速度和模型参数更新的稳定性。对于大型数据集,一次性使用全部数据进行参数更新可能会非常耗时且内存占用过高,而使用小的batch可以加速计算,同时还能使梯度估计更稳定。

两者的关系

在训练过程中,一个epoch通常会包含多个batch。每个batch会更新模型的参数一次,因此一个epoch会有多次参数更新。具体的关系可以用以下公式描述

c 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)
test_data=torchvision.datasets.CIFAR10('./dataset',train=False,transform=data_transform)
test_loader=DataLoader(dataset=test_data,batch_size=64,shuffle=False,num_workers=0,drop_last=False)#numwork,采取多少进程
img,target=test_data[0]
print(img.shape)
print(target)
writer=SummaryWriter('dataloader')

for epoch in range(2):
    step=0
    for i in test_loader:
        img,target=i
        # print(img.shape)
        # print(target)
        writer.add_images('epoch{}'.format(epoch),img,step)
        step=step+1

writer.close()

外层循环是epoch循环,共进行2个epoch。

内层循环是DataLoader的迭代器,它会遍历整个数据集。每次迭代会返回一个批量的数据,其中data是一个包含img和target的元组。

在每个批量数据上,使用SummaryWriter的add_images方法将图片数据img写入TensorBoard。这里将每个epoch的图片放在名为epoch{}的文件夹中,并使用step作为其次级目录,以便于在TensorBoard中查看不同批次的图片

相关推荐
zmd-zk40 分钟前
flink学习(2)——wordcount案例
大数据·开发语言·学习·flink
不高明的骗子43 分钟前
【深度学习之一】2024最新pytorch+cuda+cudnn下载安装搭建开发环境
人工智能·pytorch·深度学习·cuda
Chef_Chen1 小时前
从0开始学习机器学习--Day33--机器学习阶段总结
人工智能·学习·机器学习
Sxiaocai1 小时前
使用 PyTorch 实现并训练 VGGNet 用于 MNIST 分类
pytorch·深度学习·分类
hopetomorrow2 小时前
学习路之压力测试--jmeter安装教程
学习·jmeter·压力测试
hopetomorrow2 小时前
学习路之PHP--使用GROUP BY 发生错误 SELECT list is not in GROUP BY clause .......... 解决
开发语言·学习·php
/**书香门第*/2 小时前
Cocos creator 3.8 支持的动画 7
学习·游戏·游戏引擎·游戏程序·cocos2d
美式小田2 小时前
单片机学习笔记 9. 8×8LED点阵屏
笔记·单片机·嵌入式硬件·学习
猫爪笔记3 小时前
前端:HTML (学习笔记)【2】
前端·笔记·学习·html
_不会dp不改名_3 小时前
HCIA笔记3--TCP-UDP-交换机工作原理
笔记·tcp/ip·udp