基于Flask测试深度学习模型预测

Flask之最易懂的基础教程一(2020年最新-从入门到精通)-CSDN博客

Flask程序运行过程:

所有Flask程序必须有一个程序实例。

Flask调用视图函数后,会将视图函数的返回值作为响应的内容,返回给客户端。一般情况下,响应内容主要是字符串和状态码。

用户向浏览器发送http请求,web服务器把客户端所有请求交给Flask程序实例,程序用Werkzeug来做路由分发,每个url请求,找到具体的视图函数。路由的实现是通过route装饰器实现的,调用视图函数,获取数据后,把数据传入模块中,模块引擎渲染响应的数据,由Flask返回给浏览器。

二、Flask框架

  1. 简介

是一个非常小的框架,可以称为微型框架,只提供了一个强劲的核心,其他的功能都需要使用拓展来实现。意味着可以根据自己的需求量身打造;

  1. 组成

调试、路由、wsgi系统

模板引擎(Jinja2)

  1. 安装

pip install flask

响应端 flask_server.py

复制代码
import io
import json
import flask
import torch
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
#from torchvision import transforms as T
from torchvision import transforms, models, datasets
from torch.autograd import Variable

# 初始化Flask app
app = flask.Flask(__name__)
model = None
use_gpu = False

# 加载模型进来
def load_model():
    """Load the pre-trained model, you can use your model just as easily.
    """
    global model
    #这里我们直接加载官方工具包里提供的训练好的模型(代码会自动下载)括号内参数为是否下载模型对应的配置信息
    model = models.resnet18()
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))  # 类别数自己根据自己任务来

    #print(model)
    checkpoint = torch.load('best.pth')
    model.load_state_dict(checkpoint['state_dict'])
    #将模型指定为测试格式
    model.eval()
    #是否使用gpu

    if use_gpu:
        model.cuda()

# 数据预处理
def prepare_image(image, target_size):
    """Do image preprocessing before prediction on any data.

    :param image:       original image
    :param target_size: target image size
    :return:
                        preprocessed image
    """
    #针对不同模型,image的格式不同,但需要统一至RGB格式
    if image.mode != 'RGB':
        image = image.convert("RGB")

    # Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改,并转为tensor)
    image = transforms.Resize(target_size)(image)
    image = transforms.ToTensor()(image)

    # Convert to Torch.Tensor and normalize. mean与std   (RGB三通道)这里的参数和数据集中是对应的,训练过程中一致
    image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)

    # Add batch_size axis.增加一个维度,用于按batch测试   本次这里一次测试一张
    image = image[None]
    if use_gpu:
        image = image.cuda()
    return Variable(image, volatile=True) #不需要求导

# 开启服务   这里的predict只是一个名字,可自定义
@app.route("/predict", methods=["POST"])
def predict():
    # Initialize the data dictionary that will be returned from the view.
    #做一个标志,刚开始无图像传入时为false,传入图像时为true
    data = {"success": False}

    # 如果收到请求
    if flask.request.method == 'POST':
        #判断是否为图像
        if flask.request.files.get("image"):
            # Read the image in PIL format
            # 将收到的图像进行读取
            image = flask.request.files["image"].read()
            image = Image.open(io.BytesIO(image)) #二进制数据

            # 利用上面的预处理函数将读入的图像进行预处理
            image = prepare_image(image, target_size=(64, 64))

            preds = F.softmax(model(image), dim=1)
            results = torch.topk(preds.cpu().data, k=3, dim=1)
            results = (results[0].cpu().numpy(), results[1].cpu().numpy())

            #将data字典增加一个key,value,其中value为list格式
            data['predictions'] = list()

            # Loop over the results and add them to the list of returned predictions
            for prob, label in zip(results[0][0], results[1][0]):
                #label_name = idx2label[str(label)]
                r = {"label": str(label), "probability": float(prob)}
                #将预测结果添加至data字典
                data['predictions'].append(r)

            # Indicate that the request was a success.
            data["success"] = True
    # 将最终结果以json格式文件传出
    return flask.jsonify(data)

"""
test_json = {
                "status_code": 200,
                "success": {
                            "message": "image uploaded",
                            "code": 200
                        },
                "video":{
                    "video_name":opt['source'].split('/')[-1],
                    "video_path":opt['source'],
                    "description":"1",
                    "length": str(hour)+','+str(minute)+','+str(round(second,4)),
                    "model_object_completed":model_flag
                    }
                    "status_txt": "OK"
                    }
                    response = requests.post(
                        'http://xxx.xxx.xxx.xxx:8090/api/ObjectToKafka/',,
                        data={'json': str(test_json)})
"""

if __name__ == '__main__':
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")
    #先加载模型
    load_model()
    #再开启服务
    app.run(port='5012')

请求端 flask_predict.py

复制代码
import requests
import argparse

# url和端口携程自己的
flask_url = 'http://127.0.0.1:5012/predict'


def predict_result(image_path):
    #啥方法都行
    image = open(image_path, 'rb').read()
    payload = {'image': image}
    #request发给server.
    r = requests.post(flask_url, files=payload).json()

    # 成功的话在返回.
    if r['success']:
        # 输出结果.
        for (i, result) in enumerate(r['predictions']):
            print('{}. {}: {:.4f}'.format(i + 1, result['label'],
                                          result['probability']))
    # 失败了就打印.
    else:
        print('Request failed')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Classification demo')
    parser.add_argument('--file', default=r'D:\paogu\flask预测\flower_data\train\1\image_06734.jpg', type=str, help='test image file')

    args = parser.parse_args()
    predict_result(args.file)
相关推荐
geovindu12 小时前
go: Lock/Mutex Pattern
开发语言·后端·设计模式·golang·互斥锁模式
counterxing12 小时前
AI Agent 做长任务,问题到底 出在哪?
前端·后端·ai编程
aiopencode13 小时前
iOS开发中Xcode安装不完整问题解决方案与配置指南
后端·ios
狐狐生风13 小时前
使用 UV 创建并运行 Python 项目(完整步骤)
python·uv
该用户已不存在13 小时前
别让 Claude Code 果奔,用 Claude Code MCP 与 Skills 打造自动化开发(Part 2)
后端·ai编程·claude
噜噜噜阿鲁~13 小时前
python学习笔记 | 9.2、模块-安装第三方模块
笔记·python·学习
现代野蛮人13 小时前
【深度学习】 —— VGG-16 网络实现猫狗识别
网络·人工智能·python·深度学习·tensorflow
一个小猴子`13 小时前
Pytorch快速复习
人工智能·pytorch·python
wang3zc13 小时前
mysql如何提升InnoDB写入性能_对比MyISAM的写入锁机制
jvm·数据库·python
一起逃去看海吧13 小时前
工作流原理和实践
python