LeNet-5(fashion-mnist)

文章目录

前言

LeNet是最早发布的卷积神经网络之一。该模型被提出用于识别图像中的手写数字。

LeNet

LeNet-5由以下两个部分组成

  • 卷积编码器(2)
  • 全连接层(3)
    卷积块由一个卷积层、一个sigmoid激活函数和一个平均汇聚层组成。
    第一个卷积层有6个输出通道,第二个卷积层有16个输出通道。采用2×2的汇聚操作,且步幅为2.
    3个全连接层分别有120,84,10个输出。
    此处对原始模型做出部分修改,去除最后一层的高斯激活。
python 复制代码
net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
                  nn.AvgPool2d(kernel_size=2,stride=2),
                  nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
                  nn.AvgPool2d(kernel_size=2,stride=2),
                  nn.Flatten(),
                  nn.Linear(16*5*5,120),nn.Sigmoid(),
                  nn.Linear(120,84),nn.Sigmoid(),
                  nn.Linear(84,10))

模型训练

为了加快训练,使用GPU计算测试集上的精度以及训练过程中的计算。

此处采用xavier初始化模型参数以及交叉熵损失函数和小批量梯度下降。

python 复制代码
batch_size=256
train_iter,test_iter=data_iter.load_data_fashion_mnist(batch_size)

将数据送入GPU进行计算测试集准确率

python 复制代码
def evaluate_accuracy_gpu(net,data_iter,device=None):
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net,torch.nn.Module):
        net.eval()
        if not device:
            device=next(iter(net.parameters())).device
    # 正确预测的数量,预测的总数
    eva = 0.0
    y_num = 0.0
    with torch.no_grad():
        for X,y in data_iter:
            if isinstance(X,list):
                X=[x.to(device) for x in X]
            else:
                X=X.to(device)
            y=y.to(device)
            eva += accuracy(net(X), y)
            y_num += y.numel()
    return eva/y_num

训练过程同样将数据送入GPU计算

python 复制代码
def train_epoch_gpu(net, train_iter, loss, updater,device):

    # 训练损失之和,训练准确数之和,样本数
    train_loss_sum = 0.0
    train_acc_sum = 0.0
    num_samples = 0.0
    # timer = d2l.torch.Timer()
    for i, (X, y) in enumerate(train_iter):
        # timer.start()
        updater.zero_grad()
        X, y = X.to(device), y.to(device)
        y_hat = net(X)
        l = loss(y_hat, y)
        l.backward()
        updater.step()
        with torch.no_grad():
            train_loss_sum += l * X.shape[0]
            train_acc_sum += evaluation.accuracy(y_hat, y)
            num_samples += X.shape[0]
        # timer.stop()
    return train_loss_sum/num_samples,train_acc_sum/num_samples


def train_gpu(net,train_iter,test_iter,num_epochs,lr,device):
    def init_weights(m):
        if type(m)==torch.nn.Linear or type(m)==torch.nn.Conv2d:
            torch.nn.init.xavier_uniform_(m.weight)

    net.apply(init_weights)
    net.to(device)
    print('training on',device)
    optimizer=torch.optim.SGD(net.parameters(),lr=lr)
    loss=torch.nn.CrossEntropyLoss()
    # num_batches=len(train_iter)
    tr_l=[]
    tr_a=[]
    te_a=[]
    for epoch in range(num_epochs):
        net.train()
        train_metric=train_epoch_gpu(net,train_iter,loss,optimizer,device)
        test_accuracy = evaluation.evaluate_accuracy_gpu(net, test_iter)
        train_loss, train_acc = train_metric
        train_loss = train_loss.cpu().detach().numpy()
        tr_l.append(train_loss)
        tr_a.append(train_acc)
        te_a.append(test_accuracy)
        print(f'epoch: {epoch + 1}, train_loss: {train_loss}, train_acc: {train_acc}, test_acc:{test_accuracy}')
    x = torch.arange(num_epochs)
    plt.plot((x + 1), tr_l, '-', label='train_loss')
    plt.plot(x + 1, tr_a, '--', label='train_acc')
    plt.plot(x + 1, te_a, '-.', label='test_acc')
    plt.legend()
    plt.show()
    print(f'on {str(device)}')
python 复制代码
lr,num_epochs=0.9,10
Train.train_gpu(net,train_iter,test_iter,num_epochs,lr,device='cuda')


相关推荐
葫三生1 小时前
如何评价《论三生原理》在科技界的地位?
人工智能·算法·机器学习·数学建模·量子计算
m0_751336392 小时前
突破性进展:超短等离子体脉冲实现单电子量子干涉,为飞行量子比特奠定基础
人工智能·深度学习·量子计算·材料科学·光子器件·光子学·无线电电子
有Li5 小时前
通过具有一致性嵌入的大语言模型实现端到端乳腺癌放射治疗计划制定|文献速递-最新论文分享
论文阅读·深度学习·分类·医学生
IT古董5 小时前
【第二章:机器学习与神经网络概述】03.类算法理论与实践-(3)决策树分类器
神经网络·算法·机器学习
鱼摆摆拜拜9 小时前
第 3 章:神经网络如何学习
人工智能·神经网络·学习
张较瘦_9 小时前
[论文阅读] 人工智能 | 深度学习系统崩溃恢复新方案:DaiFu框架的原位修复技术
论文阅读·人工智能·深度学习
cver1239 小时前
野生动物检测数据集介绍-5,138张图片 野生动物保护监测 智能狩猎相机系统 生态研究与调查
人工智能·pytorch·深度学习·目标检测·计算机视觉·目标跟踪
学技术的大胜嗷9 小时前
离线迁移 Conda 环境到 Windows 服务器:用 conda-pack 摆脱硬路径限制
人工智能·深度学习·yolo·目标检测·机器学习
还有糕手9 小时前
西南交通大学【机器学习实验10】
人工智能·机器学习
Akttt12 小时前
【T2I】R&B: REGION AND BOUNDARY AWARE ZERO-SHOT GROUNDED TEXT-TO-IMAGE GENERATION
人工智能·深度学习·计算机视觉·text2img