基于 Flask的深度学习模型部署服务端详解

基于 Flask 的深度学习模型部署服务端详解

在深度学习领域,训练出一个高精度的模型只是第一步,将其部署到生产环境中,为实际业务提供服务才是最终目标。本文将详细解析一个基于 Flask 和 PyTorch 的深度学习模型部署服务端代码,帮助你理解如何将训练好的模型以 API 形式提供给客户端使用。

一、整体概述

这段代码的主要功能是搭建一个基于 Flask 的 Web 服务,用于接收客户端发送的图像数据,使用预训练的 PyTorch 模型对图像进行分类预测,并将预测结果以 JSON 格式返回给客户端。

二、代码详细解析

1. 导入必要的库

python 复制代码
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
  • io:用于处理二进制数据,这里主要用于将客户端发送的图像二进制数据转换为图像对象。
  • flask:一个轻量级的 Web 框架,用于搭建 Web 服务。
  • torchtorch.nn.functional:PyTorch 的核心库,用于深度学习模型的构建和计算。
  • PIL.Image:Python Imaging Library(PIL)的一部分,用于处理图像文件。
  • torch.nn:用于定义神经网络的层和模块。
  • torchvision.transformstorchvision.modelstransforms 用于图像预处理,models 提供了预训练的深度学习模型。

2. 初始化 Flask 应用和模型相关变量

python 复制代码
app = flask.Flask(__name__)
model = None
use_gpu = False
  • app = flask.Flask(__name__):创建一个新的 Flask 应用实例,__name__ 参数用于确定应用的根路径。
  • model:用于存储加载的深度学习模型,初始化为 None
  • use_gpu:一个布尔变量,用于控制是否使用 GPU 进行模型推理,初始化为 False

3. 加载模型

python 复制代码
def load_model():
    global model
    model = models.resnet18()
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))

    checkpoint = torch.load('best.pth')
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    if use_gpu:
        model.cuda()
  • global model:声明 model 为全局变量,以便在函数内部修改它。
  • model = models.resnet18():加载预训练的 ResNet-18 模型。
  • num_ftrs = model.fc.in_features:获取 ResNet-18 模型最后一层全连接层的输入特征数。
  • model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)):修改最后一层全连接层,将输出维度改为 102,这里的 102 可以根据实际任务的类别数进行调整。
  • checkpoint = torch.load('best.pth'):从文件 best.pth 中加载训练好的模型参数。
  • model.load_state_dict(checkpoint['state_dict']):将加载的参数应用到模型中。
  • model.eval():将模型设置为评估模式,关闭一些在训练时使用的特殊层(如 Dropout)。
  • if use_gpu: model.cuda():如果 use_gpuTrue,将模型移动到 GPU 上。

4. 图像预处理

python 复制代码
def prepare_image(image, target_size):
    if image.mode != 'RGB':
        image = image.convert('RGB')

    image = transforms.Resize(target_size)(image)
    image = transforms.ToTensor()(image)
    image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
    image = image[None]
    if use_gpu:
        image = image.cuda()
    return torch.tensor(image)
  • if image.mode != 'RGB': image = image.convert('RGB'):确保输入图像为 RGB 格式。
  • image = transforms.Resize(target_size)(image):将图像调整为指定的大小。
  • image = transforms.ToTensor()(image):将图像转换为 PyTorch 张量。
  • image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image):对图像进行归一化处理,使用的均值和标准差是在 ImageNet 数据集上计算得到的。
  • image = image[None]:增加一个维度,将图像转换为批量输入的格式。
  • if use_gpu: image = image.cuda():如果 use_gpuTrue,将图像移动到 GPU 上。

5. 定义预测接口

python 复制代码
@app.route('/predict', methods=['POST'])
def predict():
    data = {'success': False}

    if flask.request.method == 'POST':
        if flask.request.files.get('image'):
            image = flask.request.files['image'].read()
            image = Image.open(io.BytesIO(image))
            image = prepare_image(image, target_size=(224, 224))

            preds = F.softmax(model(image), dim=1)
            results = torch.topk(preds.cpu().data, k=3, dim=1)
            results = (results[0].cpu().numpy(), results[1].cpu().numpy())

            data['prediction'] = list()
            for prob, label in zip(results[0][0], results[1][0]):
                r = {'label': str(label), 'probability': float(prob)}
                data['prediction'].append(r)

            data['success'] = True
    return flask.jsonify(data)
  • @app.route('/predict', methods=['POST']):使用 Flask 的装饰器定义一个路由,当客户端向 /predict 路径发送 POST 请求时,会调用 predict 函数。
  • data = {'success': False}:初始化一个字典,用于存储预测结果和状态信息,初始状态为 success = False
  • if flask.request.method == 'POST':检查请求方法是否为 POST。
  • if flask.request.files.get('image'):检查请求中是否包含名为 image 的文件。
  • image = flask.request.files['image'].read():读取客户端发送的图像文件内容。
  • image = Image.open(io.BytesIO(image)):将二进制数据转换为图像对象。
  • image = prepare_image(image, target_size=(224, 224)):对图像进行预处理。
  • preds = F.softmax(model(image), dim=1):使用模型进行预测,并通过 softmax 函数将输出转换为概率分布。
  • results = torch.topk(preds.cpu().data, k=3, dim=1):获取概率最大的前 3 个结果。
  • results = (results[0].cpu().numpy(), results[1].cpu().numpy()):将结果转换为 NumPy 数组。
  • data['prediction'] = list():初始化一个列表,用于存储预测结果。
  • for prob, label in zip(results[0][0], results[1][0]):遍历前 3 个结果,将标签和概率封装成字典,并添加到 data['prediction'] 列表中。
  • data['success'] = True:将状态信息设置为 success = True,表示预测成功。
  • return flask.jsonify(data):将结果以 JSON 格式返回给客户端。

6. 启动服务

python 复制代码
if __name__ == '__main__':
    print('Loading PyTorch model and Flask starting server ...')
    print('Please wait until server has fully started')

    load_model()
    app.run(host='192.168.1.20', port=5012)
  • if __name__ == '__main__':确保代码作为主程序运行时才执行以下操作。
  • print('Loading PyTorch model and Flask starting server ...')print('Please wait until server has fully started'):打印启动信息。
  • load_model():调用 load_model 函数加载模型。
  • app.run(host='192.168.1.20', port=5012):启动 Flask 服务,监听 192.168.1.20 地址的 5012 端口。运行结果如下

三、总结

通过上述代码,我们成功搭建了一个基于 Flask 和 PyTorch 的深度学习模型部署服务端。客户端可以通过向 /predict 路径发送包含图像文件的 POST 请求,获取图像分类的预测结果。在实际应用中,可以根据需要对代码进行扩展,如增加更多的模型、优化图像预处理流程、添加错误处理机制等。希望本文能帮助你更好地理解深度学习模型的部署过程。

相关推荐
亚里随笔8 分钟前
StreamRL:弹性、可扩展、异构的RLHF架构
人工智能·架构·大语言模型·rlhf·推理加速
RUZHUA18 分钟前
阿里打通内网权限,变革再出发
人工智能
陈奕昆1 小时前
4.3【LLaMA-Factory实战】教育大模型:个性化学习路径生成系统全解析
人工智能·python·学习·llama·大模型微调
wzx_Eleven1 小时前
【论文阅读】基于客户端数据子空间主角度的聚类联邦学习分布相似性高效识别
论文阅读·人工智能·机器学习·网络安全·聚类
ykjhr_3d1 小时前
场景可视化与数据编辑器:构建数据应用情境
人工智能
补三补四1 小时前
遗传算法(GA)
人工智能·算法·机器学习·启发式算法
梁小憨憨1 小时前
循环卷积(Circular Convolutions)
人工智能·笔记·深度学习·机器学习
非凡ghost1 小时前
水印云:AI赋能,让图像处理变得简单高效
图像处理·人工智能
EQ-雪梨蛋花汤1 小时前
【相机标定】OpenCV 相机标定中的重投影误差与角点三维坐标计算详解
人工智能·opencv
deepdata_cn2 小时前
双流卷积神经网络架构(OpenPose)
深度学习·人体姿态