用C语言构建一个手写数字识别神经网络

(原理和程序基本框架请参见前一篇 "用C语言构建了一个简单的神经网路")

1.准备训练和测试数据集

http://yann.lecun.com/exdb/mnist/下载手写数字训练数据集, 包括图像数据train-images-idx3-ubyte.gz 和标签数据 train-labels-idx1-ubyte.gz.

分别将他们解压后放在本地文件夹中,解压后文件名为train-images-idx3-ubyte和train-labels-idx1-ubyte. 训练数据集一共包含了6万个手写数字灰度图片和对应的标签.

为图方便,我们直接从训练数据集中提取5000个作为测试数据.当然,实际训练数据中并不包含这些测试数据.

2.设计神经网络

采用简单的三层全连接神经网络,包括输入层(wi),中间层(wm)和输出层(wo).这里暂时不使用卷积层,下次替换后进行比较.

输入层: 一共20个神经元,每一张手写数字的图片大小为28x28,将其全部展平后的784个灰度数据归一化,即除以255.0, 使其数值位于[0 1]区间,这样可以防止数据在层层计算和传递后变得过分大.将这784个[0 1]之间的数据与20个神经元进行全连接.神经元激活函数用func_ReLU.

中间层: 一共20个神经元,与输入层的20个神经元输出进行全连接.神经元激活函数用func_ReLU.

输出层: 一共10个神经元,分别对应0~9数字的可能性,与层中间的20个神经元输出进行全连接.神经层激活函数用func_softmax.

特别地,神经元的激活函数在new_nvcell()中设定,层的激活函数直接赋给nerve_layer->transfunc.

损失函数: 采用期望和预测值的交叉熵损失函数func_lossCrossEntropy. 损失函数在nvnet_feed_forward()中以参数形式输入.

3.训练神经网络

由于整个程序是以nvcell神经元结构为基础进行构建的,其不同于矩阵/张量形式的批量数据描述,因此这个神经网络只能以神经元为单位,逐个逐层地进行前向和反向传导.

相应地,这里采用SGD(Stochastic Gradient Descent)梯度下降更新法,即对每一个样本先进行前向和反向传导计算,接着根据计算得到的梯度值马上更新所有参数.与此不同,mini-batch GD方法采用的对小批量样本进行前向和反向传导计算,然后根据累积的梯度数值做1次参数更新.显然,采用SGD方法参数更新更加频繁,计算时间相应也变长了,但是,据分析,采用SGD也更容易达到全局最优解附近.本文程序里所做的分批计算是为了方便监控计算过程和打印中间值.(当然,要实现mini-batch GD也是可以的,先完成一批量样本的前后传导计算,期间将各参数的梯度累计起来, 最后取其平均值更新一次参数.)

这里使用平均损失值mean_err=0.0025来作为训练的终止条件,为防止无法收敛到此数值,同时设置最大的epoch计数.

训练的样本数量由TRAIN_IMGTOTAL来设定, 训练时,先读取一个样本数据和一个标签,分别存入到data_input[28*28]和data_target[10], 为了配合应用softmax函数,这里data_target[]是one-hot编码格式.读入样本数据后先进行前向传导计算nvnet_feed_forward(),接着进行反向传导计算nvnet_feed_backward(), 最后更新参数nvnet_update_params(), 这样就完成了一个样本的训练.如此循环计算,完成一次所有样本的训练(epoch)后计算mean_err, 看是否达到预设目标.

4.测试训练后的神经网络

训练完成后,对模型进行简单评估.方法就是用训练后的模型来预测(predict)或推理(infer)前面的测试数据集中的图像数据,将结果与对应的标签值做对比.

同样,将一个测试样本加载到data_input[], 跑一次nvnet_feed_forward(),直接读取输出层的wo_layer->douts[k] (k=0~9),如果其值大于0.5,就认为模型预测图像上的数字是k.

5.小结

取5万条训练样本进行训练,训练后再进行测试,其准确率可接近94%.

与卷积神经网络相比较,为达到相同的结果,全连接的神经网络的所需要的训练时间会更长.

源代码:

https://github.com/midaszhou/nnc

下载后编译:

make TEST_NAME=test_nnc2

相关推荐
TF男孩8 小时前
重新认识Markdown:它不仅是排版工具,更是写Prompt的最佳结构
人工智能
想打游戏的程序猿9 小时前
AI时代的内容输出
人工智能
小兵张健9 小时前
Playwright MCP 截图标注方案调研:推荐方案 1
人工智能
凌杰11 小时前
AI 学习笔记:Agent 的能力体系
人工智能
IT_陈寒12 小时前
React状态管理终极对决:Redux vs Context API谁更胜一筹?
前端·人工智能·后端
舒一笑13 小时前
如何获取最新的技术趋势和热门技术
人工智能·程序员
聚客AI14 小时前
🎉OpenClaw深度解析:多智能体协同的三种模式、四大必装技能与自动化运维秘籍
人工智能·开源·agent
黄粱梦醒14 小时前
大模型企业级部署方案-vllm
人工智能·llm
IT_陈寒14 小时前
JavaScript代码效率提升50%?这5个优化技巧你必须知道!
前端·人工智能·后端
IT_陈寒14 小时前
Java开发必知的5个性能优化黑科技,提升50%效率不是梦!
前端·人工智能·后端