【手搓神经网络:从零实现三层BP神经网络识别手写数字】

基于Python实现的三层神经网络手写数字识别系统

项目概述

本项目实现了一个基于Python的三层神经网络,用于识别手写数字。系统使用MNIST数据集进行训练,通过反向传播算法更新权重,最终能够准确识别0-9的手写数字。

核心代码实现

python 复制代码
import numpy
import cv2
import pickle
import scipy

# 读取文件函数
def read(file):
    f=open(file,'r')
    str=f.readlines()
    return str

# 激活函数(Sigmoid函数)
def expit(data):
    return scipy.special.expit(data)

class NnetWork:
    # 初始化三层神经网络
    def __init__(self,n1_num,n2_num,n3_num):
        # 输入层节点个数
        self.n1_num=n1_num
        # 隐藏层节点个数
        self.n2_num=n2_num
        # 输出层节点个数
        self.n3_num=n3_num
        # 学习率,默认0.3
        self.learnRate=0.3
        # 创建输入到隐藏层的权重矩阵
        self.w1=numpy.random.normal(0.0, pow(self.n1_num, -0.5),(self.n2_num, self.n1_num))
        # 创建隐藏层到输出层的权重矩阵
        self.w2=numpy.random.normal(0.0, pow(self.n2_num, -0.5),(self.n3_num, self.n2_num))
        pass
    
    # 训练函数,更新权重值
    def xun(self,inputsL,targetsL):
        inputs=numpy.array(inputsL,ndmin=2).T
        in_n2=numpy.dot(self.w1, inputs)
        ou_n2=expit(in_n2)
        in_n3=numpy.dot(self.w2,ou_n2)
        ou_n3=expit(in_n3)

        targets=numpy.array(targetsL,ndmin=2).T
        # 计算误差
        e3=targets-ou_n3
        e2=numpy.dot(self.w2.T,e3)
        # 反向传播更新权重
        self.w2 +=self.learnRate * numpy.dot((e3 *ou_n3 * (1.0- ou_n3)),numpy.transpose(ou_n2))
        self.w1 +=self.learnRate * numpy.dot((e2 *ou_n2 * (1.0- ou_n2)),numpy.transpose(inputs))
        pass
    
    # 使用模型进行预测
    def use(self,inputsL):
        inputs=numpy.array(inputsL,ndmin=2).T
        in_n2=numpy.dot(self.w1, inputs)
        ou_n2=expit(in_n2)
        in_n3=numpy.dot(self.w2,ou_n2)
        ou_n3=expit(in_n3)
        return ou_n3
        pass

# 训练函数
def xun():
    n=NnetWork(784,100,10)
    i=0
    xs_list=read("C:\\Users\\hua'wei\\Desktop\\mnist_train.csv")
    for xs in xs_list:
        xl=xs.strip().split(",")
        flag=xl[0]
        data=(numpy.asfarray(xl[1:])/255.0*0.99)+0.01
        t=numpy.zeros(n.n3_num)+0.01
        t[int(flag)]=0.99
        print("训练数据:",flag,len(data))
        n.xun(data,t)
        if i>=10000:
            break
        i=i+1
    # 保存权重
    numpy.savetxt("w1.csv",n.w1)
    numpy.savetxt("w2.csv",n.w2)

# 加载训练好的模型
n=NnetWork(784,100,10)
n.w1=numpy.loadtxt("w1.csv")
n.w2=numpy.loadtxt("w2.csv")

# 测试识别
xs_list=read("C:\\Users\\hua'wei\\Desktop\\d.csv")
for i in xs_list:
    xl=i.strip().split(",")
    data=(numpy.asfarray(xl[1:])/255.0*0.99)+0.01
    res=n.use(data)
    max=0
    ii=0
    xb=0
    for i in res:
        if i[0]>max:
            max=i[0]
            xb=ii
        ii=ii+1
    img=numpy.asfarray(xl[1:]).reshape((28,28))
    cv2.imshow(str(xb),img)
    cv2.waitKey(2000)
    cv2.destroyAllWindows()
    print("神经网络结果:",xb)

技术要点解析

1. 神经网络结构

  • 输入层:784个节点(对应28×28像素的手写数字图像)
  • 隐藏层:100个节点
  • 输出层:10个节点(对应0-9十个数字)

2. 激活函数

使用Sigmoid函数作为激活函数,将输入信号转换为0-1之间的输出,模拟生物神经元的激活状态。

3. 权重初始化

采用正态分布随机初始化权重,均值为0,标准差为节点数的-0.5次方,确保初始权重不会过大或过小。

4. 反向传播算法

通过计算输出层误差,逐层反向传播更新权重,使用梯度下降法优化网络参数。

5. 数据预处理

将像素值从0-255归一化到0.01-1.0之间,避免Sigmoid函数在边界处梯度消失的问题。

使用说明

  1. 训练模型 :运行xun()函数,使用MNIST训练集训练神经网络
  2. 保存权重:训练完成后自动保存权重到w1.csv和w2.csv文件
  3. 加载模型:直接加载已保存的权重文件
  4. 测试识别:使用测试集进行数字识别,并显示识别结果和图像

注意事项

  • 需要安装numpy、cv2、scipy等依赖库
  • 训练数据文件路径需要根据实际情况修改
  • 训练10000个样本后自动停止,可根据需要调整训练次数

这个神经网络虽然结构简单,但能够有效识别手写数字,是学习神经网络原理的入门级项目。

相关推荐
DogDaoDao2 分钟前
【GitHub】VoxCPM2 实战全解析:原理、部署与效果对比
深度学习·大模型·github·音频·语音模型·tss·文本生成语音
xrgs_shz3 分钟前
基于K-Means聚类分析的鸢尾花分类
人工智能·机器学习
Chef_Chen17 分钟前
论文解读:GAIA给通用AI助手泼冷水,人类92分GPT-4插件版只到30分
人工智能
Black蜡笔小新27 分钟前
自动化AI算法训练服务器DLTM训推一体工作站赋能多行业智能化升级
人工智能·算法·自动化
KaMeidebaby27 分钟前
卡梅德生物技术快报|噬菌体文库构建实验优化及偶联体系实验数据分析
大数据·人工智能·架构·spark·新浪微博
NineData33 分钟前
SQL 都在等锁时,ChatDBA 先帮 MySQL 找到谁在挡路
数据库·人工智能·sql·mysql·安全·数据复制·数据迁移工具
意图共鸣36 分钟前
意图共鸣科技《AI记忆链商业化白皮书3.0》技术解读:“AI焦虑的解药”——从通用AI到个人记忆链架构
人工智能·科技·架构
小e说说41 分钟前
AI 时代,IT 职业教育如何为学习者赋能?——职坐标的 AI+教育实践
人工智能
后端小肥肠44 分钟前
不会做视频的我,用 Codex 跑通口播 + 自动剪辑,获客 20+
人工智能·aigc·agent
某林2121 小时前
跨越底层与AI的鸿沟:ROS2+多模态大模型(Qwen-VL)机器人全链路排障实录
人工智能·stm32·机器人·人机交互·ros2·技术复盘