一、Mnist数据集介绍
MNIST(Modified National Institute of Standards and Technology database)是一个大型的手写数字数据库,广泛用于训练和测试图像处理系统。它包含了从0到9的共10个类别的灰度手写数字图像。
数据集详情
-
来源:由美国国家标准与技术研究院(NIST)提供的原始数据集修改而来。
-
样本数量 :共有 70,000 张图像。
-
训练集:60,000 张
-
测试集:10,000 张
-
-
图像格式:
-
尺寸 :每张图像为 28x28 像素。
-
色彩 :灰度图,每个像素的值在0(黑色)到255(白色)之间。
-
数据格式 :通常被展平(Flatten) 成一个 784(28*28) 维的向量作为输入。
-
-
标签:每张图像都有一个对应的标签,是0到9之间的整数,表示图像中写的数字。
二、构建网络模型
网络结构: Conv2D -> ReLU -> MaxPool -> Conv2D -> ReLU -> MaxPool -> FC -> Dropout -> FC
代码实现:
python
class MNISTCNN(nn.Module):
"""
一个简单的CNN模型,专门为MNIST设计
网络结构:Conv2D -> ReLU -> MaxPool -> Conv2D -> ReLU -> MaxPool -> FC -> Dropout -> FC
"""
def __init__(self):
super(MNISTCNN, self).__init__()
# 卷积层1:输入通道1(灰度图),输出32个特征图,卷积核3x3
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
# 卷积层2:输入32,输出64
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
# 池化层
self.pool = nn.MaxPool2d(2, 2)
# Dropout层防止过拟合
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout(0.5)
# 全连接层
# 经过两次池化后,28x28 -> 14x14 -> 7x7
self.fc1 = nn.Linear(64 * 7 * 7, 128) # 7x7x64 -> 128
self.fc2 = nn.Linear(128, 10) # 128 -> 10个类别
def forward(self, x):
# 第一个卷积块
x = self.pool(F.relu(self.conv1(x))) # [batch, 32, 14, 14]
x = self.dropout1(x)
# 第二个卷积块
x = self.pool(F.relu(self.conv2(x))) # [batch, 64, 7, 7]
x = self.dropout1(x)
# 展平
x = x.view(-1, 64 * 7 * 7) # [batch, 3136]
# 全连接层
x = F.relu(self.fc1(x)) # [batch, 128]
x = self.dropout2(x)
x = self.fc2(x) # [batch, 10]
return x
三、数据加载和预处理
python
def load_and_preprocess_data():
"""
加载和预处理MNIST数据
"""
# 直接从torchvision下载MNIST
from torchvision import datasets
# 数据变换:转换为张量并归一化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST的均值和标准差
])
# 下载训练集
train_dataset = datasets.MNIST(
root='./mnist_dataset/train',
train=True,
download=False,
transform=transform
)
# 下载测试集
test_dataset = datasets.MNIST(
root='./mnist_dataset/train',
train=False,
download=False,
transform=transform
)
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
num_workers=2
)
test_loader = DataLoader(
test_dataset,
batch_size=1000,
shuffle=False,
num_workers=2
)
return train_loader, test_loader
代码实现:
python
def load_and_preprocess_data():
"""
加载和预处理MNIST数据
"""
# 直接从torchvision下载MNIST
from torchvision import datasets
# 数据变换:转换为张量并归一化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST的均值和标准差
])
# 下载训练集
train_dataset = datasets.MNIST(
root='./mnist_dataset/train',
train=True,
download=False,
transform=transform
)
# 下载测试集
test_dataset = datasets.MNIST(
root='./mnist_dataset/train',
train=False,
download=False,
transform=transform
)
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
num_workers=2
)
test_loader = DataLoader(
test_dataset,
batch_size=1000,
shuffle=False,
num_workers=2
)
return train_loader, test_loader
四、编写训练函数和推理函数
需要的完整代码的小伙伴可以私信我
五、模型最终预测结果
Using device: cuda
正在加载MNIST数据集...
创建CNN模型...
MNISTCNN(
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(dropout1): Dropout2d(p=0.25, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
(fc1): Linear(in_features=3136, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=10, bias=True)
)
总参数数量: 421,642
可训练参数数量: 421,642
开始训练CNN模型...
开始训练...
Epoch [1/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:37<00:00, 25.18it/s, Loss=0.0833, Batch Acc=100.00%]
============================================================
Epoch 1/10 训练完成
训练准确率: 91.92%, 训练损失: 0.2609
测试准确率: 98.06%, 测试损失: 0.0572
============================================================
Epoch [2/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:35<00:00, 26.35it/s, Loss=0.0187, Batch Acc=100.00%]
============================================================
Epoch 2/10 训练完成
训练准确率: 97.00%, 训练损失: 0.1036
测试准确率: 98.65%, 测试损失: 0.0388
============================================================
Epoch [3/10]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:33<00:00, 28.06it/s, Loss=0.0487, Batch Acc=96.88%]
============================================================
Epoch 3/10 训练完成
训练准确率: 97.56%, 训练损失: 0.0825
测试准确率: 99.04%, 测试损失: 0.0311
============================================================
Epoch [4/10]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:34<00:00, 27.14it/s, Loss=0.1242, Batch Acc=96.88%]
============================================================
Epoch 4/10 训练完成
训练准确率: 97.85%, 训练损失: 0.0713
测试准确率: 98.80%, 测试损失: 0.0350
============================================================
Epoch [5/10]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:32<00:00, 28.73it/s, Loss=0.1142, Batch Acc=96.88%]
============================================================
Epoch 5/10 训练完成
训练准确率: 98.17%, 训练损失: 0.0612
测试准确率: 99.19%, 测试损失: 0.0265
============================================================
Epoch [6/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:34<00:00, 27.44it/s, Loss=0.0197, Batch Acc=100.00%]
============================================================
Epoch 6/10 训练完成
训练准确率: 98.68%, 训练损失: 0.0459
测试准确率: 99.30%, 测试损失: 0.0221
============================================================
Epoch [7/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:33<00:00, 27.59it/s, Loss=0.0261, Batch Acc=100.00%]
============================================================
Epoch 7/10 训练完成
训练准确率: 98.75%, 训练损失: 0.0414
测试准确率: 99.26%, 测试损失: 0.0237
============================================================
Epoch [8/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:34<00:00, 26.99it/s, Loss=0.0027, Batch Acc=100.00%]
============================================================
Epoch 8/10 训练完成
训练准确率: 98.80%, 训练损失: 0.0388
测试准确率: 99.20%, 测试损失: 0.0232
============================================================
Epoch [9/10]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:30<00:00, 30.67it/s, Loss=0.0218, Batch Acc=100.00%]
============================================================
Epoch 9/10 训练完成
训练准确率: 98.93%, 训练损失: 0.0356
测试准确率: 99.27%, 测试损失: 0.0222
============================================================
Epoch [10/10]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:33<00:00, 28.20it/s, Loss=0.0053, Batch Acc=100.00%]
============================================================
Epoch 10/10 训练完成
训练准确率: 98.89%, 训练损失: 0.0338
测试准确率: 99.28%, 测试损失: 0.0231
============================================================
最终评估模型性能...
最终测试准确率: 99.28%
生成可视化结果...
分类报告:
precision recall f1-score support
0 0.99 1.00 1.00 980
1 1.00 1.00 1.00 1135
2 0.99 1.00 1.00 1032
3 0.99 1.00 1.00 1010
4 0.99 0.98 0.99 982
5 0.99 0.99 0.99 892
6 1.00 0.99 0.99 958
7 0.99 0.99 0.99 1028
8 0.99 0.99 0.99 974
9 0.98 0.99 0.99 1009
accuracy 0.99 10000
macro avg 0.99 0.99 0.99 10000
weighted avg 0.99 0.99 0.99 10000
模型已保存到 mnist_cnn_model_final.pth
============================================================
训练总结:
============================================================
最终测试准确率: 99.28%
最佳测试准确率: 99.30% (第6个epoch)
最终训练准确率: 98.89%
✅ 成功达到96%以上的准确率目标!
============================================================
需要完整代码和数据集的小伙伴私信博主吧~