TensorBoard

以下是 TensorBoard 在 PyTorch 中的使用指南,涵盖安装、基础操作和高级功能,帮助你高效监控和可视化模型训练过程。


1. 安装与验证

安装
bash 复制代码
pip install tensorboard
验证安装
bash 复制代码
tensorboard --version  # 输出版本号,例如:2.12.0

2. 基础使用流程

步骤 1:导入 SummaryWriter
python 复制代码
from torch.utils.tensorboard import SummaryWriter

# 创建 writer 对象,指定日志保存目录(默认:runs/当前时间)
writer = SummaryWriter("logs")  # 日志会保存在 ./logs 文件夹中
步骤 2:记录数据

记录标量(损失、准确率等)

python 复制代码
for epoch in range(100):
    loss = 0.1 * (100 - epoch)  # 模拟损失值
    accuracy = 0.01 * epoch     # 模拟准确率
    
    # 记录单指标
    writer.add_scalar("Loss/train", loss, epoch)
    # 记录多指标(同一图表)
    writer.add_scalars("Metrics", {"train_loss": loss, "train_acc": accuracy}, epoch)

记录模型结构

python 复制代码
model = ...  # 你的PyTorch模型
dummy_input = torch.randn(1, 3, 224, 224)  # 输入样例(batch_size=1, 3通道, 224x224图像)
writer.add_graph(model, dummy_input)  # 生成计算图

记录图像/特征图

python 复制代码
images = torch.randn(8, 3, 224, 224)  # 模拟一批图像
writer.add_images("Training Samples", images, epoch)  # 记录图像批次

# 记录卷积层特征图(假设features是中间层的输出)
features = model.conv_layers(dummy_input)
writer.add_image("Feature Maps", features[0], epoch, dataformats="HW")  # 单通道特征图

记录直方图(权重分布)

python 复制代码
for name, param in model.named_parameters():
    writer.add_histogram(f"Parameters/{name}", param, epoch)

3. 启动 TensorBoard 服务

在终端中运行以下命令(注意路径匹配):

bash 复制代码
tensorboard --logdir=logs --port=6006
  • --logdir:指定日志目录(与 SummaryWriter 的路径一致)。
  • --port:指定端口(默认6006,若冲突可改为其他端口如6007)。

访问浏览器:http://localhost:6006(或远程服务器IP:端口)。


4. 核心功能详解

Scalars(标量)
  • 监控训练/验证损失、准确率、学习率等指标。
  • 技巧 :使用 / 命名层级(如 Loss/trainLoss/val),TensorBoard会自动分组。
Graphs(模型结构)
  • 可视化模型计算图,检查数据流和层连接。
  • 注意 :确保 add_graph 的输入张量形状与实际数据一致。
Images(图像)
  • 查看输入数据、数据增强效果或中间特征图。
  • 支持格式 :单张图像(add_image)或批次图像(add_images)。
Histograms(直方图)
  • 分析权重/偏置的分布变化,检测梯度消失或爆炸。
PR Curves & ROC(分类任务)
  • 记录精确率-召回率曲线或ROC曲线:
python 复制代码
from torchmetrics import PrecisionRecallCurve
pr_curve = PrecisionRecallCurve(task="binary")
precision, recall, _ = pr_curve(predictions, labels)
writer.add_pr_curve("PR Curve", labels, predictions, epoch)

5. 高级功能

Embedding Projector(降维可视化)
python 复制代码
# 记录嵌入向量(如特征提取后的高维数据)
embeddings = model.get_embeddings(data)  # 假设输出形状 [N, 512]
writer.add_embedding(embeddings, metadata=labels, label_img=images, global_step=epoch)
  • 在 TensorBoard 的 Projector 标签页中查看PCA/t-SNE降维结果。
Hyperparameter Tuning(超参数对比)
python 复制代码
# 记录超参数和对应结果
writer.add_hparams(
    {"lr": 0.01, "batch_size": 32},
    {"hparam/accuracy": 0.95, "hparam/loss": 0.1},
)

6. 常见问题

Q1:TensorBoard 页面无数据?
  • 检查 --logdir 路径是否与 SummaryWriter 的路径一致。
  • 确保数据已写入日志(调用 writer.flush() 或关闭 writer)。
Q2:如何远程访问 TensorBoard?

在服务器运行:

bash 复制代码
tensorboard --logdir=logs --port=6006 --bind_all

本地通过 ssh 转发端口:

bash 复制代码
ssh -L 6006:localhost:6006 user@server_ip
Q3:日志文件过大?
  • 定期清理旧日志或按实验分目录保存(如 logs/exp1, logs/exp2)。
  • 使用 writer.close() 确保资源释放。

7. 完整代码示例

python 复制代码
import torch
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet18

# 初始化
writer = SummaryWriter("logs")
model = resnet18(pretrained=False)
dummy_input = torch.randn(1, 3, 224, 224)

# 记录模型结构
writer.add_graph(model, dummy_input)

# 模拟训练循环
for epoch in range(100):
    loss = 0.1 * (100 - epoch)
    accuracy = 0.01 * epoch
    
    # 记录标量
    writer.add_scalar("Loss/train", loss, epoch)
    writer.add_scalars("Metrics", {"train_acc": accuracy}, epoch)
    
    # 记录直方图
    for name, param in model.named_parameters():
        writer.add_histogram(f"Params/{name}", param, epoch)

writer.close()

总结

  • 核心步骤 :安装 → 创建 SummaryWriter → 记录数据 → 启动服务。
  • 常用方法add_scalar(标量)、add_graph(模型结构)、add_image(图像)、add_histogram(权重分布)。
  • 进阶功能:嵌入投影、超参数对比、PR曲线等。

掌握这些操作,你可以轻松实现训练过程的可视化与深度分析!

相关推荐
蔗理苦1 小时前
2025-04-05 吴恩达机器学习5——逻辑回归(2):过拟合与正则化
人工智能·python·机器学习·逻辑回归
程序猿阿伟2 小时前
《SQL赋能人工智能:解锁特征工程的隐秘力量》
数据库·人工智能·sql
csssnxy2 小时前
叁仟数智指路机器人是否支持远程监控和管理?
大数据·人工智能
车斗3 小时前
win10 笔记本电脑安装 pytorch+cuda+gpu 大模型开发环境过程记录
人工智能·pytorch·电脑
KY_chenzhao3 小时前
数据驱动防灾:AI 大模型在地质灾害应急决策中的关键作用。基于DeepSeek/ChatGPT的AI智能体开发
人工智能·chatgpt·智能体·deepseek·本地化部署
大多_C3 小时前
量化方法分类
人工智能·分类·数据挖掘
www_pp_3 小时前
# 基于 OpenCV 的人脸识别实战:从基础到进阶
人工智能·opencv·计算机视觉
三月七(爱看动漫的程序员)4 小时前
LLM面试题六
数据库·人工智能·gpt·语言模型·自然语言处理·llama·milvus
蹦蹦跳跳真可爱5895 小时前
Python----计算机视觉处理(Opencv:道路检测之车道线拟合)
开发语言·人工智能·python·opencv·计算机视觉
deephub5 小时前
计算加速技术比较分析:GPU、FPGA、ASIC、TPU与NPU的技术特性、应用场景及产业生态
人工智能·深度学习·gpu·计算加速