使用 PyTorch 和 SwanLab 实时可视化模型训练

以下是一个使用 PyTorch 和 SwanLab 实现训练可视化监控的完整示例,以 MNIST 手写数字识别为例:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import swanlab

# 初始化 SwanLab 实验 (自动生成仪表盘)
swanlab.init(
    experiment_name="MNIST_CNN",
    description="Simple CNN on MNIST with SwanLab monitoring",
    config={
        "batch_size": 64,
        "epochs": 10,
        "learning_rate": 0.01,
        "model": "CNN"
    }
)

# 1. 数据准备
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=swanlab.config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# 2. 定义 CNN 模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return nn.functional.log_softmax(x, dim=1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=swanlab.config.learning_rate)

# 3. 训练循环
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.functional.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        # 实时记录每个batch的损失
        if batch_idx % 100 == 0:
            swanlab.log({"train_loss": loss.item()}, step=epoch * len(train_loader) + batch_idx)
            
            # 打印日志到控制台
            print(f"Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")

# 4. 测试函数
def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    # 记录epoch级别的指标
    swanlab.log({
        "test_loss": test_loss,
        "accuracy": accuracy,
        "epoch": epoch
    })
    
    print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n")

# 5. 执行训练
for epoch in range(1, swanlab.config.epochs + 1):
    train(epoch)
    test(epoch)

print("训练完成!请在 https://swanlab.cn 查看可视化结果")

关键说明:

  1. SwanLab 初始化

    python 复制代码
    swanlab.init() # 创建实验并设置跟踪参数
  2. 实时日志记录

    python 复制代码
    swanlab.log({"train_loss": loss.item()}) # 记录每个batch的损失
  3. 指标可视化

    python 复制代码
    swanlab.log({"accuracy": accuracy, "test_loss": test_loss}) # 记录测试指标

使用步骤:

  1. 安装依赖:
bash 复制代码
pip install torch torchvision swanlab
  1. 运行脚本:
bash 复制代码
python mnist_example.py
  1. 查看结果:
    • 终端会自动打印监控链接(如:SwanLab Experiment: https://swanlab.cn/[username]/MNIST_CNN/runs/[run_id]
    • 或在 SwanLab 官网 登录查看

仪表盘功能:

  1. 实时监控

    • 训练损失曲线(每100个batch更新)
    • 测试精度/损失曲线(每个epoch更新)
  2. 实验管理

    • 记录所有超参数(batch_size, lr等)
    • 保存实验配置和系统环境
    • 对比多次运行结果
  3. 自动分析

    • 训练过程动态可视化
    • 指标变化趋势分析
    • 性能指标汇总统计

通过这个示例,你可以实时:

  • 监控训练损失下降趋势
  • 观察模型在验证集的性能变化
  • 分析不同超参数对结果的影响
  • 比较多次实验的结果差异

SwanLab 会自动保存所有实验数据,即使训练中断也能恢复可视化结果。

相关推荐
灏瀚星空几秒前
高频交易技术:订单簿分析与低延迟架构——从Level 2数据挖掘到FPGA硬件加速的全链路解决方案
人工智能·python·算法·信息可视化·fpga开发·架构·数据挖掘
kdniao202530 分钟前
快递接口调用选择:快递鸟、快递100、阿里云大对比
人工智能·阿里云·php
Hanson Huang43 分钟前
【Spring AI 1.0.0】Spring AI 1.0.0框架快速入门(2)——Prompt(提示词)
java·人工智能·spring·spring ai
诺亚凹凸曼1 小时前
用AI思维重塑人生:像训练神经网络一样优化自己
人工智能·机器学习
山有木兮木有枝_1 小时前
AI大模型幻觉问题的函数调用解决方案:DeepSeek 实战解析
前端·人工智能·deepseek
tony3651 小时前
强化学习 A2C算法
人工智能·算法
袋鼠云数栈1 小时前
当空间与数据联动,会展中心如何打造智慧运营新范式?
大数据·人工智能·信息可视化
HyperAI超神经1 小时前
在线教程丨刷新TTS模型SOTA,OpenAudio S1基于200万小时音频数据训练,深刻理解情感及语音细节
人工智能·深度学习·机器学习·文本转语音·语音处理·语音生成·在线教程
科技林总2 小时前
逻辑回归:给不确定性划界的分类大师
人工智能
Shining_Jiang2 小时前
打卡第44天:无人机数据集分类
人工智能·分类·数据挖掘