使用 SwanLab 进行可视化 MNIST 手写体识别训练

使用 SwanLab 进行可视化 MNIST 手写体识别训练

在线演示demo

本案例主要:

  • 使用pytorch进行CNN(卷积神经网络)的构建、模型训练与评估
  • 使用swanlab跟踪超参数、记录指标和可视化监控整个训练周期

一、相关简介

SwanLab

SwanLab是一款开源、轻量级的AI实验跟踪工具,提供了一个跟踪、比较、和协作实验的平台,旨在加速AI研发团队100倍的研发效率。其提供了友好的API和漂亮的界面,结合了超参数跟踪、指标记录、在线协作、实验链接分享、实时消息通知等功能,让您可以快速跟踪ML实验、可视化过程、分享给同伴。

SwanLab提供了一套云端AI实验跟踪方案,面向训练过程,提供了训练可视化、实验跟踪、超参数记录、日志记录、多人协同等功能,研究者能轻松通过直观的可视化图表找到迭代灵感,并且通过在线链接的分享与基于组织的多人协同训练,打破团队沟通的壁垒。

可视化界面截图:

MNIST

MNIST手写体识别是深度学习最经典的入门任务之一,由 LeCun 等人提出。

该任务基于MNIST数据集,研究者通过构建机器学习模型,来识别10个手写数字(0~9)。

二、环境配置

本案例基于Python>=3.8,请在您的计算机上安装好Python。

环境依赖:

torch
torchvision
swanlab

快速安装命令:

pip install torch torchvision swanlab

MNIST 数据集已经被 torch 自动集成了,所以不需要额外下载,很方便。

三、训练代码

复制以下代码,创建 app.py 并粘贴代码,保存后直接使用 python 或 IDE 运行:python app.py

python 复制代码
import os
import torch
from torch import nn, optim, utils
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.models import ResNet18_Weights
import swanlab

# CNN网络构建
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 1,28x28
        self.conv1 = nn.Conv2d(1, 10, 5)  # 10, 24x24
        self.conv2 = nn.Conv2d(10, 20, 3)  # 128, 10x10
        self.fc1 = nn.Linear(20 * 10 * 10, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        in_size = x.size(0)
        out = self.conv1(x)  # 24
        out = F.relu(out)
        out = F.max_pool2d(out, 2, 2)  # 12
        out = self.conv2(out)  # 10
        out = F.relu(out)
        out = out.view(in_size, -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.log_softmax(out, dim=1)
        return out


# 捕获并可视化前20张图像
def log_images(loader, num_images=16):
    images_logged = 0
    logged_images = []
    for images, labels in loader:
        # images: batch of images, labels: batch of labels
        for i in range(images.shape[0]):
            if images_logged < num_images:
                # 使用swanlab.Image将图像转换为wandb可视化格式
                logged_images.append(swanlab.Image(images[i], caption=f"Label: {labels[i]}"))
                images_logged += 1
            else:
                break
        if images_logged >= num_images:
            break
    swanlab.log({"MNIST-Preview": logged_images})
    

def train(model, device, train_dataloader, optimizer, criterion, epoch, num_epochs):
    model.train()
    # 1. 循环调用train_dataloader,每次取出1个batch_size的图像和标签
    for iter, (inputs, labels) in enumerate(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        # 2. 传入到resnet18模型中得到预测结果
        outputs = model(inputs)
        # 3. 将结果和标签传入损失函数中计算交叉熵损失
        loss = criterion(outputs, labels)
        # 4. 根据损失计算反向传播
        loss.backward()
        # 5. 优化器执行模型参数更新
        optimizer.step()
        print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(train_dataloader),
                                                                      loss.item()))
        # 6. 每20次迭代,用SwanLab记录一下loss的变化
        if iter % 20 == 0:
            swanlab.log({"train/loss": loss.item()})

def test(model, device, val_dataloader, epoch):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        # 1. 循环调用val_dataloader,每次取出1个batch_size的图像和标签
        for inputs, labels in val_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            # 2. 传入到resnet18模型中得到预测结果
            outputs = model(inputs)
            # 3. 获得预测的数字
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            # 4. 计算与标签一致的预测结果的数量
            correct += (predicted == labels).sum().item()
    
        # 5. 得到最终的测试准确率
        accuracy = correct / total
        # 6. 用SwanLab记录一下准确率的变化
        swanlab.log({"val/accuracy": accuracy}, step=epoch)
    

if __name__ == "__main__":

    #检测是否支持mps
    try:
        use_mps = torch.backends.mps.is_available()
    except AttributeError:
        use_mps = False

    #检测是否支持cuda
    if torch.cuda.is_available():
        device = "cuda"
    elif use_mps:
        device = "mps"
    else:
        device = "cpu"

    # 初始化swanlab
    run = swanlab.init(
        project="MNIST-example",
        experiment_name="PlainCNN",
        config={
            "model": "ResNet18",
            "optim": "Adam",
            "lr": 1e-4,
            "batch_size": 256,
            "num_epochs": 10,
            "device": device,
        },
    )

    # 设置MNIST训练集和验证集
    dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
    train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000])

    train_dataloader = utils.data.DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)
    val_dataloader = utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)
    
    # (可选)看一下数据集的前16张图像
    log_images(train_dataloader, 16)

    # 初始化模型
    model = ConvNet()
    model.to(torch.device(device))

    # 打印模型
    print(model)

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=run.config.lr)

    # 开始训练和测试循环
    for epoch in range(1, run.config.num_epochs+1):
        swanlab.log({"train/epoch": epoch}, step=epoch)
        train(model, device, train_dataloader, optimizer, criterion, epoch, run.config.num_epochs)
        if epoch % 2 == 0: 
            test(model, device, val_dataloader, epoch)

    # 保存模型
    # 如果不存在checkpoint文件夹,则自动创建一个
    if not os.path.exists("checkpoint"):
        os.makedirs("checkpoint")
    torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')

四、注意事项

在运行代码的时候,可能会出现如上提示,需要输入一个凭证,这个时候我们只需要去 SwanLab 云端版登录并获取,复制后粘贴到终端,回车后继续运行即可:

当然,有云端版肯定也有本地版。

上面的训练会将训练数据上传到云端,让我们可以直接通过在线链接的方式访问自己的实验数据和实验进度 。但是还可以选择不上传,而通过本地命令在本机开启一个面板服务,其前端界面与云端版基本一致,同样能查看实验数据和详细信息。

相关推荐
秃头佛爷3 分钟前
Python学习大纲总结及注意事项
开发语言·python·学习
深度学习lover1 小时前
<项目代码>YOLOv8 苹果腐烂识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·苹果腐烂识别
API快乐传递者2 小时前
淘宝反爬虫机制的主要手段有哪些?
爬虫·python
阡之尘埃4 小时前
Python数据分析案例61——信贷风控评分卡模型(A卡)(scorecardpy 全面解析)
人工智能·python·机器学习·数据分析·智能风控·信贷风控
丕羽7 小时前
【Pytorch】基本语法
人工智能·pytorch·python
bryant_meng8 小时前
【python】Distribution
开发语言·python·分布函数·常用分布
m0_594526309 小时前
Python批量合并多个PDF
java·python·pdf
工业互联网专业9 小时前
Python毕业设计选题:基于Hadoop的租房数据分析系统的设计与实现
vue.js·hadoop·python·flask·毕业设计·源码·课程设计
钱钱钱端9 小时前
【压力测试】如何确定系统最大并发用户数?
自动化测试·软件测试·python·职场和发展·压力测试·postman
慕卿扬9 小时前
基于python的机器学习(二)—— 使用Scikit-learn库
笔记·python·学习·机器学习·scikit-learn