LeNet-5(fashion-mnist)

文章目录

前言

LeNet是最早发布的卷积神经网络之一。该模型被提出用于识别图像中的手写数字。

LeNet

LeNet-5由以下两个部分组成

  • 卷积编码器(2)
  • 全连接层(3)
    卷积块由一个卷积层、一个sigmoid激活函数和一个平均汇聚层组成。
    第一个卷积层有6个输出通道,第二个卷积层有16个输出通道。采用2×2的汇聚操作,且步幅为2.
    3个全连接层分别有120,84,10个输出。
    此处对原始模型做出部分修改,去除最后一层的高斯激活。
python 复制代码
net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
                  nn.AvgPool2d(kernel_size=2,stride=2),
                  nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
                  nn.AvgPool2d(kernel_size=2,stride=2),
                  nn.Flatten(),
                  nn.Linear(16*5*5,120),nn.Sigmoid(),
                  nn.Linear(120,84),nn.Sigmoid(),
                  nn.Linear(84,10))

模型训练

为了加快训练,使用GPU计算测试集上的精度以及训练过程中的计算。

此处采用xavier初始化模型参数以及交叉熵损失函数和小批量梯度下降。

python 复制代码
batch_size=256
train_iter,test_iter=data_iter.load_data_fashion_mnist(batch_size)

将数据送入GPU进行计算测试集准确率

python 复制代码
def evaluate_accuracy_gpu(net,data_iter,device=None):
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net,torch.nn.Module):
        net.eval()
        if not device:
            device=next(iter(net.parameters())).device
    # 正确预测的数量,预测的总数
    eva = 0.0
    y_num = 0.0
    with torch.no_grad():
        for X,y in data_iter:
            if isinstance(X,list):
                X=[x.to(device) for x in X]
            else:
                X=X.to(device)
            y=y.to(device)
            eva += accuracy(net(X), y)
            y_num += y.numel()
    return eva/y_num

训练过程同样将数据送入GPU计算

python 复制代码
def train_epoch_gpu(net, train_iter, loss, updater,device):

    # 训练损失之和,训练准确数之和,样本数
    train_loss_sum = 0.0
    train_acc_sum = 0.0
    num_samples = 0.0
    # timer = d2l.torch.Timer()
    for i, (X, y) in enumerate(train_iter):
        # timer.start()
        updater.zero_grad()
        X, y = X.to(device), y.to(device)
        y_hat = net(X)
        l = loss(y_hat, y)
        l.backward()
        updater.step()
        with torch.no_grad():
            train_loss_sum += l * X.shape[0]
            train_acc_sum += evaluation.accuracy(y_hat, y)
            num_samples += X.shape[0]
        # timer.stop()
    return train_loss_sum/num_samples,train_acc_sum/num_samples


def train_gpu(net,train_iter,test_iter,num_epochs,lr,device):
    def init_weights(m):
        if type(m)==torch.nn.Linear or type(m)==torch.nn.Conv2d:
            torch.nn.init.xavier_uniform_(m.weight)

    net.apply(init_weights)
    net.to(device)
    print('training on',device)
    optimizer=torch.optim.SGD(net.parameters(),lr=lr)
    loss=torch.nn.CrossEntropyLoss()
    # num_batches=len(train_iter)
    tr_l=[]
    tr_a=[]
    te_a=[]
    for epoch in range(num_epochs):
        net.train()
        train_metric=train_epoch_gpu(net,train_iter,loss,optimizer,device)
        test_accuracy = evaluation.evaluate_accuracy_gpu(net, test_iter)
        train_loss, train_acc = train_metric
        train_loss = train_loss.cpu().detach().numpy()
        tr_l.append(train_loss)
        tr_a.append(train_acc)
        te_a.append(test_accuracy)
        print(f'epoch: {epoch + 1}, train_loss: {train_loss}, train_acc: {train_acc}, test_acc:{test_accuracy}')
    x = torch.arange(num_epochs)
    plt.plot((x + 1), tr_l, '-', label='train_loss')
    plt.plot(x + 1, tr_a, '--', label='train_acc')
    plt.plot(x + 1, te_a, '-.', label='test_acc')
    plt.legend()
    plt.show()
    print(f'on {str(device)}')
python 复制代码
lr,num_epochs=0.9,10
Train.train_gpu(net,train_iter,test_iter,num_epochs,lr,device='cuda')


相关推荐
island13143 分钟前
CANN HIXL 高性能单边通信库深度解析:PGAS 模型在异构显存上的地址映射与异步传输机制
人工智能·神经网络·架构
renhongxia122 分钟前
如何基于知识图谱进行故障原因、事故原因推理,需要用到哪些算法
人工智能·深度学习·算法·机器学习·自然语言处理·transformer·知识图谱
深鱼~28 分钟前
ops-transformer算子库:解锁昇腾大模型加速的关键
人工智能·深度学习·transformer·cann
禁默32 分钟前
不仅是 FlashAttention:揭秘 CANN ops-transformer 如何重构大模型推理
深度学习·重构·aigc·transformer·cann
笔画人生33 分钟前
进阶解读:`ops-transformer` 内部实现与性能调优实战
人工智能·深度学习·transformer
CV@CV44 分钟前
2026自动驾驶商业化提速——从智驾平权到Robotaxi规模化落地
人工智能·机器学习·自动驾驶
种时光的人1 小时前
CANN仓库核心解读:ascend-transformer-boost解锁AIGC大模型加速新范式
深度学习·aigc·transformer
island13141 小时前
CANN ops-nn 算子库深度解析:核心算子(如激活函数、归一化)的数值精度控制与内存高效实现
开发语言·人工智能·神经网络
brave and determined2 小时前
CANN ops-nn算子库使用教程:实现神经网络在NPU上的加速计算
人工智能·深度学习·神经网络
笔画人生2 小时前
系统级整合:`ops-transformer` 在 CANN 全栈架构中的角色与实践
深度学习·架构·transformer