落地级分类模型训练框架搭建(1):resnet18/50和mobilenetv2在CIFAR10上测试结果

目录

前言

1.分类结果测试汇总

2.训练过程可视化

ResNet18直接训练(准确率、召回率、Loss)

ResNet50直接训练(准确率、召回率、Loss)

3.模型权重分析

一般的训练,获取的模型权重分布

引入约束化训练的模型权重分布​编辑

4.测试结果

[模型[6] ResNet50测试结果](#模型[6] ResNet50测试结果)

[模型[7] ResNet18测试结果](#模型[7] ResNet18测试结果)

[模型[10] ResNet18_009测试结果](#模型[10] ResNet18_009测试结果)


前言

用CIFAR10测试一下自己搭建的分类模型训练框架,包括基本训练、知识蒸馏、模型稀疏化、剪枝微调、模型量化。最终,ResNet18取得96%的分类准确率,且剪枝91%的ResNet18可以取得94%的准确率。

1.分类结果测试汇总

说明:严格划分了训练集、验证集和测试集,没有将CIFAR10的测试集当做验证集

训练测试集为CIFAR10训练集随机采样95%,验证集为训练集剩余的5%,测试集为原CIFAR10的测试集。实验结果如下:

|--------|---------------|---------------|----------------------|-----------------|-----------|-----------|-----------------------|-------------------|-----------------------|
| 序号 | 模型名称 | 训练 方式 | 输入尺寸 (H x W) | 参数量 (M) | 准确率 | 召回率 | 耗时 (3060)(ms) | Flops (G) | 吞吐量(3060) (K) |
| 1 | MobileNetv3 | - | [64, 64] | 4.215 | 92.12 | 92.12 | 1.706 | 0.216 | 12.3 |
| 2 | ResNet18 | - | [64, 64] | 11.182 | 91.98 | 91.99 | 0.830 | 0.149 | 12.6 |
| 3 | ResNet18 | L1 | [64, 64] | 11.182 | 92.31 | 92.34 | 0.841 | 0.149 | 13.0 |
| 4 | ResNet18-best | L2 | [224, 224] | 11.182 | 95.12 | 95.01 | 1.206 | 1.824 | 1.6 |
| 5 | ResNet18-last | L2 | [224, 224] | 11.182 | 95.41 | 95.39 | 1.267 | 1.824 | 1.6 |
| 6 | ResNet50 | L2 | [224, 224] | 23.529 | 96.04 | 96.01 | 2.897 | 4.132 | 0.55 |
| 7 | ResNet18 | L2+di | [224, 224] | 11.182 | 96.01 | 95.89 | 1.199 | 1.824 | 1.6 |
| 8 | ResNet18 | L1+di | [224, 224] | 11.182 | 95.59 | 95.55 | 1.204 | 1.824 | 1.6 |
| 9 | ResNet18_009 | di | [224, 224] | 0.961 | 93.63 | 93.61 | 0.695 | 0.462 | 3.4 |
| 10 | ResNet18_009 | di+fu | [224, 224] | 0.961 | 93.98 | 93.95 | 0.692 | 0.462 | 3.4 |

说明L1 表示引入L1正则,L2 表示引入L2正则,di 表示使用知识蒸馏,fu 表示微调,009表示参数仅有原模型的9%,即剪枝了91%的参数。

测试结论

  1. CIFAR10上,不使用预训练模型,ResNet18准确率一般在92%~94%,使用预训练模型可以到95%,性能极限在96%左右。(修改第一层卷积核大小可以再提高一些)

2.框架使用ImageNet上的预训练模型,没有修改模型,就能达到95.41%(模型[5]),说明整个训练框架能够充分训练,且很难发生过拟合(最后一轮模型[5]比最好一轮模型[4]测试结果还高,且验证精度基本等于测试精度)。

3.使用知识蒸馏可以显著提高模型训练收敛速度和最终测试精度,模型[7]~[9]均使用模型[6]作为teacher model,且模型[6]仅直接训练一次,还未进行交叉验证的微调。

4.利用L2正则实现权重衰减,再使用L1实现权重置0剪枝,可以使ResNet18稀疏性达到91%,最终也剪枝了91%,导致了性能下降,实际剪枝84%可能更好。

2.训练过程可视化

ResNet18直接训练(准确率、召回率、Loss)

ResNet50直接训练(准确率、召回率、Loss)

结论:使用一系列的数据增强(随机翻转、旋转、裁剪、锐化、仿射变换、对比度调整、区域高斯模糊、区域翻转、区域打码、各种噪声等),外加正则化(L1、L2、dropout、batchnorm等),指数衰减学习率,模型基本不会过拟合。因此,可以对一个数据集使用交叉验证,尽可能拟合,即可提高同源数据的预测准确率(模型[9]到模型[10]的提升)。

3.模型权重分析

一般的训练,获取的模型权重分布

引入约束化训练的模型权重分布

结论:引入约束化训练,使得整体权重更接近高斯分布,却不会出现离群值,更方便后续的模型量化。

4.测试结果

模型[6] ResNet50测试结果

python 复制代码
Class airplane: Precision: 0.970884, Recall: 0.967000, F1-Score: 0.968938
Class automobile: Precision: 0.981763, Recall: 0.969000, F1-Score: 0.975340
Class bird: Precision: 0.947264, Recall: 0.952000, F1-Score: 0.949626
Class cat: Precision: 0.900771, Recall: 0.935000, F1-Score: 0.917566
Class deer: Precision: 0.966169, Recall: 0.971000, F1-Score: 0.968579
Class dog: Precision: 0.941117, Recall: 0.927000, F1-Score: 0.934005
Class frog: Precision: 0.983690, Recall: 0.965000, F1-Score: 0.974255
Class horse: Precision: 0.988753, Recall: 0.967000, F1-Score: 0.977755
Class ship: Precision: 0.970238, Recall: 0.978000, F1-Score: 0.974104
Class truck: Precision: 0.953786, Recall: 0.970000, F1-Score: 0.961824
Class macro avg: Precision: 0.960443, Recall: 0.960100, F1-Score: 0.960199

模型[7] ResNet18测试结果

python 复制代码
Class airplane: Precision: 0.964215, Recall: 0.970000, F1-Score: 0.967099
Class automobile: Precision: 0.975075, Recall: 0.978000, F1-Score: 0.976535
Class bird: Precision: 0.954683, Recall: 0.948000, F1-Score: 0.951330
Class cat: Precision: 0.895494, Recall: 0.934000, F1-Score: 0.914342
Class deer: Precision: 0.963964, Recall: 0.963000, F1-Score: 0.963482
Class dog: Precision: 0.947040, Recall: 0.912000, F1-Score: 0.929190
Class frog: Precision: 0.968379, Recall: 0.980000, F1-Score: 0.974155
Class horse: Precision: 0.992828, Recall: 0.969000, F1-Score: 0.980769
Class ship: Precision: 0.976884, Recall: 0.972000, F1-Score: 0.974436
Class truck: Precision: 0.962376, Recall: 0.972000, F1-Score: 0.967164
Class macro avg: Precision: 0.960094, Recall: 0.959800, F1-Score: 0.959850

模型[10] ResNet18_009测试结果

python 复制代码
Class airplane: Precision: 0.963001, Recall: 0.937000, F1-Score: 0.949823
Class automobile: Precision: 0.969031, Recall: 0.970000, F1-Score: 0.969515
Class bird: Precision: 0.929949, Recall: 0.916000, F1-Score: 0.922922
Class cat: Precision: 0.862275, Recall: 0.864000, F1-Score: 0.863137
Class deer: Precision: 0.937870, Recall: 0.951000, F1-Score: 0.944389
Class dog: Precision: 0.886076, Recall: 0.910000, F1-Score: 0.897879
Class frog: Precision: 0.949654, Recall: 0.962000, F1-Score: 0.955787
Class horse: Precision: 0.968813, Recall: 0.963000, F1-Score: 0.965898
Class ship: Precision: 0.975635, Recall: 0.961000, F1-Score: 0.968262
Class truck: Precision: 0.955268, Recall: 0.961000, F1-Score: 0.958126
Class macro avg: Precision: 0.939757, Recall: 0.939500, F1-Score: 0.939574
相关推荐
菜鸟一枚在这33 分钟前
深度解析建造者模式:复杂对象构建的优雅之道
java·开发语言·算法
谨慎谦虚36 分钟前
Trae 体验:探索被忽视的 Chat 模式
人工智能·trae
北极的树39 分钟前
AI驱动的大前端开发工作流
人工智能
gyeolhada1 小时前
2025蓝桥杯JAVA编程题练习Day5
java·数据结构·算法·蓝桥杯
阿巴~阿巴~1 小时前
多源 BFS 算法详解:从原理到实现,高效解决多源最短路问题
开发语言·数据结构·c++·算法·宽度优先
IT古董1 小时前
【漫话机器学习系列】100.L2 范数(L2 Norm,欧几里得范数)
人工智能·机器学习
给bug两拳1 小时前
Day9 25/2/22 SAT
算法
亲持红叶1 小时前
sklearn中的决策树-分类树:重要参数
决策树·分类·sklearn
小天努力学java1 小时前
【面试系列】Java开发--AI常见面试题
java·人工智能·面试
数造科技2 小时前
紧随“可信数据空间”政策风潮,数造科技正式加入开放数据空间联盟
大数据·人工智能·科技·安全·敏捷开发