基于PyTorch的CIFAR10加载与TensorBoard可视化实践

视频学习来源:https://www.bilibili.com/video/BV1hE411t7RN?t=1.1&p=15

python 复制代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from test_03 import writer

# 添加 添加 download=True 参数来下载数据集
test_data = torchvision.datasets.CIFAR10(
    root=".dataset",
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True  # 新增此行,用于下载数据集
)

test_loader = DataLoader(
    dataset=test_data,
    batch_size=64,
    shuffle=True,
    num_workers=0,
    drop_last=False
)

img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("dataloader")
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        # print(imgs.shape)
        # print(targets)
        writer.add_image("Epoch:{}".format(epoch), imgs, step)
        step =step +1

writer.close()

这段代码结合了 PyTorch 的数据加载、图像处理和 TensorBoard 可视化功能,是深度学习中数据预处理和可视化的典型流程。


一、整体功能概览

这段代码的核心作用是:

  1. 加载 CIFAR10 测试数据集
  2. 用 DataLoader 按批次组织数据
  3. 通过 TensorBoard 可视化不同批次的图像数据
  4. 对比不同训练轮次(epoch)的数据分布

二、逐行代码详解

1. 导入库

python 复制代码
import torchvision  # 计算机视觉工具库
from torch.utils.data import DataLoader  # 数据加载工具
from torch.utils.tensorboard import SummaryWriter  # TensorBoard可视化工具

# 从test_03.py文件导入writer(这里实际被后面的代码覆盖了,暂时忽略)
from test_03 import writer

基础知识拓展

  • torchvision:PyTorch 官方的计算机视觉库,包含常用数据集(如 CIFAR10、MNIST)、预训练模型(如 ResNet、VGG)和图像预处理工具。
  • DataLoader:PyTorch 的核心数据加载工具,负责将数据集按批次加载,支持并行处理和数据打乱。
  • SummaryWriter:TensorBoard 的 PyTorch 接口,用于记录和可视化训练过程(图像、损失值、权重分布等)。

2. 加载 CIFAR10 测试数据集

python 复制代码
test_data = torchvision.datasets.CIFAR10(
    root=".dataset",  # 数据集保存路径
    train=False,      # 是否为训练集(False表示测试集)
    transform=torchvision.transforms.ToTensor(),  # 数据转换
    download=True     # 自动下载数据集
)

参数详解

  • root:数据集存储的本地路径(这里是当前文件夹下的.dataset文件夹)。如果该路径不存在,会自动创建。
  • train=False:CIFAR10 分为训练集(50000 张图片)和测试集(10000 张图片),train=False表示加载测试集。
  • transform=ToTensor():将图像从 PIL 格式(Python 图像库格式)转换为 PyTorch 的 Tensor 格式,同时完成两个关键操作:
    • 像素值从[0, 255]归一化到[0, 1](神经网络对小范围数值更敏感)
    • 维度从(高度, 宽度, 通道)转换为(通道, 高度, 宽度)(PyTorch 的默认格式)
  • download=True:如果root路径下没有数据集,自动从官方地址下载(约 160MB)。

CIFAR10 数据集细节

  • 包含 10 个类别:飞机(0)、汽车(1)、鸟(2)、猫(3)、鹿(4)、狗(5)、青蛙(6)、马(7)、船(8)、卡车(9)。
  • 每张图片都是 32x32 像素的彩色图像(RGB 三通道)。

3. 创建 DataLoader

python 复制代码
test_loader = DataLoader(
    dataset=test_data,    # 传入数据集
    batch_size=64,        # 每批次64个样本
    shuffle=True,         # 打乱数据顺序
    num_workers=0,        # 单进程加载(Windows推荐0)
    drop_last=False       # 保留最后一个不完整批次
)

核心作用 :将test_data这个数据集转换为可迭代的批次数据,方便模型批量处理。

参数详解

  • dataset:要加载的数据集(必须是Dataset类的实例)。
  • batch_size=64:每次迭代返回 64 张图片和对应的 64 个标签。为什么用批次?
    • 单次处理太多样本会占用过多内存
    • 批次梯度下降比单样本梯度下降更稳定
  • shuffle=True:每个 epoch(轮次)前打乱数据顺序。测试集一般不需要打乱(设为False),这里可能是为了演示效果。
  • num_workers=0:数据加载的进程数。0 表示在主进程中加载(Windows 系统设为非 0 可能会报错),Linux/Mac 可设为 4、8 等加速加载。
  • drop_last=False:如果数据集总样本数不能被batch_size整除,是否丢弃最后一个不完整的批次。例如 CIFAR10 测试集有 10000 张,10000 ÷ 64 = 156 余 16,drop_last=False会保留最后 16 张的批次。

4. 查看单个样本

python 复制代码
img, target = test_data[0]  # 获取第1个样本(索引从0开始)
print(img.shape)  # 输出图片形状
print(target)     # 输出标签

输出解释

  • img.shape的结果是 torch.Size([3, 32, 32]),表示:
    • 3:RGB 三通道
    • 32:图像高度(像素)
    • 32:图像宽度(像素)
  • target的结果是一个整数(例如 3),对应 CIFAR10 的类别标签(3 代表 "猫")。

python 复制代码
    for data in test_loader:
        imgs, targets = data
        # print(imgs.shape)
        # print(targets)
        writer.add_image("Epoch:{}".format(epoch), imgs, step)
        step =step +1

为什么用test_data[0]而不是test_loader[0]

  • test_dataDataset对象,支持按索引直接访问单个样本。
  • test_loaderDataLoader对象,不支持直接按索引访问,必须通过迭代器(for循环)访问。

5. TensorBoard 可视化设置

python 复制代码
writer = SummaryWriter("dataloader")  # 创建日志写入器,日志保存到"dataloader"文件夹

TensorBoard 作用

  • 实时可视化训练过程中的图像、损失值、准确率等指标。
  • 支持对比不同实验的结果(如不同批次大小、不同学习率的效果)。

使用方法

  1. 代码运行后,会在当前目录生成 "dataloader" 文件夹,里面包含日志文件。

  2. 打开终端,运行命令:tensorboard --logdir=dataloader

    python 复制代码
    tensorboard --logdir=dataloader
  3. 在浏览器中访问提示的地址(通常是http://localhost:6006),即可查看可视化结果。

6. 多轮次可视化批次数据

python 复制代码
for epoch in range(2):  # 循环2个轮次
    step = 0  # 记录每个轮次内的步数
    for data in test_loader:  # 迭代加载批次数据
        imgs, targets = data  # 拆分批次数据为图片和标签
        # 向TensorBoard写入图像,标签为"Epoch:0"或"Epoch:1",步数为step
        writer.add_image("Epoch:{}".format(epoch), imgs, step)
        step += 1  # 步数递增

writer.close()  # 关闭写入器,释放资源

核心逻辑

  • 循环 2 个 epoch(轮次),模拟模型训练时多轮次处理数据的场景。
  • 每个 epoch 内,通过test_loader按批次加载数据,并用writer.add_image将批次图像写入 TensorBoard。

add_image参数详解

  • 第 1 个参数:图像标签(字符串),用于在 TensorBoard 中区分不同类别的图像。这里用"Epoch:0""Epoch:1"区分两个轮次。
  • 第 2 个参数:要显示的图像数据,必须是 Tensor 格式,形状可以是:
    • 单张图片:(通道数, 高度, 宽度)
    • 批次图片:(批次大小, 通道数, 高度, 宽度)(这里用的是这种格式,会自动显示网格状排列的多张图片)
  • 第 3 个参数:全局步数(step),用于在 TensorBoard 中按顺序展示。

为什么要分多个 epoch?

  • 在模型训练中,一个 epoch 表示遍历完所有训练数据一次。
  • 通常需要多个 epoch 才能让模型充分学习数据中的规律(如 10、20、50 个 epoch)。
  • 这里可视化不同 epoch 的批次数据,是为了观察数据打乱后的分布差异(因为shuffle=True)。

三、运行结果与 TensorBoard 查看

1. 控制台输出

python 复制代码
torch.Size([3, 32, 32])  # 第1张图片的形状
3                       # 第1张图片的标签(例如"猫")

2. TensorBoard 可视化

打开 TensorBoard 后,在 "IMAGES" 标签页可以看到:

  • 两个类别:Epoch:0Epoch:1
  • 每个类别下有 157 张网格图片(因为 10000 ÷ 64 = 156 余 16,共 157 个批次)
  • 每张网格图包含 64 张(或最后一批 16 张)32x32 的彩色图像
  • 对比Epoch:0Epoch:1的同一 step,会发现图像顺序不同(因为shuffle=True

四、关键知识点拓展

1. Dataset 与 DataLoader 的关系

  • Dataset:负责 "数据在哪""怎么读"(存储数据路径、定义单样本读取方式)。
  • DataLoader:负责 "怎么喂给模型"(批处理、打乱、并行加载)。
  • 类比:Dataset像仓库管理员(负责找到并取出单个商品),DataLoader像快递员(负责把商品打包成批,高效配送)。

2. 为什么需要 TensorBoard?

  • 深度学习训练周期长,需要实时监控模型状态。
  • 可以直观对比不同参数(如学习率、批次大小)对结果的影响。
  • 支持可视化图像、损失曲线、模型结构、梯度分布等,帮助调试模型。

3. 常见错误与解决

  • 数据集下载失败 :检查网络连接,或手动下载数据集放到root路径下。
  • num_workers 报错 :Windows 系统将num_workers设为 0(多进程在 Windows 上支持不好)。
  • TensorBoard 无法打开 :确保日志路径正确,或尝试更换端口(tensorboard --logdir=dataloader --port=6007)。

五、实际应用场景

这段代码是深度学习的基础流程,实际训练时会在此基础上添加:

  1. 定义模型(如 CNN、ResNet)
  2. 定义损失函数(如交叉熵损失)
  3. 定义优化器(如 Adam、SGD)
  4. 在循环中加入模型训练逻辑(前向传播→计算损失→反向传播→参数更新)
  5. 用 TensorBoard 记录损失值、准确率等指标
相关推荐
深蓝电商API3 小时前
实战破解前端渲染:当 Requests 无法获取数据时(Selenium/Playwright 入门)
前端·python·selenium·playwright
肖书婷3 小时前
人工智能-机器学习day4
人工智能·机器学习
Sui_Network3 小时前
CUDIS 健康协议在 Sui 上打造更健康的未来
人工智能·科技·web3·去中心化·区块链
飞哥数智坊3 小时前
Claude 4.5 升级解析:很强,但请别跳过“Imagine”
人工智能·ai编程·claude
星期天要睡觉3 小时前
计算机视觉(opencv)——基于 dlib 关键点定位
人工智能·opencv·计算机视觉
程序边界4 小时前
AI时代如何高效学习Python:从零基础到项目实战de封神之路(2025升级版)
人工智能·python·学习
研梦非凡4 小时前
探索3D空间的视觉基础模型系列
人工智能·深度学习·神经网络·机器学习·计算机视觉·3d
gooxi_hui4 小时前
国鑫发布新一代「海擎」服务器 全面兼容国内外主流OAM GPU
人工智能
Gerlat小智4 小时前
【手撕机器学习 04】手撕线性回归:从“蒙眼下山”彻底理解梯度下降
人工智能·机器学习·线性回归