基于CNN实现Mnist手写数字识别

一、Mnist数据集介绍

MNIST(Modified National Institute of Standards and Technology database)是一个大型的手写数字数据库,广泛用于训练和测试图像处理系统。它包含了从0到9的共10个类别的灰度手写数字图像。

数据集详情

  • 来源:由美国国家标准与技术研究院(NIST)提供的原始数据集修改而来。

  • 样本数量 :共有 70,000 张图像。

    • 训练集:60,000 张

    • 测试集:10,000 张

  • 图像格式

    • 尺寸 :每张图像为 28x28 像素。

    • 色彩灰度图,每个像素的值在0(黑色)到255(白色)之间。

    • 数据格式 :通常被展平(Flatten) 成一个 784(28*28) 维的向量作为输入。

  • 标签:每张图像都有一个对应的标签,是0到9之间的整数,表示图像中写的数字。

二、构建网络模型

复制代码
网络结构:
Conv2D -> ReLU -> MaxPool -> Conv2D -> ReLU -> MaxPool -> FC -> Dropout -> FC

代码实现:

python 复制代码
class MNISTCNN(nn.Module):
    """
    一个简单的CNN模型,专门为MNIST设计
    网络结构:Conv2D -> ReLU -> MaxPool -> Conv2D -> ReLU -> MaxPool -> FC -> Dropout -> FC
    """
    def __init__(self):
        super(MNISTCNN, self).__init__()
        
        # 卷积层1:输入通道1(灰度图),输出32个特征图,卷积核3x3
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        # 卷积层2:输入32,输出64
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        
        # 池化层
        self.pool = nn.MaxPool2d(2, 2)
        
        # Dropout层防止过拟合
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout(0.5)
        
        # 全连接层
        # 经过两次池化后,28x28 -> 14x14 -> 7x7
        self.fc1 = nn.Linear(64 * 7 * 7, 128)  # 7x7x64 -> 128
        self.fc2 = nn.Linear(128, 10)  # 128 -> 10个类别
        
    def forward(self, x):
        # 第一个卷积块
        x = self.pool(F.relu(self.conv1(x)))  # [batch, 32, 14, 14]
        x = self.dropout1(x)
        
        # 第二个卷积块
        x = self.pool(F.relu(self.conv2(x)))  # [batch, 64, 7, 7]
        x = self.dropout1(x)
        
        # 展平
        x = x.view(-1, 64 * 7 * 7)  # [batch, 3136]
        
        # 全连接层
        x = F.relu(self.fc1(x))  # [batch, 128]
        x = self.dropout2(x)
        x = self.fc2(x)  # [batch, 10]
        
        return x

三、数据加载和预处理

python 复制代码
def load_and_preprocess_data():
    """
    加载和预处理MNIST数据
    """
    # 直接从torchvision下载MNIST
    from torchvision import datasets
    
    # 数据变换:转换为张量并归一化
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST的均值和标准差
    ])
    
    # 下载训练集
    train_dataset = datasets.MNIST(
        root='./mnist_dataset/train',
        train=True,
        download=False,
        transform=transform
    )
    
    # 下载测试集
    test_dataset = datasets.MNIST(
        root='./mnist_dataset/train',
        train=False,
        download=False,
        transform=transform
    )
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=64,
        shuffle=True,
        num_workers=2
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=2
    )
    
    return train_loader, test_loader

代码实现:

python 复制代码
def load_and_preprocess_data():
    """
    加载和预处理MNIST数据
    """
    # 直接从torchvision下载MNIST
    from torchvision import datasets
    
    # 数据变换:转换为张量并归一化
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST的均值和标准差
    ])
    
    # 下载训练集
    train_dataset = datasets.MNIST(
        root='./mnist_dataset/train',
        train=True,
        download=False,
        transform=transform
    )
    
    # 下载测试集
    test_dataset = datasets.MNIST(
        root='./mnist_dataset/train',
        train=False,
        download=False,
        transform=transform
    )
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=64,
        shuffle=True,
        num_workers=2
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=2
    )
    
    return train_loader, test_loader

四、编写训练函数和推理函数

需要的完整代码的小伙伴可以私信我

五、模型最终预测结果

Using device: cuda

正在加载MNIST数据集...

创建CNN模型...

MNISTCNN(

(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

(dropout1): Dropout2d(p=0.25, inplace=False)

(dropout2): Dropout(p=0.5, inplace=False)

(fc1): Linear(in_features=3136, out_features=128, bias=True)

(fc2): Linear(in_features=128, out_features=10, bias=True)

)

总参数数量: 421,642

可训练参数数量: 421,642

开始训练CNN模型...

开始训练...

Epoch [1/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:37<00:00, 25.18it/s, Loss=0.0833, Batch Acc=100.00%]

============================================================

Epoch 1/10 训练完成

训练准确率: 91.92%, 训练损失: 0.2609

测试准确率: 98.06%, 测试损失: 0.0572

============================================================

Epoch [2/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:35<00:00, 26.35it/s, Loss=0.0187, Batch Acc=100.00%]

============================================================

Epoch 2/10 训练完成

训练准确率: 97.00%, 训练损失: 0.1036

测试准确率: 98.65%, 测试损失: 0.0388

============================================================

Epoch [3/10]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:33<00:00, 28.06it/s, Loss=0.0487, Batch Acc=96.88%]

============================================================

Epoch 3/10 训练完成

训练准确率: 97.56%, 训练损失: 0.0825

测试准确率: 99.04%, 测试损失: 0.0311

============================================================

Epoch [4/10]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:34<00:00, 27.14it/s, Loss=0.1242, Batch Acc=96.88%]

============================================================

Epoch 4/10 训练完成

训练准确率: 97.85%, 训练损失: 0.0713

测试准确率: 98.80%, 测试损失: 0.0350

============================================================

Epoch [5/10]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:32<00:00, 28.73it/s, Loss=0.1142, Batch Acc=96.88%]

============================================================

Epoch 5/10 训练完成

训练准确率: 98.17%, 训练损失: 0.0612

测试准确率: 99.19%, 测试损失: 0.0265

============================================================

Epoch [6/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:34<00:00, 27.44it/s, Loss=0.0197, Batch Acc=100.00%]

============================================================

Epoch 6/10 训练完成

训练准确率: 98.68%, 训练损失: 0.0459

测试准确率: 99.30%, 测试损失: 0.0221

============================================================

Epoch [7/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:33<00:00, 27.59it/s, Loss=0.0261, Batch Acc=100.00%]

============================================================

Epoch 7/10 训练完成

训练准确率: 98.75%, 训练损失: 0.0414

测试准确率: 99.26%, 测试损失: 0.0237

============================================================

Epoch [8/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:34<00:00, 26.99it/s, Loss=0.0027, Batch Acc=100.00%]

============================================================

Epoch 8/10 训练完成

训练准确率: 98.80%, 训练损失: 0.0388

测试准确率: 99.20%, 测试损失: 0.0232

============================================================

Epoch [9/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:30<00:00, 30.67it/s, Loss=0.0218, Batch Acc=100.00%]

============================================================

Epoch 9/10 训练完成

训练准确率: 98.93%, 训练损失: 0.0356

测试准确率: 99.27%, 测试损失: 0.0222

============================================================

Epoch [10/10]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:33<00:00, 28.20it/s, Loss=0.0053, Batch Acc=100.00%]

============================================================

Epoch 10/10 训练完成

训练准确率: 98.89%, 训练损失: 0.0338

测试准确率: 99.28%, 测试损失: 0.0231

============================================================

最终评估模型性能...

最终测试准确率: 99.28%

生成可视化结果...

分类报告:

precision recall f1-score support

0 0.99 1.00 1.00 980

1 1.00 1.00 1.00 1135

2 0.99 1.00 1.00 1032

3 0.99 1.00 1.00 1010

4 0.99 0.98 0.99 982

5 0.99 0.99 0.99 892

6 1.00 0.99 0.99 958

7 0.99 0.99 0.99 1028

8 0.99 0.99 0.99 974

9 0.98 0.99 0.99 1009

accuracy 0.99 10000

macro avg 0.99 0.99 0.99 10000

weighted avg 0.99 0.99 0.99 10000

模型已保存到 mnist_cnn_model_final.pth

============================================================

训练总结:

============================================================

最终测试准确率: 99.28%

最佳测试准确率: 99.30% (第6个epoch)

最终训练准确率: 98.89%

✅ 成功达到96%以上的准确率目标!

============================================================

需要完整代码和数据集的小伙伴私信博主吧~

相关推荐
说私域2 小时前
基于AI智能名片链动2+1模式预约服务商城小程序的数据管理与系统集成研究
大数据·人工智能·小程序
AC赳赳老秦2 小时前
技术文档合著:DeepSeek辅助多人协作文档的风格统一与内容补全
android·大数据·人工智能·微服务·golang·自动化·deepseek
咚咚王者2 小时前
人工智能之核心基础 机器学习 第十四章 半监督与自监督学习总结归纳
人工智能·学习·机器学习
风栖柳白杨2 小时前
【语音识别】SenseVoice非流式改流式
人工智能·语音识别
Aloudata2 小时前
企业落地 AI 数据分析,如何做好敏感数据安全防护?
人工智能·安全·数据挖掘·数据分析·chatbi·智能问数·dataagent
安达发公司2 小时前
安达发|煤炭行业APS高级排产:开启高效生产新时代
大数据·人工智能·aps高级排程·安达发aps·车间排产软件·aps高级排产
中科天工2 小时前
如何实现工业4.0智能制造的自动化包装解决方案?
大数据·人工智能·智能
ai_top_trends2 小时前
AI 生成 PPT 工具横评:效率、质量、稳定性一次说清
人工智能·python·powerpoint
三千世界0062 小时前
Claude Code Agent Skills 自动发现原理详解
人工智能·ai·大模型·agent·claude·原理