多种策略提升线上 tensorflow 模型推理速度

前言

本文以最常见的模型 Bi-LSTM-CRF 为例,总结了在实际工作中能有效提升在 CPU/GPU 上的推理速度的若干方法,包括优化模型结构优化超参数,使用 onnx 框架等。当然如果你有充足的 GPU ,结合以上方法提升推理速度的效果将更加震撼。

数据

本文使用的数据就是常见的 NER 数据,我这里使用的是 BMEO 标注方法,如下列举一个样本作为说明:

华\B_ORG 东\M_ORG 师\M_ORG 范\M_ORG 大\M_ORG 学\E_ORG 位\O 于\O 上\B_LOC 海\E_LOC。

具体的标注方法标注规则可以根据自己的实际业务中的实体类型进行定义,这里不做深入探讨,但是有个基本原则就是标注的实体是符合实际业务意义的内容。

优化模型结构

对于 Bi-LSTM-CRF 这一模型的具体细节,我这里默认都是知道的,所以不再赘述。我们平时在使用模型的时候有个误区觉得 LSTM 层堆叠的越多效果越好,其实不然,如果是对于入门级的 NER 任务,只需要一个 Bi-LSTM 就足够可以把实体识别出来,完全没有必要堆叠多个 Bi-LSTM ,这样有点杀鸡用牛刀了,而且多层的模型参数量会激增,这也会拖垮最终的训练和推理速度。

对于其他的模型来说,也是同样的道理,优化模型结构,砍掉过量的层和参数,可能会取到意想不到的推理效果和速度。

优化超参数

在我看来三个最重要的超参数就是 batch_sizehidden_sizeembedding_dim ,这三个分别表示批处理样本数,隐层状态维度,嵌入纬度。这里的常见误区和模型参数量一样,会认为越大效果越好。其实不然,太大的超参数也会拖垮最终的训练和推理速度。正常在模型推理过程中,耗时基本是和这三个参数呈正相关关系。常见的参数设置可以按照以下的推荐值来进行即可:

batch_size:32、64
hidden_size:128、256
embedding_dim:128、256

对于简单的 NER 任务来说,这些超参数的设置已经足够使用了,如果是比较复杂的任务,那就需要适当调大 hidden_sizeembedding_dim,最好以 2 的 N 次方为值。batch_size 如果没有特殊业务要求,按照推荐值即可。

另外,如果你使用的是 tensorflow2.x 框架,可以使用 Keras Tuner 提到的 API ,不仅可以挑选最优的模型超参数,还能挑选最优的算法超参数。

onnx

ONNX(Open Neural Network Exchange)是一个用于表示深度学习模型的开放式标准。ONNX 的设计目标是使得在不同框架中训练的模型能够轻松地在其他框架中部署和运行。ONNX 支持在不同的部署环境中(例如移动设备、边缘计算、云端服务器)更加灵活地使用深度学习模型。

ONNX 在模型部署的时候会对模型做很多优化策略,如图结构优化、节点通信优化、量化、硬件加速、多线程和并行计算等。onnxruntime 是一个对 ONNX 模型提供推理加速的 python 库,支持 CPU 和 GPU 加速,GPU 加速版本为onnxruntime-gpu,默认版本为 CPU 加速。安装也很简单,直接使用 pip 安装即可。另外安装 tf2onnx 需要将 tensorflow2.x 模型转换为 onnx 模型

下面以本文中使用的模型来进行转化,需要注意的有两点,第一是要有已经训练并保存好的 h5 模型,第二是明确指定模型的输入结构,代码中的是 (None, config['max_len']) ,意思是输入的 batch_size 可以是任意数量,输入的序列长度为 config['max_len'] , 具体代码如下:

scss 复制代码
def tensorflow2onnx():
    model = NerModel()
    model.build((None, config['max_len']))
    model.load_weights(best.h5)
    input_signature = (tf.TensorSpec((None, config['max_len']), tf.int32, name="input"),)
    onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=input_signature,)
    onnx.save(onnx_model, 'best.onnx')

保存好 onnx 模型之后,我们使用 onnx 模型进行 CPU 推理。只需要简单的几步即可完成推理任务, results_ort 就是推理结果 logits ,具体代码如下:

css 复制代码
def inference():
    x_train, y_train, x_test, y_test = getData()
    sess = ort.InferenceSession(config['onnxPath'], providers=['CPUExecutionProvider'])   
    results_ort = sess.run(["output_1"], {'input': x_train})[0]

效果对比

在综合运用以上的三种,将之前的模型结构进行减小到一层的 Bi-LSTM ,并且将超参数进行适当的减少到都为 256 ,然后使用 onnx 加速推理,在 CPU 上面最终从推理速度 278 ms ,下降到 29 ms ,提升了 9 倍的推理速度。

如果有 GPU ,我们可以安装 onnxruntime-gpu (如果安装时候和 onnxruntime 有冲突,可以先卸载 onnxruntime ),然后将上面的代码改为如下即可,最终的推理时间进一步减少了一半:

css 复制代码
sess = ort.InferenceSession(config['onnxPath'], providers=['CUDAExecutionProvider'])

结论

最终我们从 278 ms 下降到 15 ms ,实现了 18 倍的推理提速,综上可以看出本文介绍的几种策略的综合使用确实能够加速推理速度,也说明了工业上进行模型部署优化是很有必要的。

参考

相关推荐
桃花键神22 分钟前
AI可信论坛亮点:合合信息分享视觉内容安全技术前沿
人工智能
野蛮的大西瓜44 分钟前
开源呼叫中心中,如何将ASR与IVR菜单结合,实现动态的IVR交互
人工智能·机器人·自动化·音视频·信息与通信
CountingStars6191 小时前
目标检测常用评估指标(metrics)
人工智能·目标检测·目标跟踪
tangjunjun-owen1 小时前
第四节:GLM-4v-9b模型的tokenizer源码解读
人工智能·glm-4v-9b·多模态大模型教程
冰蓝蓝1 小时前
深度学习中的注意力机制:解锁智能模型的新视角
人工智能·深度学习
橙子小哥的代码世界1 小时前
【计算机视觉基础CV-图像分类】01- 从历史源头到深度时代:一文读懂计算机视觉的进化脉络、核心任务与产业蓝图
人工智能·计算机视觉
新加坡内哥谈技术2 小时前
苏黎世联邦理工学院与加州大学伯克利分校推出MaxInfoRL:平衡内在与外在探索的全新强化学习框架
大数据·人工智能·语言模型
fanstuck3 小时前
Prompt提示工程上手指南(七)Prompt编写实战-基于智能客服问答系统下的Prompt编写
人工智能·数据挖掘·openai
lovelin+v175030409663 小时前
安全性升级:API接口在零信任架构下的安全防护策略
大数据·数据库·人工智能·爬虫·数据分析
wydxry3 小时前
LoRA(Low-Rank Adaptation)模型微调
深度学习