使用 PyTorch 和 SwanLab 实时可视化模型训练

以下是一个使用 PyTorch 和 SwanLab 实现训练可视化监控的完整示例,以 MNIST 手写数字识别为例:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import swanlab

# 初始化 SwanLab 实验 (自动生成仪表盘)
swanlab.init(
    experiment_name="MNIST_CNN",
    description="Simple CNN on MNIST with SwanLab monitoring",
    config={
        "batch_size": 64,
        "epochs": 10,
        "learning_rate": 0.01,
        "model": "CNN"
    }
)

# 1. 数据准备
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=swanlab.config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# 2. 定义 CNN 模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return nn.functional.log_softmax(x, dim=1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=swanlab.config.learning_rate)

# 3. 训练循环
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.functional.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        # 实时记录每个batch的损失
        if batch_idx % 100 == 0:
            swanlab.log({"train_loss": loss.item()}, step=epoch * len(train_loader) + batch_idx)
            
            # 打印日志到控制台
            print(f"Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")

# 4. 测试函数
def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    # 记录epoch级别的指标
    swanlab.log({
        "test_loss": test_loss,
        "accuracy": accuracy,
        "epoch": epoch
    })
    
    print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n")

# 5. 执行训练
for epoch in range(1, swanlab.config.epochs + 1):
    train(epoch)
    test(epoch)

print("训练完成!请在 https://swanlab.cn 查看可视化结果")

关键说明:

  1. SwanLab 初始化

    python 复制代码
    swanlab.init() # 创建实验并设置跟踪参数
  2. 实时日志记录

    python 复制代码
    swanlab.log({"train_loss": loss.item()}) # 记录每个batch的损失
  3. 指标可视化

    python 复制代码
    swanlab.log({"accuracy": accuracy, "test_loss": test_loss}) # 记录测试指标

使用步骤:

  1. 安装依赖:
bash 复制代码
pip install torch torchvision swanlab
  1. 运行脚本:
bash 复制代码
python mnist_example.py
  1. 查看结果:
    • 终端会自动打印监控链接(如:SwanLab Experiment: https://swanlab.cn/[username]/MNIST_CNN/runs/[run_id]
    • 或在 SwanLab 官网 登录查看

仪表盘功能:

  1. 实时监控

    • 训练损失曲线(每100个batch更新)
    • 测试精度/损失曲线(每个epoch更新)
  2. 实验管理

    • 记录所有超参数(batch_size, lr等)
    • 保存实验配置和系统环境
    • 对比多次运行结果
  3. 自动分析

    • 训练过程动态可视化
    • 指标变化趋势分析
    • 性能指标汇总统计

通过这个示例,你可以实时:

  • 监控训练损失下降趋势
  • 观察模型在验证集的性能变化
  • 分析不同超参数对结果的影响
  • 比较多次实验的结果差异

SwanLab 会自动保存所有实验数据,即使训练中断也能恢复可视化结果。

相关推荐
智星云算力2 分钟前
本地GPU与租用GPU混合部署:混合算力架构搭建指南
人工智能·架构·gpu算力·智星云·gpu租用
jinanwuhuaguo3 分钟前
截止到4月8日,OpenClaw 2026年4月更新深度解读剖析:从“能力回归”到“信任内建”的范式跃迁
android·开发语言·人工智能·深度学习·kotlin
xiaozhazha_7 分钟前
效率提升80%:2026年AI CRM与ERP深度集成的架构设计与实现
人工智能
枫叶林FYL8 分钟前
【自然语言处理 NLP】7.2.2 安全性评估与Constitutional AI
人工智能·自然语言处理
AI人工智能+15 分钟前
基于高精度身份证OCR识别、炫彩活体检测及人脸比对技术的人脸核身系统,为通信行业数字化转型提供了坚实的安全底座
人工智能·计算机视觉·人脸识别·ocr·人脸核身
AI人工智能+24 分钟前
一种以深度学习与计算机视觉技术为核心的表格识别系统,实现了结构化、半结构化表格的精准文字提取、布局解析与版面完整还原
深度学习·计算机视觉·ocr·表格识别
小敬爱吃饭25 分钟前
Ragflow Docker部署及问题解决方案(界面为Welcome to nginx,ragflow上传文件失败,Docker中的ragflow-cpu-1一直重启)
人工智能·python·nginx·docker·语言模型·容器·数据挖掘
宸津-代码粉碎机32 分钟前
Spring Boot 4.0虚拟线程实战调优技巧,最大化发挥并发优势
java·人工智能·spring boot·后端·python
老兵发新帖41 分钟前
Hermes:比openclaw更好用的智能体?
人工智能
俊哥V1 小时前
每日 AI 研究简报 · 2026-04-09
人工智能·ai