手写数字识别项目
这是一个使用PyTorch实现的手写数字识别项目,基于MNIST数据集训练了一个卷积神经网络模型。
项目结构
train.py
- 训练手写数字识别模型predict.py
- 使用训练好的模型预测图像中的数字download_mnist.py
- 直接下载MNIST数据集visualize_mnist.py
- 可视化MNIST数据集中的样本draw_and_predict.py
- 交互式绘图工具,可以绘制数字并进行实时预测
环境要求
- Python 3.6+
- PyTorch
- torchvision
- matplotlib
- numpy
- Pillow (PIL)
- tkinter (Python内置,用于交互式绘图工具)
可以使用以下命令安装所需依赖:
bash
pip install torch torchvision matplotlib numpy pillow
使用说明
1. 下载数据集
运行以下命令直接下载MNIST数据集:
bash
python download_mnist.py
2. 可视化数据集
查看MNIST数据集中的样本:
bash
python visualize_mnist.py
这将生成多个图像文件,显示数据集中的随机样本和每个数字的样本。
3. 训练模型
运行以下命令开始训练模型:
bash
python train.py
训练完成后,模型将保存为mnist_cnn.pt
。同时会生成以下文件:
sample_digits.png
- 显示训练数据集中的一些样本accuracy.png
- 显示训练过程中测试集准确率的变化
4. 预测图像
使用训练好的模型预测图像中的数字:
bash
python predict.py
按照提示输入图像路径,程序将显示预测结果。
5. 交互式绘图和预测
启动交互式绘图工具,可以自己绘制数字并实时预测:
bash
python draw_and_predict.py
使用方法:
- 在黑色画布上用鼠标绘制白色数字
- 调整画笔粗细
- 点击"预测"按钮进行识别
- 点击"清除"按钮清空画布
- 点击"保存"按钮保存当前绘制的图像
模型结构
该项目使用了一个简单的卷积神经网络(CNN),结构如下:
- 2个卷积层
- 最大池化层
- 2个全连接层
- Dropout用于防止过拟合
性能
在MNIST测试集上,该模型通常可以达到约99%的准确率。
代码
git clone https://gitee.com/wan_you_to/digital-recognition.git