丹摩 | 基于PyTorch的CIFAR-10图像分类实现


从创建实例开始的新项目流程

第一步:创建实例

  1. 登录 DAMODEL 平台。
  2. 创建一个 GPU 实例:
    • GPU 配置 :选择 NVIDIA H800 或其他可用高性能 GPU。

    • 系统配置:推荐使用 Ubuntu 20.04,内存 16GB,硬盘 50GB。

    • 启动实例后,获取实例的 IP 地址。

    • 选择镜像


第二步:连接实例

  1. 登录成功后,你会进入实例的终端界面。


第三步:更新系统和安装基础工具

  1. 更新系统:

    bash 复制代码
    sudo apt update && sudo apt upgrade -y
  2. 安装 Python 和基础工具:

    bash 复制代码
    sudo apt install python3 python3-pip git -y
  3. (可选)安装文本编辑器:

    bash 复制代码
    sudo apt install vim nano -y

第四步:创建项目目录并配置环境

  1. 创建项目目录:

    bash 复制代码
    mkdir ~/workspace/cifar10_project
    cd ~/workspace/cifar10_project
  2. 创建并激活虚拟环境:

    bash 复制代码
    python3 -m venv venv
    source venv/bin/activate

    前面出现venu则表示已经激活虚拟环境了

  3. 安装必要的 Python 包:

    bash 复制代码
    pip install torch torchvision matplotlib

第五步:下载数据并初始化项目代码

  1. 创建 Python 脚本:

    bash 复制代码
    vim train_cifar10.py
  2. 在文件中输入以下代码,加载 CIFAR-10 数据集并定义简单模型:

    python 复制代码
    import torch
    import torchvision
    import torchvision.transforms as transforms
    import torch.nn as nn
    import torch.optim as optim
    
    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # 加载 CIFAR-10 数据集
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    
    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
    
    # 定义简单卷积神经网络
    class SimpleCNN(nn.Module):
        def __init__(self):
            super(SimpleCNN, self).__init__()
            self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.fc1 = nn.Linear(32 * 16 * 16, 10)
    
        def forward(self, x):
            x = self.pool(torch.relu(self.conv1(x)))
            x = x.view(-1, 32 * 16 * 16)
            x = self.fc1(x)
            return x
    
    # 初始化模型、损失函数和优化器
    net = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    
    # 模型训练
    for epoch in range(5):  # 训练 5 个周期
        running_loss = 0.0
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader)}")
    
    print("Finished Training")
  3. 保存并退出(按下 Esc,然后输入 :wq)。


第六步:运行训练脚本

运行脚本进行模型训练:

bash 复制代码
python train_cifar10.py
  • 脚本会下载 CIFAR-10 数据集并训练模型。
  • 训练完成后会输出每个 epoch 的损失值。

第七步:保存和测试模型

  1. 保存模型:在脚本末尾添加代码以保存训练好的模型:

    python 复制代码
    torch.save(net.state_dict(), "cifar10_model.pth")
    print("Model saved as cifar10_model.pth")
  2. 重新运行脚本以保存模型:

    bash 复制代码
    python train_cifar10.py
  3. 检查是否生成了 cifar10_model.pth 文件:

    bash 复制代码
    ls
  4. 测试模型(可选):加载保存的模型并在测试集上评估准确率:

    python 复制代码
    net.load_state_dict(torch.load("cifar10_model.pth"))
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            outputs = net(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f"Accuracy on test dataset: {100 * correct / total}%")

第八步:清理和扩展

  1. 扩展功能

    • 使用更复杂的模型(如 ResNet)。
    • 尝试使用 Adam 优化器提高性能。
    • 可视化训练过程或模型预测结果。
  2. 清理资源

    • 如果完成训练并不再需要 GPU 计算,记得停止或删除实例以节省费用。

\

相关推荐
GIOTTO情3 小时前
媒介宣发的技术革命:Infoseek如何用AI重构企业传播全链路
大数据·人工智能·重构
阿里云大数据AI技术3 小时前
云栖实录 | 从多模态数据到 Physical AI,PAI 助力客户快速启动 Physical AI 实践
人工智能
小关会打代码3 小时前
计算机视觉进阶教学之颜色识别
人工智能·计算机视觉
IT小哥哥呀3 小时前
基于深度学习的数字图像分类实验与分析
人工智能·深度学习·分类
机器之心4 小时前
VAE时代终结?谢赛宁团队「RAE」登场,表征自编码器或成DiT训练新基石
人工智能·openai
机器之心4 小时前
Sutton判定「LLM是死胡同」后,新访谈揭示AI困境
人工智能·openai
大模型真好玩4 小时前
低代码Agent开发框架使用指南(四)—Coze大模型和插件参数配置最佳实践
人工智能·agent·coze
jerryinwuhan4 小时前
基于大语言模型(LLM)的城市时间、空间与情感交织分析:面向智能城市的情感动态预测与空间优化
人工智能·语言模型·自然语言处理
落雪财神意4 小时前
股指10月想法
大数据·人工智能·金融·区块链·期股
中杯可乐多加冰4 小时前
无代码开发实践|基于业务流能力快速开发市场监管系统,实现投诉处理快速响应
人工智能·低代码