深度学习模型部署篇——利用Flask实现深度学习模型部署(三)

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!
🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊专栏推荐:深度学习网络原理与实战

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

写在前面

Hello,大家好,我是小苏👦🏽👦🏽👦🏽

在前面两节,我为大家介绍了如何将pytocrh模型转换成ONNX格式,进而提高深度学习模型部署速度,还不清楚的可以点击下面连接了解详情。

在评论区有小伙伴们问,如何进行服务化,那么今天就为大家介绍介绍通过Flask来实现服务化。

JYM准备好了喵,准备发车了喔。🚖🚖🚖

整体思路

我们先来说说使用Flask实现模型部署的思路,其实很简单啦。首先需要准备模型和一些资源,然后服务端会读取这些数据,并开启服务,这时候就会不断的等待客户端发来请求,此时服务器会处理该请求,并将结果返回客户端,如下图所示:

这里我提醒大家注意一下,下文所介绍代码的相关数据集都是前面两讲中的,如果大家有不明白的地方可以去上两讲寻找寻找答案。🍄🍄🍄
我也将项目上传到Github上,想玩的大家可以试试喔,地址:model_deployment_flask🍄🍄🍄
我也把用到的模型权重文件上传到了百度网盘,大家自行下载,地址:模型权重🍄🍄🍄

flask模型部署初探

下面我们就来开始介绍了喔,让我们一起来学一下叭。【这回采用步骤式讲解看看效果🍋🍋🍋】

首先我先来梳理一下代码运行的整体流程,这样大家可能会更清晰一点,如下图所示。

大家要注意的是,我们会有两个.py文件,一个用于服务端开启服务,另一个用于客户端发送请求并接受服务端的返回值。客户端的代码较为简单,这里重点说说服务端的代码运行流程。服务端代码中主要有三个函数,分别为predictget_predictiontransform_image。当我们运行服务端程序时,app.run启动,服务开启,此时会监听客户端是否发送请求,若检测到客户端发送请求,则会进入predict函数处理这个请求,接着predict函数会调用get_prediction函数,而get_prediction函数会调用transform_image函数。

先给大家介绍代码运行流程,大家再看下面的代码应该就比较清晰了,下文将分为服务端和客户端两部分介绍代码。

服务端

  • 创建一个Flask应用:
ini 复制代码
 app = Flask(__name__)
  • 准备模型和资源
python 复制代码
 # 在指定设备上创建 AlexNet 模型
 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 model = AlexNet(num_classes=5).to(device)
 # 加载预训练权重
 model.load_state_dict(torch.load(r'E:\模型部署\checkpoint\AlexNet.pth', map_location='cpu'))
 # 设置模型为评估模式
 model.eval()

注意AlexNet.pth模型是我们在上讲中介绍的花的五分类模型。其实模型部署本质上就是模型的测试,所有我们要将模型设置成评估模式。

  • 定义预测函数
scss 复制代码
 def transform_image(image_bytes):
     my_transforms = transforms.Compose([transforms.Resize(255),
                                         transforms.CenterCrop(227),
                                         transforms.ToTensor(),
                                         transforms.Normalize(
                                             [0.485, 0.456, 0.406],
                                             [0.229, 0.224, 0.225])])
     image = Image.open(io.BytesIO(image_bytes))
     return my_transforms(image).unsqueeze(0)

大家还是要注意一下这里,我们将图像resize到227×227大小,这是因为AlexNet的输入要求,不清楚的可以看下这篇博客:深度学习经典网络模型汇总1------LeNet、AlexNet、ZFNet🌱🌱🌱

还要注意这里最后使用unsqueeze(0)方法添加了一个batch维度信息。

  • 定义预测函数
ini 复制代码
 def get_prediction(image_bytes):
     # 记录该帧开始处理的时间
     start_time = time.time()
 ​
     # 转换图像数据为模型输入格式
     tensor = transform_image(image_bytes=image_bytes)
 ​
     # 通过模型进行前向传播
     outputs = model.forward(tensor)
 ​
     # 对模型输出进行 softmax 操作
     pred_softmax = F.softmax(outputs, dim=1)
 ​
     # 获取前N个预测结果
     top_n = torch.topk(pred_softmax, 5)
     pred_ids = top_n.indices[0].cpu().detach().numpy()  # 转换为NumPy数组
     confs = top_n.values[0].cpu().detach().numpy() * 100  # 转换为NumPy数组,并转换为百分比
 ​
     # 记录该帧处理完毕的时间
     end_time = time.time()
     # 计算每秒处理图像帧数FPS
     FPS = 1 / (end_time - start_time)
 ​
     # 载入类别和对应 ID
     idx_to_labels = np.load('idx_to_labels1.npy', allow_pickle=True).item()
 ​
     results = []  # 用于存储结果的列表
     for i in range(5):
         class_name = idx_to_labels[pred_ids[i]]  # 获取类别名称
         confidence = confs[i]  # 获取置信度
         text = '{:<6} {:>.3f}'.format(class_name, confidence)
         results.append(text)  # 将结果添加到列表中
 ​
     return results, FPS  # 返回包含类别和置信度的列表

这步其实和我上一讲的内容是差不多的,这个函数主要是对图像进行推理,并输出推理的结果和推理时间。

  • 定义接收上传图片并预测的路由
ini 复制代码
 @app.route('/predict', methods=['POST'])
 def predict():
     if request.method == 'POST':
         file = request.files['file']
         img_bytes = file.read()
         class_info, FPS = get_prediction(image_bytes=img_bytes)
         response_data = {'class_info': class_info, 'FPS': FPS}
         return jsonify(response_data)

解释一下这个@app.route('/predict', methods=['POST'])叭,它是一个 Flask 路由装饰器,它告诉 Flask 在接收到 /predict 路径上的 POST 请求时,会调用下面定义的predict()函数来处理这个请求。

  • 启动 Flask 应用
ini 复制代码
 if __name__ == '__main__':
     app.run()

app.run()是 Flask 应用的运行函数,它启动了一个本地的开发服务器,用于监听来自客户端的请求并响应。

客户端

  • 发送请求
ini 复制代码
 # 发送 POST 请求到 Flask 服务器
 resp = requests.post("http://localhost:5000/predict",
                      files={"file": open('flower.jpg', 'rb')})
  • 处理服务端返回结果
ini 复制代码
 # 检查服务器响应状态码
 if resp.status_code == 200:  # 如果响应状态码为 200 表示成功
     # 从响应中提取 JSON 数据
     response_data = resp.json()
     class_info = response_data['class_info']  # 提取预测结果信息
     fps = response_data['FPS']  # 提取处理帧数信息
 ​
     # 输出预测的类别信息
     for info in class_info:
         print(info)
 ​
     # 输出处理帧数信息
     print("FPS:", fps)
 else:  # 如果响应状态码不是 200,则表示出现了错误
     print("Error:", resp.text)  # 输出错误信息

运行结果

首先我们要运行服务端的程序test_alexnet.py开启服务,可以通过anaconda终端执行,如下:

接着我们可以新开一个终端执行客户端程序sent_post.py发送请求,或者直接在pycharm上执行程序,如下:

从上图可以看出郁金香的识别率达到了99.964,哦,忘记给大家看我测试的图片了,是这张喔:

我们也可以发现FPS为15.6,但是我们一般不取第一次的结果,因为会进行初始化等操作,影响速度,我们在运行几次看看FPS结果。

后面几次FPS大概稳定在20-21左右。🥗🥗🥗

通过ONNX加速模型部署

在上两讲我们介绍了通过ONNX加速模型部署,那么这里我们自然也要试一试,看看速度有没有加速腻。🚀🚀🚀

那么其实这一部分的代码和上一小节是非常类似的,我将服务端代码写在了test_alexnet_onnx.py中,客户端代码没有改变,仍然是sent_post.py。我把主要修改的地方说明一下,其它一些细节大家可以自己去github下载源码查看。

  • 加载ONNX模型
python 复制代码
 # 加载模型
 model = AlexNet(num_classes=5).to(device)
 def load_onnx_model():
     global ort_session
     ort_session = onnxruntime.InferenceSession(r'E:\模型部署\Alex_flower5.onnx')
 ​
 # 在应用启动时加载 ONNX 模型
 load_onnx_model()

这里加载的是ONNX模型,至于如何得到ONNX模型可以看我上一讲的内容。🍭🍭🍭

  • ONNX推理引擎推理
css 复制代码
 ort_inputs = {'input': tensor.numpy()}
 pred_logits = ort_session.run(['output'], ort_inputs)[0]

剩下的基本都差不多了,我们直接来看看运行的效果叭。⭐⭐⭐

首先开启服务,等待请求,如下:

然后运行客户端代码,发送请求,获得结果:

可以看到预测精度和之前使用pytorch预测时是一致的,但FPS提高到了26.8。当然了,同样的道理,这是第一次运行,FPS会相对较低,我们再运行几次,如下:

可以发现,现在的FPS可以基本稳定在32左右,是不是比之前快了不少呢,大家快去试试叭。🍄🍄🍄

加点佐料

不知道大家发现没有,上面的功能算是实现了通过Flask部署深度学习模型,但是总感觉差点意思,于是准备结合前端来搭建一个稍微好看的界面,通过点击前端的按钮来发送请求。🍵🍵🍵

说干就干,但是好像干不动,因为自己不会前端呀,但是又问题不大,因为我会百度呀,直接找一个前端的代码就好了嘛,于是找到了霹雳吧啦Wz大佬滴代码,对其稍微改进了一下,使用了ONNX进行模型推理 ,并在前端输出FPS信息,代码为main_html_test.py。同样的,一些细节大家详细移步源码查看🥂🥂🥂

我们先来看一看实现的效果叭,然后我再来解释一下如何实现的,效果如下:

enmmmm,开始准备展示动态图的,但是运行录屏工具后,预测的FPS就下降了,所以大家还是看看图片叭。🍋🍋🍋

首先我们运行main_html_test.py程序,会得到如下结果:

点击上图中的链接进入前端界面:

然后点击选择文件,再点击预测,即可显示预测结果和FPS,如下:


上面就是最终的效果啦,最后我来稍微解释代码的整个流程,如下:

  1. 用户在前端界面上选择一个图像文件。
  2. 用户点击预测按钮,触发 test() 函数。
  3. test() 函数使用 AJAX 将图像文件发送到后端的 /predict 路由。
  4. 后端接收到请求,调用 predict() 函数进行图像预测。
  5. predict() 函数返回预测结果和FPS信息,发送回前端。
  6. 前端接收到后端返回的数据,将预测结果和FPS信息展示在页面上。

关于test()函数的内容如下:

ini 复制代码
 function test() {
     // 获取选择的文件对象
     var fileobj = $("#file0")[0].files[0];
     console.log(fileobj);
     
     // 创建一个 FormData 对象,用于将文件对象传递到后端
     var form = new FormData();
     form.append("file", fileobj);
     
     // 初始化变量用于存储分类结果和FPS信息
     var flower='';
     var fps = '';
 ​
     // 发送AJAX请求到后端的predict路由
     $.ajax({
         type: 'POST',
         url: "predict",
         data: form,
         async: false,
         processData: false,
         contentType: false,
         success: function (data) {
             console.log(data);
             
             // 从返回的数据中获取分类结果和FPS信息
             var results = data.class_info;
             fps = data.FPS;
             
             console.log(results);
             console.log(fps);
             
             // 生成分类结果的HTML字符串
             var flowerResult = '';
             results.forEach(e => {
                 flowerResult += `<div style="border-bottom: 1px solid #CCCCCC;line-height: 60px;font-size:16px;">${e}</div>`;
             });
 ​
             // 生成FPS信息的HTML字符串
             var fpsResult = `<div style="border-top: 1px solid #CCCCCC;line-height: 60px;font-size:16px;">FPS: ${fps.toFixed(2)}</div>`;
 ​
             // 将生成的分类结果和FPS信息插入到页面元素中
             document.getElementById("out").innerHTML = flowerResult + fpsResult;
         },
         error: function () {
             console.log("后台处理错误");
         }
     });
 }

小结

本节就为大家介绍到这里了喔,感兴趣的大家可以自己去玩玩,希望大家都能够有所收获。🌾🌾🌾

参考链接

DEPLOYING PYTORCH IN PYTHON VIA A REST API WITH FLASK🍁🍁🍁

pytorch_flask_service🍁🍁🍁

如若文章对你有所帮助,那就🛴🛴🛴

相关推荐
Tianyanxiao2 分钟前
如何利用探商宝精准营销,抓住行业机遇——以AI技术与大数据推动企业信息精准筛选
大数据·人工智能·科技·数据分析·深度优先·零售
撞南墙者9 分钟前
OpenCV自学系列(1)——简介和GUI特征操作
人工智能·opencv·计算机视觉
OCR_wintone42111 分钟前
易泊车牌识别相机,助力智慧工地建设
人工智能·数码相机·ocr
王哈哈^_^32 分钟前
【数据集】【YOLO】【VOC】目标检测数据集,查找数据集,yolo目标检测算法详细实战训练步骤!
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·pyqt
一者仁心38 分钟前
【AI技术】PaddleSpeech
人工智能
是瑶瑶子啦1 小时前
【深度学习】论文笔记:空间变换网络(Spatial Transformer Networks)
论文阅读·人工智能·深度学习·视觉检测·空间变换
EasyCVR1 小时前
萤石设备视频接入平台EasyCVR多品牌摄像机视频平台海康ehome平台(ISUP)接入EasyCVR不在线如何排查?
运维·服务器·网络·人工智能·ffmpeg·音视频
柳鲲鹏1 小时前
OpenCV视频防抖源码及编译脚本
人工智能·opencv·计算机视觉
西柚小萌新1 小时前
8.机器学习--决策树
人工智能·决策树·机器学习
向阳12181 小时前
Bert快速入门
人工智能·python·自然语言处理·bert