TensorBoard 与 WandB 在 PyTorch Lightning 中的完整指南
📑 目录
- 背景与动机
- 核心概念与工具介绍
- [TensorBoard 在 PyTorch Lightning 中的使用](#TensorBoard 在 PyTorch Lightning 中的使用)
- [WandB 在 PyTorch Lightning 中的使用](#WandB 在 PyTorch Lightning 中的使用)
- 高级用法与最佳实践
- [TensorBoard vs WandB 详细对比](#TensorBoard vs WandB 详细对比)
- 常见问题与调试技巧
- 扩展阅读与进阶方向
1. 背景与动机
1.1 为什么需要实验追踪工具?
在深度学习研究与开发中,我们面临以下挑战:
- 实验数量庞大:一个项目可能运行数百次实验
- 超参数复杂:学习率、batch size、模型结构等需要精细调优
- 结果难以复现:需要记录完整的配置与环境
- 团队协作困难:需要共享实验结果与模型
实验追踪工具的核心价值:
- 自动记录训练指标(loss、accuracy、学习率等)
- 可视化训练过程(曲线图、直方图、图像等)
- 对比不同实验的结果
- 保存模型检查点与配置
- 支持团队协作与知识沉淀
1.2 PyTorch Lightning 的优势
PyTorch Lightning 是 PyTorch 的高层封装,它:
- 自动处理训练循环、验证、测试
- 原生支持多种 Logger(TensorBoard、WandB、MLflow 等)
- 通过 插件化设计,无需修改核心代码即可切换追踪工具
2. 核心概念与工具介绍
2.1 TensorBoard
定义:TensorFlow 官方开发的可视化工具,但完全支持 PyTorch。
核心特点:
- ✅ 本地运行:无需联网,数据完全私有
- ✅ 轻量级:无额外依赖,启动快速
- ✅ 标准化:工业界广泛使用
- ❌ 功能有限:缺乏实验管理、超参数扫描等高级功能
- ❌ 团队协作弱:需要手动共享日志文件
适用场景:
- 个人学习与小型项目
- 对数据隐私有严格要求
- 快速调试与可视化
2.2 WandB (Weights & Biases)
定义:专业的实验追踪与管理平台,支持云端协作。
核心特点:
- ✅ 功能强大:实验对比、超参数扫描、模型版本管理
- ✅ 团队协作:云端共享,支持实时查看
- ✅ 生态丰富:集成 Hugging Face、Ray Tune 等
- ❌ 需要联网:数据默认上传到云端(可自建私有服务器)
- ❌ 学习曲线:功能多,需要一定时间掌握
适用场景:
- 科研实验管理(多实验对比、论文复现)
- 团队协作项目
- 需要超参数优化或模型选择
2.3 PyTorch Lightning 的 Logger 机制
PyTorch Lightning 通过 Logger 抽象层统一管理日志记录:
python
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
# 所有 Logger 都遵循统一接口
logger = TensorBoardLogger(...) # 或 WandbLogger(...)
trainer = Trainer(logger=logger)
关键优势:
- 无需修改模型代码,只需更换
logger参数 - 支持同时使用多个 Logger
- 自动记录超参数、系统信息等元数据
3. TensorBoard 在 PyTorch Lightning 中的使用
3.1 安装与环境配置
bash
# 安装 PyTorch Lightning
pip install pytorch-lightning
# 安装 TensorBoard
pip install tensorboard
# 验证安装
tensorboard --version
3.2 基础用法:最小示例
python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
# 1. 定义 LightningModule
class SimpleModel(pl.LightningModule):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
# 核心:记录指标到 TensorBoard
self.log('train_loss', loss, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
# 记录验证指标
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# 2. 准备数据
X = torch.randn(1000, 10)
y = torch.randint(0, 3, (1000,))
dataset = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset, batch_size=32)
# 3. 配置 TensorBoard Logger
logger = TensorBoardLogger(
save_dir='logs/', # 日志保存目录
name='my_experiment', # 实验名称
version='v1', # 版本号(可选)
default_hp_metric=False # 不记录默认超参数指标
)
# 4. 训练模型
model = SimpleModel(input_dim=10, hidden_dim=64, output_dim=3)
trainer = pl.Trainer(
max_epochs=10,
logger=logger,
log_every_n_steps=10 # 每 10 步记录一次
)
trainer.fit(model, train_loader, val_loader)
# 5. 启动 TensorBoard
# 在命令行运行:tensorboard --logdir=logs/
# 然后在浏览器打开 http://localhost:6006
3.3 核心 API:self.log() 详解
python
self.log(
name='metric_name', # 指标名称
value=tensor_or_scalar, # 值(Tensor 或标量)
on_step=True, # 是否记录每个 step
on_epoch=True, # 是否记录每个 epoch
prog_bar=False, # 是否显示在进度条
logger=True, # 是否记录到 logger
reduce_fx='mean', # 聚合方式(mean/sum/max/min)
sync_dist=False # 分布式训练时同步
)
常用模式:
| 场景 | 配置 |
|---|---|
| 训练 loss(每步记录) | on_step=True, on_epoch=False |
| 验证 loss(每轮记录) | on_step=False, on_epoch=True |
| 学习率(每步记录) | on_step=True, on_epoch=False |
| 准确率(显示在进度条) | on_epoch=True, prog_bar=True |
3.4 高级功能
3.4.1 记录图像
python
from torchvision.utils import make_grid
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
# 每 5 个 epoch 记录一次图像
if self.current_epoch % 5 == 0 and batch_idx == 0:
# 假设 x 是图像数据
grid = make_grid(x[:8], normalize=True)
self.logger.experiment.add_image(
'val_images', grid, self.global_step
)
3.4.2 记录直方图(参数分布)
python
def on_train_epoch_end(self):
for name, param in self.named_parameters():
self.logger.experiment.add_histogram(
name, param, self.current_epoch
)
3.4.3 记录超参数
python
# 在模型初始化时保存超参数
def __init__(self, input_dim, hidden_dim, lr):
super().__init__()
self.save_hyperparameters() # 自动保存所有参数
# TensorBoard 会在 HPARAMS 标签页显示
3.4.4 记录学习率
python
def on_train_batch_end(self, outputs, batch, batch_idx):
# 获取当前学习率
sch = self.lr_schedulers()
current_lr = sch.get_last_lr()[0]
self.log('learning_rate', current_lr, on_step=True)
3.5 启动 TensorBoard
bash
# 方式1:基础启动
tensorboard --logdir=logs/
# 方式2:指定端口
tensorboard --logdir=logs/ --port=6007
# 方式3:对比多个实验
tensorboard --logdir_spec=exp1:logs/exp1,exp2:logs/exp2
# 方式4:远程服务器(允许外部访问)
tensorboard --logdir=logs/ --host=0.0.0.0 --port=6006
TensorBoard 界面说明:
- SCALARS:查看 loss、accuracy 等曲线
- IMAGES:查看记录的图像
- GRAPHS:查看模型计算图
- DISTRIBUTIONS / HISTOGRAMS:查看参数分布
- HPARAMS:对比不同超参数的实验结果
4. WandB 在 PyTorch Lightning 中的使用
4.1 安装与账号配置
bash
# 安装 wandb
pip install wandb
# 登录账号(需要先在 https://wandb.ai 注册)
wandb login
# 或使用 API Key
# wandb login --relogin YOUR_API_KEY
4.2 基础用法:最小示例
python
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
# 1. 配置 WandB Logger
logger = WandbLogger(
project='my_project', # 项目名称
name='exp_001', # 实验名称
save_dir='logs/', # 本地日志目录
log_model='all' # 保存模型检查点('all' 或 True)
)
# 2. 训练模型(模型定义同 TensorBoard 示例)
model = SimpleModel(input_dim=10, hidden_dim=64, output_dim=3)
trainer = pl.Trainer(
max_epochs=10,
logger=logger
)
trainer.fit(model, train_loader, val_loader)
# 3. 训练结束后自动上传到云端
# 访问 https://wandb.ai/YOUR_USERNAME/my_project 查看结果
wandb.finish() # 可选:显式结束运行
4.3 核心功能
4.3.1 自动记录配置与代码
python
# WandB 会自动记录:
# - Python 版本、依赖库版本
# - Git commit hash(如果在 Git 仓库中)
# - 硬件信息(GPU、CPU、内存)
# - 命令行参数
# 手动记录额外配置
logger.experiment.config.update({
"architecture": "ResNet50",
"dataset": "CIFAR-10",
"custom_param": 123
})
4.3.2 记录图像与媒体
python
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
if batch_idx == 0:
# 记录图像
self.logger.log_image(
key='predictions',
images=[x[0]],
caption=[f"True: {y[0]}, Pred: {y_hat[0].argmax()}"]
)
# 记录表格数据
columns = ["id", "prediction", "truth"]
data = [[i, y_hat[i].argmax().item(), y[i].item()]
for i in range(8)]
self.logger.log_table(key="predictions_table",
columns=columns, data=data)
4.3.3 记录混淆矩阵
python
from sklearn.metrics import confusion_matrix
import wandb
def on_validation_epoch_end(self):
# 假设已收集所有预测结果
y_true = [...] # 真实标签
y_pred = [...] # 预测标签
cm = confusion_matrix(y_true, y_pred)
self.logger.experiment.log({
"confusion_matrix": wandb.plot.confusion_matrix(
probs=None,
y_true=y_true,
preds=y_pred,
class_names=['Class0', 'Class1', 'Class2']
)
})
4.3.4 保存模型 Artifact
python
# 在 Trainer 中配置
logger = WandbLogger(
project='my_project',
log_model='all' # 自动保存所有检查点
# log_model=True # 仅保存最佳模型
)
# 手动保存自定义 Artifact
artifact = wandb.Artifact('model_weights', type='model')
artifact.add_file('best_model.ckpt')
logger.experiment.log_artifact(artifact)
4.4 高级功能:超参数扫描(Sweeps)
WandB Sweeps 支持自动化超参数优化(类似 Optuna)。
配置文件:sweep_config.yaml
yaml
program: train.py
method: bayes # 优化方法:grid/random/bayes
metric:
name: val_loss
goal: minimize
parameters:
learning_rate:
distribution: log_uniform_values
min: 0.0001
max: 0.1
batch_size:
values: [16, 32, 64]
hidden_dim:
values: [64, 128, 256]
启动 Sweep
bash
# 初始化 sweep
wandb sweep sweep_config.yaml
# 运行 agent(可启动多个并行)
wandb agent YOUR_SWEEP_ID
在代码中集成 Sweep
python
import wandb
def train():
# 初始化 wandb(从 sweep 获取配置)
wandb.init()
config = wandb.config
# 使用 sweep 配置
model = SimpleModel(
input_dim=10,
hidden_dim=config.hidden_dim,
output_dim=3
)
logger = WandbLogger()
trainer = pl.Trainer(max_epochs=10, logger=logger)
trainer.fit(model, train_loader, val_loader)
# 主程序
if __name__ == '__main__':
train()
4.5 离线模式(无网络环境)
python
import os
os.environ['WANDB_MODE'] = 'offline' # 本地保存,稍后手动上传
# 训练完成后手动同步
# wandb sync logs/wandb/run-xxx
5. 高级用法与最佳实践
5.1 同时使用多个 Logger
python
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
# 创建多个 logger
tb_logger = TensorBoardLogger('logs/', name='tb')
wandb_logger = WandbLogger(project='my_project', name='wandb')
# 传递 logger 列表
trainer = pl.Trainer(
max_epochs=10,
logger=[tb_logger, wandb_logger] # 同时记录到两个平台
)
适用场景:
- 本地调试用 TensorBoard,云端备份用 WandB
- 团队协作(共享 WandB)+ 个人分析(TensorBoard)
5.2 自定义 Logger 行为
仅记录特定指标到特定 Logger
python
def training_step(self, batch, batch_idx):
loss = ...
# 仅记录到 TensorBoard
if isinstance(self.logger, TensorBoardLogger):
self.log('debug_metric', some_value)
# 仅记录到 WandB
if isinstance(self.logger, WandbLogger):
self.logger.experiment.log({'custom_metric': value})
return loss
5.3 记录自定义可视化
TensorBoard:记录 PR 曲线
python
from torch.utils.tensorboard import SummaryWriter
def on_validation_epoch_end(self):
writer = self.logger.experiment
# 假设有预测概率和标签
predictions = torch.cat(self.validation_predictions)
labels = torch.cat(self.validation_labels)
writer.add_pr_curve('pr_curve', labels, predictions, self.current_epoch)
WandB:记录自定义图表
python
import matplotlib.pyplot as plt
import wandb
def on_validation_epoch_end(self):
# 创建 matplotlib 图表
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [4, 5, 6])
# 记录到 WandB
self.logger.experiment.log({"custom_plot": wandb.Image(fig)})
plt.close()
5.4 分布式训练中的日志记录
python
# PyTorch Lightning 自动处理分布式同步
self.log(
'train_loss',
loss,
sync_dist=True, # 关键:在多 GPU 间同步指标
rank_zero_only=True # 仅主进程记录(避免重复)
)
5.5 性能优化技巧
python
# 1. 减少记录频率
trainer = pl.Trainer(
log_every_n_steps=50 # 每 50 步记录一次(默认 50)
)
# 2. 异步记录(WandB 默认支持)
os.environ['WANDB_CONSOLE'] = 'off' # 关闭控制台输出
# 3. 避免记录大张量
self.log('loss', loss.item()) # 使用 .item() 转为标量
6. TensorBoard vs WandB 详细对比
6.1 功能对比表
| 功能 | TensorBoard | WandB | 说明 |
|---|---|---|---|
| 基础功能 | |||
| 记录标量指标 | ✅ | ✅ | 两者均支持 |
| 可视化曲线 | ✅ | ✅ | WandB 交互性更强 |
| 记录图像 | ✅ | ✅ | WandB 支持更多格式 |
| 记录音频/视频 | ✅ | ✅ | WandB 更易用 |
| 模型结构可视化 | ✅ | ❌ | TensorBoard 专有 |
| 参数分布直方图 | ✅ | ✅ | 两者相当 |
| 高级功能 | |||
| 实验对比 | 🟡 需手动 | ✅ 原生支持 | WandB 可交互式对比 |
| 超参数扫描 | ❌ | ✅ Sweeps | WandB 内置优化算法 |
| 模型版本管理 | ❌ | ✅ Artifacts | WandB 支持模型/数据追踪 |
| 团队协作 | 🟡 需手动共享文件 | ✅ 云端共享 | WandB 天然支持 |
| API 调用 | 🟡 复杂 | ✅ 简洁 | WandB API 更友好 |
| 报告生成 | ❌ | ✅ Reports | WandB 可生成可分享报告 |
| 部署与隐私 | |||
| 本地运行 | ✅ | 🟡 需 offline 模式 | TensorBoard 完全本地 |
| 云端托管 | ❌ | ✅ | WandB 默认云端 |
| 私有化部署 | ✅ | 🟡 需企业版 | TensorBoard 开源 |
| 数据隐私 | ✅ 完全本地 | 🟡 可配置(默认云端) | 根据需求选择 |
| 学习与使用 | |||
| 学习曲线 | 🟢 简单 | 🟡 中等 | TensorBoard 更易上手 |
| 文档质量 | ✅ | ✅ | 两者均完善 |
| 社区支持 | ✅ | ✅ | 两者均活跃 |
| 开源/商业 | 开源 | 商业(免费版有限制) | TensorBoard 完全免费 |
图例:
- ✅ 完全支持
- 🟡 部分支持或需额外配置
- 🟢 优势明显
- ❌ 不支持
6.2 选择建议
选择 TensorBoard 的场景
-
个人学习与小型项目
- 快速调试,无需复杂功能
- 对数据隐私要求高
-
无法联网的环境
- 内网服务器、离线开发
-
已有 TensorBoard 基础
- 团队已熟悉 TensorBoard 工作流
-
需要模型结构可视化
- TensorBoard 的 GRAPHS 功能独有
选择 WandB 的场景
-
科研实验管理
- 需要对比大量实验
- 论文实验复现与记录
-
团队协作项目
- 多人共享实验结果
- 实时查看训练进度
-
超参数优化
- 使用 Sweeps 自动化调优
-
模型/数据版本管理
- 需要追踪模型演进历史
同时使用两者
- 本地调试 → TensorBoard(快速启动)
- 实验记录 → WandB(长期存储)
- 代码配置:通过 Lightning 的多 logger 机制轻松切换
7. 常见问题与调试技巧
7.1 TensorBoard 常见问题
Q1: TensorBoard 界面显示 "No dashboards are active"
原因:日志目录错误或无数据
解决:
bash
# 检查日志目录结构
ls -R logs/
# 确保目录包含 events.out.tfevents.* 文件
# 正确的目录结构示例:
# logs/
# └── my_experiment/
# └── version_0/
# └── events.out.tfevents.xxx
Q2: 曲线不更新或刷新慢
解决:
bash
# 强制刷新(默认 30 秒更新一次)
tensorboard --logdir=logs/ --reload_interval=5
Q3: 多实验对比时曲线颜色混乱
解决:
python
# 使用明确的实验名称和版本号
logger = TensorBoardLogger(
save_dir='logs/',
name='baseline', # 清晰的实验名
version='lr_1e-3' # 描述性版本号
)
7.2 WandB 常见问题
Q1: wandb: ERROR Error uploading
原因:网络问题或 API Key 失效
解决:
bash
# 重新登录
wandb login --relogin
# 或使用离线模式
export WANDB_MODE=offline
Q2: 训练速度明显变慢
原因:记录频率过高或记录大张量
解决:
python
# 1. 减少记录频率
trainer = pl.Trainer(log_every_n_steps=100)
# 2. 避免记录大张量
self.log('loss', loss.item()) # 而非 loss
# 3. 异步上传
os.environ['WANDB_CONSOLE'] = 'off'
Q3: 如何删除或隐藏失败的实验?
解决:
- 在 WandB 网页界面,进入 Runs → 选择实验 → Settings → Delete
Q4: 离线模式的日志如何上传?
解决:
bash
# 找到离线日志目录
wandb sync wandb/offline-run-xxx
# 批量同步
wandb sync --sync-all
7.3 PyTorch Lightning 集成问题
Q1: self.log() 不记录指标
检查清单:
python
# 1. 确保 logger=True
self.log('metric', value, logger=True)
# 2. 检查 Trainer 是否传入 logger
trainer = pl.Trainer(logger=logger)
# 3. 验证指标类型
assert isinstance(value, (int, float, torch.Tensor))
Q2: 分布式训练时指标重复或错误
解决:
python
# 多 GPU 时必须同步
self.log('val_loss', loss, sync_dist=True, rank_zero_only=True)
8. 扩展阅读与进阶方向
8.1 官方文档
-
PyTorch Lightning 文档 :
-
TensorBoard 文档 :
-
WandB 文档 :
8.2 进阶主题
8.2.1 其他 Logger 选项
PyTorch Lightning 还支持:
- MLflow:适合 MLOps 工作流
- Comet.ml:类似 WandB,提供实验追踪
- Neptune.ai:企业级实验管理
python
from pytorch_lightning.loggers import MLFlowLogger, CometLogger
logger = MLFlowLogger(experiment_name='my_exp', tracking_uri='http://localhost:5000')
8.2.2 自定义 Logger
python
from pytorch_lightning.loggers import Logger
class CustomLogger(Logger):
def log_metrics(self, metrics, step):
# 自定义日志逻辑(如写入数据库)
print(f"Step {step}: {metrics}")
def log_hyperparams(self, params):
print(f"Hyperparams: {params}")
@property
def name(self):
return "CustomLogger"
@property
def version(self):
return "0.1"
8.2.3 与 Hydra 配置管理集成
python
import hydra
from omegaconf import DictConfig
@hydra.main(config_path="configs", config_name="config")
def main(cfg: DictConfig):
logger = WandbLogger(
project=cfg.wandb.project,
name=cfg.wandb.name,
config=dict(cfg) # 自动记录所有配置
)
model = hydra.utils.instantiate(cfg.model)
trainer = pl.Trainer(logger=logger, **cfg.trainer)
trainer.fit(model)
8.3 实战项目推荐
- 图像分类:在 CIFAR-10 上对比不同优化器效果
- NLP 任务:使用 WandB Sweeps 优化 Transformer 超参数
- GAN 训练:用 TensorBoard 可视化生成图像质量演进
- 强化学习:记录 reward 曲线和智能体行为
8.4 相关工具生态
- DVC (Data Version Control):数据集版本管理
- Optuna:超参数优化(可与 WandB 集成)
- Ray Tune:分布式超参数搜索
- Hugging Face Hub:模型与数据集托管
📌 总结
核心要点回顾
- PyTorch Lightning 的 Logger 机制提供了统一接口,降低切换成本
- TensorBoard 适合快速本地调试,WandB 适合团队协作与实验管理
- 使用
self.log()即可无缝集成,无需修改核心代码 - 高级功能(如 Sweeps、Artifacts)可显著提升研究效率
最佳实践建议
- ✅ 实验命名规范 :使用描述性名称(如
resnet50_lr1e-3_bs64) - ✅ 定期清理日志:避免磁盘占用过大
- ✅ 记录关键超参数 :使用
self.save_hyperparameters() - ✅ 结合 Git 使用:WandB 可自动记录 commit hash
- ✅ 团队协作时约定规范:统一 project 名称、指标命名等
附录:快速参考卡片
TensorBoard 快速启动
python
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger('logs/', name='exp')
trainer = pl.Trainer(logger=logger)
# 终端运行:tensorboard --logdir=logs/
WandB 快速启动
python
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(project='my_project', name='exp')
trainer = pl.Trainer(logger=logger)
# 自动上传到 wandb.ai
常用记录模式
python
# 训练 loss
self.log('train_loss', loss, on_step=True, on_epoch=True)
# 验证指标
self.log('val_acc', acc, on_epoch=True, prog_bar=True)
# 学习率
self.log('lr', optimizer.param_groups[0]['lr'], on_step=True)
# 图像(TensorBoard)
self.logger.experiment.add_image('name', img, step)
# 图像(WandB)
self.logger.log_image(key='name', images=[img])