深度学习练手小例子——cifar10数据集分类问题

CIFAR-10 是一个经典的计算机视觉数据集,广泛用于图像分类任务。它包含 10 个类别的 60,000 张彩色图像,每张图像的大小是 32x32 像素。数据集被分为 50,000 张训练图像和 10,000 张测试图像。每个类别包含 6,000 张图像,具体类别包括:

  • 飞机 (airplane)
  • 汽车 (automobile)
  • 鸟 (bird)
  • 猫 (cat)
  • 鹿 (deer)
  • 狗 (dog)
  • 青蛙 (frog)
  • 马 (horse)
  • 船 (ship)
  • 卡车 (truck)

CIFAR-10 是一个多类分类问题,目标是根据图像内容(例如,物体的形状、颜色等特征)预测图像所属的类别。图像分类模型(如卷积神经网络 CNN)常用于这个任务,通过学习图像的空间特征来做出预测。

来看看实现过程:

python 复制代码
import torch
import torchvision.datasets
from torch.utils.data import DataLoader
from torch import nn

train_data = torchvision.datasets.CIFAR10(root="../input/cifar10-python",train=True,transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="../input/cifar10-python",train=False,transform=torchvision.transforms.ToTensor(),
                                          download=True)
print(f"train length: {len(train_data)}")
print(f"test length: {len(test_data)}")
复制代码
Files already downloaded and verified
Files already downloaded and verified
train length: 50000
test length: 10000

找到了CIFAR10数据集并且导入进来,用了三个卷积层的网络模型来训练,进行了10轮训练。

python 复制代码
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
    def forward(self,x):
        x = self.model(x)
        return x
mynet = CNN()
mynet = mynet.cuda()

loss_func = nn.CrossEntropyLoss().cuda()
learning_rate = 0.0001
optimizer = torch.optim.Adam(mynet.parameters(),lr=learning_rate)
total_train = 0
total_test = 0
epoch = 10

for i in range(epoch):
    print(f"----No.{i+1} training...-----")
    mynet.train()
    for data in train_dataloader:
        imgs, targets = data
        imgs = imgs.cuda()
        targets = targets.cuda()
        outputs = mynet(imgs)
        loss = loss_func(outputs,targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train = total_train + 1
        if total_train % 100 == 0:
            print(f"训练次数:{total_train},loss:{loss.item()}")
    #测试
    mynet.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            imgs = imgs.cuda()
            targets = targets.cuda()
            outputs = mynet(imgs)
            loss = loss_func(outputs, targets)
            total_test_loss = total_test_loss + loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy
    print(f"测试集的loss:{total_test_loss},准确率:{total_accuracy/len(test_data)}")
    torch.save(mynet, f'myCNN_{i+1}p.pth')
    print("模型保存成功")
复制代码
----No.1 training...-----
训练次数:100,loss:2.0156445503234863
训练次数:200,loss:1.999146580696106
训练次数:300,loss:1.860052466392517
训练次数:400,loss:1.7510318756103516
训练次数:500,loss:1.7712416648864746
训练次数:600,loss:1.6994789838790894
训练次数:700,loss:1.7278780937194824
测试集的loss:257.74497163295746,准确率:0.41990000009536743
模型保存成功
----No.2 training...-----
训练次数:800,loss:1.515326976776123
训练次数:900,loss:1.485555648803711
训练次数:1000,loss:1.6138449907302856
训练次数:1100,loss:1.7650551795959473
训练次数:1200,loss:1.4380264282226562
训练次数:1300,loss:1.3843588829040527
训练次数:1400,loss:1.5849156379699707
训练次数:1500,loss:1.5038520097732544
测试集的loss:236.6359145641327,准确率:0.47110000252723694
模型保存成功
----No.3 training...-----
训练次数:1600,loss:1.4474828243255615
训练次数:1700,loss:1.4474865198135376
训练次数:1800,loss:1.7310973405838013
训练次数:1900,loss:1.5719612836837769
训练次数:2000,loss:1.6212022304534912
训练次数:2100,loss:1.2924069166183472
训练次数:2200,loss:1.256321907043457
训练次数:2300,loss:1.560215711593628
测试集的loss:221.27214550971985,准确率:0.5011000037193298
模型保存成功
----No.4 training...-----
训练次数:2400,loss:1.4557472467422485
训练次数:2500,loss:1.2620049715042114
训练次数:2600,loss:1.4703019857406616
训练次数:2700,loss:1.4131494760513306
训练次数:2800,loss:1.303225040435791
训练次数:2900,loss:1.4961038827896118
训练次数:3000,loss:1.2810102701187134
训练次数:3100,loss:1.337519645690918
测试集的loss:210.63251876831055,准确率:0.5252999663352966
模型保存成功
----No.5 training...-----
训练次数:3200,loss:1.1311390399932861
训练次数:3300,loss:1.2354803085327148
训练次数:3400,loss:1.2415772676467896
训练次数:3500,loss:1.4213279485702515
训练次数:3600,loss:1.4151396751403809
训练次数:3700,loss:1.2579320669174194
训练次数:3800,loss:1.201486349105835
训练次数:3900,loss:1.287066102027893
测试集的loss:202.65885722637177,准确率:0.5475999712944031
模型保存成功
----No.6 training...-----
训练次数:4000,loss:1.2759090662002563
训练次数:4100,loss:1.3534283638000488
训练次数:4200,loss:1.4388338327407837
训练次数:4300,loss:1.1126259565353394
训练次数:4400,loss:1.072700023651123
训练次数:4500,loss:1.2942607402801514
训练次数:4600,loss:1.3078550100326538
测试集的loss:195.93554836511612,准确率:0.5615000128746033
模型保存成功
----No.7 training...-----
训练次数:4700,loss:1.3510404825210571
训练次数:4800,loss:1.3887534141540527
训练次数:4900,loss:1.2628172636032104
训练次数:5000,loss:1.3063734769821167
训练次数:5100,loss:0.9366315007209778
训练次数:5200,loss:1.208983063697815
训练次数:5300,loss:1.0933520793914795
训练次数:5400,loss:1.2654058933258057
测试集的loss:190.015959918499,准确率:0.5735999941825867
模型保存成功
----No.8 training...-----
训练次数:5500,loss:1.1543941497802734
训练次数:5600,loss:1.0732381343841553
训练次数:5700,loss:1.179479718208313
训练次数:5800,loss:1.0669857263565063
训练次数:5900,loss:1.3145105838775635
训练次数:6000,loss:1.4563915729522705
训练次数:6100,loss:1.0026252269744873
训练次数:6200,loss:0.9769096374511719
测试集的loss:184.76930475234985,准确率:0.5831999778747559
模型保存成功
----No.9 training...-----
训练次数:6300,loss:1.2531676292419434
训练次数:6400,loss:1.0582406520843506
训练次数:6500,loss:1.467718482017517
训练次数:6600,loss:0.9885475635528564
训练次数:6700,loss:0.9887412190437317
训练次数:6800,loss:1.1251451969146729
训练次数:6900,loss:1.0831143856048584
训练次数:7000,loss:0.8735517263412476
测试集的loss:180.18007707595825,准确率:0.5949000120162964
模型保存成功
----No.10 training...-----
训练次数:7100,loss:1.1680148839950562
训练次数:7200,loss:0.9758849740028381
训练次数:7300,loss:1.1076891422271729
训练次数:7400,loss:0.8192071914672852
训练次数:7500,loss:1.2766807079315186
训练次数:7600,loss:1.2046217918395996
训练次数:7700,loss:0.8206453323364258
训练次数:7800,loss:1.1484739780426025
测试集的loss:176.2480058670044,准确率:0.6036999821662903
模型保存成功

拿网上下载的几张图片测试一下,注意路径

python 复制代码
import torch
import torchvision
from PIL import Image
from torch import nn

# 10分类,分别为airplane'= 0 'automobile'= 1 'bird'= 2'cat'= 3 'deer'=  4 'dog'=  5 'frog'= 6 'horse'= 7 'ship'= 8 'truck'= 9
image_path = "/kaggle/input/testdata/bird.jpg"
image = Image.open(image_path)
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
                                            torchvision.transforms.ToTensor()])
image = transform(image)
image = torch.reshape(image,(1,3,32,32))

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
    def forward(self,x):
        x = self.model(x)
        return x

model = torch.load("/kaggle/working/myCNN_10p.pth",map_location=torch.device('cpu'))
model.eval()
with torch.no_grad():
    output = model(image)
print(output.argmax(1))
复制代码
tensor([2])
相关推荐
风铃喵游35 分钟前
让大模型调用MCP服务变得超级简单
前端·人工智能
旷世奇才李先生38 分钟前
Pillow 安装使用教程
深度学习·microsoft·pillow
booooooty1 小时前
基于Spring AI Alibaba的多智能体RAG应用
java·人工智能·spring·多智能体·rag·spring ai·ai alibaba
PyAIExplorer1 小时前
基于 OpenCV 的图像 ROI 切割实现
人工智能·opencv·计算机视觉
风口猪炒股指标1 小时前
技术分析、超短线打板模式与情绪周期理论,在市场共识的形成、分歧、瓦解过程中缘起性空的理解
人工智能·博弈论·群体博弈·人生哲学·自我引导觉醒
ai_xiaogui2 小时前
一键部署AI工具!用AIStarter快速安装ComfyUI与Stable Diffusion
人工智能·stable diffusion·部署ai工具·ai应用市场教程·sd快速部署·comfyui一键安装
聚客AI3 小时前
Embedding进化论:从Word2Vec到OpenAI三代模型技术跃迁
人工智能·llm·掘金·日新计划
weixin_387545643 小时前
深入解析 AI Gateway:新一代智能流量控制中枢
人工智能·gateway
聽雨2373 小时前
03每日简报20250705
人工智能·社交电子·娱乐·传媒·媒体
二川bro4 小时前
飞算智造JavaAI:智能编程革命——AI重构Java开发新范式
java·人工智能·重构