丹摩 | 基于PyTorch的CIFAR-10图像分类实现


从创建实例开始的新项目流程

第一步:创建实例

  1. 登录 DAMODEL 平台。
  2. 创建一个 GPU 实例:
    • GPU 配置 :选择 NVIDIA H800 或其他可用高性能 GPU。

    • 系统配置:推荐使用 Ubuntu 20.04,内存 16GB,硬盘 50GB。

    • 启动实例后,获取实例的 IP 地址。

    • 选择镜像


第二步:连接实例

  1. 登录成功后,你会进入实例的终端界面。


第三步:更新系统和安装基础工具

  1. 更新系统:

    bash 复制代码
    sudo apt update && sudo apt upgrade -y
  2. 安装 Python 和基础工具:

    bash 复制代码
    sudo apt install python3 python3-pip git -y
  3. (可选)安装文本编辑器:

    bash 复制代码
    sudo apt install vim nano -y

第四步:创建项目目录并配置环境

  1. 创建项目目录:

    bash 复制代码
    mkdir ~/workspace/cifar10_project
    cd ~/workspace/cifar10_project
  2. 创建并激活虚拟环境:

    bash 复制代码
    python3 -m venv venv
    source venv/bin/activate

    前面出现venu则表示已经激活虚拟环境了

  3. 安装必要的 Python 包:

    bash 复制代码
    pip install torch torchvision matplotlib

第五步:下载数据并初始化项目代码

  1. 创建 Python 脚本:

    bash 复制代码
    vim train_cifar10.py
  2. 在文件中输入以下代码,加载 CIFAR-10 数据集并定义简单模型:

    python 复制代码
    import torch
    import torchvision
    import torchvision.transforms as transforms
    import torch.nn as nn
    import torch.optim as optim
    
    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # 加载 CIFAR-10 数据集
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    
    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
    
    # 定义简单卷积神经网络
    class SimpleCNN(nn.Module):
        def __init__(self):
            super(SimpleCNN, self).__init__()
            self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.fc1 = nn.Linear(32 * 16 * 16, 10)
    
        def forward(self, x):
            x = self.pool(torch.relu(self.conv1(x)))
            x = x.view(-1, 32 * 16 * 16)
            x = self.fc1(x)
            return x
    
    # 初始化模型、损失函数和优化器
    net = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    
    # 模型训练
    for epoch in range(5):  # 训练 5 个周期
        running_loss = 0.0
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader)}")
    
    print("Finished Training")
  3. 保存并退出(按下 Esc,然后输入 :wq)。


第六步:运行训练脚本

运行脚本进行模型训练:

bash 复制代码
python train_cifar10.py
  • 脚本会下载 CIFAR-10 数据集并训练模型。
  • 训练完成后会输出每个 epoch 的损失值。

第七步:保存和测试模型

  1. 保存模型:在脚本末尾添加代码以保存训练好的模型:

    python 复制代码
    torch.save(net.state_dict(), "cifar10_model.pth")
    print("Model saved as cifar10_model.pth")
  2. 重新运行脚本以保存模型:

    bash 复制代码
    python train_cifar10.py
  3. 检查是否生成了 cifar10_model.pth 文件:

    bash 复制代码
    ls
  4. 测试模型(可选):加载保存的模型并在测试集上评估准确率:

    python 复制代码
    net.load_state_dict(torch.load("cifar10_model.pth"))
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            outputs = net(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f"Accuracy on test dataset: {100 * correct / total}%")

第八步:清理和扩展

  1. 扩展功能

    • 使用更复杂的模型(如 ResNet)。
    • 尝试使用 Adam 优化器提高性能。
    • 可视化训练过程或模型预测结果。
  2. 清理资源

    • 如果完成训练并不再需要 GPU 计算,记得停止或删除实例以节省费用。

\

相关推荐
码蜂窝编程官方22 分钟前
【含开题报告+文档+PPT+源码】基于SSM的电影数据挖掘与分析可视化系统设计与实现
java·vue.js·人工智能·后端·spring·数据挖掘·maven
遗落凡尘的萤火-生信小白24 分钟前
转录组数据挖掘(生物技能树)(第11节)下游分析
人工智能·数据挖掘
XinZong33 分钟前
【OpenAI】获取OpenAI API Key的多种方式全攻略:从入门到精通,再到详解教程!
人工智能
没有余地 EliasJie35 分钟前
深度学习图像视觉 RKNN Toolkit2 部署 RK3588S边缘端 过程全记录
人工智能·嵌入式硬件·深度学习
HelpLook HelpLook1 小时前
高新技术行业中的知识管理:关键性、挑战、策略及工具应用
人工智能·科技·aigc·客服·知识库搭建
青松@FasterAI2 小时前
【RAG 项目实战 05】重构:封装代码
人工智能·深度学习·自然语言处理·nlp
chnyi6_ya2 小时前
论文笔记:Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
论文阅读·人工智能·自然语言处理
&黄昏的乐师2 小时前
Opencv+ROS实现摄像头读取处理画面信息
linux·人工智能·opencv·计算机视觉·ros
默凉2 小时前
opencv-python 分离边缘粘连的物体(距离变换)
人工智能·python·opencv
xiandong202 小时前
241123_基于MindSpore学习Bert
人工智能·学习·bert