用C语言构建一个手写数字识别神经网络

(原理和程序基本框架请参见前一篇 "用C语言构建了一个简单的神经网路")

1.准备训练和测试数据集

http://yann.lecun.com/exdb/mnist/下载手写数字训练数据集, 包括图像数据train-images-idx3-ubyte.gz 和标签数据 train-labels-idx1-ubyte.gz.

分别将他们解压后放在本地文件夹中,解压后文件名为train-images-idx3-ubyte和train-labels-idx1-ubyte. 训练数据集一共包含了6万个手写数字灰度图片和对应的标签.

为图方便,我们直接从训练数据集中提取5000个作为测试数据.当然,实际训练数据中并不包含这些测试数据.

2.设计神经网络

采用简单的三层全连接神经网络,包括输入层(wi),中间层(wm)和输出层(wo).这里暂时不使用卷积层,下次替换后进行比较.

输入层: 一共20个神经元,每一张手写数字的图片大小为28x28,将其全部展平后的784个灰度数据归一化,即除以255.0, 使其数值位于[0 1]区间,这样可以防止数据在层层计算和传递后变得过分大.将这784个[0 1]之间的数据与20个神经元进行全连接.神经元激活函数用func_ReLU.

中间层: 一共20个神经元,与输入层的20个神经元输出进行全连接.神经元激活函数用func_ReLU.

输出层: 一共10个神经元,分别对应0~9数字的可能性,与层中间的20个神经元输出进行全连接.神经层激活函数用func_softmax.

特别地,神经元的激活函数在new_nvcell()中设定,层的激活函数直接赋给nerve_layer->transfunc.

损失函数: 采用期望和预测值的交叉熵损失函数func_lossCrossEntropy. 损失函数在nvnet_feed_forward()中以参数形式输入.

3.训练神经网络

由于整个程序是以nvcell神经元结构为基础进行构建的,其不同于矩阵/张量形式的批量数据描述,因此这个神经网络只能以神经元为单位,逐个逐层地进行前向和反向传导.

相应地,这里采用SGD(Stochastic Gradient Descent)梯度下降更新法,即对每一个样本先进行前向和反向传导计算,接着根据计算得到的梯度值马上更新所有参数.与此不同,mini-batch GD方法采用的对小批量样本进行前向和反向传导计算,然后根据累积的梯度数值做1次参数更新.显然,采用SGD方法参数更新更加频繁,计算时间相应也变长了,但是,据分析,采用SGD也更容易达到全局最优解附近.本文程序里所做的分批计算是为了方便监控计算过程和打印中间值.(当然,要实现mini-batch GD也是可以的,先完成一批量样本的前后传导计算,期间将各参数的梯度累计起来, 最后取其平均值更新一次参数.)

这里使用平均损失值mean_err=0.0025来作为训练的终止条件,为防止无法收敛到此数值,同时设置最大的epoch计数.

训练的样本数量由TRAIN_IMGTOTAL来设定, 训练时,先读取一个样本数据和一个标签,分别存入到data_input[28*28]和data_target[10], 为了配合应用softmax函数,这里data_target[]是one-hot编码格式.读入样本数据后先进行前向传导计算nvnet_feed_forward(),接着进行反向传导计算nvnet_feed_backward(), 最后更新参数nvnet_update_params(), 这样就完成了一个样本的训练.如此循环计算,完成一次所有样本的训练(epoch)后计算mean_err, 看是否达到预设目标.

4.测试训练后的神经网络

训练完成后,对模型进行简单评估.方法就是用训练后的模型来预测(predict)或推理(infer)前面的测试数据集中的图像数据,将结果与对应的标签值做对比.

同样,将一个测试样本加载到data_input[], 跑一次nvnet_feed_forward(),直接读取输出层的wo_layer->douts[k] (k=0~9),如果其值大于0.5,就认为模型预测图像上的数字是k.

5.小结

取5万条训练样本进行训练,训练后再进行测试,其准确率可接近94%.

与卷积神经网络相比较,为达到相同的结果,全连接的神经网络的所需要的训练时间会更长.

源代码:

https://github.com/midaszhou/nnc

下载后编译:

make TEST_NAME=test_nnc2

相关推荐
南东山人1 小时前
一文说清:C和C++混合编程
c语言·c++
yusaisai大鱼1 小时前
TensorFlow如何调用GPU?
人工智能·tensorflow
stm 学习ing1 小时前
FPGA 第十讲 避免latch的产生
c语言·开发语言·单片机·嵌入式硬件·fpga开发·fpga
珠海新立电子科技有限公司4 小时前
FPC柔性线路板与智能生活的融合
人工智能·生活·制造
IT古董4 小时前
【机器学习】机器学习中用到的高等数学知识-8. 图论 (Graph Theory)
人工智能·机器学习·图论
曼城周杰伦4 小时前
自然语言处理:第六十三章 阿里Qwen2 & 2.5系列
人工智能·阿里云·语言模型·自然语言处理·chatgpt·nlp·gpt-3
余炜yw5 小时前
【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感
人工智能·rnn·深度学习
莫叫石榴姐5 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
96775 小时前
对抗样本存在的原因
深度学习
如若1235 小时前
利用 `OpenCV` 和 `Matplotlib` 库进行图像读取、颜色空间转换、掩膜创建、颜色替换
人工智能·opencv·matplotlib