pytorch学习笔记7

getitem在进行索引取值的时候自动调用,也是一个魔法方法,就像列表索引取值那样,一个意思

c 复制代码
import torchvision
from torch.utils.data import DataLoader

data_transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)
test_data=torchvision.datasets.CIFAR10('./dataset',train=False,transform=data_transform)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)#numwork,采取多少进程
# img,target=test_data[0]
# print(img.shape)
# print(target)
for i in test_loader:
    img,target=i
    print(img.shape)
    print(target)

如图所示的输出的选中部分中:

分别为4张图片,三通道,32*32

tensor([3, 2, 3, 2])

这是每张图片的target

c 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)
test_data=torchvision.datasets.CIFAR10('./dataset',train=False,transform=data_transform)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)#numwork,采取多少进程
# img,target=test_data[0]
# print(img.shape)
# print(target)
writer=SummaryWriter('dataloader')
step=0
for i in test_loader:
    img,target=i
    # print(img.shape)
    # print(target)
    writer.add_images('test_loader_data',img,step)
    step=step+1

writer.close()

debug设置断点可以查看当前断点的信息

add_images

功能:添加多张图像到TensorBoard。

用法:用于将多张图片添加到日志文件中,通常用于展示一批次的图像。

这里用的是dataloader,每批次4张图片因此用add_images而不是add_image

Epoch

定义:一个epoch表示使用整个训练数据集对模型进行一次完整的训练过程。换句话说,当所有的训练数据都被用来更新模型参数一次时,就完成了一个epoch。

用途:在训练神经网络时,单次遍历所有训练数据通常不足以使模型收敛。需要多次遍历数据集(即多个epoch)以逐渐优化模型参数,从而提高模型的性能。

Batch

定义:batch(也称为mini-batch)是指在一次参数更新过程中所使用的训练样本的一个子集。训练数据通常会被分成若干个batch,每个batch包含一定数量的样本。

用途:使用batch可以平衡训练速度和模型参数更新的稳定性。对于大型数据集,一次性使用全部数据进行参数更新可能会非常耗时且内存占用过高,而使用小的batch可以加速计算,同时还能使梯度估计更稳定。

两者的关系

在训练过程中,一个epoch通常会包含多个batch。每个batch会更新模型的参数一次,因此一个epoch会有多次参数更新。具体的关系可以用以下公式描述

c 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)
test_data=torchvision.datasets.CIFAR10('./dataset',train=False,transform=data_transform)
test_loader=DataLoader(dataset=test_data,batch_size=64,shuffle=False,num_workers=0,drop_last=False)#numwork,采取多少进程
img,target=test_data[0]
print(img.shape)
print(target)
writer=SummaryWriter('dataloader')

for epoch in range(2):
    step=0
    for i in test_loader:
        img,target=i
        # print(img.shape)
        # print(target)
        writer.add_images('epoch{}'.format(epoch),img,step)
        step=step+1

writer.close()

外层循环是epoch循环,共进行2个epoch。

内层循环是DataLoader的迭代器,它会遍历整个数据集。每次迭代会返回一个批量的数据,其中data是一个包含img和target的元组。

在每个批量数据上,使用SummaryWriter的add_images方法将图片数据img写入TensorBoard。这里将每个epoch的图片放在名为epoch{}的文件夹中,并使用step作为其次级目录,以便于在TensorBoard中查看不同批次的图片

相关推荐
Diamond技术流5 分钟前
从0开始学习Linux——网络配置
linux·运维·网络·学习·安全·centos
密码小丑6 分钟前
11月4日(内网横向移动(一))
笔记
斑布斑布8 分钟前
【linux学习2】linux基本命令行操作总结
linux·运维·服务器·学习
鸭鸭梨吖1 小时前
产品经理笔记
笔记·产品经理
Chef_Chen1 小时前
从0开始学习机器学习--Day13--神经网络如何处理复杂非线性函数
神经网络·学习·机器学习
齐 飞1 小时前
MongoDB笔记01-概念与安装
前端·数据库·笔记·后端·mongodb
lulu_gh_yu1 小时前
数据结构之排序补充
c语言·开发语言·数据结构·c++·学习·算法·排序算法
丫头,冲鸭!!!2 小时前
B树(B-Tree)和B+树(B+ Tree)
笔记·算法
Re.不晚2 小时前
Java入门15——抽象类
java·开发语言·学习·算法·intellij-idea
听忆.2 小时前
手机屏幕上进行OCR识别方案
笔记