DAY45 TensorBoard深度学习可视化工具

一、TensorBoard 简介与核心原理

1.1 什么是 TensorBoard

TensorBoard 是 TensorFlow 生态系统中的官方可视化工具,目前已完美兼容 PyTorch。它就像给深度学习模型装上了一个"监控仪表盘",将枯燥的训练日志(如 Loss 值、准确率、权重分布等)转化为直观的图表和图像。

它的核心价值在于:

  1. 实时监控:在训练过程中实时查看 Loss 曲线,快速判断模型是否收敛或过拟合。
  2. 可视化调试:展示模型计算图(Graph),检查网络结构是否符合预期。
  3. 数据分析:查看输入图像、数据增强效果以及预测错误的样本。
  4. 参数诊断:通过直方图监控权重和梯度的分布,诊断梯度消失或梯度爆炸问题。

1.2 核心工作原理

TensorBoard 的工作流程主要分为两步:

  1. 数据写入(Logging) :在 Python 代码中,使用 SummaryWriter 将训练过程中的各类数据(标量、图像、直方图等)写入到特定的日志文件(.tfevents 文件)中。
  2. 前端展示(Visualization):启动 TensorBoard 本地网页服务,该服务自动读取日志文件,并在浏览器中渲染出可视化界面。

二、TensorBoard 基础操作与代码实现

2.1 安装与启动

首先确保已安装 TensorBoard:

复制代码
pip install tensorboard

在训练脚本所在目录下,通过终端启动 TensorBoard 服务:

复制代码
tensorboard --logdir=runs
  • --logdir 参数指定日志文件的根目录(通常命名为 runs)。
  • 启动后,访问终端提示的 URL(通常是 http://localhost:6006)即可查看。

2.2 初始化 SummaryWriter 与日志目录管理

为了避免不同实验的日志相互覆盖,建议实现自动递增的日志目录命名策略。

复制代码
import os
from torch.utils.tensorboard import SummaryWriter

# 自动管理日志目录,防止覆盖
log_dir = 'runs/cifar10_cnn_experiment'
if os.path.exists(log_dir):
    version = 1
    while os.path.exists(f"{log_dir}_v{version}"):
        version += 1
    log_dir = f"{log_dir}_v{version}"

# 初始化写入器
writer = SummaryWriter(log_dir)
print(f"日志将保存在: {log_dir}")

2.3 核心功能详解

1. 记录标量数据 (add_scalar)

最常用的功能,用于记录损失值(Loss)、准确率(Accuracy)、学习率(Learning Rate)等随迭代次数变化的数值。

复制代码
# 记录每个 Batch 的损失和准确率
# global_step 通常是当前的迭代次数(batch_idx + epoch * len(train_loader))
writer.add_scalar('Train/Batch_Loss', loss.item(), global_step)
writer.add_scalar('Train/Batch_Accuracy', accuracy, global_step)

# 记录每个 Epoch 的汇总指标
writer.add_scalar('Train/Epoch_Loss', epoch_loss, epoch)
writer.add_scalar('Test/Accuracy', test_accuracy, epoch)
  • 界面位置:SCALARS 选项卡。
2. 可视化模型结构 (add_graph)

将模型的计算图结构可视化,帮助检查层与层之间的连接和张量形状。

复制代码
# 需要提供模型实例和一个样例输入
dataiter = iter(train_loader)
images, labels = next(dataiter)
images = images.to(device)

writer.add_graph(model, images)
  • 界面位置:GRAPHS 选项卡。双击节点可展开查看详细结构。
3. 可视化图像 (add_image)

用于查看输入数据(检查数据增强是否正确)或模型预测结果(查看错误样本)。通常配合 torchvision.utils.make_grid 使用,将多张图拼接为一张网格图。

复制代码
import torchvision

# 可视化原始训练图像(检查数据增强效果)
img_grid = torchvision.utils.make_grid(images[:8].cpu()) # 取前8张拼接
writer.add_image('Input_Images', img_grid, global_step=0)

# 可视化错误预测样本(在测试阶段)
if wrong_images:
    wrong_grid = torchvision.utils.make_grid(wrong_images[:8])
    writer.add_image('Error_Analysis', wrong_grid, epoch)
  • 界面位置:IMAGES 选项卡。
4. 记录参数直方图 (add_histogram)

监控模型参数(Weights)和梯度(Gradients)的数值分布。这是诊断训练停滞、梯度消失/爆炸的重要工具。

复制代码
# 通常不需要每个 batch 都记录,可以在每个 Epoch 结束或每隔几百个 step 记录一次
for name, param in model.named_parameters():
    writer.add_histogram(f'Weights/{name}', param, global_step)
    if param.grad is not None:
        writer.add_histogram(f'Gradients/{name}', param.grad, global_step)
  • 界面位置:HISTOGRAMS 选项卡。

三、实战:集成 TensorBoard 到 CNN 训练流程

以下代码展示了如何将 TensorBoard 完整集成到一个标准的 PyTorch 训练循环中。

3.1 完整训练函数示例

复制代码
def train_with_tensorboard(model, train_loader, test_loader, criterion, optimizer, device, epochs, writer):
    model.train()
    global_step = 0  # 全局步数计数器
    
    # 1. 记录模型结构(仅需一次)
    dataiter = iter(train_loader)
    images, _ = next(dataiter)
    writer.add_graph(model, images.to(device))
    
    # 2. 记录样例输入图像
    img_grid = torchvision.utils.make_grid(images[:8])
    writer.add_image('Train/Input_Samples', img_grid, 0)
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            # 统计指标
            running_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
            # 3. 记录 Batch 级标量数据
            writer.add_scalar('Train/Batch_Loss', loss.item(), global_step)
            writer.add_scalar('Train/Batch_Accuracy', 100. * correct / total, global_step)
            writer.add_scalar('Train/Learning_Rate', optimizer.param_groups[0]['lr'], global_step)
            
            # 4. 记录参数直方图(每 200 个 batch)
            if (batch_idx + 1) % 200 == 0:
                for name, param in model.named_parameters():
                    writer.add_histogram(f'Weights/{name}', param, global_step)
                    if param.grad is not None:
                        writer.add_histogram(f'Gradients/{name}', param.grad, global_step)
            
            global_step += 1
            
        # 5. 记录 Epoch 级训练指标
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100. * correct / total
        writer.add_scalar('Train/Epoch_Loss', epoch_loss, epoch)
        writer.add_scalar('Train/Epoch_Accuracy', epoch_acc, epoch)
        
        # 6. 测试集评估与错误样本分析
        test_acc, wrong_images, wrong_labels, wrong_preds = evaluate(model, test_loader, device)
        writer.add_scalar('Test/Accuracy', test_acc, epoch)
        
        # 7. 可视化错误样本
        if wrong_images:
            # 限制显示数量
            display_count = min(8, len(wrong_images))
            wrong_grid = torchvision.utils.make_grid(wrong_images[:display_count])
            writer.add_image('Test/Error_Cases', wrong_grid, epoch)
            
            # 添加文本标签说明
            # 假设 classes 是类别名称列表
            # text_info = [f"True: {classes[l]} Pred: {classes[p]}" for l, p in zip(wrong_labels[:display_count], wrong_preds[:display_count])]
            # writer.add_text('Test/Error_Labels', "  \n".join(text_info), epoch)

        print(f"Epoch {epoch+1} | Train Acc: {epoch_acc:.2f}% | Test Acc: {test_acc:.2f}%")
        
    # 训练结束关闭 writer
    writer.close()

3.2 辅助评估函数

为了保持主循环整洁,建议将评估逻辑封装:

复制代码
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    wrong_imgs = []
    wrong_lbls = []
    wrong_prds = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = output.max(1)
            
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
            # 收集错误样本
            wrong_mask = predicted != target
            if wrong_mask.sum() > 0:
                wrong_imgs.extend(data[wrong_mask].cpu())
                wrong_lbls.extend(target[wrong_mask].cpu())
                wrong_prds.extend(predicted[wrong_mask].cpu())
                
    model.train() # 恢复训练模式
    return 100. * correct / total, wrong_imgs, wrong_lbls, wrong_prds
相关推荐
小毅&Nora3 小时前
【AI微服务】【Spring AI Alibaba】 ③ Spring AI Alibaba Agent 核心执行流程源码解析
人工智能·微服务·spring ai
轻竹办公PPT3 小时前
PPT生成效率提升的方法:AI生成PPT实战说明
人工智能·python·powerpoint
YJlio3 小时前
Python 一键拆分 PDF:按“目录/章节”建文件夹 + 每页单独导出(支持书签识别&正文识别)
开发语言·python·pdf
Das13 小时前
【计算机视觉】04_角点
人工智能·计算机视觉
Amelia1111113 小时前
day30
python
SEO_juper3 小时前
零基础快速上手:亚马逊CodeWhisperer实战入门指南
人工智能·机器学习·工具·亚马逊·codewhisperer
RanceGru3 小时前
LLM学习笔记7——unsloth微调Qwen3-4B模型与vllm部署测试
人工智能·笔记·学习·语言模型·vllm
如意鼠3 小时前
大模型教我成为大模型算法工程师之day20: 预训练语言模型 (Pre-trained Language Models)
人工智能·算法·语言模型
囊中之锥.3 小时前
机器学习第二部分----逻辑回归
人工智能·机器学习·逻辑回归