丹摩 | 基于PyTorch的CIFAR-10图像分类实现


从创建实例开始的新项目流程

第一步:创建实例

  1. 登录 DAMODEL 平台。
  2. 创建一个 GPU 实例:
    • GPU 配置 :选择 NVIDIA H800 或其他可用高性能 GPU。

    • 系统配置:推荐使用 Ubuntu 20.04,内存 16GB,硬盘 50GB。

    • 启动实例后,获取实例的 IP 地址。

    • 选择镜像


第二步:连接实例

  1. 登录成功后,你会进入实例的终端界面。


第三步:更新系统和安装基础工具

  1. 更新系统:

    bash 复制代码
    sudo apt update && sudo apt upgrade -y
  2. 安装 Python 和基础工具:

    bash 复制代码
    sudo apt install python3 python3-pip git -y
  3. (可选)安装文本编辑器:

    bash 复制代码
    sudo apt install vim nano -y

第四步:创建项目目录并配置环境

  1. 创建项目目录:

    bash 复制代码
    mkdir ~/workspace/cifar10_project
    cd ~/workspace/cifar10_project
  2. 创建并激活虚拟环境:

    bash 复制代码
    python3 -m venv venv
    source venv/bin/activate

    前面出现venu则表示已经激活虚拟环境了

  3. 安装必要的 Python 包:

    bash 复制代码
    pip install torch torchvision matplotlib

第五步:下载数据并初始化项目代码

  1. 创建 Python 脚本:

    bash 复制代码
    vim train_cifar10.py
  2. 在文件中输入以下代码,加载 CIFAR-10 数据集并定义简单模型:

    python 复制代码
    import torch
    import torchvision
    import torchvision.transforms as transforms
    import torch.nn as nn
    import torch.optim as optim
    
    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # 加载 CIFAR-10 数据集
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    
    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
    
    # 定义简单卷积神经网络
    class SimpleCNN(nn.Module):
        def __init__(self):
            super(SimpleCNN, self).__init__()
            self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.fc1 = nn.Linear(32 * 16 * 16, 10)
    
        def forward(self, x):
            x = self.pool(torch.relu(self.conv1(x)))
            x = x.view(-1, 32 * 16 * 16)
            x = self.fc1(x)
            return x
    
    # 初始化模型、损失函数和优化器
    net = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    
    # 模型训练
    for epoch in range(5):  # 训练 5 个周期
        running_loss = 0.0
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader)}")
    
    print("Finished Training")
  3. 保存并退出(按下 Esc,然后输入 :wq)。


第六步:运行训练脚本

运行脚本进行模型训练:

bash 复制代码
python train_cifar10.py
  • 脚本会下载 CIFAR-10 数据集并训练模型。
  • 训练完成后会输出每个 epoch 的损失值。

第七步:保存和测试模型

  1. 保存模型:在脚本末尾添加代码以保存训练好的模型:

    python 复制代码
    torch.save(net.state_dict(), "cifar10_model.pth")
    print("Model saved as cifar10_model.pth")
  2. 重新运行脚本以保存模型:

    bash 复制代码
    python train_cifar10.py
  3. 检查是否生成了 cifar10_model.pth 文件:

    bash 复制代码
    ls
  4. 测试模型(可选):加载保存的模型并在测试集上评估准确率:

    python 复制代码
    net.load_state_dict(torch.load("cifar10_model.pth"))
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            outputs = net(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f"Accuracy on test dataset: {100 * correct / total}%")

第八步:清理和扩展

  1. 扩展功能

    • 使用更复杂的模型(如 ResNet)。
    • 尝试使用 Adam 优化器提高性能。
    • 可视化训练过程或模型预测结果。
  2. 清理资源

    • 如果完成训练并不再需要 GPU 计算,记得停止或删除实例以节省费用。

\

相关推荐
浠寒AI2 小时前
智能体模式篇(上)- 深入 ReAct:LangGraph构建能自主思考与行动的 AI
人工智能·python
weixin_505154462 小时前
数字孪生在建设智慧城市中可以起到哪些作用或帮助?
大数据·人工智能·智慧城市·数字孪生·数据可视化
Best_Me072 小时前
深度学习模块缝合
人工智能·深度学习
YuTaoShao2 小时前
【论文阅读】YOLOv8在单目下视多车目标检测中的应用
人工智能·yolo·目标检测
算家计算3 小时前
字节开源代码模型——Seed-Coder 本地部署教程,模型自驱动数据筛选,让每行代码都精准落位!
人工智能·开源
伪_装3 小时前
大语言模型(LLM)面试问题集
人工智能·语言模型·自然语言处理
gs801403 小时前
Tavily 技术详解:为大模型提供实时搜索增强的利器
人工智能·rag
music&movie3 小时前
算法工程师认知水平要求总结
人工智能·算法
量子位4 小时前
苹果炮轰推理模型全是假思考!4 个游戏戳破神话,o3/DeepSeek 高难度全崩溃
人工智能·deepseek
黑鹿0224 小时前
机器学习基础(四) 决策树
人工智能·决策树·机器学习