Pytorch训练LeNet模型MNIST数据集

如何用torch框架训练深度学习模型(详解)

0. 需要的包

python 复制代码
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

1. 数据加载和导入

以MNIST数据集为例

python 复制代码
# 1.1 需要设置数据归一化
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))])
# 1.2 用dataset.MNIST函数下载和加载训练集与测试集 
train_dataset = datasets.MNIST(dataset_path, train=True, 
	download=False, transform=train_transform)
test_dataset = datasets.MNIST(dataset_path, train=False, 
	download=False, transform=test_transform)
# 1.3 加载进dataload用于后续数据按batch取用
batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

补充:这里的transform根据不同的数据集选择不同的值

datasets加载数据集时path的路径为:'.\data\' 该目录下包括\MNIST文件夹

2. 加载模型和设置超参数

python 复制代码
# 2.1 这里需要提前定义model的class,包括层结构和forward函数
model = LeNet_Mnist().to(device)
# 2.2 设置优化器、损失函数、训练轮次
learning_rate = 1e-2
# 传入模型参数,用于优化更新
sgd = SGD(model.parameters(), lr=learning_rate)  
loss_fn = CrossEntropyLoss()
all_epoch = 20

3. 训练

python 复制代码
# 3.1 首先设置训练模式
model.train()
# 3.2 按照batch从train_loader中批量选择数据
for idx, (train_x, train_label) in enumerate(train_loader):
    train_x = train_x.to(device)
    train_label = train_label.to(device)
    sgd.zero_grad()
    predict_y = model(train_x.float())
    loss = loss_fn(predict_y, train_label.long())
    loss.backward()
    sgd.step()

补充:可以在外面再套一层迭代次数

python 复制代码
for current_epoch in range(all_epoch):  # local training

4. 测试

python 复制代码
# 4.1 记录测试结果
all_correct_num = 0
all_sample_num = 0
# 4.2 进入模型验证模式,该模式下不会修改梯度
model.eval()
# 4.3 按批次测试
for idx, (test_x, test_label) in enumerate(test_loader):
    test_x = test_x.to(device)
    test_label = test_label.to(device)
    predict_y = model(test_x.float()).detach()
    predict_y = torch.argmax(predict_y, dim=-1)
    current_correct_num = predict_y == test_label
    all_correct_num += np.sum(current_correct_num.to('cpu').numpy(), axis=-1)
    all_sample_num += current_correct_num.shape[0]
# 4.4 记录结果并输出
acc = all_correct_num / all_sample_num
print('accuracy: {:.3f}'.format(acc), flush=True)

5. 保存结果

python 复制代码
# 5.1 保存参数
print("Save the model state dict")
torch.save(model.state_dict(), "./lenet_mnist.pt")
# 5.2 或者也可以选择保存checkpoint,每轮都保存一次,万一中断能继续
checkpoint = {
                "model": model.state_dict(),
                "optim": sgd.state_dict(),
             }
print("Save the checkpoint")
torch.save(checkpoint, "./checkpoint{}.pt".format(current_epoch))
相关推荐
元宇宙时间3 小时前
RWA加密金融高峰论坛&星链品牌全球发布 —— 稳定币与Web3的香港新篇章
人工智能·web3·区块链
MZ_ZXD0014 小时前
springboot汽车租赁服务管理系统-计算机毕业设计源码58196
java·c++·spring boot·python·django·flask·php
A 计算机毕业设计-小途4 小时前
大四零基础用Vue+ElementUI一周做完化妆品推荐系统?
java·大数据·hadoop·python·spark·毕业设计·毕设
天涯海风6 小时前
检索增强生成(RAG) 缓存增强生成(CAG) 生成中检索(RICHES) 知识库增强语言模型(KBLAM)
人工智能·缓存·语言模型
lxmyzzs7 小时前
基于深度学习CenterPoint的3D目标检测部署实战
人工智能·深度学习·目标检测·自动驾驶·ros·激光雷达·3d目标检测
跟着珅聪学java7 小时前
Apache OpenNLP简介
人工智能·知识图谱
AwhiteV8 小时前
利用图数据库高效解决 Text2sql 任务中表结构复杂时占用过多大模型上下文的问题
数据库·人工智能·自然语言处理·oracle·大模型·text2sql
念念01078 小时前
数学建模竞赛中评价类相关模型
python·数学建模·因子分析·topsis
Black_Rock_br8 小时前
AI on Mac, Your Way!全本地化智能代理,隐私与性能兼得
人工智能·macos
云天徽上8 小时前
【数据可视化-94】2025 亚洲杯总决赛数据可视化分析:澳大利亚队 vs 中国队
python·信息可视化·数据挖掘·数据分析·数据可视化·pyecharts