模型部署——Flask 部署 PyTorch 模型


一、整体功能概述

这两段代码组合起来实现了一个 深度学习图像分类推理系统

  • 代码一(服务端)

    使用 Flask 搭建 HTTP 服务器,加载一个 PyTorch 训练好的模型(如 ResNet18),接受图片上传请求,并返回分类预测结果(前 3 名类别与概率)。

  • 代码二(客户端)

    使用 requests 库向服务端发送图片(HTTP POST 请求),获取预测结果并打印。

这种结构在工业场景中非常常见,被称为:

模型服务化部署\]+\[客户端调用


二、运行流程详解

复制代码
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models

(1)模型加载阶段

复制代码
# 初始化Flask app
app = flask.Flask(__name__)  # 创建一个新的Flask应用程序实例,
# __name__参数通常被传递给Flask应用程序来定位应用程序中的模块,这样Flask就可以知道在哪里找到模板、静态文件等。
# 总体来说app = flask.Flask(__name__)是Flask应用程序的起点。它初始化了一个新的Flask应用程序实例,为后续添加路由、配置等奠定了基础。
use_gpu = False  # 是否使用GPU训练


def load_model():
    global model
    # 加载resnet18网络
    model = models.resnet18()
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))  # 类别数自己根据自己任务来

    # print(model)
    checkpoint = torch.load("best.pth")
    model.load_state_dict(checkpoint['state_dict'])
    # 将模型指定为测试格式
    model.eval()

    # 是否使用gpu
    if use_gpu:
        model.cuda()

解释:

初始化Flask实例

使用 torchvision 内置的 resnet18 结构;

修改全连接层输出改为 102 类;

best.pth 文件加载权重;

state_dict,用于存储模型的所有可学习参数。

  • torch.load('best.pth')

    从磁盘读取 .pth 文件,反序列化为 Python 字典。

  • checkpoint['state_dict']

    提取出模型参数部分(即训练保存的 state_dict)。

  • model.load_state_dict(...)

    把这些参数加载到你当前定义的模型结构中去。

设定为评估模式(model.eval())以禁用 dropout、BN 更新。


(2)图像预处理阶段

复制代码
def prepare_image(image, target_size):
    # 针对不同情况,image格式不一样,需要统一到RGB格式
    if image.mode != "RGB":
        image = image.convert("RGB")

    # 按照所使用的模型调整输入图片的尺寸格式,并转为tensor
    image = transforms.Resize(target_size)(image)   #.forword(image)
    image = transforms.ToTensor()(image)

    # (RGB三通道)这的参数和数据集中是对应的
    image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)

    # 增加一个维度,用于做batch测试  本次这里一次测试一张
    image = image[None]
    if use_gpu:
        image = image.cuda()
    return torch.tensor(image)

功能:

将输入图像统一为 RGB;

调整大小至模型输入要求;

转换为 Tensor;

像素分布进行归一化(标准化)按 ImageNet 统计值进行标准化;

增加 batch 维度,使输入符合 (1, 3, 224, 224) 形状。

torch.Size([3, 224, 224]) -> torch.Size([1, 3, 224, 224])

  • 3 → 通道数(R、G、B)

  • 224 → 图像高度

  • 224 → 图像宽度

多了一个 batch 维度,表示"1 张图像"。batch size(一次输入的图片数量)

对于模型推理(inference)是必不可少的。


(3)请求接收与预测阶段

复制代码
@app.route("/predict", methods=["POST"])
def predict():
    # 做一个标志,刚开始无图像传入时为False,传入图像时为True
    data = {"success": False}

    if flask.request.method == "POST":  # 如果收到请求
        if flask.request.files.get("image"):  # 判断是否为图像
            image = flask.request.files["image"].read()  # 收到的图像进行读取,内容为二进制
            # BytesIO()提供了一个类似文件对象的接口,允许你像操作文件一样在内存中读写二进制。传输过来的数据一般在缓存
            image = Image.open(io.BytesIO(image))

            # 利用上面的预处理函数对传入的图像进行预处理
            image = prepare_image(image, target_size=(224, 224))
            preds = F.softmax(model(image), dim=1)  # 得到各个类的概率
            results = torch.topk(preds.cpu().data, k=3, dim=1)  # 概率最大的前3个结果
            # torch.topk用于返回输入张量中每行的最大的k个元素及其对应的索引
            results = (results[0].cpu().numpy(), results[1].cpu().numpy())

            # 向data字典增加一个key,value,其中value为list格式
            data["predictions"] = list()

流程:

项目 含义
@app.route() Flask 用于定义 路由 的装饰器
"/predict" 指定访问的 URL 路径
methods=["POST"] 指定 HTTP 请求方式(这里只允许 POST)

这意味着:

当客户端向服务器发送一个 POST 请求 ,访问 /predict 路径时,

Flask 就会调用下面定义的 predict() 函数来处理。

  1. Flask 接收到客户端上传的图片文件;

  2. 读取图片二进制流并转为 PIL.Image

  3. 调用预处理函数;

  4. 使用模型进行前向推理;

  5. 计算 softmax 概率;

  6. 取出 top-3 预测结果。

  7. 转回 CPU 将 PyTorch 张量转换为 NumPy 数组


(4)返回结果给客户端

复制代码
            for prob, label in zip(results[0][0], results[1][0]):
                r = {"label": str(label), "probability": float(prob)}
                # 将预测结果添加到data字典
                data["predictions"].append(r)

            data["success"] = True
    # jsonify非常有利于网络传输,字典可以直接转换为json文件的内容,用作网络传输,一般都用json格式的数据
    return flask.jsonify(data)

结果示例:

这段结果将通过 JSON 格式返回,供客户端使用。


(5)服务启动

复制代码
if __name__ == "__main__": #是Python中的内置变量,是有值,和你当前运行的代码文件有关
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")

    load_model()  # 先加载模型
    # 再开启服务
    app.run(port='5012')  # 最好都有自己的端口号,有些端口是固定的(Linux),若使用自己的电脑,端口号为5012来跑这个保护、服务器。进入1

Flask 启动后默认监听本机地址:

http://127.0.0.1:5012/predict


三、客户端请求

复制代码
import requests
'''客户端的程序代码'''

flask_url = 'http://127.0.0.1:5012/predict'  # url和端口写成自己的本地ip
def predict_result(image_path):
    image = open(image_path, 'rb').read()
    payload = {'image': image}  #
    r = requests.post(flask_url, files = payload).json()  # 是因为网络传输一般使用json类型的数据 字符类型的数据
    # 向flask_url服务发送一个POST请求,并尝试将返回的JSON响应解析为一个Python字典。
    if r['success']:  # 成功的话再返回。
        # 输出结果
        for (i, result) in enumerate(r['predictions']):
            print('{}.预测类别为{}:的概率: {}'.format(i + 1, result['label'], result['probability']))
    # 失败了就打印
    else:
        print('Request failed')

if __name__ == '__main__':
    predict_result('./flower_data/val_filelist/image_00028.jpg')

流程:

  • 指定服务器 IP 和端口;

  • 读取本地图片为二进制;

  • 使用 requests.post() 向 Flask 服务发送请求;

  • 解析返回的 JSON 结果;

  • 打印预测的 top-3 类别与概率。enumerate()用来在 遍历可迭代对象(如列表、元组、字符串等)时,同时获取每个元素的索引和值。

输出示例:


四、运行准备与注意事项

  1. 模型文件

    需在 Flask 目录下存在 best.pth 文件(训练好的权重)。

    文件结构大致为:

    复制代码
    ./server.py
    ./best.pth
    ./flower_data/
  2. IP 与端口

    • 服务端(代码一)运行主机的 IP 地址需与客户端一致;

    • 可在命令行中使用以下命令查看 IP:

      复制代码
      ipconfig  # Windows
      ifconfig  # Linux/Mac
  3. GPU 支持

    若要启用 GPU,请修改:

    复制代码
    use_gpu = True

    并确保 PyTorch 安装了 CUDA 版本。

  4. 跨机器访问时防火墙

    需开放端口 5012:

    复制代码
    netsh advfirewall firewall add rule name="Flask5012" dir=in action=allow protocol=TCP localport=5012

🧩 五、可优化与扩展建议

模块 优化方向 建议
模型加载 减少重复加载 可在全局定义模型并仅加载一次
接口设计 更友好输出 返回类别名称而非数字索引(通过类别映射表)
性能优化 批处理或异步请求 可支持批量图片预测
安全性 限制上传文件类型 仅允许 .jpg, .png 等图片格式
部署 使用 gunicorn + nginx 提升并发与稳定性
客户端 GUI 或 Web 页面 可做简单上传界面查看预测结果
相关推荐
羊羊小栈3 小时前
基于「多模态大模型 + BGE向量检索增强RAG」的航空维修智能问答系统(vue+flask+AI算法)
vue.js·人工智能·python·语言模型·flask·毕业设计
weixin_456904273 小时前
SHAP可视化代码详细讲解
python
DTS小夏3 小时前
算法社Python基础入门面试题库(新手版·含答案)
python·算法·面试
刘一哥GIS3 小时前
Windows环境搭建:PostGreSQL+PostGIS安装教程
数据库·python·arcgis·postgresql·postgis
西柚小萌新3 小时前
【深入浅出PyTorch】--4.PyTorch基础实战
人工智能·pytorch·python
用户8356290780513 小时前
掌控PDF页面:使用Python轻松实现添加与删除
后端·python
用户3721574261354 小时前
Python 实现 Excel 文件加密与保护
python
Derrick__14 小时前
Python访问数据库——使用SQLite
数据库·python·sqlite
总有刁民想爱朕ha4 小时前
AI大模型学习(17)python-flask AI大模型和图片处理工具的从一张图到多平台适配的简单方法
人工智能·python·学习·电商图片处理