项目经验总结|修改深度学习模型训练结果的标签

文章背景

笔者不是研究深度学习算法的,但是在工作过程中会稍微涉及一点算法工程化的内容。通俗来说就是对成熟的算法进行应用落地,并且使该应用能够与特定机器进行适配。

截至目前,笔者已经跑了好几个深度学习-计算机视觉方面的算法应用了。它们被封装的程度各异,所以对a应用问题的解决方案没法直接应用到b应用上,不过总体的思路是一致的。

项目需求

笔者要解决的问题是:原本的模型展示的标签结果是数字0123...,现在需要把它更改为对应的物品名称。

解决思路

  • 标签结果的数字不是随机产生的,而是有一定的规律。比如检测到人就是0,检测到汽车就是1。说明源代码里面一定是定义了一个从数字下标到文本信息的映射信息,只是它没有被用上。所以关键在于找到这个映射关系
  • 修改界面:找到显示方框、标签、计算概率的代码,修改其中的标签值。

源码定位

通过搜索源码可知,在代码的model目录下有一个onnx文件,在文件的最后一行,以数组形式记录了标签数字与文本的映射关系:

css 复制代码
namesÝ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']

不过项目量化时所使用的是另一个onnx文件,另一个文件里没有这段定义。

什么是onnx文件?

Open Neural Network Exchange(开放神经网络交换),由微软和Facebook提出,是用来表示深度学习模型的开放格式。所谓开放就是ONNX定义了一组和环境,平台均无关的标准格式,来增强各种AI模型的可交互性。

  • 无论使用何种训练框架训练模型(比如TensorFlow/Pytorch/OneFlow/Paddle),在训练完毕后你都可以将这些框架的模型统一转换为ONNX这种统一的格式进行存储。
  • ONNX文件不仅仅存储了神经网络模型的权重,同时也存储了模型的结构信息以及网络中每一层的输入输出和一些其它的辅助信息。

而在network目录下,有一个和onnx文件同前缀的cpp文件,其中有关于显示结果的代码:

cpp 复制代码
sprintf(text, "type-%d, score-%.2f", pObj->type, (double)pObj->score);
cv::putText(img, text, Point, cv::FONT_HERSHEY_SIMPLEX, 1.0, cvScalar(0, 255, 0), 1);
  • 第一行定义了文本信息的格式
  • 第二行通过opencv库将文本信息与图像等等结合起来,最终体现为在视频流帧中能对识别出的移动物体生成动态的方框、标签及识别物体的概率。

至此,两处关键的地方都已经找到了,接下来就是修改代码。

具体解决方案

  • 在cpp文件中添加一个字符串类型的数组,内容参见onnx中的数组:
c 复制代码
string map[]={"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"};
  • 将sprintf那行代码修改为:
c 复制代码
string type=map[pObj->type];
sprintf(text, "type-%d, score-%.2f", type.c_str(), (double)pObj->score);

注意:

  • 使用c_str()函数是为了返回一个指向正规C字符串的指针常量 ,因为sprintf不支持string类型。如果不进行该转换,会导致无法解析数据,最终显示的数据形如???u
  • 本项目仅仅是将下标转化为英文,如果要转为中文,可能会出现乱码的问题,这是由于使用的编码格式不一致导致的。如果读者有需要,也可以找到源代码中的对应位置进行修改。但也有可能找不到,因为这些可能已经被封装好了并不对外暴露。

源码据说是开源的,但是我并没有在github和gitee中找到,因此也无法提供相关链接。如果有一天哪位有缘人刚好发现了和我描述很像的代码仓库,欢迎在评论区留言!

参考资料

相关推荐
China_Yanhy40 分钟前
动手学大模型第一篇学习总结
人工智能
空间机器人1 小时前
自动驾驶 ADAS 器件选型:算力只是门票,系统才是生死线
人工智能·机器学习·自动驾驶
C+++Python1 小时前
提示词、Agent、MCP、Skill 到底是什么?
人工智能
小松要进步1 小时前
机器学习1
人工智能·机器学习
泰恒1 小时前
openclaw近期怎么样了?
人工智能·深度学习·机器学习
KaneLogger1 小时前
从传统笔记到 LLM 驱动的结构化 Wiki
人工智能·程序员·架构
tinygone2 小时前
OpenClaw之Memory配置成本地模式,Ubuntu+CUDA+cuDNN+llama.cpp
人工智能·ubuntu·llama
正在走向自律2 小时前
第二章-AIGC入门-AIGC工具全解析:技术控的效率神器,DeepSeek国产大模型的骄傲(8/36)
人工智能·chatgpt·aigc·可灵·deepseek·即梦·阿里通义千问
轩轩分享AI2 小时前
DeepSeek、Kimi、笔灵谁最好用?5款网文作者亲测的AI写作神器横评
人工智能·ai·ai写作·小说写作·小说·小说干货
Aevget2 小时前
基于嵌入向量的智能检索!HOOPS AI 解锁 CAD 零件相似性搜索新方式
人工智能·hoops·cad·hoops ai·cad数据格式