基于Python实现的三层神经网络手写数字识别系统
项目概述
本项目实现了一个基于Python的三层神经网络,用于识别手写数字。系统使用MNIST数据集进行训练,通过反向传播算法更新权重,最终能够准确识别0-9的手写数字。
核心代码实现
python
import numpy
import cv2
import pickle
import scipy
# 读取文件函数
def read(file):
f=open(file,'r')
str=f.readlines()
return str
# 激活函数(Sigmoid函数)
def expit(data):
return scipy.special.expit(data)
class NnetWork:
# 初始化三层神经网络
def __init__(self,n1_num,n2_num,n3_num):
# 输入层节点个数
self.n1_num=n1_num
# 隐藏层节点个数
self.n2_num=n2_num
# 输出层节点个数
self.n3_num=n3_num
# 学习率,默认0.3
self.learnRate=0.3
# 创建输入到隐藏层的权重矩阵
self.w1=numpy.random.normal(0.0, pow(self.n1_num, -0.5),(self.n2_num, self.n1_num))
# 创建隐藏层到输出层的权重矩阵
self.w2=numpy.random.normal(0.0, pow(self.n2_num, -0.5),(self.n3_num, self.n2_num))
pass
# 训练函数,更新权重值
def xun(self,inputsL,targetsL):
inputs=numpy.array(inputsL,ndmin=2).T
in_n2=numpy.dot(self.w1, inputs)
ou_n2=expit(in_n2)
in_n3=numpy.dot(self.w2,ou_n2)
ou_n3=expit(in_n3)
targets=numpy.array(targetsL,ndmin=2).T
# 计算误差
e3=targets-ou_n3
e2=numpy.dot(self.w2.T,e3)
# 反向传播更新权重
self.w2 +=self.learnRate * numpy.dot((e3 *ou_n3 * (1.0- ou_n3)),numpy.transpose(ou_n2))
self.w1 +=self.learnRate * numpy.dot((e2 *ou_n2 * (1.0- ou_n2)),numpy.transpose(inputs))
pass
# 使用模型进行预测
def use(self,inputsL):
inputs=numpy.array(inputsL,ndmin=2).T
in_n2=numpy.dot(self.w1, inputs)
ou_n2=expit(in_n2)
in_n3=numpy.dot(self.w2,ou_n2)
ou_n3=expit(in_n3)
return ou_n3
pass
# 训练函数
def xun():
n=NnetWork(784,100,10)
i=0
xs_list=read("C:\\Users\\hua'wei\\Desktop\\mnist_train.csv")
for xs in xs_list:
xl=xs.strip().split(",")
flag=xl[0]
data=(numpy.asfarray(xl[1:])/255.0*0.99)+0.01
t=numpy.zeros(n.n3_num)+0.01
t[int(flag)]=0.99
print("训练数据:",flag,len(data))
n.xun(data,t)
if i>=10000:
break
i=i+1
# 保存权重
numpy.savetxt("w1.csv",n.w1)
numpy.savetxt("w2.csv",n.w2)
# 加载训练好的模型
n=NnetWork(784,100,10)
n.w1=numpy.loadtxt("w1.csv")
n.w2=numpy.loadtxt("w2.csv")
# 测试识别
xs_list=read("C:\\Users\\hua'wei\\Desktop\\d.csv")
for i in xs_list:
xl=i.strip().split(",")
data=(numpy.asfarray(xl[1:])/255.0*0.99)+0.01
res=n.use(data)
max=0
ii=0
xb=0
for i in res:
if i[0]>max:
max=i[0]
xb=ii
ii=ii+1
img=numpy.asfarray(xl[1:]).reshape((28,28))
cv2.imshow(str(xb),img)
cv2.waitKey(2000)
cv2.destroyAllWindows()
print("神经网络结果:",xb)
技术要点解析
1. 神经网络结构
- 输入层:784个节点(对应28×28像素的手写数字图像)
- 隐藏层:100个节点
- 输出层:10个节点(对应0-9十个数字)
2. 激活函数
使用Sigmoid函数作为激活函数,将输入信号转换为0-1之间的输出,模拟生物神经元的激活状态。
3. 权重初始化
采用正态分布随机初始化权重,均值为0,标准差为节点数的-0.5次方,确保初始权重不会过大或过小。
4. 反向传播算法
通过计算输出层误差,逐层反向传播更新权重,使用梯度下降法优化网络参数。
5. 数据预处理
将像素值从0-255归一化到0.01-1.0之间,避免Sigmoid函数在边界处梯度消失的问题。
使用说明
- 训练模型 :运行
xun()函数,使用MNIST训练集训练神经网络 - 保存权重:训练完成后自动保存权重到w1.csv和w2.csv文件
- 加载模型:直接加载已保存的权重文件
- 测试识别:使用测试集进行数字识别,并显示识别结果和图像
注意事项
- 需要安装numpy、cv2、scipy等依赖库
- 训练数据文件路径需要根据实际情况修改
- 训练10000个样本后自动停止,可根据需要调整训练次数
这个神经网络虽然结构简单,但能够有效识别手写数字,是学习神经网络原理的入门级项目。