CIFAR10彩色图片识别

CIFAR10彩色图片识别

这是我参加训练营的第二周

数据处理

好多项目的数据处理部分思路是相同的。

带入库函数

js 复制代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

加上下面的这一句,,matplotlib 绘制的图形会直接嵌入在 Notebook 的输出单元格中显示,而不是弹出独立窗口。

js 复制代码
%matplotlib inline

在我的电脑上pytorch和matplotlib容易冲突,加上下面的三行才能在pytorch环境中运行matplotlib

lua 复制代码
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["OMP_NUM_THREADS"] = "4"          # 限制线程数
os.environ["MKL_NUM_THREADS"] = "4"

下载数据集

ini 复制代码
train_ds = torchvision.datasets.CIFAR10('data', 
                                      train=True, 
                                      transform=torchvision.transforms.ToTensor(), # 将数据类型转化为Tensor
                                      download=True)

test_ds  = torchvision.datasets.CIFAR10('data', 
                                      train=False, 
                                      transform=torchvision.transforms.ToTensor(), # 将数据类型转化为Tensor
                                      download=True)

下载好数据集,得取数据集,先用torch.utils.data.DataLoader取出,以32张图片为一组,取完所有的图片为一轮。训练集要取很多轮,所以要设置shuffle=True,每次取完一轮,顺序不一样。测试集不用取很多轮,所以不用设置。

ini 复制代码
batch_size=32
train_dl=torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl=torch.utils.data.DataLoader(test_ds, batch_size=batch_size)

训练集有图片和标签,分别赋值给imgs,labels

ini 复制代码
imgs,labels=next(iter(train_dl))
imgs.shape

看看训练集长什么样子的

ini 复制代码
imgs,labels=next(iter(train_dl))
imgs.shape

画图

css 复制代码
import numpy as np
plt.figure(figsize=(20,5))
for i,img in enumerate(imgs[:20]):
    #进行轴变换
    npimg=img.numpy().transpose((1,2,0))
    plt.subplot(2,10,i+1)
    plt.imshow(npimg,cmap=plt.cm.binary)

    plt.axis('off')
plt.show()

继承torch的类,创建一个Model的类

ini 复制代码
import torch.nn.functional as F
num_classes=10
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        #特征提取网络
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        self.pool1=nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.pool2=nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)
        self.pool3=nn.MaxPool2d(2)

        #分类网络
        self.fc1 = nn.Linear(512,256)
        self.fc2 = nn.Linear(256,num_classes)

        #前向传播
    def forward(self, x):
        x=self.pool1(self.conv1(x))
        x=self.pool2(self.conv2(x))
        x=self.pool3(self.conv3(x))

        x=F.relu(self.fc1(x))
        x=self.fc2(x)
        return x

采用模型训练

scss 复制代码
from torchinfo import summary
model = Model().to(device)
summary(model)

编写训练的函数

scss 复制代码
def train(dataloader, model, loss_fn, optimizer):
    size=len(dataloader.dataset)
    num_batches=len(dataloader)
    train_loss,train_acc=0,0
    for X,y in dataloader:
        X,y=X.to(device),y.to(device)

        pred=model(X)
        loss=loss_fn(pred,y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()
        train_loss+=loss.item()
    train_acc/=size
    train_loss/=size
    return train_loss,train_acc

测试函数

ini 复制代码
def test(dataloader, model, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    total_loss, correct = 0, 0

    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)

            pred = model(imgs)
            loss = loss_fn(pred, labels)

            total_loss += loss.item()
            correct += (pred.argmax(1) == labels).type(torch.float).sum().item()

    avg_loss = total_loss / num_batches
    accuracy = correct / size
    return avg_loss, accuracy

开始训练,训练5轮

ini 复制代码
# 训练循环
epochs = 5  # 增加训练轮数
train_loss = []
train_acc = []
test_loss = []
test_acc = []

best_acc = 0.0  # 保存最佳模型

for epoch in range(epochs):
    # 训练
    epoch_train_loss, epoch_train_acc = train(train_dl, model, loss_fn, optimizer)

    # 测试
    epoch_test_loss, epoch_test_acc = test(test_dl, model, loss_fn)

    # 更新学习率
    scheduler.step()

    # 记录结果
    train_loss.append(epoch_train_loss)
    train_acc.append(epoch_train_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)

    # 保存最佳模型
    if epoch_test_acc > best_acc:
        best_acc = epoch_test_acc
        torch.save(model.state_dict(), 'best_model.pth')

    # 打印结果
    # if (epoch + 1) % 5 == 0 or epoch == 0:
    template = 'Epoch:{:3d}, LR:{:.4f}, Train Acc:{:.2f}%, Train Loss:{:.4f}, Test Acc:{:.2f}%, Test Loss:{:.4f}'
    current_lr = optimizer.param_groups[0]['lr']
    print(template.format(
        epoch + 1,
        current_lr,
        epoch_train_acc * 100,
        epoch_train_loss,
        epoch_test_acc * 100,
        epoch_test_loss
    ))

print(f'Finished Training. Best Test Accuracy: {best_acc * 100:.2f}%')

画图

scss 复制代码
from datetime import datetime
current_time = datetime.now()

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_acc, label='Train Accuracy')
plt.plot(train_acc, 'ro')
plt.plot(test_acc, label='Test Accuracy')
plt.plot(test_acc,'go')
plt.title('Accuracy')
plt.xlabel(current_time)
plt.grid(True)
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_loss, label='Train Loss')
plt.plot(train_loss, 'ro')
plt.plot(test_loss, label='Test Loss')
plt.plot(test_loss, 'go')
plt.title('Loss')
plt.xlabel(current_time)
plt.grid(True)
plt.legend()
plt.tight_layout()
# plt.savefig('training_curve.png')
plt.show()

运行截图

相关推荐
zkmall3 小时前
企业电商解决方案哪家好?ZKmall模块商城全渠道支持 + 定制化服务更省心
大数据·运维·重构·架构·开源
青阳流月13 小时前
1.vue权衡的艺术
前端·vue.js·开源
小小鱼儿小小林13 小时前
免费一键自动化申请、续期、部署、监控所有 SSL/TLS 证书,ALLinSSL开源免费的 SSL 证书自动化管理平台
开源·自动化·ssl
三花AI14 小时前
阿里开源 OmniAvatar:音频驱动数字人模型
开源·资讯
说私域14 小时前
基于开源AI智能客服、AI智能名片与S2B2C商城小程序的微商服务质量提升路径研究
人工智能·小程序·开源
蚂蚁数据AntData14 小时前
从性能优化赛到社区Committer,走进赵宇捷在Apache Fory的成长之路
大数据·开源·apache·数据库架构
阿里云云原生15 小时前
Spring AI Alibaba 游乐场开放!一站式体验AI 应用开发全流程
开源
NocoBase16 小时前
为什么越来越多 Airtable 用户开始尝试 NocoBase?
低代码·开源·资讯
算家计算16 小时前
4 位量化 + FP8 混合精度:ERNIE-4.5-0.3B-Paddle本地部署,重新定义端侧推理效率
人工智能·开源
于顾而言16 小时前
【开源品鉴】FRP源码阅读
后端·网络协议·开源