PyTorch 图像预处理 transforms 与 TensorBoard 可视化 (自己学习记录)

简介

本文基于 PyTorch 框架,详细讲解 torchvision.transforms 图像预处理工具TensorBoard 可视化Compose 组合操作 以及 CIFAR10 官方数据集加载,附带逐行代码解析、原理说明、路径规则、常见坑点,适合零基础入门、后期复习查阅。

运行环境:Python + PyTorch + torchvision + PIL 核心知识点:图像格式转换、尺寸缩放、标准化、预处理流水线、数据集加载、TensorBoard 图像可视化、相对路径 ./ ../


一、环境与模块导入说明

1. 核心依赖库作用

python

运行

复制代码
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
  1. from PIL import Image PIL 是 Python 主流图像处理库,负责读取本地图片 ,读取后得到 PIL.Image 类型图像,也是 transforms 绝大多数操作的输入对象。
  2. from torch.utils.tensorboard import SummaryWriter TensorBoard 日志记录器,作用是将图像、数值等数据写入日志文件,后续通过命令启动 TensorBoard 实现网页端可视化。
  3. from torchvision import transforms PyTorch 官方图像预处理工具箱,提供图像缩放、格式转换、标准化、裁剪、翻转等一系列深度学习常用图像操作。

二、第一部分:单张图片预处理 + TensorBoard 可视化(逐行代码详解)

完整代码

python

运行

复制代码
# 导入工具库
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

# 1. 读取本地图片
img = Image.open("../images/12.png")

# ===================== 1. ToTensor:PIL图像 → Tensor张量 =====================
# 创建ToTensor转换器实例
tensor_tran = transforms.ToTensor()
# 执行转换:PIL图片 转为 PyTorch 张量
tensor_img = tensor_tran(img)

# 创建TensorBoard日志记录器,自动生成logs文件夹存放日志
writer = SummaryWriter("logs")
# 将转换后的张量图像写入日志,标签:tensor_img,默认step=0
writer.add_image("tensor_img", tensor_img)

# ===================== 2. Normalize:图像标准化 =====================
# 打印标准化前的像素值
print(tensor_img[0][0][0])

# 定义标准化工具:均值[R,G,B],标准差[R,G,B]
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
# 执行标准化操作
img_norm = trans_norm(tensor_img)
# 打印标准化后的像素值
print(img_norm[0][0][0])

# 将标准化后的图像写入日志,step=2
writer.add_image("img_norm", img_norm, 2)

# ===================== 3. Resize:图像固定尺寸缩放(强制正方形) =====================
# 打印原始图片尺寸 (宽, 高)
print(img.size)

# 创建缩放工具:强制缩放为 512*512,会拉伸图像
trans_resize = transforms.Resize((512, 512))
# 执行缩放:输入PIL图像,输出依旧是PIL图像
resize_img = trans_resize(img)
# 调用之前的转换器,将缩放后的PIL图转为Tensor
resize_img = tensor_tran(resize_img)

# 将缩放后的图像写入日志,step=1
writer.add_image("resize_img", resize_img, 1)

# ===================== 4. Compose:组合多个预处理操作 =====================
# 等比例缩放:仅设置一个数值,最短边缩放到512,保持原图比例,不变形
trans_resize_2 = transforms.Resize(512)
# Compose:按【从左到右】顺序打包多个操作
trans_compose = transforms.Compose([trans_resize_2, tensor_tran])
# 一键执行所有打包操作
img_compose = trans_compose(img)

# 将组合处理后的图像写入日志,step=3
writer.add_image("img_compose", img_compose, 3)

# 关闭日志记录器,确保日志完整写入
writer.close()

2. 分模块深度解析

2.1 图片读取 & 相对路径规则

python

运行

复制代码
img = Image.open("../images/12.png")
  • 路径解释(重点)
    • ../返回上一级目录
    • ./当前目录(可省略)
    • 假设代码文件在 项目根目录/pytorch/../images/12.png 表示:退出 pytorch 文件夹 → 进入同级 images 文件夹 → 读取 12.png
  • 输出类型:imgPIL.Image 类型,神经网络无法直接使用,必须转为 Tensor。

2.2 transforms.ToTensor() 格式转换(核心必用)

python

运行

复制代码
tensor_tran = transforms.ToTensor()
tensor_img = tensor_tran(img)
功能:
  1. 数据类型转换:PIL.Image / numpy数组PyTorch Tensor 张量
  2. 像素值归一化:原始像素范围 [0, 255] → 映射为 [0, 1]
  3. 维度变换:PIL 维度 (高度, 宽度, 通道数) → Tensor 标准维度 (通道数, 高度, 宽度),适配卷积神经网络输入要求。

2.3 SummaryWriteradd_image 可视化

python

运行

复制代码
writer = SummaryWriter("logs")
writer.add_image("tensor_img", tensor_img)
  1. SummaryWriter("logs")

    • 作用:创建 TensorBoard 日志记录器,自动在当前目录生成 logs 文件夹,存放日志文件 events.out.tfevents.xxx
    • 规则:日志文件夹路径必须和启动 TensorBoard 命令的 --logdir 保持一致,否则无法读取数据。
  2. add_image(标签名, 图像张量, step)

    • 第一个参数:图像标签,TensorBoard 中用于区分不同图像;
    • 第二个参数:必须是 Tensor 类型图像(PIL 图像会报错);
    • 第三个参数 step:步骤序号,用于区分同标签下不同阶段数据,序号重复会覆盖图像,建议按顺序设置 0/1/2/3。

2.4 transforms.Normalize 图像标准化

python

运行

复制代码
# 公式:output = (input - mean) / std
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
img_norm = trans_norm(tensor_img)
1. 计算公式

imgnorm​=stdimgtensor​−mean​

  • 传入两个列表:[R通道均值, G通道均值, B通道均值][R通道标准差, G通道标准差, B通道标准差]
2. 本例效果

输入 Tensor 像素范围 [0, 1]mean=0.5,std=0.5

xnew​=0.5x−0.5​

像素范围被映射为 [-1, 1]

3. 作用

神经网络对 [-1,1] 区间数据 拟合效果更好,加速训练、防止梯度爆炸,是深度学习标配预处理。


2.5 transforms.Resize 图像缩放(两种用法区分)

用法 1:Resize((512, 512)) 强制固定尺寸

python

运行

复制代码
trans_resize = transforms.Resize((512, 512))
resize_img = trans_resize(img)
resize_img = tensor_tran(resize_img)
  • 传入元组 (H, W):强制将图像高、宽统一设置为 512×512;
  • 缺点:会拉伸 / 压缩图像,破坏原始宽高比,图像变形
  • 注意:Resize 仅支持 PIL 图像,缩放后依旧是 PIL 类型,必须再通过 ToTensor() 转张量。
用法 2:Resize(512) 等比例缩放(推荐)

python

运行

复制代码
trans_resize_2 = transforms.Resize(512)
  • 传入单个数字 :将图像最短边缩放到指定尺寸,长边按原始比例自动计算;
  • 优点:保留图像原始比例,不会变形,工程中更常用。

2.6 transforms.Compose 组合预处理流水线(高频用法)

python

运行

复制代码
trans_compose = transforms.Compose([trans_resize_2, tensor_tran])
img_compose = trans_compose(img)
1. 底层原理

Compose 是一个容器类,核心逻辑伪代码:

python

运行

复制代码
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms  # 接收操作列表并保存
    def __call__(self, img):
        # 从左到右依次执行每一个预处理操作
        for t in self.transforms:
            img = t(img)
        return img
2. 执行规则(重中之重)
  • 执行顺序:列表从左 → 从右

  • 本例等价于手动分步执行: python

    运行

    复制代码
    img = trans_resize_2(img)   # 第一步:等比例缩放(PIL操作)
    img = tensor_tran(img)      # 第二步:转Tensor
3. 避坑规则

操作顺序不能颠倒: ❌ 错误:[ToTensor(), Resize()] Resize 只能处理 PIL 图像,不能处理 Tensor,颠倒会直接报错。 ✅ 标准顺序:PIL图像操作(Resize/翻转/裁剪) → ToTensor → Normalize


2.7 收尾 writer.close()

python

运行

复制代码
writer.close()
  • 作用:关闭日志流,强制将内存中的日志数据完整写入本地文件;
  • 不写的后果:程序异常退出时,日志丢失,TensorBoard 看不到数据。

三、第二部分:CIFAR10 官方数据集加载 + 批量图像可视化

完整代码

python

运行

复制代码
import torchvision
from torch.utils.tensorboard import SummaryWriter

# 定义数据集预处理流水线
dataset_transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()]
)

# 加载CIFAR10训练集
train_set = torchvision.datasets.CIFAR10(
    root="../dataset",    # 数据集存放路径
    train=True,           # True=训练集,False=测试集
    transform=dataset_transform,  # 绑定预处理操作
    download=True         # 本地无数据集则自动下载
)

# 加载CIFAR10测试集
test_set = torchvision.datasets.CIFAR10(
    root="../dataset",
    train=False,
    transform=dataset_transform,
    download=True
)

# 新建日志记录器,日志存到logs2文件夹
writer = SummaryWriter("logs2")

# 遍历测试集前10张图片,批量可视化
for i in range(10):
    img, target = test_set[i]  # 取出:图像张量 + 标签
    writer.add_image("test_img", img, i)

# 关闭记录器
writer.close()

逐模块解析

3.1 导入方式区别

python

运行

复制代码
import torchvision
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
  1. import torchvision:导入整个库,调用内部模块必须加前缀 torchvision.
  2. 对比前文 from torchvision import transforms:直接导入子模块,无需前缀,两种写法功能完全一致,仅语法习惯不同。

3.2 torchvision.datasets.CIFAR10 数据集参数详解

  1. root="../dataset" 数据集本地存放路径,结合相对路径规则,数据集会自动下载 / 读取到项目根目录 /dataset 文件夹下。

  2. train=True/False

    • train=True:加载 训练集(50000 张图像);
    • train=False:加载 测试集(10000 张图像)。
  3. transform=dataset_transform 绑定预处理流水线,数据集读取图像时,自动执行 Compose 内的所有操作(本例自动转为 Tensor)。

  4. download=True

    • 本地 root 路径下无数据集:自动联网下载并解压;
    • 本地已存在完整数据集:跳过下载,直接加载。

补充:官方源下载速度慢解决方案

  1. 手动下载数据集压缩包,放入 ../dataset 目录,代码会自动识别;
  2. 替换为清华镜像地址加速下载。

3.3 数据集取值与批量可视化

python

运行

复制代码
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_img", img, i)
  1. test_set[i]:通过下标取值,返回二元组 (图像张量, 标签)
    • img:预处理后的 Tensor 图像,可直接用于 add_image
    • target:图像分类标签(CIFAR10 共 10 个类别,标签 0~9)。
  2. 循环遍历前 10 张图,step=i 按序号区分,在 TensorBoard 中滑动查看批量图片。

四、TensorBoard 启动使用教程

1. 启动命令(两种写法,推荐第一种)

方式 1(稳定无路径问题,必学)

激活你的 conda 虚拟环境,进入代码所在目录,执行:

bash

运行

复制代码
# 查看第一部分图片日志
python -m tensorboard.main --logdir=logs --port=6006

# 查看第二部分数据集日志
python -m tensorboard.main --logdir=logs2 --port=6007

方式 2(易受环境 PATH 影响,不推荐)

bash

运行

复制代码
tensorboard --logdir=logs

2. 访问方式

复制终端输出的 http://localhost:端口号,在浏览器打开,切换到 Images 标签页即可查看所有图像。

3. 常见问题汇总

  1. 页面空白、无图像

    • 原因 1:logdir 路径和代码中 SummaryWriter 路径不匹配;
    • 原因 2:step 序号重复,图像被覆盖;
    • 解决:删除 logs/logs2 旧日志文件,重新运行代码,重启 TensorBoard。
  2. 提示 无法将"tensorboard"项识别为 cmdlet 原因:虚拟环境未识别到 tensorboard 命令; 解决:统一使用 python -m tensorboard.main 方式启动。

  3. Resize 后图像变形 解决:将 (H,W) 固定尺寸写法,改为单数值 Resize(边长) 等比例缩放。


五、核心知识点总结(复习速记)

  1. 三大基础预处理

    • ToTensor():PIL → Tensor,像素 [0,255]→[0,1],转换维度;
    • Normalize(mean, std):标准化,常用 mean=0.5,std=0.5,映射到 [-1,1]
    • Resize():双参数 = 强制固定尺寸(变形),单参数 = 等比例缩放(不变形)。
  2. Compose 组合器 按列表从左到右 顺序执行操作,是项目中标准预处理写法,严格遵守 PIL操作 → ToTensor → Normalize 顺序。

  3. 相对路径

    • ../:上一级目录;./:当前目录;路径错误会导致图片 / 数据集读取失败。
  4. TensorBoard 规则

    • 仅支持 Tensor 图像;
    • step 序号唯一,避免图像覆盖;
    • 代码日志路径 与 启动命令 logdir 必须一致;
    • 代码末尾必须执行 writer.close()
  5. CIFAR10 数据集

相关推荐
莱歌数字1 天前
换热器计算方法与步骤:从热平衡到性能校核
人工智能·科技·制造·cae·散热
小鹿研究点东西1 天前
AI直播工具实操:从直播录制、AI剪辑去重到直播伴侣开播完整流程
人工智能·自动化·音视频·语音识别
碳基硅坊1 天前
Spring AI:把大模型接进 Spring 应用
java·人工智能·spring ai
才兄说1 天前
机器人二次开发机器狗巡检?全环境稳定感知
人工智能·机器人
一一哥Sun1 天前
第06课:Transformer与注意力机制——大模型背后的秘密武器
人工智能·深度学习·transformer
landyjzlai1 天前
蓝迪哥玩转Ai(10)---Harness工程说透1。
人工智能·harness
onething3651 天前
Spring Boot + Spring AI 从入门到实战:7天转型计划 Day 3 —— 消息表设计 + 级联删除 + 事务管理
人工智能·后端
王某某人1 天前
LangChain4j 入门:Java 程序员的第一个 AI 对话程序
人工智能·后端
海兰1 天前
【实用程序】电商销售分析仪表盘 — 从零搭建一个AI参与的全栈数据洞察系统
人工智能·学习·算法
枫糖浆AI1 天前
openclaw页面无法访问解决方法
人工智能