LeNet训练集详细实现

一、下载训练集

导包

python 复制代码
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

ToTensor()函数:

把图像heigh x width x channels 转换为 channels x height x width

Normalize() 数据标准化函数:

最后一行是标准化数值计算公式

python 复制代码
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

参数解释:

root='./data':数据集下载的路径,我下载到当前目录下的data文件夹,下载完成后会自动创建

train=True:当前为训练集

download=True:下载数据集时设置为True,下载完成后改为False

transform=transform :设置对图像进行预处理的函数

运行下载数据集结果为:

下载完成后生成了data文件夹

二、导入训练集

python 复制代码
# 导入训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36,
                                          shuffle=True, num_workers=0)

参数解释:

trainset:把刚刚下载的数据导入进来

batch_size=36:一批数据的大小

shuffle=True:训练集中的数据是否打乱(一般默认打乱)

num_workers=0:载入数据的现成数,在lunix操作系统下,可以设置为别的参数,在windows操作系统系统下,默认为0.

三、下载测试集

python 复制代码
# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000,
                                         shuffle=False, num_workers=0)
test_data_iter = iter(testloader)
test_image, test_lable = test_data_iter.next()

classes = ('plane', 'car', 'bird', 'cat',   # 数据集中的分类,设置为元组,不可变类
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

参数解释:

test_data_iter = iter(testloader):通过iter()函数把testloader转化成可迭代的迭代器

test_image, test_lable = test_data_iter.next():通过next()方法可以获得测试的图像和图像对应的标签值。

四、查看导入的图片

在中间过程打印图片进行查看,后续会注释掉

python 复制代码
def imshow(img):
    img = img / 2 + 0.5
    nping = img.numpy()
    plt.imshow(np.transpose(nping, (1, 2, 0)))
    plt.show()


# print labels
print(' '.join('%5s' % classes[test_lable[j]] for j in range(4)))
# show images
imshow(torchvision.utils.make_grid(test_image))

运行结果

图片很模糊,因为像素很低。

上面识别出来的结果都对了。

我遇到的问题:

一开始有结果但是没有图片,我以为时matplotlib的问题,我重新安装并且更新了版本,但是我再运行后报错更多了,报错提示我 AttributeError: module 'numpy' has no attribute 'bool',我就知道是numpy的问题了,我重新安装并且更新了版本结果还是不行,我百度了一下,发现不是越新的版本越好,我重新下载了1.23.2这个版本的numpy,下载完成后运行就出来结果了。

python 复制代码
pip install numpy==1.23.2

这个也只是中间过程,后续会注释或者删了。

五、将创建的模型实例化

创建模型请看PyTorch搭建LeNet神经网络-CSDN博客

python 复制代码
for epoch in range(5):

    running_loss = 0.0
    for step, data in enumerate(trainloader, start=0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = loss_fuction(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if step % 500 == 499:
            with torch.no_grad():
                outputs = net(test_image)
                predict_y = torch.max(outputs, dim=1)[1]
                accuracy = (predict_y == test_lable).sum().item() / test_lable.size(0)

                print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
                       (epoch + 1, step + 1, running_loss / 500, accuracy))
                running_loss = 0.0

print('Finished Training')

运行结果

相关推荐
Raink老师4 小时前
【AI面试临阵磨枪-79】实时数据 RAG:订单、商家、物流、天气、动态库存
人工智能·面试·职场和发展
是一个Bug4 小时前
Agent(智能体)应用 的入门学习路径
学习·机器学习
脑极体4 小时前
点亮星河AI+鸿蒙,一座艺术场馆的日神觉醒
人工智能·华为·harmonyos
Cosolar4 小时前
Chroma向量库面试学习指南
数据库·人工智能·面试·职场和发展·数据库架构
BUG指挥官4 小时前
Claude Code的自动化编程
人工智能
意图共鸣5 小时前
意图共鸣科技《认知智能白皮书》——感知与执行分离:认知架构(CA)如何重塑大模型底层结构
人工智能·架构
等一个人的@5 小时前
让数据自己开口:数睿通智库新增智能问数模块
人工智能·自然语言处理
ZGi.ai5 小时前
人工审查节点:让自动化工作流多一步人工把关
运维·人工智能·自动化·人机协同·智能体工作流·人工审查
王莎莎-MinerU5 小时前
MinerU 深度技术解析:从架构原理到生产部署的全面指南
css·人工智能·自然语言处理·架构·ocr·个人开发
盘古信息IMS5 小时前
盘古信息IMS V6 8.0重磅发布:以薪火AI数智平台点燃离散制造数智化引擎
大数据·人工智能·制造