项目经验总结|修改深度学习模型训练结果的标签

文章背景

笔者不是研究深度学习算法的,但是在工作过程中会稍微涉及一点算法工程化的内容。通俗来说就是对成熟的算法进行应用落地,并且使该应用能够与特定机器进行适配。

截至目前,笔者已经跑了好几个深度学习-计算机视觉方面的算法应用了。它们被封装的程度各异,所以对a应用问题的解决方案没法直接应用到b应用上,不过总体的思路是一致的。

项目需求

笔者要解决的问题是:原本的模型展示的标签结果是数字0123...,现在需要把它更改为对应的物品名称。

解决思路

  • 标签结果的数字不是随机产生的,而是有一定的规律。比如检测到人就是0,检测到汽车就是1。说明源代码里面一定是定义了一个从数字下标到文本信息的映射信息,只是它没有被用上。所以关键在于找到这个映射关系
  • 修改界面:找到显示方框、标签、计算概率的代码,修改其中的标签值。

源码定位

通过搜索源码可知,在代码的model目录下有一个onnx文件,在文件的最后一行,以数组形式记录了标签数字与文本的映射关系:

css 复制代码
namesÝ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']

不过项目量化时所使用的是另一个onnx文件,另一个文件里没有这段定义。

什么是onnx文件?

Open Neural Network Exchange(开放神经网络交换),由微软和Facebook提出,是用来表示深度学习模型的开放格式。所谓开放就是ONNX定义了一组和环境,平台均无关的标准格式,来增强各种AI模型的可交互性。

  • 无论使用何种训练框架训练模型(比如TensorFlow/Pytorch/OneFlow/Paddle),在训练完毕后你都可以将这些框架的模型统一转换为ONNX这种统一的格式进行存储。
  • ONNX文件不仅仅存储了神经网络模型的权重,同时也存储了模型的结构信息以及网络中每一层的输入输出和一些其它的辅助信息。

而在network目录下,有一个和onnx文件同前缀的cpp文件,其中有关于显示结果的代码:

cpp 复制代码
sprintf(text, "type-%d, score-%.2f", pObj->type, (double)pObj->score);
cv::putText(img, text, Point, cv::FONT_HERSHEY_SIMPLEX, 1.0, cvScalar(0, 255, 0), 1);
  • 第一行定义了文本信息的格式
  • 第二行通过opencv库将文本信息与图像等等结合起来,最终体现为在视频流帧中能对识别出的移动物体生成动态的方框、标签及识别物体的概率。

至此,两处关键的地方都已经找到了,接下来就是修改代码。

具体解决方案

  • 在cpp文件中添加一个字符串类型的数组,内容参见onnx中的数组:
c 复制代码
string map[]={"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"};
  • 将sprintf那行代码修改为:
c 复制代码
string type=map[pObj->type];
sprintf(text, "type-%d, score-%.2f", type.c_str(), (double)pObj->score);

注意:

  • 使用c_str()函数是为了返回一个指向正规C字符串的指针常量 ,因为sprintf不支持string类型。如果不进行该转换,会导致无法解析数据,最终显示的数据形如???u
  • 本项目仅仅是将下标转化为英文,如果要转为中文,可能会出现乱码的问题,这是由于使用的编码格式不一致导致的。如果读者有需要,也可以找到源代码中的对应位置进行修改。但也有可能找不到,因为这些可能已经被封装好了并不对外暴露。

源码据说是开源的,但是我并没有在github和gitee中找到,因此也无法提供相关链接。如果有一天哪位有缘人刚好发现了和我描述很像的代码仓库,欢迎在评论区留言!

参考资料

相关推荐
九尾狐ai1 分钟前
从九尾狐AI案例拆解智能矩阵技术架构:如何实现AI获客300万播放?
人工智能
wasp5201 分钟前
Hudi 客户端实现分析
java·开发语言·人工智能·hudi
秦苒&3 分钟前
【脉脉】AI 创作者 xAMA 知无不言:在浪潮里,做会发光的造浪者
大数据·c语言·数据库·c++·人工智能·ai·操作系统
chinesegf4 分钟前
嵌入模型和大语言模型的关系
人工智能·语言模型·自然语言处理
啊阿狸不会拉杆4 分钟前
《计算机操作系统》 第十一章 -多媒体操作系统
开发语言·c++·人工智能·os·计算机操作系统
_ziva_7 分钟前
分布式(三)深入浅出理解PyTorch分布式训练:nn.parallel.DistributedDataParallel详解
人工智能·pytorch·分布式
江南小书生7 分钟前
非标制造行业装配报工工时不准?缺料干扰+标准缺失如何破局?
大数据·人工智能
组合缺一10 分钟前
Solon AI Remote Skills:开启分布式技能的“感知”时代
java·人工智能·分布式·agent·langgraph·mcp
m0_7373025815 分钟前
火山引擎安全增强型云服务器,筑牢AI时代数据屏障
网络·人工智能
zl_vslam19 分钟前
SLAM中的非线性优-3D图优化之绝对位姿SE3约束SO3/t形式(十八)
人工智能·算法·计算机视觉·3d