基于pth模型文件,使用flask库将服务端部署到开发者电脑

目录

一.服务端构建

[1. 依赖导入:搭建开发基础](#1. 依赖导入:搭建开发基础)

[2. Flask 应用初始化:启动 Web 服务的第一步](#2. Flask 应用初始化:启动 Web 服务的第一步)

[3. 模型加载函数:加载 ResNet18 与微调权重](#3. 模型加载函数:加载 ResNet18 与微调权重)

[4. 图像预处理函数:将输入图像转为模型可接受格式](#4. 图像预处理函数:将输入图像转为模型可接受格式)

[5. API 接口定义:处理请求与返回结果](#5. API 接口定义:处理请求与返回结果)

[6. 服务启动:加载模型并启动 Flask 服务](#6. 服务启动:加载模型并启动 Flask 服务)

二.客户端构建及测试

[1. 服务端 URL 配置:指定 API 地址](#1. 服务端 URL 配置:指定 API 地址)

[2. 核心函数predict_result:实现请求与结果解析](#2. 核心函数predict_result:实现请求与结果解析)

(1)图像文件的读取方式:二进制模式'rb'

(2)请求参数的格式:files=payload

[(3)结果解析逻辑:基于服务端的 JSON 结构](#(3)结果解析逻辑:基于服务端的 JSON 结构)

[3. 主函数调用:指定图像路径并执行](#3. 主函数调用:指定图像路径并执行)


在 AI 项目开发中,训练好的模型只有部署到实际应用中才能发挥价值。本文将以一段完整的代码为例,详细讲解如何使用 Flask 框架将 PyTorch 训练的 ResNet18 图像分类模型封装成 API 服务,实现通过网络接口接收图像、返回分类结果的功能。无论你是 AI 开发新手还是需要快速落地模型的工程师,都能通过本文掌握模型部署的核心流程。

一.服务端构建

1. 依赖导入:搭建开发基础

代码开头首先导入了项目所需的所有库,这些库覆盖了从模型处理到 Web 服务的全流程:

python 复制代码
import io                  # 内存流操作,用于处理图像字节数据
import flask               # Web框架,用于构建API服务
import torch               # PyTorch核心库,用于模型加载和推理
import torch.nn.functional as F  # PyTorch的函数库,用于计算softmax等
from PIL import Image      # 图像处理库,用于读取和转换图像格式
from torch import nn       # PyTorch的神经网络模块,用于修改模型结构
from torchvision import transforms, models  # 图像变换和预训练模型库

关键库说明

  • flask:轻量级 Web 框架,核心优势是简洁、易上手,适合快速构建 API 服务;
  • torchvision:PyTorch 官方的计算机视觉工具库,提供了 ResNet 等预训练模型和常用的图像变换方法;
  • PIL.Image:处理图像的标准库,支持多种图像格式的读取和转换;
  • io:由于 Flask 接收的是字节流形式的图像,需要用io.BytesIO将字节数据转换为图像对象。

2. Flask 应用初始化:启动 Web 服务的第一步

这部分代码完成 Flask 应用的初始化,并定义了两个全局变量用于存储模型和设备配置:

python 复制代码
# 初始化Flask app
app = flask.Flask(__name__)  # 创建Flask应用实例
# __name__参数的作用:告诉Flask应用的根路径,用于定位模板、静态文件等资源

# 全局变量:存储模型和设备配置
model = None  # 用于存储加载后的PyTorch模型,初始为None
use_gpu = False  # 是否使用GPU推理,此处默认关闭(可根据硬件调整)

核心知识点

  • flask.Flask(__name__):这是 Flask 应用的 "入口",__name__会自动识别当前脚本的路径,确保 Flask 能正确找到后续可能用到的模板(templates)或静态文件(static)目录;
  • 全局变量model:由于模型加载耗时较长,我们在服务启动时加载一次(而非每次请求都加载),通过全局变量存储模型实例,提升接口响应速度;
  • use_gpu:控制模型推理是否使用 GPU,若本地有可用 GPU,可将其设为True(需确保 PyTorch 已安装 GPU 版本)。

3. 模型加载函数:加载 ResNet18 与微调权重

load_model()函数是模型部署的核心,负责加载预训练的 ResNet18 模型、修改输出层以适配自定义分类任务,并加载训练好的模型文件:

python 复制代码
def load_model():
    global model  # 声明使用全局变量model,避免在函数内创建局部变量
    # 1. 加载预训练的ResNet18模型
    model = models.resnet18(pretrained=False)  # 注意:新版torchvision需显式设pretrained=False
    # 2. 修改模型的全连接层(fc层),适配自定义分类任务
    num_ftrs = model.fc.in_features  # 获取fc层的输入特征数(ResNet18默认是512)
    # 将fc层替换为输出102类的全连接层(因原代码分类任务为102类)
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
    # 3. 加载训练好的权重文件(best.pth)
    checkpoint = torch.load('best.pth', map_location=torch.device('cpu'))  # 加载权重到CPU
    model.load_state_dict(checkpoint['state_dict'])  # 将权重加载到模型
    # 4. 设置模型为评估模式(至关重要!)
    model.eval()
    # 5. 若使用GPU,将模型移动到GPU
    if use_gpu:
        model.cuda()

关键细节解析

  • 为什么修改fc层?ResNet18 预训练模型的fc层默认输出 1000 类(对应 ImageNet 数据集),而该项目是 102 类分类任务,因此需要替换fc层的输出维度为 102;
  • model.eval()的作用:将模型从训练模式切换到评估模式,会关闭 Dropout 层、固定 BatchNorm 层的统计参数,确保推理结果的一致性(若不设置,推理结果会异常);
  • torch.load(..., map_location=...):若use_gpu=False,需指定map_location=torch.device('cpu'),避免因权重文件是在 GPU 上训练而无法在 CPU 加载的问题。

4. 图像预处理函数:将输入图像转为模型可接受格式

PyTorch 模型对输入图像有严格的格式要求(如尺寸、归一化等),prepare_image()函数负责将用户上传的图像转换为模型能处理的张量:

python 复制代码
def prepare_image(image, target_size):
    # 1. 确保图像为RGB格式(若上传的是灰度图,转换为RGB)
    if image.mode != 'RGB':
        image = image.convert('RGB')
    # 2. 图像尺寸调整(ResNet18默认输入尺寸为224x224)
    image = transforms.Resize(target_size)(image)  # target_size=(224,224)
    # 3. 将PIL图像转为PyTorch张量(形状:[C, H, W],数值范围0-1)
    image = transforms.ToTensor()(image)
    # 4. 图像归一化(使用ImageNet数据集的均值和标准差,预训练模型的标准操作)
    image = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet数据集的RGB通道均值
        std=[0.229, 0.224, 0.225]    # ImageNet数据集的RGB通道标准差
    )(image)
    # 5. 增加批次维度(模型要求输入为[batch_size, C, H, W],单张图像需手动加batch维度)
    image = image[None]  # 等价于image.unsqueeze(0),形状从[3,224,224]变为[1,3,224,224]
    # 6. 若使用GPU,将张量移动到GPU
    if use_gpu:
        image = image.cuda()
    return image

预处理的必要性

  • 格式统一:用户可能上传灰度图、PNG 图(带透明通道)等,需转为 RGB 格式;
  • 尺寸统一:模型训练时用的是 224x224 图像,输入图像必须保持相同尺寸;
  • 归一化:消除像素值范围差异(如 0-255 转为 0-1),并使用与预训练模型一致的均值 / 标准差,确保模型推理稳定。

5. API 接口定义:处理请求与返回结果

/predict路由是整个服务的 "接口门面",负责接收客户端的 POST 请求、调用模型推理、组织返回结果:

python 复制代码
@app.route("/predict", methods=['POST'])  # 定义路由:URL为/predict,仅支持POST方法
def predict():
    # 初始化返回数据(默认success=False,避免异常时无返回)
    data = {'success': False}
    
    # 1. 检查请求方法是否为POST(POST适合传输文件等大数据)
    if flask.request.method == 'POST':
        # 2. 检查请求中是否包含'image'字段(即图像文件)
        if flask.request.files.get('image'):
            # 3. 读取图像文件(从Flask请求中获取字节数据)
            image_bytes = flask.request.files['image'].read()
            # 将字节数据转为PIL图像对象(通过io.BytesIO模拟文件流)
            image = Image.open(io.BytesIO(image_bytes))
            
            # 4. 图像预处理(转为模型可接受的张量)
            image_tensor = prepare_image(image, target_size=(224, 224))
            
            # 5. 模型推理(禁用梯度计算,提升速度并节省内存)
            with torch.no_grad():
                # 计算模型输出并通过softmax转为概率(dim=1表示按类别维度计算)
                preds = F.softmax(model(image_tensor), dim=1)
                # 获取Top3概率和对应的类别索引(k=3)
                top3_probs, top3_labels = torch.topk(preds.cpu().data, k=3, dim=1)
                # 转为numpy数组(方便后续处理)
                top3_probs = top3_probs.numpy()
                top3_labels = top3_labels.numpy()
            
            # 6. 组织返回结果(将Top3结果存入data字典)
            data['predictions'] = list()  # 用于存储多个预测结果
            # 遍历Top3结果,构造每个结果的字典
            for prob, label in zip(top3_probs[0], top3_labels[0]):
                result = {
                    "label": str(label),  # 类别索引(实际项目中可替换为类别名称)
                    "probability": float(prob)  # 对应类别的概率(保留float类型,便于JSON序列化)
                }
                data['predictions'].append(result)
            
            # 7. 标记请求成功
            data['success'] = True
    
    # 8. 以JSON格式返回结果(Flask自动处理序列化)
    return flask.jsonify(data)

接口设计关键点

  • 为什么用 POST 方法?GET 方法适合传递少量参数,而图像文件体积较大,POST 方法支持更大的请求体,且更安全(数据不暴露在 URL 中);
  • torch.no_grad():推理阶段不需要计算梯度,该上下文管理器可显著提升速度并减少内存占用,是 PyTorch 推理的标准操作;
  • 结果组织:返回success字段便于客户端判断请求是否成功,predictions字段包含 Top3 分类结果,符合实际应用中 "给出多个可能结果" 的需求;
  • 数据类型转换:PyTorch 张量需转为 numpy 数组,再转为 Python 基础类型(如float(prob)),避免 JSON 序列化失败(JSON 不支持张量类型)。、

6. 服务启动:加载模型并启动 Flask 服务

最后一段代码负责在脚本直接运行时加载模型并启动 Flask 服务:

python 复制代码
if __name__ == '__main__':
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")
    # 1. 加载模型(服务启动时仅加载一次)
    load_model()
    # 2. 启动Flask服务(默认运行在localhost:5012)
    app.run(port='5012')  # port参数指定服务端口,可根据需求修改

服务启动细节

  • if __name__ == '__main__':确保只有当脚本被直接运行时才执行服务启动逻辑(若作为模块导入则不执行);
  • 端口选择:port='5012'指定服务运行在 5012 端口,若该端口被占用,可修改为其他未占用端口(如 5000、8080 等);
  • 默认配置:app.run()默认仅允许本地(localhost)访问,若需要外部设备访问,可添加host='0.0.0.0'参数(如app.run(host='0.0.0.0', port='5012'))。

启动完成如下:

二.客户端构建及测试

首先明确这段客户端代码的目标:作为 "中间桥梁",读取本地图像文件,向服务端的/predict接口发送 POST 请求,接收并解析返回的 JSON 结果,最终以友好的格式展示给用户

它解决了两个关键问题:

  1. 如何将本地图像文件转换为服务端可接收的请求格式;
  2. 如何解析服务端返回的 JSON 数据,提取有用的分类结果(类别、概率)。

1. 服务端 URL 配置:指定 API 地址

python 复制代码
import requests  # 导入HTTP请求库,用于向服务端发送POST请求
flask_url = 'http://127.0.0.1:5012/predict'
# flask_url = 'http://192.168.2.114:5012/predict'  # 注释掉的跨设备调用地址

这是客户端与服务端建立连接的 "关键地址",需要重点理解两个参数:

  • 127.0.0.1 :表示 "本地回环地址",仅当客户端和服务端在同一台设备 (如同一台电脑)时使用,用于本地测试
  • 192.168.2.114 :服务端设备在局域网中的 IP 地址,当客户端和服务端在同一局域网下的不同设备 (如服务端在台式机、客户端在笔记本)时,需将127.0.0.1替换为服务端的实际局域网 IP;
  • 5012 :端口号,必须与服务端app.run(port='5012')配置的端口一致(端口不匹配会导致请求失败);
  • /predict :服务端定义的 API 路由,需与服务端@app.route("/predict")完全一致。

2. 核心函数predict_result:实现请求与结果解析

python 复制代码
def predict_result(image_path):
    # 1. 读取本地图像文件(以二进制模式读取,避免编码问题)
    image = open(image_path, 'rb').read()
    
    # 2. 构造请求参数(关键:用'files'格式传递二进制图像)
    payload = {'image': image}
    
    # 3. 发送POST请求到服务端,并解析返回的JSON结果
    # requests.post():发送POST请求,files参数用于传递文件类数据
    # .json():将服务端返回的JSON字符串自动转为Python字典
    r = requests.post(flask_url, files=payload).json()
    
    # 4. 根据服务端返回的'success'字段判断请求是否成功
    if r['success']:
        # 若成功,遍历Top3预测结果并格式化输出
        for (i, result) in enumerate(r['predictions']):
            # 输出格式:1.预测类别为XX的概率:0.9876
            print('{}.预测类别为{}的概率:{}'.format(i + 1, result['label'], result['probability']))
    else:
        # 若失败,提示用户请求失败
        print('Request Failed')

这部分是客户端的核心,需要重点理解以下 3 个关键点:

(1)图像文件的读取方式:二进制模式'rb'
  • 为什么用'rb'(read binary)而不是'r'(文本模式)?图像是二进制文件(包含像素值、格式信息等二进制数据),若用文本模式读取,会导致数据损坏,服务端无法解析;而'rb'模式能完整保留图像的二进制数据,确保服务端接收后能正确还原为图像。
(2)请求参数的格式:files=payload
  • 服务端flask.request.files.get('image')接收的是 "文件类型" 的请求数据,因此客户端必须用requests.post()files参数传递,而非datajson参数:
    • data:用于传递文本类型的键值对(如表单数据);
    • json:用于传递 JSON 格式的文本数据;
    • files:专门用于传递二进制文件(如图片、文档),格式为{'字段名': 二进制数据},其中 "字段名"'image'必须与服务端flask.request.files['image']的字段名完全一致。
(3)结果解析逻辑:基于服务端的 JSON 结构

服务端返回的 JSON 数据格式如下(回顾上面内容):

python 复制代码
{
    "success": true,
    "predictions": [
        {"label": "23", "probability": 0.9234},
        {"label": "45", "probability": 0.0512},
        {"label": "18", "probability": 0.0257}
    ]
}

客户端通过r = ... .json()将其转为 Python 字典后:

  • r['success']判断请求是否成功(true表示成功,false表示失败);
  • 若成功,遍历r['predictions']列表(包含 Top3 结果),通过result['label']result['probability']提取类别和概率,再用enumerate添加序号,让输出更清晰。

3. 主函数调用:指定图像路径并执行

python 复制代码
if __name__ == '__main__':
    # 调用predict_result函数,传入本地图像的路径
    predict_result('./flower_data/flower_data/val_filelist/image_00028.jpg')
  • if __name__ == '__main__':确保只有当脚本被直接运行时,才执行图像预测逻辑(若作为模块导入则不执行);
  • 图像路径:./flower_data/.../image_00028.jpg是相对路径,表示 "当前脚本所在目录下的flower_data文件夹中的目标图像";若图像在其他位置,需传入绝对路径(如'C:/Users/Admin/Desktop/test.jpg')。

测试完成如下(确保提前打开服务端):

相关推荐
NAGNIP5 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab6 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab6 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP10 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年10 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼10 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS11 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区12 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈12 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang12 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx