神经网络入门实战:(二十)MNIST训练网络(只用线性层和ReLU)

MNIST训练网络(只用线性层和ReLU)

该数据集一共有7万张图片,其中6万张是训练集,1万张是测试集;每张图片都是28×28像素的单通道(黑白)图片

类比 CIFAR10 的训练过程:

python 复制代码
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from NN_models import *


# 检查CUDA是否可用,并设置设备为 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataclass_transform = transforms.Compose([
	transforms.ToTensor(),
	transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='E:\\4_Data_sets\\MNIST', train=True,transform=dataclass_transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='E:\\4_Data_sets\\MNIST', train=False,transform=dataclass_transform, download=True)

# 训练和测试数据集的长度
train_data_size = len(train_dataset)
test_size = len(test_dataset)
print(train_data_size,test_size)

train_dataloader = DataLoader(dataset=train_dataset,batch_size=64)
test_dataloader = DataLoader(dataset=test_dataset,batch_size=64)

# 创建网络模型
class MNIST_NET(nn.Module):
	def __init__(self):
		super(MNIST_NET, self).__init__()
		self.model = nn.Sequential(
			nn.Flatten(),
			nn.Linear(784, 512),
			nn.ReLU(),  # 添加ReLU激活函数
			nn.Linear(512, 256),
			nn.ReLU(),  # 添加ReLU激活函数
			nn.Linear(256, 128),
			nn.ReLU(),  # 添加ReLU激活函数
			nn.Linear(128, 64),
			nn.ReLU(),  # 添加ReLU激活函数
			nn.Linear(64, 10)
		)

	def forward(self, x):
		x = self.model(x)
		return x

MNIST_NET_Instance = MNIST_NET().to(device)

# 定义损失函数
loss = nn.CrossEntropyLoss()
# 定义优化器
learning_rate = 0.01
optimizer = torch.optim.SGD(MNIST_NET_Instance.parameters(), lr=learning_rate, momentum=0.9)

# 开始训练
total_train_step = 0
first_train_step = 0
total_test_step = 0
epoch_sum = 10 # 迭代次数

# 添加tensorboard
writer = SummaryWriter('logs')

for i in range(epoch_sum):
	print("------------第 {} 轮训练开始了------------:".format(i+1))

	# 训练步骤开始
	for data in train_dataloader:
		imgs, labels = data
		imgs, labels = imgs.to(device), labels.to(device)  # 将数据和目标移动到GPU
		outputs = MNIST_NET_Instance(imgs)
		loss_real = loss(outputs, labels) # 这里的损失变量 loss_real,千万别和损失函数 loss 相同,否则会报错!
		optimizer.zero_grad()
		loss_real.backward()
		optimizer.step()

		total_train_step += 1
		# 表示第一轮训练结束,取每一轮的第一个batch_size来看看训练效果
		if total_train_step % 938 == 0:
			first_train_step += 1
			print("训练次数为:{}, loss为:{}".format(total_train_step, loss_real)) # 此训练次数非训练轮次,而是训练到第几个batch_size了
			writer.add_scalar('first_batch_size', loss_real.item(), first_train_step)
		writer.add_scalar('total_batch_size', loss_real.item(), total_train_step)


	# 每训练一轮,就使用测试集看看训练效果
	total_test_loss = 0
	with torch.no_grad(): # 后续测试不计算梯度    
		for data in test_dataloader:
			imgs, labels = data
			imgs, labels = imgs.to(device), labels.to(device)
			outputs = MNIST_NET_Instance(imgs)
			loss_fake = loss(outputs, labels)
			total_test_loss += loss_fake.item()
	print("# # 整体测试集上的LOSS为:{}".format(total_test_loss))

writer.close()

torch.save(MNIST_NET_Instance,"E:\\5_NN_model\\MNIST_NET_train10")
print("模型已保存!!")

结果如下:


上一篇 下一篇
神经网络入门实战(十九) 待发布
相关推荐
运维行者_18 小时前
OpManager 对接 ERP 避坑指南,网络自动化提升数据同步效率
运维·服务器·开发语言·网络·microsoft·网络安全·php
AI科技星18 小时前
统一场论理论下理解物体在不同运动状态的本质
人工智能·线性代数·算法·机器学习·概率论
乾元18 小时前
数据为王——安全数据集的清洗与特征工程
大数据·网络·人工智能·安全·web安全·机器学习·架构
wangmengxxw18 小时前
SpringAI-结构化输出API
java·人工智能·springai
国际期刊-秋秋18 小时前
[ACM] 2026 年人工智能系统、区块链与数字经济国际学术会议(DEAI 2026)
人工智能·国际会议·会议投稿
2501_9402778018 小时前
告别碎片化集成:使用 MCP 标准化重构企业内部遗留 API,构建统一的 AI 原生接口中心
人工智能·重构
萤丰信息18 小时前
智慧园区:科技赋能的未来产业生态新载体
大数据·运维·人工智能·科技·智慧园区
ASD123asfadxv18 小时前
【医疗影像检测】VFNet模型在医疗器械目标检测中的应用与优化
人工智能·目标检测·计算机视觉
小真zzz18 小时前
2025-2026年AI PPT工具排行榜:ChatPPT的全面领先与竞品格局解析
人工智能·ai·powerpoint·ppt·aippt
智慧化智能化数字化方案18 小时前
详解人工智能安全治理框架(中文版)【附全文阅读】
大数据·人工智能·人工智能安全治理框架