基于CNN实现Mnist手写数字识别

一、Mnist数据集介绍

MNIST(Modified National Institute of Standards and Technology database)是一个大型的手写数字数据库,广泛用于训练和测试图像处理系统。它包含了从0到9的共10个类别的灰度手写数字图像。

数据集详情

  • 来源:由美国国家标准与技术研究院(NIST)提供的原始数据集修改而来。

  • 样本数量 :共有 70,000 张图像。

    • 训练集:60,000 张

    • 测试集:10,000 张

  • 图像格式

    • 尺寸 :每张图像为 28x28 像素。

    • 色彩灰度图,每个像素的值在0(黑色)到255(白色)之间。

    • 数据格式 :通常被展平(Flatten) 成一个 784(28*28) 维的向量作为输入。

  • 标签:每张图像都有一个对应的标签,是0到9之间的整数,表示图像中写的数字。

二、构建网络模型

复制代码
网络结构:
Conv2D -> ReLU -> MaxPool -> Conv2D -> ReLU -> MaxPool -> FC -> Dropout -> FC

代码实现:

python 复制代码
class MNISTCNN(nn.Module):
    """
    一个简单的CNN模型,专门为MNIST设计
    网络结构:Conv2D -> ReLU -> MaxPool -> Conv2D -> ReLU -> MaxPool -> FC -> Dropout -> FC
    """
    def __init__(self):
        super(MNISTCNN, self).__init__()
        
        # 卷积层1:输入通道1(灰度图),输出32个特征图,卷积核3x3
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        # 卷积层2:输入32,输出64
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        
        # 池化层
        self.pool = nn.MaxPool2d(2, 2)
        
        # Dropout层防止过拟合
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout(0.5)
        
        # 全连接层
        # 经过两次池化后,28x28 -> 14x14 -> 7x7
        self.fc1 = nn.Linear(64 * 7 * 7, 128)  # 7x7x64 -> 128
        self.fc2 = nn.Linear(128, 10)  # 128 -> 10个类别
        
    def forward(self, x):
        # 第一个卷积块
        x = self.pool(F.relu(self.conv1(x)))  # [batch, 32, 14, 14]
        x = self.dropout1(x)
        
        # 第二个卷积块
        x = self.pool(F.relu(self.conv2(x)))  # [batch, 64, 7, 7]
        x = self.dropout1(x)
        
        # 展平
        x = x.view(-1, 64 * 7 * 7)  # [batch, 3136]
        
        # 全连接层
        x = F.relu(self.fc1(x))  # [batch, 128]
        x = self.dropout2(x)
        x = self.fc2(x)  # [batch, 10]
        
        return x

三、数据加载和预处理

python 复制代码
def load_and_preprocess_data():
    """
    加载和预处理MNIST数据
    """
    # 直接从torchvision下载MNIST
    from torchvision import datasets
    
    # 数据变换:转换为张量并归一化
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST的均值和标准差
    ])
    
    # 下载训练集
    train_dataset = datasets.MNIST(
        root='./mnist_dataset/train',
        train=True,
        download=False,
        transform=transform
    )
    
    # 下载测试集
    test_dataset = datasets.MNIST(
        root='./mnist_dataset/train',
        train=False,
        download=False,
        transform=transform
    )
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=64,
        shuffle=True,
        num_workers=2
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=2
    )
    
    return train_loader, test_loader

代码实现:

python 复制代码
def load_and_preprocess_data():
    """
    加载和预处理MNIST数据
    """
    # 直接从torchvision下载MNIST
    from torchvision import datasets
    
    # 数据变换:转换为张量并归一化
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST的均值和标准差
    ])
    
    # 下载训练集
    train_dataset = datasets.MNIST(
        root='./mnist_dataset/train',
        train=True,
        download=False,
        transform=transform
    )
    
    # 下载测试集
    test_dataset = datasets.MNIST(
        root='./mnist_dataset/train',
        train=False,
        download=False,
        transform=transform
    )
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=64,
        shuffle=True,
        num_workers=2
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=2
    )
    
    return train_loader, test_loader

四、编写训练函数和推理函数

需要的完整代码的小伙伴可以私信我

五、模型最终预测结果

Using device: cuda

正在加载MNIST数据集...

创建CNN模型...

MNISTCNN(

(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

(dropout1): Dropout2d(p=0.25, inplace=False)

(dropout2): Dropout(p=0.5, inplace=False)

(fc1): Linear(in_features=3136, out_features=128, bias=True)

(fc2): Linear(in_features=128, out_features=10, bias=True)

)

总参数数量: 421,642

可训练参数数量: 421,642

开始训练CNN模型...

开始训练...

Epoch [1/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:37<00:00, 25.18it/s, Loss=0.0833, Batch Acc=100.00%]

============================================================

Epoch 1/10 训练完成

训练准确率: 91.92%, 训练损失: 0.2609

测试准确率: 98.06%, 测试损失: 0.0572

============================================================

Epoch [2/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:35<00:00, 26.35it/s, Loss=0.0187, Batch Acc=100.00%]

============================================================

Epoch 2/10 训练完成

训练准确率: 97.00%, 训练损失: 0.1036

测试准确率: 98.65%, 测试损失: 0.0388

============================================================

Epoch [3/10]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:33<00:00, 28.06it/s, Loss=0.0487, Batch Acc=96.88%]

============================================================

Epoch 3/10 训练完成

训练准确率: 97.56%, 训练损失: 0.0825

测试准确率: 99.04%, 测试损失: 0.0311

============================================================

Epoch [4/10]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:34<00:00, 27.14it/s, Loss=0.1242, Batch Acc=96.88%]

============================================================

Epoch 4/10 训练完成

训练准确率: 97.85%, 训练损失: 0.0713

测试准确率: 98.80%, 测试损失: 0.0350

============================================================

Epoch [5/10]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:32<00:00, 28.73it/s, Loss=0.1142, Batch Acc=96.88%]

============================================================

Epoch 5/10 训练完成

训练准确率: 98.17%, 训练损失: 0.0612

测试准确率: 99.19%, 测试损失: 0.0265

============================================================

Epoch [6/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:34<00:00, 27.44it/s, Loss=0.0197, Batch Acc=100.00%]

============================================================

Epoch 6/10 训练完成

训练准确率: 98.68%, 训练损失: 0.0459

测试准确率: 99.30%, 测试损失: 0.0221

============================================================

Epoch [7/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:33<00:00, 27.59it/s, Loss=0.0261, Batch Acc=100.00%]

============================================================

Epoch 7/10 训练完成

训练准确率: 98.75%, 训练损失: 0.0414

测试准确率: 99.26%, 测试损失: 0.0237

============================================================

Epoch [8/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:34<00:00, 26.99it/s, Loss=0.0027, Batch Acc=100.00%]

============================================================

Epoch 8/10 训练完成

训练准确率: 98.80%, 训练损失: 0.0388

测试准确率: 99.20%, 测试损失: 0.0232

============================================================

Epoch [9/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:30<00:00, 30.67it/s, Loss=0.0218, Batch Acc=100.00%]

============================================================

Epoch 9/10 训练完成

训练准确率: 98.93%, 训练损失: 0.0356

测试准确率: 99.27%, 测试损失: 0.0222

============================================================

Epoch [10/10]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:33<00:00, 28.20it/s, Loss=0.0053, Batch Acc=100.00%]

============================================================

Epoch 10/10 训练完成

训练准确率: 98.89%, 训练损失: 0.0338

测试准确率: 99.28%, 测试损失: 0.0231

============================================================

最终评估模型性能...

最终测试准确率: 99.28%

生成可视化结果...

分类报告:

precision recall f1-score support

0 0.99 1.00 1.00 980

1 1.00 1.00 1.00 1135

2 0.99 1.00 1.00 1032

3 0.99 1.00 1.00 1010

4 0.99 0.98 0.99 982

5 0.99 0.99 0.99 892

6 1.00 0.99 0.99 958

7 0.99 0.99 0.99 1028

8 0.99 0.99 0.99 974

9 0.98 0.99 0.99 1009

accuracy 0.99 10000

macro avg 0.99 0.99 0.99 10000

weighted avg 0.99 0.99 0.99 10000

模型已保存到 mnist_cnn_model_final.pth

============================================================

训练总结:

============================================================

最终测试准确率: 99.28%

最佳测试准确率: 99.30% (第6个epoch)

最终训练准确率: 98.89%

✅ 成功达到96%以上的准确率目标!

============================================================

需要完整代码和数据集的小伙伴私信博主吧~

相关推荐
J_Xiong01171 分钟前
【Agents篇】04:Agent 的推理能力——思维链与自我反思
人工智能·ai agent·推理
星爷AG I21 分钟前
9-26 主动视觉(AGI基础理论)
人工智能·计算机视觉·agi
爱吃泡芙的小白白27 分钟前
CNN参数量计算全解析:从基础公式到前沿优化
人工智能·神经网络·cnn·参数量
拐爷37 分钟前
vibe‑coding 九阳神功之喂:把链接喂成“本地知识”,AI 才能稳定干活(API / 设计 / 报道 / 截图)
人工智能
石去皿38 分钟前
大模型面试通关指南:28道高频考题深度解析与实战要点
人工智能·python·面试·职场和发展
yuezhilangniao1 小时前
AI智能体全栈开发工程化规范 备忘 ~ fastAPI+Next.js
javascript·人工智能·fastapi
好奇龙猫1 小时前
【人工智能学习-AI入试相关题目练习-第十八次】
人工智能·学习
Guheyunyi1 小时前
智能守护:视频安全监测系统的演进与未来
大数据·人工智能·科技·安全·信息可视化
程序员辣条1 小时前
AI产品经理:2024年职场发展的新机遇
人工智能·学习·职场和发展·产品经理·大模型学习·大模型入门·大模型教程
AI大模型测试1 小时前
大龄程序员想转行到AI大模型,好转吗?
人工智能·深度学习·机器学习·ai·语言模型·职场和发展·大模型