模型部署:基于flask和pytorch

目录

一、模型部署的目的和用到的技术

二、系统架构设计

1.整体架构(c/s)

2.具体流程

三、服务端实现

[1. 服务端功能概述](#1. 服务端功能概述)

2.导入所需库

3.Flask应用初始化

[4. 模型加载函数](#4. 模型加载函数)

[5. 图像预处理函数](#5. 图像预处理函数)

[6. API路由定义](#6. API路由定义)

[7. 服务启动](#7. 服务启动)

四、客户端

[1. 客户端功能概述](#1. 客户端功能概述)

[2. 客户端代码](#2. 客户端代码)

五、部署

[1. 环境配置](#1. 环境配置)

[2. 模型准备](#2. 模型准备)

[3. 启动服务](#3. 启动服务)

[4. 修改IP地址](#4. 修改IP地址)

[5. 客户端调用](#5. 客户端调用)


在深度学习领域,模型训练只是整个项目生命周期的一部分。如何将训练好的模型部署到生产环境中,让其他应用程序能够方便地调用,是许多开发者面临的挑战。本文将详细介绍如何使用Flask框架和PyTorch,构建一个完整的图像分类API服务,实现模型的在线部署与调用。

一、模型部署的目的和用到的技术

  • 提供实时预测服务:让模型能够处理实时请求
  • 实现模型复用:避免重复训练,提高开发效率
  • 支持多端调用:Web、移动端、桌面应用都可以通过API访问
  • 便于模型更新:只需更新服务端,客户端无需改动

需要用到的:

  • Flask:轻量级Web框架,适合快速构建API服务
  • PyTorch:深度学习框架,用于加载和运行模型
  • Requests:Python HTTP库,用于客户端请求
  • PIL:图像处理库
  • Torchvision:提供预训练模型和图像变换工具

二、系统架构设计

1.整体架构(c/s)

复制代码
客户端 (Client) → HTTP请求 → Flask服务器(server) → PyTorch模型 → 预测结果
客户端 (Client) ← HTTP响应 ← Flask服务器 ← 返回结果

2.具体流程

  • 服务端启动:加载预训练模型,监听指定端口
  • 客户端请求:发送图像文件到服务器
  • 服务器处理:接收图像,预处理,模型推理
  • 结果返回:将预测结果封装成JSON格式返回
  • 客户端展示:解析结果并显示

三、服务端实现

1. 服务端功能概述

根据代码注释,服务端需要实现以下核心功能:

  • 接收来自客户端的信息,24小时运行
  • 将模型部署起来
  • 对图片进行识别
  • 将识别结果返回给客户端

2.导入所需库

python 复制代码
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models

io:处理字节流数据,用于图像读取

flask:Web框架,构建API服务

torch 和 torch.nn.functional:PyTorch核心库和函数式接口

PIL:Python图像处理库

torchvision:提供预训练模型和数据预处理工具

3.Flask应用初始化

python 复制代码
app = flask.Flask(__name__)
model = None
use_gpu = False

app = flask.Flask(name):初始化Flask应用程序实例。__name__参数用于定位应用程序的根路径,这样Flask就可以知道在哪里找到模板、静态文件等,是Flask应用程序的起点,为后续添加路由、配置等奠定了基础。

model = None:全局变量,用于存储加载的模型

use_gpu = False:是否使用GPU进行推理

4. 模型加载函数

python 复制代码
def load_model():
    """Load the pre-trained model, you can use your model just as easily."""
    global model
    # 加载resnet18网络
    model = models.resnet18()
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))  # 修改最后的类别数

    checkpoint = torch.load('best.pth')
    model.load_state_dict(checkpoint['state_dict'])
    # 将模型指定为测试模式
    model.eval()

    # 是否使用gpu
    if use_gpu:
        model.cuda()

global model:声明使用全局变量model,在函数内部修改全局变量

models.resnet18():加载ResNet18预训练网络结构

model.fc.in_features:获取原全连接层的输入特征数

nn.Sequential(nn.Linear(num_ftrs, 102)):修改最后的全连接层,适配102类分类任务

torch.load('best.pth'):加载训练好的模型权重文件

model.load_state_dict(checkpoint['state_dict']):将权重加载到模型中

model.eval():将模型设置为评估模式,关闭Dropout和Batch Normalization的训练行为

model.cuda():如果use_gpu为True,将模型移动到GPU

5. 图像预处理函数

python 复制代码
def prepare_image(image, target_size):
    if image.mode != 'RGB':
        image = image.convert('RGB')

    image = transforms.Resize(target_size)(image)
    image = transforms.ToTensor()(image)

    image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)

    image = image[None]
    if use_gpu:
        image = image.cuda()

    return torch.tensor(image)

if image.mode != 'RGB':检查图像颜色模式,如果不是RGB则转换

transforms.Resize(target_size)(image):调整图像尺寸到目标大小(224×224)

transforms.ToTensor()(image):将PIL图像转换为PyTorch张量,并将像素值从[0,255]缩放到[0.0,1.0]

transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])(image):使用ImageNet数据集的均值和标准差进行标准化

image = image[None]:在第0维添加一个维度,将形状从(C,H,W)变为(1,C,H,W),满足模型输入的batch要求

return torch.tensor(image):返回处理后的张量

6. API路由定义

python 复制代码
@app.route("/predict", methods=["POST"])
def predict():
    data = {"success": False}

    if flask.request.method == "POST":
        if flask.request.files.get("image"):
            # 读取图像文件
            image = flask.request.files["image"].read()
            image = Image.open(io.BytesIO(image))

            # 预处理图像
            image = prepare_image(image, target_size=(224, 224))
            preds = F.softmax(model(image), dim=1)
            results = torch.topk(preds.cpu().data, k=3, dim=1)
            results = (results[0].cpu().numpy(), results[1].cpu().numpy())

            data['predictions'] = list()

            for prob, label in zip(results[0][0], results[1][0]):
                r = {'label': str(label), 'probability': float(prob)}
                data['predictions'].append(r)
            data['success'] = True

    return flask.jsonify(data)

@app.route("/predict", methods=["POST"]):装饰器,将URL路径"/predict"与predict函数关联,并指定只处理POST请求

data = {"success": False}:初始化返回数据,success默认为False

flask.request.files.get("image"):获取上传的文件,key为"image"

image = flask.request.files["image"].read():读取文件的二进制数据

Image.open(io.BytesIO(image)):将二进制数据转换为PIL图像

F.softmax(model(image), dim=1):模型推理后应用softmax函数,将输出转换为概率分布

torch.topk(preds.cpu().data, k=3, dim=1):获取概率最高的前3个类别及其概率

results[0].cpu().numpy():将概率值转换为numpy数组

results[1].cpu().numpy():将类别索引转换为numpy数组

循环遍历,将结果封装成字典列表

flask.jsonify(data):将字典转换为JSON格式的HTTP响应

7. 服务启动

python 复制代码
if __name__ == '__main__':
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")
    load_model()  # 加载模型
    # 开启服务
    app.run(host='0.0.0.0', port=5010)

if name == 'main':判断是否直接运行此脚本

load_model():启动时加载模型

app.run(host='0.0.0.0', port=5010):启动Flask开发服务器,host='0.0.0.0'表示监听所有网络接口,允许局域网内的其他设备访问,port=5010指定端口号

四、客户端

1. 客户端功能概述

根据代码注释,客户端需要实现以下核心功能:

  1. 负责发送图片到指定的服务器

  2. 接收服务器端返回的结果信息

2. 客户端代码

python 复制代码
import requests

flask_url = 'http://192.168.31.83:5010/predict'  # url和端口携程自己的本地ip

def predict_result(image_path):
    image = open(image_path, 'rb').read()
    payload = {'image': image}
    r = requests.post(flask_url, files=payload).json()
    # 向fLask_url服务发送一个P0ST请求,并尝试将返回的JSON响应解析为一个Python字典

    if r['success']:
        for (i, result) in enumerate(r['predictions']):
            print('{}.预测类别为{}:的概率{}'.format(i + 1, result['label'], result['probability']))
    else:
        print('Request failed')

if __name__ == '__main__':
    predict_result('flower.jpg')

import requests:导入requests库,用于发送HTTP请求

flask_url = 'http://192.168.31.83:5010/predict':服务器端API地址,需要根据实际情况修改IP和端口

open(image_path, 'rb').read():以二进制模式打开并读取图像文件

payload = {'image': image}:构建请求数据,键名"image"必须与服务器端flask.request.files.get("image")一致

requests.post(flask_url, files=payload).json():发送POST请求,files参数用于上传文件,.json()将响应解析为Python字典

if r['success']:检查服务器返回的success字段

enumerate(r['predictions']):遍历预测结果列表,同时获取索引和值

打印每个预测结果的类别和概率

五、部署

1. 环境配置

python 复制代码
# 创建虚拟环境(可选)
python -m venv venv

# 激活虚拟环境(Windows)
venv\Scripts\activate

# 安装依赖
pip install flask torch torchvision pillow requests

2. 模型准备

确保在服务端目录下有训练好的模型文件 best.pth。根据代码,这个文件应该包含:

模型权重(通过checkpoint['state_dict']访问)

3. 启动服务

先执行server代码

4. 修改IP地址

客户端代码中的IP地址需要根据实际情况修改:

python 复制代码
flask_url = 'http://你的实际IP:5010/predict'

5. 客户端调用

相关推荐
超级学长2 小时前
Real-ESRGAN:用纯合成数据训练真实世界盲超分辨率模型
图像处理·深度学习·图像超分辨·超分辨
linxinglu2 小时前
DeepMind:解开智能之谜与「科学发现」的终极自动化杠杆
运维·人工智能·自动化
AEIC学术交流中心2 小时前
【快速EI检索 | ACM ICPS出版】2026年人工智能、虚拟现实与文化遗产国际学术会议 (AIVRCH 2026)
人工智能·vr
wenzhangli72 小时前
OUC NLP双链路闭环设计:基于ooderAgent的LLM+知识库+RAG架构深度解析
人工智能·自然语言处理·架构
KKKlucifer2 小时前
动态数据识别与分类分级一体化技术研究
人工智能·分类·数据挖掘
balmtv2 小时前
Gemini 3多模态统一架构深度拆解:从稀疏注意力到原生视频生成的工程实现
人工智能·架构·音视频
IT_陈寒2 小时前
JavaScript开发者必知的5个高效调试技巧,比console.log强10倍!
前端·人工智能·后端
m0_743297422 小时前
将Python Web应用部署到服务器(Docker + Nginx)
jvm·数据库·python