用C语言构建一个手写数字识别神经网络

(原理和程序基本框架请参见前一篇 "用C语言构建了一个简单的神经网路")

1.准备训练和测试数据集

http://yann.lecun.com/exdb/mnist/下载手写数字训练数据集, 包括图像数据train-images-idx3-ubyte.gz 和标签数据 train-labels-idx1-ubyte.gz.

分别将他们解压后放在本地文件夹中,解压后文件名为train-images-idx3-ubyte和train-labels-idx1-ubyte. 训练数据集一共包含了6万个手写数字灰度图片和对应的标签.

为图方便,我们直接从训练数据集中提取5000个作为测试数据.当然,实际训练数据中并不包含这些测试数据.

2.设计神经网络

采用简单的三层全连接神经网络,包括输入层(wi),中间层(wm)和输出层(wo).这里暂时不使用卷积层,下次替换后进行比较.

输入层: 一共20个神经元,每一张手写数字的图片大小为28x28,将其全部展平后的784个灰度数据归一化,即除以255.0, 使其数值位于[0 1]区间,这样可以防止数据在层层计算和传递后变得过分大.将这784个[0 1]之间的数据与20个神经元进行全连接.神经元激活函数用func_ReLU.

中间层: 一共20个神经元,与输入层的20个神经元输出进行全连接.神经元激活函数用func_ReLU.

输出层: 一共10个神经元,分别对应0~9数字的可能性,与层中间的20个神经元输出进行全连接.神经层激活函数用func_softmax.

特别地,神经元的激活函数在new_nvcell()中设定,层的激活函数直接赋给nerve_layer->transfunc.

损失函数: 采用期望和预测值的交叉熵损失函数func_lossCrossEntropy. 损失函数在nvnet_feed_forward()中以参数形式输入.

3.训练神经网络

由于整个程序是以nvcell神经元结构为基础进行构建的,其不同于矩阵/张量形式的批量数据描述,因此这个神经网络只能以神经元为单位,逐个逐层地进行前向和反向传导.

相应地,这里采用SGD(Stochastic Gradient Descent)梯度下降更新法,即对每一个样本先进行前向和反向传导计算,接着根据计算得到的梯度值马上更新所有参数.与此不同,mini-batch GD方法采用的对小批量样本进行前向和反向传导计算,然后根据累积的梯度数值做1次参数更新.显然,采用SGD方法参数更新更加频繁,计算时间相应也变长了,但是,据分析,采用SGD也更容易达到全局最优解附近.本文程序里所做的分批计算是为了方便监控计算过程和打印中间值.(当然,要实现mini-batch GD也是可以的,先完成一批量样本的前后传导计算,期间将各参数的梯度累计起来, 最后取其平均值更新一次参数.)

这里使用平均损失值mean_err=0.0025来作为训练的终止条件,为防止无法收敛到此数值,同时设置最大的epoch计数.

训练的样本数量由TRAIN_IMGTOTAL来设定, 训练时,先读取一个样本数据和一个标签,分别存入到data_input[28*28]和data_target[10], 为了配合应用softmax函数,这里data_target[]是one-hot编码格式.读入样本数据后先进行前向传导计算nvnet_feed_forward(),接着进行反向传导计算nvnet_feed_backward(), 最后更新参数nvnet_update_params(), 这样就完成了一个样本的训练.如此循环计算,完成一次所有样本的训练(epoch)后计算mean_err, 看是否达到预设目标.

4.测试训练后的神经网络

训练完成后,对模型进行简单评估.方法就是用训练后的模型来预测(predict)或推理(infer)前面的测试数据集中的图像数据,将结果与对应的标签值做对比.

同样,将一个测试样本加载到data_input[], 跑一次nvnet_feed_forward(),直接读取输出层的wo_layer->douts[k] (k=0~9),如果其值大于0.5,就认为模型预测图像上的数字是k.

5.小结

取5万条训练样本进行训练,训练后再进行测试,其准确率可接近94%.

与卷积神经网络相比较,为达到相同的结果,全连接的神经网络的所需要的训练时间会更长.

源代码:

https://github.com/midaszhou/nnc

下载后编译:

make TEST_NAME=test_nnc2

相关推荐
陈苏同学几秒前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm
吾名招财19 分钟前
yolov5-7.0模型DNN加载函数及参数详解(重要)
c++·人工智能·yolo·dnn
FL162386312929 分钟前
[深度学习][python]yolov11+bytetrack+pyqt5实现目标追踪
深度学习·qt·yolo
羊小猪~~36 分钟前
深度学习项目----用LSTM模型预测股价(包含LSTM网络简介,代码数据均可下载)
pytorch·python·rnn·深度学习·机器学习·数据分析·lstm
Cons.W37 分钟前
Codeforces Round 975 (Div. 1) C. Tree Pruning
c语言·开发语言·剪枝
鼠鼠龙年发大财1 小时前
【鼠鼠学AI代码合集#7】概率
人工智能
龙的爹23331 小时前
论文 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·深度学习·语言模型·自然语言处理·prompt
挥剑决浮云 -1 小时前
Linux 之 安装软件、GCC编译器、Linux 操作系统基础
linux·服务器·c语言·c++·经验分享·笔记
工业机器视觉设计和实现1 小时前
cnn突破四(生成卷积核与固定核对比)
人工智能·深度学习·cnn