本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!
🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题🍊专栏推荐:深度学习网络原理与实战
🍊近期目标:写好专栏的每一篇文章
🍊支持小苏:点赞👍🏼、收藏⭐、留言📩
写在前面
Hello,大家好,我是小苏👦🏽👦🏽👦🏽
在前面两节,我为大家介绍了如何将pytocrh模型转换成ONNX格式,进而提高深度学习模型部署速度,还不清楚的可以点击下面连接了解详情。
在评论区有小伙伴们问,如何进行服务化,那么今天就为大家介绍介绍通过Flask来实现服务化。
JYM准备好了喵,准备发车了喔。🚖🚖🚖
整体思路
我们先来说说使用Flask实现模型部署的思路,其实很简单啦。首先需要准备模型和一些资源,然后服务端会读取这些数据,并开启服务,这时候就会不断的等待客户端发来请求,此时服务器会处理该请求,并将结果返回客户端,如下图所示:
这里我提醒大家注意一下,下文所介绍代码的相关数据集都是前面两讲中的,如果大家有不明白的地方可以去上两讲寻找寻找答案。🍄🍄🍄
我也将项目上传到Github上,想玩的大家可以试试喔,地址:model_deployment_flask🍄🍄🍄
我也把用到的模型权重文件上传到了百度网盘,大家自行下载,地址:模型权重🍄🍄🍄
flask模型部署初探
下面我们就来开始介绍了喔,让我们一起来学一下叭。【这回采用步骤式讲解看看效果🍋🍋🍋】
首先我先来梳理一下代码运行的整体流程,这样大家可能会更清晰一点,如下图所示。
大家要注意的是,我们会有两个.py文件,一个用于服务端开启服务,另一个用于客户端发送请求并接受服务端的返回值。客户端的代码较为简单,这里重点说说服务端的代码运行流程。服务端代码中主要有三个函数,分别为predict
、get_prediction
、transform_image
。当我们运行服务端程序时,app.run
启动,服务开启,此时会监听客户端是否发送请求,若检测到客户端发送请求,则会进入predict
函数处理这个请求,接着predict
函数会调用get_prediction
函数,而get_prediction
函数会调用transform_image
函数。
先给大家介绍代码运行流程,大家再看下面的代码应该就比较清晰了,下文将分为服务端和客户端两部分介绍代码。
服务端
- 创建一个Flask应用:
ini
app = Flask(__name__)
- 准备模型和资源
python
# 在指定设备上创建 AlexNet 模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AlexNet(num_classes=5).to(device)
# 加载预训练权重
model.load_state_dict(torch.load(r'E:\模型部署\checkpoint\AlexNet.pth', map_location='cpu'))
# 设置模型为评估模式
model.eval()
注意AlexNet.pth
模型是我们在上讲中介绍的花的五分类模型。其实模型部署本质上就是模型的测试,所有我们要将模型设置成评估模式。
- 定义预测函数
scss
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(227),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
大家还是要注意一下这里,我们将图像resize到227×227大小,这是因为AlexNet的输入要求,不清楚的可以看下这篇博客:深度学习经典网络模型汇总1------LeNet、AlexNet、ZFNet🌱🌱🌱
还要注意这里最后使用unsqueeze(0)
方法添加了一个batch维度信息。
- 定义预测函数
ini
def get_prediction(image_bytes):
# 记录该帧开始处理的时间
start_time = time.time()
# 转换图像数据为模型输入格式
tensor = transform_image(image_bytes=image_bytes)
# 通过模型进行前向传播
outputs = model.forward(tensor)
# 对模型输出进行 softmax 操作
pred_softmax = F.softmax(outputs, dim=1)
# 获取前N个预测结果
top_n = torch.topk(pred_softmax, 5)
pred_ids = top_n.indices[0].cpu().detach().numpy() # 转换为NumPy数组
confs = top_n.values[0].cpu().detach().numpy() * 100 # 转换为NumPy数组,并转换为百分比
# 记录该帧处理完毕的时间
end_time = time.time()
# 计算每秒处理图像帧数FPS
FPS = 1 / (end_time - start_time)
# 载入类别和对应 ID
idx_to_labels = np.load('idx_to_labels1.npy', allow_pickle=True).item()
results = [] # 用于存储结果的列表
for i in range(5):
class_name = idx_to_labels[pred_ids[i]] # 获取类别名称
confidence = confs[i] # 获取置信度
text = '{:<6} {:>.3f}'.format(class_name, confidence)
results.append(text) # 将结果添加到列表中
return results, FPS # 返回包含类别和置信度的列表
这步其实和我上一讲的内容是差不多的,这个函数主要是对图像进行推理,并输出推理的结果和推理时间。
- 定义接收上传图片并预测的路由
ini
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
img_bytes = file.read()
class_info, FPS = get_prediction(image_bytes=img_bytes)
response_data = {'class_info': class_info, 'FPS': FPS}
return jsonify(response_data)
解释一下这个@app.route('/predict', methods=['POST'])
叭,它是一个 Flask 路由装饰器,它告诉 Flask 在接收到 /predict
路径上的 POST 请求时,会调用下面定义的predict()函数来处理这个请求。
- 启动 Flask 应用
ini
if __name__ == '__main__':
app.run()
app.run()
是 Flask 应用的运行函数,它启动了一个本地的开发服务器,用于监听来自客户端的请求并响应。
客户端
- 发送请求
ini
# 发送 POST 请求到 Flask 服务器
resp = requests.post("http://localhost:5000/predict",
files={"file": open('flower.jpg', 'rb')})
- 处理服务端返回结果
ini
# 检查服务器响应状态码
if resp.status_code == 200: # 如果响应状态码为 200 表示成功
# 从响应中提取 JSON 数据
response_data = resp.json()
class_info = response_data['class_info'] # 提取预测结果信息
fps = response_data['FPS'] # 提取处理帧数信息
# 输出预测的类别信息
for info in class_info:
print(info)
# 输出处理帧数信息
print("FPS:", fps)
else: # 如果响应状态码不是 200,则表示出现了错误
print("Error:", resp.text) # 输出错误信息
运行结果
首先我们要运行服务端的程序test_alexnet.py
开启服务,可以通过anaconda终端执行,如下:
接着我们可以新开一个终端执行客户端程序sent_post.py
发送请求,或者直接在pycharm上执行程序,如下:
从上图可以看出郁金香的识别率达到了99.964,哦,忘记给大家看我测试的图片了,是这张喔:
我们也可以发现FPS为15.6,但是我们一般不取第一次的结果,因为会进行初始化等操作,影响速度,我们在运行几次看看FPS结果。
后面几次FPS大概稳定在20-21左右。🥗🥗🥗
通过ONNX加速模型部署
在上两讲我们介绍了通过ONNX加速模型部署,那么这里我们自然也要试一试,看看速度有没有加速腻。🚀🚀🚀
那么其实这一部分的代码和上一小节是非常类似的,我将服务端代码写在了test_alexnet_onnx.py
中,客户端代码没有改变,仍然是sent_post.py
。我把主要修改的地方说明一下,其它一些细节大家可以自己去github下载源码查看。
- 加载ONNX模型
python
# 加载模型
model = AlexNet(num_classes=5).to(device)
def load_onnx_model():
global ort_session
ort_session = onnxruntime.InferenceSession(r'E:\模型部署\Alex_flower5.onnx')
# 在应用启动时加载 ONNX 模型
load_onnx_model()
这里加载的是ONNX模型,至于如何得到ONNX模型可以看我上一讲的内容。🍭🍭🍭
- ONNX推理引擎推理
css
ort_inputs = {'input': tensor.numpy()}
pred_logits = ort_session.run(['output'], ort_inputs)[0]
剩下的基本都差不多了,我们直接来看看运行的效果叭。⭐⭐⭐
首先开启服务,等待请求,如下:
然后运行客户端代码,发送请求,获得结果:
可以看到预测精度和之前使用pytorch预测时是一致的,但FPS提高到了26.8。当然了,同样的道理,这是第一次运行,FPS会相对较低,我们再运行几次,如下:
可以发现,现在的FPS可以基本稳定在32左右,是不是比之前快了不少呢,大家快去试试叭。🍄🍄🍄
加点佐料
不知道大家发现没有,上面的功能算是实现了通过Flask部署深度学习模型,但是总感觉差点意思,于是准备结合前端来搭建一个稍微好看的界面,通过点击前端的按钮来发送请求。🍵🍵🍵
说干就干,但是好像干不动,因为自己不会前端呀,但是又问题不大,因为我会百度呀,直接找一个前端的代码就好了嘛,于是找到了霹雳吧啦Wz大佬滴代码,对其稍微改进了一下,使用了ONNX进行模型推理 ,并在前端输出FPS信息,代码为main_html_test.py
。同样的,一些细节大家详细移步源码查看🥂🥂🥂
我们先来看一看实现的效果叭,然后我再来解释一下如何实现的,效果如下:
enmmmm,开始准备展示动态图的,但是运行录屏工具后,预测的FPS就下降了,所以大家还是看看图片叭。🍋🍋🍋
首先我们运行main_html_test.py
程序,会得到如下结果:
点击上图中的链接进入前端界面:
然后点击选择文件,再点击预测,即可显示预测结果和FPS,如下:
上面就是最终的效果啦,最后我来稍微解释代码的整个流程,如下:
- 用户在前端界面上选择一个图像文件。
- 用户点击预测按钮,触发
test()
函数。 test()
函数使用 AJAX 将图像文件发送到后端的/predict
路由。- 后端接收到请求,调用
predict()
函数进行图像预测。 predict()
函数返回预测结果和FPS信息,发送回前端。- 前端接收到后端返回的数据,将预测结果和FPS信息展示在页面上。
关于test()函数的内容如下:
ini
function test() {
// 获取选择的文件对象
var fileobj = $("#file0")[0].files[0];
console.log(fileobj);
// 创建一个 FormData 对象,用于将文件对象传递到后端
var form = new FormData();
form.append("file", fileobj);
// 初始化变量用于存储分类结果和FPS信息
var flower='';
var fps = '';
// 发送AJAX请求到后端的predict路由
$.ajax({
type: 'POST',
url: "predict",
data: form,
async: false,
processData: false,
contentType: false,
success: function (data) {
console.log(data);
// 从返回的数据中获取分类结果和FPS信息
var results = data.class_info;
fps = data.FPS;
console.log(results);
console.log(fps);
// 生成分类结果的HTML字符串
var flowerResult = '';
results.forEach(e => {
flowerResult += `<div style="border-bottom: 1px solid #CCCCCC;line-height: 60px;font-size:16px;">${e}</div>`;
});
// 生成FPS信息的HTML字符串
var fpsResult = `<div style="border-top: 1px solid #CCCCCC;line-height: 60px;font-size:16px;">FPS: ${fps.toFixed(2)}</div>`;
// 将生成的分类结果和FPS信息插入到页面元素中
document.getElementById("out").innerHTML = flowerResult + fpsResult;
},
error: function () {
console.log("后台处理错误");
}
});
}
小结
本节就为大家介绍到这里了喔,感兴趣的大家可以自己去玩玩,希望大家都能够有所收获。🌾🌾🌾
参考链接
如若文章对你有所帮助,那就🛴🛴🛴