文章背景
笔者不是研究深度学习算法的,但是在工作过程中会稍微涉及一点算法工程化的内容。通俗来说就是对成熟的算法进行应用落地,并且使该应用能够与特定机器进行适配。
截至目前,笔者已经跑了好几个深度学习-计算机视觉方面的算法应用了。它们被封装的程度各异,所以对a应用问题的解决方案没法直接应用到b应用上,不过总体的思路是一致的。
项目需求
笔者要解决的问题是:原本的模型展示的标签结果是数字0123...,现在需要把它更改为对应的物品名称。
解决思路
- 标签结果的数字不是随机产生的,而是有一定的规律。比如检测到人就是0,检测到汽车就是1。说明源代码里面一定是定义了一个从数字下标到文本信息的映射信息,只是它没有被用上。所以关键在于找到这个映射关系。
- 修改界面:找到显示方框、标签、计算概率的代码,修改其中的标签值。
源码定位
通过搜索源码可知,在代码的model目录下有一个onnx文件,在文件的最后一行,以数组形式记录了标签数字与文本的映射关系:
css
namesÝ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
不过项目量化时所使用的是另一个onnx文件,另一个文件里没有这段定义。
什么是onnx文件?
Open Neural Network Exchange(开放神经网络交换),由微软和Facebook提出,是用来表示深度学习模型的开放格式。所谓开放就是ONNX定义了一组和环境,平台均无关的标准格式,来增强各种AI模型的可交互性。
- 无论使用何种训练框架训练模型(比如TensorFlow/Pytorch/OneFlow/Paddle),在训练完毕后你都可以将这些框架的模型统一转换为ONNX这种统一的格式进行存储。
- ONNX文件不仅仅存储了神经网络模型的权重,同时也存储了模型的结构信息以及网络中每一层的输入输出和一些其它的辅助信息。
而在network目录下,有一个和onnx文件同前缀的cpp文件,其中有关于显示结果的代码:
cpp
sprintf(text, "type-%d, score-%.2f", pObj->type, (double)pObj->score);
cv::putText(img, text, Point, cv::FONT_HERSHEY_SIMPLEX, 1.0, cvScalar(0, 255, 0), 1);
- 第一行定义了文本信息的格式
- 第二行通过opencv库将文本信息与图像等等结合起来,最终体现为在视频流帧中能对识别出的移动物体生成动态的方框、标签及识别物体的概率。
至此,两处关键的地方都已经找到了,接下来就是修改代码。
具体解决方案
- 在cpp文件中添加一个字符串类型的数组,内容参见onnx中的数组:
c
string map[]={"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"};
- 将sprintf那行代码修改为:
c
string type=map[pObj->type];
sprintf(text, "type-%d, score-%.2f", type.c_str(), (double)pObj->score);
注意:
- 使用c_str()函数是为了返回一个指向正规C字符串的指针常量 ,因为sprintf不支持string类型。如果不进行该转换,会导致无法解析数据,最终显示的数据形如
???u。 - 本项目仅仅是将下标转化为英文,如果要转为中文,可能会出现乱码的问题,这是由于使用的编码格式不一致导致的。如果读者有需要,也可以找到源代码中的对应位置进行修改。但也有可能找不到,因为这些可能已经被封装好了并不对外暴露。
源码据说是开源的,但是我并没有在github和gitee中找到,因此也无法提供相关链接。如果有一天哪位有缘人刚好发现了和我描述很像的代码仓库,欢迎在评论区留言!