PyTorch 入门精讲:从框架选择到 MNIST 手写数字识别实战

在人工智能深度学习的学习和实践中,选择一款合适的框架是入门的关键。而 PyTorch 凭借其上手简单、灵活性高的特性,成为了当下深度学习学习者和开发者的主流选择,更是入门者的首选框架。本文将结合 PyTorch 核心知识点,从深度学习框架对比、PyTorch 安装、CPU 与 GPU 的差异,到基于 MNIST 数据集的手写数字识别实战,全方位讲解 PyTorch 入门的核心内容,帮助大家快速掌握 PyTorch 的基础使用。

一、深度学习框架怎么选?主流框架优劣势对比

当下深度学习领域有不少经典框架,不同框架的设计理念、使用难度和适用场景各有不同,我们先对几款主流框架做核心对比,理解 PyTorch 的核心优势:

框架名称 核心信息
Caffe 优点:仅需配置文件即可搭建深度神经网络模型;缺点:安装麻烦,缺少新网络模型,近年几乎不更新
TensorFlow 开发方:Google 公司;缺点:1.x 版本代码冗余、上手难;2.x 版本不兼容 1.x 版本
Keras 优点:在 TensorFlow 基础上封装,简化代码编写难度
PyTorch 开发方:Facebook 公司;优点:上手极易,可直接套用模板

简单来说,如果你是深度学习入门者,想要快速上手并理解神经网络的核心逻辑,PyTorch 无疑是最佳选择。

二、PyTorch 实战:MNIST 手写数字识别

MNIST 数据集是深度学习入门的 "Hello World",包含 70000 张 28×28 的手写数字灰度图片,其中 60000 张训练集、10000 张测试集,是快速理解神经网络核心逻辑的最佳数据集。接下来我们基于 PyTorch 实现 MNIST 手写数字识别,拆解核心步骤和知识点。

2.1 核心知识点铺垫

在实战前,先了解两个核心组件,这是 PyTorch 构建和训练模型的基础:

(1)优化器:模型参数更新的核心

优化器的作用是根据模型的损失值更新网络参数,让模型不断逼近最优解,主流的优化器本质是梯度下降法的不同改进版本:

优化器名称 核心特点
BGD(批量梯度下降) 用全样本数据计算梯度,收敛稳定,但占用内存大、训练耗时极长
SGD(随机梯度下降) 随机抽取一组数据计算梯度,训练速度快,但收敛波动大、易陷入局部最优
Mini-batch SGD(小批量梯度下降) 数据集分小批次计算梯度,是 BGD 与 SGD 的结合,兼顾训练速度与收敛稳定性,为实战基础
Adam(自适应矩估计优化器) 自带自适应学习率,收敛更快更稳定,是当下深度学习实战最常用的优化器,适配绝大多数场景
(2)激活函数:解决梯度消失问题

激活函数为神经网络引入非线性因素,让模型能拟合复杂的非线性关系。而梯度消失 是深层网络训练的核心问题(梯度反向传播时,连乘的因子小于 1,最终梯度趋于 0,网络深层参数无法更新),而ReLU 激活函数是解决该问题的主流选择,其公式为f(x)=max(0,x),导数在正区间恒为 1,从根本上避免了梯度消失的问题,也是入门实战的首选激活函数。

2.2 MNIST 实战核心步骤

以下是基于 PyTorch 实现 MNIST 手写数字识别的完整代码,我会按 "数据集加载→数据预处理→网络搭建→模型训练→模型测试" 的核心流程,逐段拆解讲解代码逻辑:

python 复制代码
import torch
import torchvision
import torchaudio

print(torch.__version__)
print(torchvision.__version__)
print(torchaudio.__version__)

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# ===================== 1. 数据集加载 =====================
'''下载训练数据集'''
training_data = datasets.MNIST(
    root='data',          # 数据存储路径
    train=True,           # 标记为训练集
    download=True,        # 本地无数据则自动下载
    transform=ToTensor(), # 将图片转为PyTorch张量(0-1归一化,维度:C×H×W)
)

'''下载测试数据集'''
test_data = datasets.MNIST(
    root='data',
    train=False,          # 标记为测试集
    download=True,
    transform=ToTensor(),
)
print(len(training_data)) # 输出训练集数量(60000)

# 可视化前9张手写数字图片
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):
    img, label = training_data[i]  # 取出第i个样本(图片+标签)
    figure.add_subplot(3, 3, i+1)  # 创建3×3子图
    plt.title(label)               # 显示图片对应的数字标签
    plt.axis('off')                # 关闭坐标轴
    plt.imshow(img.squeeze(), cmap='gray')  # 展示图片(squeeze()去掉维度为1的通道维度)
plt.show()

# ===================== 2. 数据预处理(DataLoader封装) =====================
'''创建数据DataLoader(数据加载器)
batch_size:将数据集分成多份,每一份为batch_size个数据。
优点:可以减少内存的使用,提高训练速度。'''
training_dataloader = DataLoader(training_data, batch_size=64)  # 训练集按64个样本为一批打包
test_dataloader = DataLoader(test_data, batch_size=64)          # 测试集同理

# 验证DataLoader输出的维度
for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")  # 输出:torch.Size([64, 1, 28, 28])
    print(f"Shape of y: {y.shape} {y.dtype}")     # 输出:torch.Size([64]) torch.int64
    break

# ===================== 3. 网络搭建(自定义神经网络) =====================
'''判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"using {device} device")

'''定义神经网络类的继承'''
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()          # 展平层:将28×28的二维张量转为784维一维张量
        self.hidden1 = nn.Linear(28*28, 128) # 全连接层1:784输入→128输出
        self.hidden2 = nn.Linear(128, 256)   # 全连接层2:128输入→256输出
        self.out = nn.Linear(256, 10)        # 输出层:256输入→10输出(对应0-9十个数字)

    def forward(self, x):
        # 前向传播逻辑
        x = self.flatten(x)    # 展平:[64,1,28,28] → [64,784]
        x = self.hidden1(x)    # 第一层全连接:[64,784] → [64,128]
        x = torch.relu(x)      # ReLU激活函数:引入非线性
        x = self.hidden2(x)    # 第二层全连接:[64,128] → [64,256]
        x = torch.relu(x)      # ReLU激活函数
        x = self.out(x)        # 输出层:[64,256] → [64,10]
        return x

# 实例化模型并移至指定设备(CPU/GPU)
model = NeuralNetwork().to(device)
print(model)  # 打印网络结构

# ===================== 4. 模型训练 =====================
def train(dataloader, model, loss_fn, optimizer):
    model.train()  # 切换为训练模式(启用Dropout、BatchNorm等训练相关层)
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)  # 数据移至指定设备
        pred = model.forward(X)            # 前向传播:得到模型预测值
        loss = loss_fn(pred, y)            # 计算损失(交叉熵损失)

        # 反向传播+参数更新
        optimizer.zero_grad()  # 清空梯度(避免梯度累积)
        loss.backward()        # 反向传播:计算梯度
        optimizer.step()       # 优化器更新模型参数

        # 每100个批次打印一次损失值
        loss_value = loss.item()
        if batch_size_num % 100 == 0:
            print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")
        batch_size_num += 1

# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失(分类任务首选)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # SGD优化器,学习率0.01
train(training_dataloader, model, loss_fn, optimizer)  # 首次训练

# ===================== 5. 模型测试 =====================
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)    # 测试集总样本数(10000)
    num_batches = len(dataloader)     # 测试集总批次数
    model.eval()                      # 切换为评估模式(关闭Dropout、BatchNorm等)
    test_loss, correct = 0, 0
    with torch.no_grad():             # 关闭梯度计算(节省内存,加速推理)
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model.forward(X)   # 前向传播得到预测值
            test_loss += loss_fn(pred, y).item()  # 累加测试损失
            # 计算正确数:pred.argmax(1)取每行最大值索引(预测数字),与真实标签y比较
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    # 计算平均损失和准确率
    test_loss /= num_batches
    correct /= size
    print(f"Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}")

# 多轮训练(10轮)
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-----------------")
    train(training_dataloader, model, loss_fn, optimizer)
    print("Done!")
    test(test_dataloader, model, loss_fn)

这是代码在root="data"路径下自动下载的 MNIST 数据集文件

  1. 核心流程:这段代码完整实现了 "加载数据→预处理→搭网络→训练→测试" 的深度学习通用流程,是 PyTorch 入门的经典模板;

  2. 关键 API:核心依赖torch.nn(网络层)、torch.util.data.DataLoader(数据加载)、torchvision.datasets(数据集)三大模块;

  3. 核心逻辑:训练的本质是 "前向传播算损失→反向传播算梯度→优化器更参数" 的闭环,多轮训练(epochs)让模型逐步收敛。

这段代码的基础准确率约 90% 左右,若要提升到 93% 以上,只需将优化器从 SGD 改为 Adam(optimizer = torch.optim.SGD(model.parameters(), lr=0.01))即可。

三、PyTorch 入门总结与学习建议

PyTorch 作为入门首选的深度学习框架,其核心优势在于上手简单、逻辑清晰、灵活性高,从框架安装到实战开发,所有步骤都能让学习者清晰理解深度学习的核心逻辑,而非单纯的 "调包"。结合本次学习,给新手几点核心学习建议:

  1. 先搭环境,再练基础:优先完成 PyTorch 的 CPU/GPU 版本安装,掌握张量操作、数据加载等基础 API,打好代码基础;
  2. 从 MNIST 入手,理解核心闭环:MNIST 是入门必做的实战案例,通过该案例吃透 "数据预处理 - 网络搭建 - 模型训练 - 模型测试" 的全流程,理解优化器、激活函数、损失函数的作用;
  3. 重视硬件差异:理解 GPU 并行计算的核心优势,掌握 CUDA 的配置方法,这是后续进行大模型、大数据集训练的基础;
  4. 循序渐进,逐步深入:在掌握基础的全连接网络后,再学习卷积神经网络(CNN)、循环神经网络(RNN)等更复杂的网络结构,PyTorch 的模块化设计能让你快速迁移基础知识。

深度学习的学习是一个 "理论 + 实战" 的过程,而 PyTorch 正是连接理论和实战的最佳桥梁。希望本文能帮助大家快速入门 PyTorch,在深度学习的道路上迈出坚实的第一步。

相关推荐
智驱力人工智能3 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
qq_160144873 小时前
亲测!2026年零基础学AI的入门干货,新手照做就能上手
人工智能
Howie Zphile3 小时前
全面预算管理难以落地的核心真相:“完美模型幻觉”的认知误区
人工智能·全面预算
人工不智能5773 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
盟接之桥3 小时前
盟接之桥说制造:引流品 × 利润品,全球电商平台高效产品组合策略(供讨论)
大数据·linux·服务器·网络·人工智能·制造
kfyty7253 小时前
集成 spring-ai 2.x 实践中遇到的一些问题及解决方案
java·人工智能·spring-ai
猫头虎3 小时前
如何排查并解决项目启动时报错Error encountered while processing: java.io.IOException: closed 的问题
java·开发语言·jvm·spring boot·python·开源·maven
h64648564h3 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
数据与后端架构提升之路4 小时前
论系统安全架构设计及其应用(基于AI大模型项目)
人工智能·安全·系统安全
忆~遂愿4 小时前
ops-cv 算子库深度解析:面向视觉任务的硬件优化与数据布局(NCHW/NHWC)策略
java·大数据·linux·人工智能