使用 Flask 实现本机 PyTorch 模型部署:从服务端搭建到客户端调用

目录

前言

一、部署前准备

[1.1 环境要求](#1.1 环境要求)

[1.2 必备文件](#1.2 必备文件)

[二、服务端搭建:让模型 "听候指令"](#二、服务端搭建:让模型 “听候指令”)

[2.1 服务端完整代码(server.py)](#2.1 服务端完整代码(server.py))

[2.2 服务端启动与验证](#2.2 服务端启动与验证)

[三、客户端开发:向服务端 "发送请求"](#三、客户端开发:向服务端 “发送请求”)

[3.1 客户端完整代码(client.py)](#3.1 客户端完整代码(client.py))

[3.2 客户端运行与结果示例](#3.2 客户端运行与结果示例)

四、常见问题排查

[4.1 服务端启动失败](#4.1 服务端启动失败)

[4.2 客户端连接失败](#4.2 客户端连接失败)

[4.3 模型推理报错](#4.3 模型推理报错)

五、扩展与优化建议

总结


前言

在机器学习项目中,训练好的模型只有部署到实际环境中才能发挥价值。对于本机测试或小规模应用场景,Flask 框架是实现模型部署的轻量优选 ------ 它能快速搭建 HTTP 服务,让模型以接口形式接收请求、返回预测结果,无需复杂的服务器配置。

本文将以 ResNet18 图像分类模型为例,完整讲解如何用 Flask 实现 "本机模型部署":从服务端代码编写(模型加载、接口定义),到客户端代码开发(图像上传、结果解析),再到常见问题排查,确保新手也能一步到位跑通流程。

一、部署前准备

在开始编写代码前,需先确认环境和依赖是否齐全,避免后续因版本或包缺失导致报错。

1.1 环境要求

  • Python 版本:3.7~3.9(PyTorch 对高版本 Python 兼容性可能不稳定)
  • 核心依赖包:
python 复制代码
# 安装Flask(Web服务框架)
pip install flask
# 安装PyTorch+TorchVision(模型加载与图像预处理)
pip install torch torchvision
# 安装PIL(图像读取处理)和requests(客户端请求)
pip install pillow requests

1.2 必备文件

  • 训练好的模型权重文件:本文使用best.pth(ResNet18 微调后权重,需确保与代码中类别数匹配)
  • 测试图像:准备 1~2 张用于验证的图像(如 JPG/PNG 格式)
  • 代码结构:建议按如下目录组织,避免路径混乱

二、服务端搭建:让模型 "听候指令"

服务端的核心作用是:加载预训练模型、定义预测接口、监听本机请求。当客户端发送图像请求时,服务端会完成图像预处理、模型推理,并返回 JSON 格式的预测结果。

2.1 服务端完整代码(server.py

python 复制代码
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models

# 1. 初始化Flask应用
app = flask.Flask(__name__)  # __name__定位应用根路径,用于查找静态资源
model = None  # 全局变量存储模型,避免重复加载
use_gpu = False  # 本机部署默认用CPU(若有GPU可设为True,需确保PyTorch支持CUDA)


# 2. 加载预训练模型
def load_model():
    """加载ResNet18模型,替换全连接层适配自定义分类任务"""
    global model
    # 加载ResNet18基础网络(pretrained=False表示不加载默认预训练权重,用自己的best.pth)
    model = models.resnet18(pretrained=False)
    # 获取全连接层输入特征数,替换为自定义类别数(本文以102类为例)
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))  # 输出维度=类别数
    
    # 加载训练好的权重文件(需确保best.pth路径正确)
    checkpoint = torch.load('best.pth', map_location=torch.device('cpu'))  # 强制CPU加载,避免GPU报错
    model.load_state_dict(checkpoint['state_dict'])  # 加载权重参数
    
    # 设为评估模式(禁用Dropout、BatchNorm等训练特有的层)
    model.eval()
    
    # 若启用GPU且设备支持,将模型移至CUDA
    if use_gpu and torch.cuda.is_available():
        model = model.cuda()
    print("模型加载完成,等待请求...")


# 3. 图像预处理函数(需与训练时保持一致)
def prepare_image(image, target_size=(224, 224)):
    """将客户端传入的图像转为模型可接受的Tensor格式"""
    # 统一图像为RGB格式(避免灰度图/透明图报错)
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # 预处理 pipeline:Resize→ToTensor→Normalize(ResNet默认预处理参数)
    preprocess = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image_tensor = preprocess(image)
    
    # 增加batch维度(模型要求输入为[batch_size, C, H, W],单张图batch_size=1)
    image_tensor = image_tensor.unsqueeze(0)  # 等价于image_tensor[None]
    
    # 若启用GPU,将Tensor移至CUDA
    if use_gpu and torch.cuda.is_available():
        image_tensor = image_tensor.cuda()
    return image_tensor


# 4. 定义预测接口(POST方法)
@app.route("/predict", methods=["POST"])
def predict():
    """
    接收客户端POST请求:
    - 请求体包含image字段(二进制图像)
    - 返回JSON格式结果:success(请求状态)、predictions(Top3预测结果)
    """
    # 初始化返回结果字典
    result = {"success": False, "predictions": []}
    
    # 检查请求方法是否为POST,且包含image文件
    if flask.request.method == "POST" and flask.request.files.get("image"):
        # 步骤1:读取客户端传入的二进制图像
        image_bytes = flask.request.files["image"].read()
        # 将二进制数据转为PIL图像对象
        image = Image.open(io.BytesIO(image_bytes))
        
        # 步骤2:图像预处理
        image_tensor = prepare_image(image, target_size=(224, 224))
        
        # 步骤3:模型推理(禁用梯度计算,加速推理)
        with torch.no_grad():
            # 计算各类别概率(softmax归一化)
            preds = F.softmax(model(image_tensor), dim=1)
            # 获取概率Top3的类别和概率值(cpu()转为CPU张量,避免GPU与CPU数据冲突)
            top3_probs, top3_labels = torch.topk(preds.cpu(), k=3, dim=1)
        
        # 步骤4:处理结果(转为numpy数组→构造JSON格式)
        top3_probs = top3_probs.numpy()[0]  # 取第1个batch(仅1张图)
        top3_labels = top3_labels.numpy()[0]
        
        # 遍历Top3结果,添加到返回字典
        for prob, label in zip(top3_probs, top3_labels):
            result["predictions"].append({
                "label": str(label),  # 类别标签(若有类别名映射,可此处替换为中文)
                "probability": round(float(prob), 4)  # 概率值(保留4位小数)
            })
        
        # 标记请求成功
        result["success"] = True
    
    # 以JSON格式返回结果(Flask自动设置Content-Type为application/json)
    return flask.jsonify(result)


# 5. 启动服务
if __name__ == "__main__":
    # 先加载模型(确保模型加载成功后再启动服务)
    try:
        load_model()
    except Exception as e:
        print(f"模型加载失败:{str(e)}")
        exit(1)  # 模型加载失败则退出程序
    
    # 启动Flask服务(本机部署关键参数)
    # host='127.0.0.1':仅本机可访问(推荐本机测试用)
    # port=5012:端口号(避免与其他服务冲突,如8080、5000)
    app.run(host='127.0.0.1', port=5012, debug=False)

2.2 服务端启动与验证

  1. 运行server.py,若控制台输出以下内容,说明服务启动成功:

2.初步验证:打开浏览器访问http://127.0.0.1:5012/predict,若返回{"success":false,"predictions":[]},证明接口正常监听(无图像请求时返回默认状态)。

三、客户端开发:向服务端 "发送请求"

客户端的作用是:读取本地图像、以 POST 方式向服务端接口发送请求、解析返回的 JSON 结果并打印。

3.1 客户端完整代码(client.py

python 复制代码
import requests

# 1. 配置服务端地址(需与服务端host和port一致)
# 本机部署用http://127.0.0.1:5012/predict,跨设备需替换为服务端局域网IP
FLASK_URL = "http://127.0.0.1:5012/predict"


def predict_image(image_path):
    """
    向服务端发送图像预测请求:
    param image_path: 本地图像路径(如"./test_img/image_06975.jpg")
    return: 打印预测结果
    """
    try:
        # 步骤1:以二进制形式读取本地图像(保持图像原始格式)
        with open(image_path, 'rb') as f:
            image_bytes = f.read()
        
        # 步骤2:构造请求体(key为"image",与服务端flask.request.files.get("image")对应)
        payload = {"image": image_bytes}
        
        # 步骤3:发送POST请求(timeout设为10秒,避免请求超时)
        response = requests.post(FLASK_URL, files=payload, timeout=10)
        
        # 步骤4:解析响应结果(JSON→字典)
        result = response.json()
        
        # 步骤5:判断请求是否成功并打印结果
        if response.status_code == 200 and result["success"]:
            print("请求成功(状态码:200)")
            print("Top3预测结果:")
            for i, pred in enumerate(result["predictions"], 1):
                print(f"  {i}. 类别:{pred['label']},概率:{pred['probability']}")
        else:
            print(f"请求失败:{result}")
    
    except Exception as e:
        print(f"客户端报错:{str(e)}")


# 6. 运行客户端(测试单张图像)
if __name__ == "__main__":
    # 替换为你的测试图像路径(相对路径/绝对路径均可)
    test_image_path = "./test_img/image_06975.jpg"
    predict_image(test_image_path)

3.2 客户端运行与结果示例

  1. 确保服务端已启动,运行client.py,若请求成功,控制台输出如下:

状态码说明:

◦ 200:请求成功(服务端正常处理)

◦ 404:接口地址错误(如 URL 写错)

◦ 500:服务端内部错误(如模型加载失败、代码报错)

四、常见问题排查

在部署过程中,新手容易遇到连接失败、模型报错等问题,以下是高频问题的解决方案:

4.1 服务端启动失败

  • 报错 1:FileNotFoundError: [Errno 2] No such file or directory: 'best.pth' 原因:模型权重文件路径错误。解决:确认best.pthserver.py在同一目录,或使用绝对路径(如torch.load("C:/model_deployment/best.pth"))。

  • 报错 2:WinError 10049 请求的地址无效 原因:app.run(host=...)中 IP 地址错误(非本机 IP)。解决:本机部署用host='127.0.0.1',或通过ipconfig(Windows)/ifconfig(Linux)查看本机正确 IP。

4.2 客户端连接失败

  • 报错 1:requests.exceptions.ConnectionError: HTTPConnectionPool 原因:服务端未启动,或 IP / 端口不匹配。解决:先启动服务端,确认客户端FLASK_URL与服务端host:port完全一致(如均为127.0.0.1:5012)。

  • **报错 2:PIL.UnidentifiedImageError: cannot identify image file**原因:图像路径错误,或文件不是有效图像(如后缀为.jpg 但实际是.txt)。解决:检查图像路径,用画图工具打开图像确认是否正常。

4.3 模型推理报错

  • 报错:RuntimeError: Expected 4-dimensional input for 4-dimensional weight 原因:图像未增加 batch 维度(模型要求输入为[batch_size, C, H, W])。解决:确保prepare_image函数中调用image_tensor.unsqueeze(0)

五、扩展与优化建议

  1. 类别名映射 :当前返回的是类别编号(如 35),可添加字典映射为中文(如label_map = {35: "玫瑰", 36: "百合"}),在服务端predict函数中替换"label": str(label)"label": label_map[label]

  2. 多图批量预测:修改客户端代码,支持遍历文件夹下所有图像,批量发送请求。

  3. 生产环境优化 :Flask 开发服务器不适合生产环境,可改用Gunicorn(Linux)或Waitress(Windows)作为 WSGI 服务器,搭配Nginx反向代理,提升并发能力。

总结

本文通过 Flask 框架实现了本机 PyTorch 模型的完整部署流程:服务端负责加载模型和提供接口,客户端负责发送请求和解析结果,整个流程轻量、易上手,适合小规模测试或个人项目使用。

核心要点可总结为 3 点:

  1. 服务端与客户端的host:port必须一致;
  2. 图像预处理需与训练时保持一致(如 Resize 尺寸、Normalize 参数);
  3. 先启动服务端,再运行客户端,避免连接失败。

按照本文步骤操作,即可快速将自己的 PyTorch 模型部署到本机,实现 "训练→部署→调用" 的闭环。

相关推荐
后端小肥肠7 小时前
【n8n 入门系列】10 分钟部署 n8n,手把手教你搭第一个自动化工作流,小白可学!
人工智能·aigc
mwq301237 小时前
从 Word2Vec 到 GPT:词向量的上下文进化史
人工智能
(时光煮雨)7 小时前
【Python进阶】Python爬虫-Selenium
爬虫·python·selenium
爱读源码的大都督7 小时前
RAG效果不理想?试试用魔法打败魔法:让大模型深度参与优化的三阶段实战
java·人工智能·后端
小政同学8 小时前
【Python】小练习-考察变量作用域问题
开发语言·python
Lynnxiaowen8 小时前
今天我们开始学习python3编程之python基础
linux·运维·python·学习
青青草原羊村懒大王8 小时前
1、pycharm相关知识
python
嫂子的姐夫8 小时前
10-七麦js扣代码
前端·javascript·爬虫·python·node.js·网络爬虫
极客BIM工作室8 小时前
机器学习之规则学习(Rule Learning)
人工智能·机器学习