落地级分类模型训练框架搭建(1):resnet18/50和mobilenetv2在CIFAR10上测试结果

目录

前言

1.分类结果测试汇总

2.训练过程可视化

ResNet18直接训练(准确率、召回率、Loss)

ResNet50直接训练(准确率、召回率、Loss)

3.模型权重分析

一般的训练,获取的模型权重分布

引入约束化训练的模型权重分布​编辑

4.测试结果

[模型[6] ResNet50测试结果](#模型[6] ResNet50测试结果)

[模型[7] ResNet18测试结果](#模型[7] ResNet18测试结果)

[模型[10] ResNet18_009测试结果](#模型[10] ResNet18_009测试结果)


前言

用CIFAR10测试一下自己搭建的分类模型训练框架,包括基本训练、知识蒸馏、模型稀疏化、剪枝微调、模型量化。最终,ResNet18取得96%的分类准确率,且剪枝91%的ResNet18可以取得94%的准确率。

1.分类结果测试汇总

说明:严格划分了训练集、验证集和测试集,没有将CIFAR10的测试集当做验证集

训练测试集为CIFAR10训练集随机采样95%,验证集为训练集剩余的5%,测试集为原CIFAR10的测试集。实验结果如下:

|--------|---------------|---------------|----------------------|-----------------|-----------|-----------|-----------------------|-------------------|-----------------------|
| 序号 | 模型名称 | 训练 方式 | 输入尺寸 (H x W) | 参数量 (M) | 准确率 | 召回率 | 耗时 (3060)(ms) | Flops (G) | 吞吐量(3060) (K) |
| 1 | MobileNetv3 | - | [64, 64] | 4.215 | 92.12 | 92.12 | 1.706 | 0.216 | 12.3 |
| 2 | ResNet18 | - | [64, 64] | 11.182 | 91.98 | 91.99 | 0.830 | 0.149 | 12.6 |
| 3 | ResNet18 | L1 | [64, 64] | 11.182 | 92.31 | 92.34 | 0.841 | 0.149 | 13.0 |
| 4 | ResNet18-best | L2 | [224, 224] | 11.182 | 95.12 | 95.01 | 1.206 | 1.824 | 1.6 |
| 5 | ResNet18-last | L2 | [224, 224] | 11.182 | 95.41 | 95.39 | 1.267 | 1.824 | 1.6 |
| 6 | ResNet50 | L2 | [224, 224] | 23.529 | 96.04 | 96.01 | 2.897 | 4.132 | 0.55 |
| 7 | ResNet18 | L2+di | [224, 224] | 11.182 | 96.01 | 95.89 | 1.199 | 1.824 | 1.6 |
| 8 | ResNet18 | L1+di | [224, 224] | 11.182 | 95.59 | 95.55 | 1.204 | 1.824 | 1.6 |
| 9 | ResNet18_009 | di | [224, 224] | 0.961 | 93.63 | 93.61 | 0.695 | 0.462 | 3.4 |
| 10 | ResNet18_009 | di+fu | [224, 224] | 0.961 | 93.98 | 93.95 | 0.692 | 0.462 | 3.4 |

说明L1 表示引入L1正则,L2 表示引入L2正则,di 表示使用知识蒸馏,fu 表示微调,009表示参数仅有原模型的9%,即剪枝了91%的参数。

测试结论

  1. CIFAR10上,不使用预训练模型,ResNet18准确率一般在92%~94%,使用预训练模型可以到95%,性能极限在96%左右。(修改第一层卷积核大小可以再提高一些)

2.框架使用ImageNet上的预训练模型,没有修改模型,就能达到95.41%(模型[5]),说明整个训练框架能够充分训练,且很难发生过拟合(最后一轮模型[5]比最好一轮模型[4]测试结果还高,且验证精度基本等于测试精度)。

3.使用知识蒸馏可以显著提高模型训练收敛速度和最终测试精度,模型[7]~[9]均使用模型[6]作为teacher model,且模型[6]仅直接训练一次,还未进行交叉验证的微调。

4.利用L2正则实现权重衰减,再使用L1实现权重置0剪枝,可以使ResNet18稀疏性达到91%,最终也剪枝了91%,导致了性能下降,实际剪枝84%可能更好。

2.训练过程可视化

ResNet18直接训练(准确率、召回率、Loss)

ResNet50直接训练(准确率、召回率、Loss)

结论:使用一系列的数据增强(随机翻转、旋转、裁剪、锐化、仿射变换、对比度调整、区域高斯模糊、区域翻转、区域打码、各种噪声等),外加正则化(L1、L2、dropout、batchnorm等),指数衰减学习率,模型基本不会过拟合。因此,可以对一个数据集使用交叉验证,尽可能拟合,即可提高同源数据的预测准确率(模型[9]到模型[10]的提升)。

3.模型权重分析

一般的训练,获取的模型权重分布

引入约束化训练的模型权重分布

结论:引入约束化训练,使得整体权重更接近高斯分布,却不会出现离群值,更方便后续的模型量化。

4.测试结果

模型[6] ResNet50测试结果

python 复制代码
Class airplane: Precision: 0.970884, Recall: 0.967000, F1-Score: 0.968938
Class automobile: Precision: 0.981763, Recall: 0.969000, F1-Score: 0.975340
Class bird: Precision: 0.947264, Recall: 0.952000, F1-Score: 0.949626
Class cat: Precision: 0.900771, Recall: 0.935000, F1-Score: 0.917566
Class deer: Precision: 0.966169, Recall: 0.971000, F1-Score: 0.968579
Class dog: Precision: 0.941117, Recall: 0.927000, F1-Score: 0.934005
Class frog: Precision: 0.983690, Recall: 0.965000, F1-Score: 0.974255
Class horse: Precision: 0.988753, Recall: 0.967000, F1-Score: 0.977755
Class ship: Precision: 0.970238, Recall: 0.978000, F1-Score: 0.974104
Class truck: Precision: 0.953786, Recall: 0.970000, F1-Score: 0.961824
Class macro avg: Precision: 0.960443, Recall: 0.960100, F1-Score: 0.960199

模型[7] ResNet18测试结果

python 复制代码
Class airplane: Precision: 0.964215, Recall: 0.970000, F1-Score: 0.967099
Class automobile: Precision: 0.975075, Recall: 0.978000, F1-Score: 0.976535
Class bird: Precision: 0.954683, Recall: 0.948000, F1-Score: 0.951330
Class cat: Precision: 0.895494, Recall: 0.934000, F1-Score: 0.914342
Class deer: Precision: 0.963964, Recall: 0.963000, F1-Score: 0.963482
Class dog: Precision: 0.947040, Recall: 0.912000, F1-Score: 0.929190
Class frog: Precision: 0.968379, Recall: 0.980000, F1-Score: 0.974155
Class horse: Precision: 0.992828, Recall: 0.969000, F1-Score: 0.980769
Class ship: Precision: 0.976884, Recall: 0.972000, F1-Score: 0.974436
Class truck: Precision: 0.962376, Recall: 0.972000, F1-Score: 0.967164
Class macro avg: Precision: 0.960094, Recall: 0.959800, F1-Score: 0.959850

模型[10] ResNet18_009测试结果

python 复制代码
Class airplane: Precision: 0.963001, Recall: 0.937000, F1-Score: 0.949823
Class automobile: Precision: 0.969031, Recall: 0.970000, F1-Score: 0.969515
Class bird: Precision: 0.929949, Recall: 0.916000, F1-Score: 0.922922
Class cat: Precision: 0.862275, Recall: 0.864000, F1-Score: 0.863137
Class deer: Precision: 0.937870, Recall: 0.951000, F1-Score: 0.944389
Class dog: Precision: 0.886076, Recall: 0.910000, F1-Score: 0.897879
Class frog: Precision: 0.949654, Recall: 0.962000, F1-Score: 0.955787
Class horse: Precision: 0.968813, Recall: 0.963000, F1-Score: 0.965898
Class ship: Precision: 0.975635, Recall: 0.961000, F1-Score: 0.968262
Class truck: Precision: 0.955268, Recall: 0.961000, F1-Score: 0.958126
Class macro avg: Precision: 0.939757, Recall: 0.939500, F1-Score: 0.939574
相关推荐
Mintopia11 分钟前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮1 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬1 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia1 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区2 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两4 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪5 小时前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
strayCat232555 小时前
Clawdbot 源码解读 7: 扩展机制
人工智能·开源
程序员打怪兽5 小时前
详解Visual Transformer (ViT)网络模型
深度学习