手写数字识别(分类任务)

1. 导入必要的库

python 复制代码
from pathlib import Path
import requests
import pickle
import gzip
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch import optim
  • 功能: 导入所需的库,以便进行文件操作、数据处理、构建神经网络、计算损失以及加载和处理数据。

2. 数据准备

a. 创建数据路径并下载数据集
python 复制代码
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
    content = requests.get(URL + FILENAME).content
    (PATH / FILENAME).open("wb").write(content)
  • 功能 :
    • 创建存储数据集的路径 data/mnist
    • 从指定 URL 下载 MNIST 数据集,并保存为 mnist.pkl.gz 文件,如果文件已经存在则不重复下载。
b. 加载数据集
python 复制代码
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
  • 功能 :
    • 使用 gzip 解压缩文件并使用 pickle 加载数据集,得到训练数据 x_train、训练标签 y_train、验证数据 x_valid 和验证标签 y_valid

3. 数据转换为 PyTorch 张量

python 复制代码
x_train_test, y_train_test, x_valid_test, y_valid_test = map(torch.tensor, (x_train, y_train, x_valid, y_valid))
  • 功能: 将 NumPy 数组转换为 PyTorch 张量,以便后续的模型训练和计算。

4. 模型构建

a. 定义神经网络结构
python 复制代码
class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)  # 第一隐藏层
        self.hidden2 = nn.Linear(128, 256)  # 第二隐藏层
        self.out = nn.Linear(256, 10)       # 输出层
        self.dropout = nn.Dropout(0.5)      # Dropout层
  • 功能 :
    • 创建一个名为 Mnist_NN 的神经网络类,继承自 nn.Module
    • 在初始化方法中定义网络的结构,包括两个隐藏层和一个输出层,以及一个 Dropout 层以减少过拟合。
b. 定义前向传播方法
python 复制代码
def forward(self, x):
    x = F.relu(self.hidden1(x))  # 第一层激活
    x = self.dropout(x)          # Dropout
    x = F.relu(self.hidden2(x))  # 第二层激活
    x = self.dropout(x)          # Dropout
    x = self.out(x)              # 输出层
    return x
  • 功能 :
    • 定义前向传播过程,输入数据通过各层进行计算,并应用 ReLU 激活函数和 Dropout。

5. 损失函数和优化器

a. 定义损失函数
python 复制代码
loss_func = F.cross_entropy  # 使用交叉熵损失函数
  • 功能: 选择交叉熵损失函数作为模型的损失计算标准。
b. 定义优化器
python 复制代码
def get_model():
    model = Mnist_NN()  # 实例化模型
    return model, optim.SGD(model.parameters(), lr=0.001)  # 使用 SGD 优化器
  • 功能 :
    • 创建模型实例,并定义 SGD 优化器,学习率设置为 0.001。

6. 数据加载

a. 创建数据集和数据加载器
python 复制代码
train_ds = TensorDataset(x_train, y_train)  # 创建训练数据集
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)  # 创建训练数据加载器

valid_ds = TensorDataset(x_valid, y_valid)  # 创建验证数据集
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)  # 创建验证数据加载器
  • 功能 :
    • 将训练和验证数据转换为 TensorDataset 格式,以便于进行批量处理。
    • 创建数据加载器 DataLoader,在训练时打乱训练集的顺序,方便分批次读取数据。

6.5 loss_batch

python 复制代码
def loss_batch(model, loss_func, xb, yb, opt=None):  
    # 计算当前批次的损失
    loss = loss_func(model(xb), yb)  

    # 如果提供了优化器 opt,则进行反向传播和优化
    if opt is not None:
        loss.backward()  # 计算损失的梯度(反向传播)
        opt.step()       # 更新模型参数
        opt.zero_grad()  # 清空梯度,避免累积

    # 返回当前批次的损失值和该批次数据的大小
    return loss.item(), len(xb)  

7. 训练过程

a. 定义训练函数
python 复制代码
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()  # 设置模型为训练模式
        for xb, yb in train_dl:  # 遍历训练数据
            loss_batch(model, loss_func, xb, yb, opt)  # 计算并优化损失

        model.eval()  # 设置模型为评估模式
        with torch.no_grad():  # 禁用梯度计算
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )  # 计算验证集损失

        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)  # 计算加权平均损失
        print('当前step:' + str(step), '验证集损失:' + str(val_loss))  # 打印损失
  • 功能 :
    • 定义 fit 函数,负责训练过程,包括:
      • 将模型设置为训练模式并进行训练。
      • 在验证模式下计算验证集损失并输出。

8. 运行训练

python 复制代码
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)  # 获取数据加载器
model, opt = get_model()  # 获取模型和优化器
fit(25, model, loss_func, opt, train_dl, valid_dl)  # 开始训练
  • 功能 :
    • 通过调用 get_dataget_model 函数获取数据加载器和模型,然后调用 fit 函数进行训练。

总结

以上步骤展示了从数据准备到模型训练的完整过程。每一步都围绕着构建一个用于手写数字识别的神经网络进行,确保数据的加载、模型的构建和训练过程都能顺利进行。通过这些步骤,最终可以得到一个能够对手写数字进行分类的模型。

相关推荐
梦云澜1 小时前
论文阅读(十二):全基因组关联研究中生物通路的图形建模
论文阅读·人工智能·深度学习
远洋录1 小时前
构建一个数据分析Agent:提升分析效率的实践
人工智能·ai·ai agent
IT古董2 小时前
【深度学习】常见模型-Transformer模型
人工智能·深度学习·transformer
沐雪架构师3 小时前
AI大模型开发原理篇-2:语言模型雏形之词袋模型
人工智能·语言模型·自然语言处理
python算法(魔法师版)4 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
kakaZhui4 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20255 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥5 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
云空7 小时前
《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》
运维·人工智能·web安全·网络安全·开源·网络攻击模型·安全威胁分析
AIGC大时代7 小时前
对比DeepSeek、ChatGPT和Kimi的学术写作关键词提取能力
论文阅读·人工智能·chatgpt·数据分析·prompt