基于Flask测试深度学习模型预测

Flask之最易懂的基础教程一(2020年最新-从入门到精通)-CSDN博客

Flask程序运行过程:

所有Flask程序必须有一个程序实例。

Flask调用视图函数后,会将视图函数的返回值作为响应的内容,返回给客户端。一般情况下,响应内容主要是字符串和状态码。

用户向浏览器发送http请求,web服务器把客户端所有请求交给Flask程序实例,程序用Werkzeug来做路由分发,每个url请求,找到具体的视图函数。路由的实现是通过route装饰器实现的,调用视图函数,获取数据后,把数据传入模块中,模块引擎渲染响应的数据,由Flask返回给浏览器。

二、Flask框架

  1. 简介

是一个非常小的框架,可以称为微型框架,只提供了一个强劲的核心,其他的功能都需要使用拓展来实现。意味着可以根据自己的需求量身打造;

  1. 组成

调试、路由、wsgi系统

模板引擎(Jinja2)

  1. 安装

pip install flask

响应端 flask_server.py

复制代码
import io
import json
import flask
import torch
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
#from torchvision import transforms as T
from torchvision import transforms, models, datasets
from torch.autograd import Variable

# 初始化Flask app
app = flask.Flask(__name__)
model = None
use_gpu = False

# 加载模型进来
def load_model():
    """Load the pre-trained model, you can use your model just as easily.
    """
    global model
    #这里我们直接加载官方工具包里提供的训练好的模型(代码会自动下载)括号内参数为是否下载模型对应的配置信息
    model = models.resnet18()
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))  # 类别数自己根据自己任务来

    #print(model)
    checkpoint = torch.load('best.pth')
    model.load_state_dict(checkpoint['state_dict'])
    #将模型指定为测试格式
    model.eval()
    #是否使用gpu

    if use_gpu:
        model.cuda()

# 数据预处理
def prepare_image(image, target_size):
    """Do image preprocessing before prediction on any data.

    :param image:       original image
    :param target_size: target image size
    :return:
                        preprocessed image
    """
    #针对不同模型,image的格式不同,但需要统一至RGB格式
    if image.mode != 'RGB':
        image = image.convert("RGB")

    # Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改,并转为tensor)
    image = transforms.Resize(target_size)(image)
    image = transforms.ToTensor()(image)

    # Convert to Torch.Tensor and normalize. mean与std   (RGB三通道)这里的参数和数据集中是对应的,训练过程中一致
    image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)

    # Add batch_size axis.增加一个维度,用于按batch测试   本次这里一次测试一张
    image = image[None]
    if use_gpu:
        image = image.cuda()
    return Variable(image, volatile=True) #不需要求导

# 开启服务   这里的predict只是一个名字,可自定义
@app.route("/predict", methods=["POST"])
def predict():
    # Initialize the data dictionary that will be returned from the view.
    #做一个标志,刚开始无图像传入时为false,传入图像时为true
    data = {"success": False}

    # 如果收到请求
    if flask.request.method == 'POST':
        #判断是否为图像
        if flask.request.files.get("image"):
            # Read the image in PIL format
            # 将收到的图像进行读取
            image = flask.request.files["image"].read()
            image = Image.open(io.BytesIO(image)) #二进制数据

            # 利用上面的预处理函数将读入的图像进行预处理
            image = prepare_image(image, target_size=(64, 64))

            preds = F.softmax(model(image), dim=1)
            results = torch.topk(preds.cpu().data, k=3, dim=1)
            results = (results[0].cpu().numpy(), results[1].cpu().numpy())

            #将data字典增加一个key,value,其中value为list格式
            data['predictions'] = list()

            # Loop over the results and add them to the list of returned predictions
            for prob, label in zip(results[0][0], results[1][0]):
                #label_name = idx2label[str(label)]
                r = {"label": str(label), "probability": float(prob)}
                #将预测结果添加至data字典
                data['predictions'].append(r)

            # Indicate that the request was a success.
            data["success"] = True
    # 将最终结果以json格式文件传出
    return flask.jsonify(data)

"""
test_json = {
                "status_code": 200,
                "success": {
                            "message": "image uploaded",
                            "code": 200
                        },
                "video":{
                    "video_name":opt['source'].split('/')[-1],
                    "video_path":opt['source'],
                    "description":"1",
                    "length": str(hour)+','+str(minute)+','+str(round(second,4)),
                    "model_object_completed":model_flag
                    }
                    "status_txt": "OK"
                    }
                    response = requests.post(
                        'http://xxx.xxx.xxx.xxx:8090/api/ObjectToKafka/',,
                        data={'json': str(test_json)})
"""

if __name__ == '__main__':
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")
    #先加载模型
    load_model()
    #再开启服务
    app.run(port='5012')

请求端 flask_predict.py

复制代码
import requests
import argparse

# url和端口携程自己的
flask_url = 'http://127.0.0.1:5012/predict'


def predict_result(image_path):
    #啥方法都行
    image = open(image_path, 'rb').read()
    payload = {'image': image}
    #request发给server.
    r = requests.post(flask_url, files=payload).json()

    # 成功的话在返回.
    if r['success']:
        # 输出结果.
        for (i, result) in enumerate(r['predictions']):
            print('{}. {}: {:.4f}'.format(i + 1, result['label'],
                                          result['probability']))
    # 失败了就打印.
    else:
        print('Request failed')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Classification demo')
    parser.add_argument('--file', default=r'D:\paogu\flask预测\flower_data\train\1\image_06734.jpg', type=str, help='test image file')

    args = parser.parse_args()
    predict_result(args.file)
相关推荐
程序员清风24 分钟前
贝壳一面:年轻代回收频率太高,如何定位?
java·后端·面试
摆烂z33 分钟前
Jupyter Notebook的交互式开发环境方便py开发
ide·python·jupyter
考虑考虑35 分钟前
Java实现字节转bcd编码
java·后端·java ee
AAA修煤气灶刘哥1 小时前
ES 聚合爽到飞起!从分桶到 Java 实操,再也不用翻烂文档
后端·elasticsearch·面试
爱读源码的大都督1 小时前
Java已死?别慌,看我如何用Java手写一个Qwen Code Agent,拯救Java
java·人工智能·后端
星辰大海的精灵2 小时前
SpringBoot与Quartz整合,实现订单自动取消功能
java·后端·算法
天天摸鱼的java工程师2 小时前
RestTemplate 如何优化连接池?—— 八年 Java 开发的踩坑与优化指南
java·后端
一乐小哥2 小时前
一口气同步10年豆瓣记录———豆瓣书影音同步 Notion分享 🚀
后端·python
LSTM972 小时前
如何使用C#实现Excel和CSV互转:基于Spire.XLS for .NET的专业指南
后端
三十_2 小时前
【NestJS】构建可复用的数据存储模块 - 动态模块
前端·后端·nestjs