以下是 TensorBoard 在 PyTorch 中的使用指南,涵盖安装、基础操作和高级功能,帮助你高效监控和可视化模型训练过程。
1. 安装与验证
安装
bash
pip install tensorboard
验证安装
bash
tensorboard --version # 输出版本号,例如:2.12.0
2. 基础使用流程
步骤 1:导入 SummaryWriter
python
from torch.utils.tensorboard import SummaryWriter
# 创建 writer 对象,指定日志保存目录(默认:runs/当前时间)
writer = SummaryWriter("logs") # 日志会保存在 ./logs 文件夹中
步骤 2:记录数据
记录标量(损失、准确率等)
python
for epoch in range(100):
loss = 0.1 * (100 - epoch) # 模拟损失值
accuracy = 0.01 * epoch # 模拟准确率
# 记录单指标
writer.add_scalar("Loss/train", loss, epoch)
# 记录多指标(同一图表)
writer.add_scalars("Metrics", {"train_loss": loss, "train_acc": accuracy}, epoch)
记录模型结构
python
model = ... # 你的PyTorch模型
dummy_input = torch.randn(1, 3, 224, 224) # 输入样例(batch_size=1, 3通道, 224x224图像)
writer.add_graph(model, dummy_input) # 生成计算图
记录图像/特征图
python
images = torch.randn(8, 3, 224, 224) # 模拟一批图像
writer.add_images("Training Samples", images, epoch) # 记录图像批次
# 记录卷积层特征图(假设features是中间层的输出)
features = model.conv_layers(dummy_input)
writer.add_image("Feature Maps", features[0], epoch, dataformats="HW") # 单通道特征图
记录直方图(权重分布)
python
for name, param in model.named_parameters():
writer.add_histogram(f"Parameters/{name}", param, epoch)
3. 启动 TensorBoard 服务
在终端中运行以下命令(注意路径匹配):
bash
tensorboard --logdir=logs --port=6006
--logdir
:指定日志目录(与SummaryWriter
的路径一致)。--port
:指定端口(默认6006,若冲突可改为其他端口如6007)。
访问浏览器:http://localhost:6006
(或远程服务器IP:端口)。
4. 核心功能详解
Scalars(标量)
- 监控训练/验证损失、准确率、学习率等指标。
- 技巧 :使用
/
命名层级(如Loss/train
和Loss/val
),TensorBoard会自动分组。
Graphs(模型结构)
- 可视化模型计算图,检查数据流和层连接。
- 注意 :确保
add_graph
的输入张量形状与实际数据一致。
Images(图像)
- 查看输入数据、数据增强效果或中间特征图。
- 支持格式 :单张图像(
add_image
)或批次图像(add_images
)。
Histograms(直方图)
- 分析权重/偏置的分布变化,检测梯度消失或爆炸。
PR Curves & ROC(分类任务)
- 记录精确率-召回率曲线或ROC曲线:
python
from torchmetrics import PrecisionRecallCurve
pr_curve = PrecisionRecallCurve(task="binary")
precision, recall, _ = pr_curve(predictions, labels)
writer.add_pr_curve("PR Curve", labels, predictions, epoch)
5. 高级功能
Embedding Projector(降维可视化)
python
# 记录嵌入向量(如特征提取后的高维数据)
embeddings = model.get_embeddings(data) # 假设输出形状 [N, 512]
writer.add_embedding(embeddings, metadata=labels, label_img=images, global_step=epoch)
- 在 TensorBoard 的
Projector
标签页中查看PCA/t-SNE降维结果。
Hyperparameter Tuning(超参数对比)
python
# 记录超参数和对应结果
writer.add_hparams(
{"lr": 0.01, "batch_size": 32},
{"hparam/accuracy": 0.95, "hparam/loss": 0.1},
)
6. 常见问题
Q1:TensorBoard 页面无数据?
- 检查
--logdir
路径是否与SummaryWriter
的路径一致。 - 确保数据已写入日志(调用
writer.flush()
或关闭writer
)。
Q2:如何远程访问 TensorBoard?
在服务器运行:
bash
tensorboard --logdir=logs --port=6006 --bind_all
本地通过 ssh
转发端口:
bash
ssh -L 6006:localhost:6006 user@server_ip
Q3:日志文件过大?
- 定期清理旧日志或按实验分目录保存(如
logs/exp1
,logs/exp2
)。 - 使用
writer.close()
确保资源释放。
7. 完整代码示例
python
import torch
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet18
# 初始化
writer = SummaryWriter("logs")
model = resnet18(pretrained=False)
dummy_input = torch.randn(1, 3, 224, 224)
# 记录模型结构
writer.add_graph(model, dummy_input)
# 模拟训练循环
for epoch in range(100):
loss = 0.1 * (100 - epoch)
accuracy = 0.01 * epoch
# 记录标量
writer.add_scalar("Loss/train", loss, epoch)
writer.add_scalars("Metrics", {"train_acc": accuracy}, epoch)
# 记录直方图
for name, param in model.named_parameters():
writer.add_histogram(f"Params/{name}", param, epoch)
writer.close()
总结
- 核心步骤 :安装 → 创建
SummaryWriter
→ 记录数据 → 启动服务。 - 常用方法 :
add_scalar
(标量)、add_graph
(模型结构)、add_image
(图像)、add_histogram
(权重分布)。 - 进阶功能:嵌入投影、超参数对比、PR曲线等。
掌握这些操作,你可以轻松实现训练过程的可视化与深度分析!