用 PyTorch 实现全连接网络识别 MNIST 手写数字

目录

一、什么是全连接网络

二、代码实现步骤

[1. 导入必要的库](#1. 导入必要的库)

[2. 数据准备](#2. 数据准备)

[3. 定义网络结构](#3. 定义网络结构)

[4. 模型训练](#4. 模型训练)

[5. 模型保存和加载](#5. 模型保存和加载)

[6. 预测单张图片](#6. 预测单张图片)

[7. 主函数](#7. 主函数)

三、运行结果说明

四、小结


一、什么是全连接网络

全连接神经网络(Fully Connected Neural Network)是一种最基础的神经网络结构,其特点是每一层的每个神经元都与上一层的所有神经元相连。

打个比方,就像公司里的部门架构:输入层是基层员工,隐藏层是中层管理,输出层是高层决策。基层的每个人都要向所有中层汇报,中层再向所有高层汇报,这样信息就能经过多层处理后得到最终结果。

但全连接网络处理图像时有个缺点:它会把图像的二维像素矩阵转换成一维向量,这就像把一张完整的图片撕成一条线,会丢失图像的空间特征。

二、代码实现步骤

1. 导入必要的库

python 复制代码
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image

这些库就像我们的工具包:

  • torch 是 PyTorch 的核心库
  • nn 模块包含神经网络相关的工具
  • optim 提供优化器
  • torchvision 有现成的数据集和图像处理工具
  • DataLoader 帮助我们批量加载数据
  • PIL 用于处理图像

2. 数据准备

python 复制代码
def build_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_set = datasets.MNIST(
        root = '../dataset',
        train = True,
        download = True,
        transform = transform
    )
    test_set = datasets.MNIST(
        root = '../dataset',
        train = False,
        download = True,
        transform = transform
    )
    train_loader = DataLoader(
        dataset = train_set,
        batch_size = 128,
        shuffle = True
    )
    test_loader = DataLoader(
        dataset = test_set,
        batch_size = 64,
        shuffle = True
    )
    return train_loader, test_loader

这段代码做了三件事:

  • 定义了数据转换方式,ToTensor()会把图像转换成张量并归一化
  • 加载 MNIST 数据集(手写数字数据集,包含 0-9 共 10 类数字)
  • DataLoader把数据分成批次,方便训练时批量处理

batch_size表示每次处理多少张图片,shuffle=True表示打乱数据顺序,让模型学习更全面。

3. 定义网络结构

python 复制代码
class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)  # 把28x28的图像展平成784维向量
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return x

我们定义了一个 3 层的全连接网络:

  • 输入层:MNIST 图像是 28x28 的,展平后是 784 个像素点
  • 第一个隐藏层:256 个神经元,使用 ReLU 激活函数
  • 第二个隐藏层:128 个神经元,同样使用 ReLU 激活函数
  • 输出层:10 个神经元(对应 0-9 十个数字)

激活函数 ReLU 的作用是引入非线性,让网络能够学习复杂的模式,就像给计算器增加了更多运算功能。

4. 模型训练

python 复制代码
def train(model, train_loader, epochs):
    criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,适合分类问题
    opt = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器
    
    for epoch in range(epochs):
        loss_sum = 0
        count = 0
        for x, y in train_loader:
            y_pred = model(x)  # 前向传播,得到预测结果
            loss = criterion(y_pred, y)  # 计算损失
            
            # 反向传播更新参数
            opt.zero_grad()  # 清空梯度
            loss.backward()  # 计算梯度
            opt.step()  # 更新参数
            
            loss_sum += loss.item()
            _, pred = torch.max(y_pred, dim=1)  # 找到概率最大的类别
            count += (pred == y).sum().item()  # 统计正确的数量
        
        acc = count / len(train_loader.dataset)  # 计算准确率
        print(f'epoch: {epoch+1}, Loss: {loss_sum:.4f}, Acc: {acc:.4f}')

训练过程就像学生做习题:

  1. 先用当前模型做预测(前向传播)
  2. 计算预测结果和正确答案的差距(损失函数)
  3. 分析哪里错了,怎么改进(反向传播计算梯度)
  4. 调整模型参数(优化器更新参数)

我们用交叉熵损失函数来衡量预测错误的程度,用随机梯度下降(SGD)来优化模型参数,学习率lr=0.01控制每次调整的幅度。

5. 模型保存和加载

python 复制代码
def save_model(model, model_path):
    torch.save(model.state_dict(), model_path)  # 保存模型参数

def load_model(model_path):
    model = MNISTNet()
    model.load_state_dict(torch.load(model_path))  # 加载模型参数
    return model

训练好的模型可以保存下来,下次用的时候直接加载,不用重新训练,就像保存游戏进度一样。

6. 预测单张图片

python 复制代码
def predict(model, filePath):
    img = Image.open(filePath)
    # 图像预处理:调整大小、转成张量、归一化
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    t_img = transform(img)
    with torch.no_grad():  # 预测时不需要计算梯度
        y_pred = model(t_img)
        _, pred = torch.max(y_pred, dim=1)
        print(f'预测结果: {pred.item()}')

预测时需要对输入图片做和训练数据相同的预处理,with torch.no_grad()可以加快计算速度,因为预测时不需要更新参数。

7. 主函数

python 复制代码
if __name__ == '__main__':
    train_loader, test_loader = build_data()
    model = MNISTNet()
    
    # 训练模型
    train(model, train_loader, epochs=10)
    
    # 保存模型
    save_model(model, './mnist.pt')
    
    # 加载模型并预测
    model_pred = load_model('./mnist.pt')
    predict(model_pred, './img/3.png')

三、运行结果说明

训练过程中,我们会看到损失(Loss)逐渐减小,准确率(Acc)逐渐提高,这说明模型在不断进步。

对于 MNIST 这种简单数据集,用这个全连接网络通常能达到 97% 以上的准确率。如果想进一步提高性能,可以考虑使用卷积神经网络(CNN),它能更好地保留图像的空间特征。

四、小结

本文用 PyTorch 实现了一个全连接神经网络来识别 MNIST 手写数字,主要步骤包括:

  1. 准备数据:加载并预处理 MNIST 数据集
  2. 定义网络:设计 3 层全连接网络
  3. 训练模型:使用交叉熵损失和 SGD 优化器
  4. 保存和加载模型:方便复用
  5. 单张图片预测:实际应用模型

全连接网络虽然简单,但它是理解更复杂神经网络的基础。通过这个例子,我们可以了解神经网络的基本工作原理和 PyTorch 的使用方法。

相关推荐
科技峰行者8 小时前
通义万相2.5系列模型发布,可生成音画同步视频
人工智能·阿里云·ai·大模型·agi
两只程序猿8 小时前
数据可视化 | Violin Plot小提琴图Python实现 数据分布密度可视化科研图表
开发语言·python·信息可视化
Vizio<8 小时前
《面向物理交互任务的触觉传感阵列仿真》2020AIM论文解读
论文阅读·人工智能·机器人·机器人触觉
尤超宇8 小时前
基于卷积神经网络的 CIFAR-10 图像分类实验报告
人工智能·分类·cnn
alex1009 小时前
BeaverTails数据集:大模型安全对齐的关键资源与实战应用
人工智能·算法·安全
大模型真好玩9 小时前
架构大突破! DeepSeek-V3.2发布,五分钟速通DeepSeek-V3.2核心特性
人工智能·python·deepseek
春末的南方城市9 小时前
苏大团队联合阿丘科技发表异常生成新方法:创新双分支训练法,同步攻克异常图像生成、分割及下游模型性能提升难题。
人工智能·科技·深度学习·计算机视觉·aigc
OpenCSG9 小时前
超越颠覆:AI与Web3如何为传统金融的“华兴资本们”提供新生之路
人工智能·金融·web3
玩转C语言和数据结构9 小时前
Jupyter Notebook下载安装使用教程(附安装包,图文并茂)
ide·python·jupyter·anaconda·jupyternotebook·anaconda下载·anaconda安装包
2401_841495649 小时前
【自然语言处理】Universal Transformer(UT)模型
人工智能·python·深度学习·算法·自然语言处理·transformer·ut