pytorch学习笔记7

getitem在进行索引取值的时候自动调用,也是一个魔法方法,就像列表索引取值那样,一个意思

c 复制代码
import torchvision
from torch.utils.data import DataLoader

data_transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)
test_data=torchvision.datasets.CIFAR10('./dataset',train=False,transform=data_transform)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)#numwork,采取多少进程
# img,target=test_data[0]
# print(img.shape)
# print(target)
for i in test_loader:
    img,target=i
    print(img.shape)
    print(target)

如图所示的输出的选中部分中:

分别为4张图片,三通道,32*32

tensor([3, 2, 3, 2])

这是每张图片的target

c 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)
test_data=torchvision.datasets.CIFAR10('./dataset',train=False,transform=data_transform)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)#numwork,采取多少进程
# img,target=test_data[0]
# print(img.shape)
# print(target)
writer=SummaryWriter('dataloader')
step=0
for i in test_loader:
    img,target=i
    # print(img.shape)
    # print(target)
    writer.add_images('test_loader_data',img,step)
    step=step+1

writer.close()

debug设置断点可以查看当前断点的信息

add_images

功能:添加多张图像到TensorBoard。

用法:用于将多张图片添加到日志文件中,通常用于展示一批次的图像。

这里用的是dataloader,每批次4张图片因此用add_images而不是add_image

Epoch

定义:一个epoch表示使用整个训练数据集对模型进行一次完整的训练过程。换句话说,当所有的训练数据都被用来更新模型参数一次时,就完成了一个epoch。

用途:在训练神经网络时,单次遍历所有训练数据通常不足以使模型收敛。需要多次遍历数据集(即多个epoch)以逐渐优化模型参数,从而提高模型的性能。

Batch

定义:batch(也称为mini-batch)是指在一次参数更新过程中所使用的训练样本的一个子集。训练数据通常会被分成若干个batch,每个batch包含一定数量的样本。

用途:使用batch可以平衡训练速度和模型参数更新的稳定性。对于大型数据集,一次性使用全部数据进行参数更新可能会非常耗时且内存占用过高,而使用小的batch可以加速计算,同时还能使梯度估计更稳定。

两者的关系

在训练过程中,一个epoch通常会包含多个batch。每个batch会更新模型的参数一次,因此一个epoch会有多次参数更新。具体的关系可以用以下公式描述

c 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

data_transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)
test_data=torchvision.datasets.CIFAR10('./dataset',train=False,transform=data_transform)
test_loader=DataLoader(dataset=test_data,batch_size=64,shuffle=False,num_workers=0,drop_last=False)#numwork,采取多少进程
img,target=test_data[0]
print(img.shape)
print(target)
writer=SummaryWriter('dataloader')

for epoch in range(2):
    step=0
    for i in test_loader:
        img,target=i
        # print(img.shape)
        # print(target)
        writer.add_images('epoch{}'.format(epoch),img,step)
        step=step+1

writer.close()

外层循环是epoch循环,共进行2个epoch。

内层循环是DataLoader的迭代器,它会遍历整个数据集。每次迭代会返回一个批量的数据,其中data是一个包含img和target的元组。

在每个批量数据上,使用SummaryWriter的add_images方法将图片数据img写入TensorBoard。这里将每个epoch的图片放在名为epoch{}的文件夹中,并使用step作为其次级目录,以便于在TensorBoard中查看不同批次的图片

相关推荐
ysy164806723922 分钟前
03算法学习_977、有序数组的平方
学习·算法
FAREWELL0007530 分钟前
Unity学习总结篇(1)关于各种坐标系
学习·unity·c#·游戏引擎
龙湾开发30 分钟前
计算机图形学编程(使用OpenGL和C++)(第2版)学习笔记 12.曲面细分
c++·笔记·学习·3d·图形渲染
霸王蟹1 小时前
React中巧妙使用异步组件Suspense优化页面性能。
前端·笔记·学习·react.js·前端框架
ljt27249606611 小时前
Compose笔记(二十四)--Canvas
笔记·android jetpack
jz_ddk1 小时前
[学习] RTKLib详解:rtcm2.c、rtcm3.c、rtcm3e与rtcmn.c
c语言·学习·算法
霸王蟹1 小时前
React 19 中的useRef得到了进一步加强。
前端·javascript·笔记·学习·react.js·ts
霸王蟹1 小时前
React 19版本refs也支持清理函数了。
前端·javascript·笔记·react.js·前端框架·ts
Generalzy2 小时前
学习!FastAPI
学习·sqlite·fastapi
ha20428941942 小时前
c++学习之--- list
c语言·c++·学习·list