手写数字识别学习笔记

一、任务背景

本次学习围绕利用PyTorch实现神经网络对手写数字(MNIST数据集)进行识别展开,旨在了解借助`nn`工具箱构建神经网络的流程,后续还会深入介绍`nn`各模块。

二、主要步骤与关键知识点

(一)准备数据

  1. **数据导入与预处理**
  • 从`torchvision.datasets`导入MNIST数据集,利用`torchvision.transforms`对数据进行预处理,如`ToTensor()`将图像转为张量,`Normalize([0.5], [0.5])`进行标准化,使数据分布更适合模型训练。

  • 通过`torch.utils.data.DataLoader`创建数据迭代器(`train_loader`和`test_loader`),`batch_size`控制每次训练的样本数,`shuffle`在训练集设为`True`实现数据打乱,增强模型泛化能力,测试集设为`False`保证评估稳定性。

  1. **数据可视化前的准备**
  • 使用`enumerate`和`next`获取测试集的一个批次数据,通过`shape`查看数据维度(如`torch.Size([128, 1, 28, 28])`,表示128个样本,单通道,28×28像素)。

(二)可视化源数据

借助`matplotlib.pyplot`库,遍历测试集批次数据,用`subplot`和`imshow`展示手写数字图像,并标注真实标签(`Ground Truth`),直观感受数据形态。

(三)构建模型

  1. **网络结构设计**
  • 定义`Net`类继承`nn.Module`,在`init`方法中,用`nn.Sequential`组合网络层,包括`nn.Flatten`(将28×28的图像展平为784维向量)、两个含`nn.Linear`(线性变换)和`nn.BatchNorm1d`(批量归一化,加速训练、提升稳定性)的隐藏层,以及输出层(将特征映射到10个类别,对应0 - 9数字)。

  • `forward`方法定义前向传播过程,对各层输出应用`F.relu`(隐藏层,引入非线性)和`F.softmax`(输出层,将输出转为概率分布,`dim=1`按行计算)。

(四)实例化模型与定义优化相关组件

  1. **设备与模型实例化**
  • 根据`torch.cuda.is_available()`判断是否使用GPU,实例化`Net`模型并转移到对应设备。
  1. **损失函数与优化器**
  • 损失函数选用`nn.CrossEntropyLoss`,适合多分类任务;优化器使用`optim.SGD`,设置学习率`lr`和动量`momentum`,加速梯度下降过程。

(五)训练模型

  1. **训练循环**
  • 遍历`num_epochs`个epoch,每个epoch内分训练和测试阶段。训练时,模型设为`train()`模式,开启梯度计算。

  • 动态调整学习率:每5个epoch将学习率乘以0.9,平衡训练速度与精度。

  • 正向传播:输入图像,经模型得到输出,计算与真实标签的损失;反向传播:通过`zero_grad()`清空梯度,`backward()`计算梯度,`step()`更新参数。

  • 记录训练损失和准确率,以及测试集上的损失和准确率,用于后续分析。

  1. **结果可视化**
  • 用`matplotlib`绘制训练损失曲线,可观察到损失随epoch增加逐渐下降,说明模型在不断学习。

三、总结

通过本次学习,掌握了利用PyTorch完成手写数字识别任务的完整流程,从数据准备、模型构建到训练评估,理解了各环节关键技术(如数据预处理、网络层设计、损失函数与优化器选择等)的作用,为后续深入学习神经网络打下基础。

相关推荐
峰顶听歌的鲸鱼6 分钟前
9.OpenStack管理(三)
运维·笔记·分布式·openstack·学习方法
我命由我1234535 分钟前
Photoshop - Photoshop 工具栏(22)单行选框工具
学习·ui·职场和发展·求职招聘·职场发展·学习方法·photoshop
立志成为大牛的小牛1 小时前
数据结构——三十七、关键路径(王道408)
数据结构·笔记·程序人生·考研·算法
User_芊芊君子1 小时前
【成长纪实】我的鸿蒙成长之路:从“小白”到独立开发,带你走进鸿蒙的世界
学习·华为·harmonyos·鸿蒙开发
oe10192 小时前
好文与笔记分享 A Survey of Context Engineering for Large Language Models(下)
人工智能·笔记·语言模型·agent
冷雨夜中漫步2 小时前
高级系统架构师笔记——系统质量属性与架构评估(1)软件系统质量属性
笔记·架构·系统架构
oe10193 小时前
好文与笔记分享 A Survey of Context Engineering for Large Language Models(中)
人工智能·笔记·语言模型·agent开发
嵌入式-老费3 小时前
自己动手写深度学习框架(快速学习python和关联库)
开发语言·python·学习
许长安3 小时前
C++中指针和引用的区别
c++·经验分享·笔记
摇滚侠4 小时前
Spring Boot3零基础教程,StreamAPI 介绍,笔记98
java·spring boot·笔记