深度学习------专题《神经网络完成手写数字识别》

目录

[一、为什么从 "手写数字识别" 开始?](#一、为什么从 “手写数字识别” 开始?)

二、准备工作:数据与工具

[1. 数据集:MNIST](#1. 数据集:MNIST)

[2. 工具:PyTorch & 辅助库](#2. 工具:PyTorch & 辅助库)

三、构建第一个神经网络模型

[1. 模型结构设计](#1. 模型结构设计)

[2. 模型实例化与配置](#2. 模型实例化与配置)

[四、训练与验证:让模型学会 "认" 数字](#四、训练与验证:让模型学会 “认” 数字)

[1. 训练的核心逻辑](#1. 训练的核心逻辑)

[2. 代码实现(带注释,方便理解)](#2. 代码实现(带注释,方便理解))

[3. 可视化训练结果](#3. 可视化训练结果)

五、总结与收


深度学习入门实战:手把手教你用 PyTorch 识别手写数字

作为深度学习的新手,最近跟着学习了用 PyTorch 实现手写数字识别的案例,总算摸到了深度学习的 "门槛"!今天把学习过程和关键知识点分享出来,希望能帮到和我一样刚入门的同学~

一、为什么从 "手写数字识别" 开始?

刚接触深度学习时,我总觉得它很 "高大上",怕学不懂。但老师说,MNIST 手写数字数据集是深度学习领域的 "Hello World"------ 它简单(只有数字 0-9 的手写图片)、经典,能帮新手快速理解 "神经网络如何学习" 的核心逻辑。用它入门,能一步步掌握 "数据处理→模型构建→训练验证" 的核心步骤,特别适合建立信心~

二、准备工作:数据与工具

开始前,得备好 "原材料" 和 "工具":

1. 数据集:MNIST

MNIST 包含大量手写数字图片,每张是 28×28 的灰度图,标签是 0-9 的数字。PyTorch 能直接下载它,省去了自己找数据、预处理的麻烦。

2. 工具:PyTorch & 辅助库

我们需要这些工具来搭建、训练模型:

  • torchvision:处理图像(比如转成 Tensor、做归一化);
  • DataLoader:批量加载数据,让训练更高效;
  • matplotlib:可视化数据和训练结果,看模型学得咋样;
  • numpy:辅助数值计算。

下面是我跟着敲的导入代码(亲测能跑通):

python 复制代码
import numpy as np
import torch
from torchvision.datasets import mnist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline

三、构建第一个神经网络模型

接下来是核心的 "建模型" 环节 ------ 我们要让模型把 28×28 的图像 "压平" 成一维向量,再通过两层隐藏层,最终输出 10 个类别(0-9)的概率。

1. 模型结构设计

神经网络就像 "信息加工厂":输入是手写数字图像,经过几重 "加工"(隐藏层),最后输出对每个数字的 "判断概率"。

这次我们设计了两个隐藏层 ,还加了BatchNorm让训练更稳定;激活函数用ReLU(避免 "梯度消失",让模型更易训练);最后用Softmax把输出变成 "概率分布"(能直观看到模型对每个数字的置信度)。

用 PyTorch 的nn.Module定义模型(我自己加了注释,方便记忆):

python 复制代码
class Net(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(Net, self).__init__()
        self.flatten = nn.Flatten()  # 把28×28图像压成784维向量
        # 第一层隐藏层(带BatchNorm)
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1))
        # 第二层隐藏层(带BatchNorm)
        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2))
        # 输出层(对应10个数字类别)
        self.out = nn.Sequential(nn.Linear(n_hidden_2, out_dim))
        
    def forward(self, x):
        x = self.flatten(x)  # 第一步:压平图像
        x = F.relu(self.layer1(x))  # 过第一层,ReLU激活
        x = F.relu(self.layer2(x))  # 过第二层,ReLU激活
        x = F.softmax(self.out(x), dim=1)  # 输出层用Softmax转成概率
        return x

2. 模型实例化与配置

定义好模型后,还要指定损失函数 (衡量 "预测准不准")和优化器(帮模型调整参数变得更准):

  • 损失函数用CrossEntropyLoss(适合分类问题);
  • 优化器用SGD(随机梯度下降,加了momentum让训练更稳定);
  • 同时判断硬件:有 GPU 就用 GPU 加速,没有就用 CPU。

代码如下:

python 复制代码
# 超参数
lr = 0.01
momentum = 0.9
# 实例化模型(输入28*28=784,两层隐藏层300、100,输出10类)
model = Net(28 * 28, 300, 100, 10)
# 选择设备(GPU/CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

四、训练与验证:让模型学会 "认" 数字

模型建好后,就要训练 (让模型在训练数据上学规律)和验证(用测试数据看模型学得好不好)了。

1. 训练的核心逻辑

训练过程其实是 "预测→算误差→调整参数" 的循环:

  • 把数据喂给模型,得到预测结果;
  • 用损失函数算 "预测值" 和 "真实标签" 的误差;
  • 通过backward()(反向传播)和step()(优化器)调整模型参数;
  • 同时记录损失和准确率,看模型是否在进步。

2. 代码实现(带注释,方便理解)

python 复制代码
# 记录训练过程的指标
losses = []
acces = []
eval_losses = []
eval_acces = []
writer = SummaryWriter(log_dir='logs', comment='train-loss')

# 训练循环,跑num_epochs轮
for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    model.train()  # 切换到训练模式(BatchNorm、Dropout等生效)
    # 动态调整学习率(每5轮乘以0.9,让训练更精细)
    if epoch % 5 == 0:
        optimizer.param_groups[0]['lr'] *= 0.9
        print(f'学习率调整为: {optimizer.param_groups[0]["lr"]:.6f}')
    # 遍历训练数据
    for img, label in train_loader:
        img, label = img.to(device), label.to(device)
        # 1. 正向传播:模型预测
        out = model(img)
        loss = criterion(out, label)
        # 2. 反向传播+优化:调整参数
        optimizer.zero_grad()  # 清空之前的梯度
        loss.backward()  # 计算梯度
        optimizer.step()  # 更新参数
        # 记录训练损失
        train_loss += loss.item()
        # 记录训练准确率
        _, pred = out.max(1)  # 取概率最大的类别
        num_correct = (pred == label).sum().item()
        acc = num_correct / img.shape[0]
        train_acc += acc
    # 保存本轮训练的平均损失和准确率
    losses.append(train_loss / len(train_loader))
    acces.append(train_acc / len(train_loader))
    
    # 验证:用测试集看模型效果
    eval_loss = 0
    eval_acc = 0
    model.eval()  # 切换到评估模式(BatchNorm、Dropout等关闭)
    with torch.no_grad():  # 验证时不需要计算梯度,节省资源
        for img, label in test_loader:
            img, label = img.to(device), label.to(device)
            img = img.view(img.size(0), -1)  # 压平图像(匹配模型输入)
            out = model(img)
            loss = criterion(out, label)
            eval_loss += loss.item()
            # 计算验证准确率
            _, pred = out.max(1)
            num_correct = (pred == label).sum().item()
            acc = num_correct / img.shape[0]
            eval_acc += acc
    eval_losses.append(eval_loss / len(test_loader))
    eval_acces.append(eval_acc / len(test_loader))
    # 打印本轮的训练+验证结果
    print(f' epoch: {epoch}, 训练损失: {train_loss/len(train_loader):.4f}, 训练准确率: {train_acc/len(train_loader):.4f}, '
          f'测试损失: {eval_loss/len(test_loader):.4f}, 测试准确率: {eval_acc/len(test_loader):.4f}')

3. 可视化训练结果

训练完后,用matplotlib把损失变化画出来,能更直观看到模型的 "学习过程":

python 复制代码
plt.title('训练损失变化')
plt.plot(np.arange(len(losses)), losses)
plt.legend(['Train Loss'], loc='upper right')

从图中能看到,损失随训练轮数增加逐渐下降 ------ 这说明模型在不断优化~

五、总结与收获

通过这次手写数字识别的实战,我终于对深度学习不再 "发怵",还明白了几个关键知识点:

  1. 数据处理是基础 :从下载 MNIST、预处理(转 Tensor、归一化)到用DataLoader批量加载,每一步都影响模型效果。
  2. 神经网络的核心是 "层的组合":Linear 层、激活函数、BatchNorm 的组合,让模型能学习复杂规律。
  3. 训练是 "迭代优化" 的过程:通过 "正向传播算误差→反向传播调参数",模型才会逐渐 "学会" 识别数字;同时要用测试集验证,避免 "学偏"(过拟合)。

虽然这只是个简单案例,但它帮我把深度学习的理论和实践串了起来。接下来我打算试试更复杂的数据集,或者调整模型结构,看看能不能让效果更好~

如果有同学也在学 PyTorch,欢迎交流呀~

相关推荐
流年染指悲伤、4 小时前
2024年最新技术趋势分析:AI、前端与后端开发新动向
人工智能·前端开发·后端开发·2024·技术趋势
乐迪信息5 小时前
乐迪信息:基于AI算法的煤矿作业人员安全规范智能监测与预警系统
大数据·人工智能·算法·安全·视觉检测·推荐算法
Bugman.5 小时前
分类任务-三个重要网络模型
深度学习·机器学习·分类
oe10195 小时前
好文与笔记分享 Paris, A Decentralized Trained Open-Weight Diffusion Model
人工智能·笔记·去中心化·多模态
HelloWorld__来都来了6 小时前
Agent S / Agent S2 的架构、亮点与局限
人工智能·架构
JAVA学习通6 小时前
发布自己的 jar 包到 Maven 中央仓库 ( mvnrepository.com )
人工智能·docker·自然语言处理·容器·rocketmq
文火冰糖的硅基工坊6 小时前
[嵌入式系统-107]:语音识别的信号处理流程和软硬件职责
人工智能·语音识别·信号处理
lianyinghhh6 小时前
瓦力机器人-舵机控制(基于树莓派5)
人工智能·python·自然语言处理·硬件工程
小殊小殊6 小时前
超越CNN:GCN如何重塑图像处理
图像处理·人工智能·深度学习