目录
[1. 依赖导入:搭建开发基础](#1. 依赖导入:搭建开发基础)
[2. Flask 应用初始化:启动 Web 服务的第一步](#2. Flask 应用初始化:启动 Web 服务的第一步)
[3. 模型加载函数:加载 ResNet18 与微调权重](#3. 模型加载函数:加载 ResNet18 与微调权重)
[4. 图像预处理函数:将输入图像转为模型可接受格式](#4. 图像预处理函数:将输入图像转为模型可接受格式)
[5. API 接口定义:处理请求与返回结果](#5. API 接口定义:处理请求与返回结果)
[6. 服务启动:加载模型并启动 Flask 服务](#6. 服务启动:加载模型并启动 Flask 服务)
[1. 服务端 URL 配置:指定 API 地址](#1. 服务端 URL 配置:指定 API 地址)
[2. 核心函数predict_result:实现请求与结果解析](#2. 核心函数predict_result:实现请求与结果解析)
[(3)结果解析逻辑:基于服务端的 JSON 结构](#(3)结果解析逻辑:基于服务端的 JSON 结构)
[3. 主函数调用:指定图像路径并执行](#3. 主函数调用:指定图像路径并执行)
在 AI 项目开发中,训练好的模型只有部署到实际应用中才能发挥价值。本文将以一段完整的代码为例,详细讲解如何使用 Flask 框架将 PyTorch 训练的 ResNet18 图像分类模型封装成 API 服务,实现通过网络接口接收图像、返回分类结果的功能。无论你是 AI 开发新手还是需要快速落地模型的工程师,都能通过本文掌握模型部署的核心流程。
一.服务端构建
1. 依赖导入:搭建开发基础
代码开头首先导入了项目所需的所有库,这些库覆盖了从模型处理到 Web 服务的全流程:
python
import io # 内存流操作,用于处理图像字节数据
import flask # Web框架,用于构建API服务
import torch # PyTorch核心库,用于模型加载和推理
import torch.nn.functional as F # PyTorch的函数库,用于计算softmax等
from PIL import Image # 图像处理库,用于读取和转换图像格式
from torch import nn # PyTorch的神经网络模块,用于修改模型结构
from torchvision import transforms, models # 图像变换和预训练模型库
关键库说明:
flask
:轻量级 Web 框架,核心优势是简洁、易上手,适合快速构建 API 服务;torchvision
:PyTorch 官方的计算机视觉工具库,提供了 ResNet 等预训练模型和常用的图像变换方法;PIL.Image
:处理图像的标准库,支持多种图像格式的读取和转换;io
:由于 Flask 接收的是字节流形式的图像,需要用io.BytesIO
将字节数据转换为图像对象。
2. Flask 应用初始化:启动 Web 服务的第一步
这部分代码完成 Flask 应用的初始化,并定义了两个全局变量用于存储模型和设备配置:
python
# 初始化Flask app
app = flask.Flask(__name__) # 创建Flask应用实例
# __name__参数的作用:告诉Flask应用的根路径,用于定位模板、静态文件等资源
# 全局变量:存储模型和设备配置
model = None # 用于存储加载后的PyTorch模型,初始为None
use_gpu = False # 是否使用GPU推理,此处默认关闭(可根据硬件调整)
核心知识点:
flask.Flask(__name__)
:这是 Flask 应用的 "入口",__name__
会自动识别当前脚本的路径,确保 Flask 能正确找到后续可能用到的模板(templates)或静态文件(static)目录;- 全局变量
model
:由于模型加载耗时较长,我们在服务启动时加载一次(而非每次请求都加载),通过全局变量存储模型实例,提升接口响应速度; use_gpu
:控制模型推理是否使用 GPU,若本地有可用 GPU,可将其设为True
(需确保 PyTorch 已安装 GPU 版本)。
3. 模型加载函数:加载 ResNet18 与微调权重
load_model()
函数是模型部署的核心,负责加载预训练的 ResNet18 模型、修改输出层以适配自定义分类任务,并加载训练好的模型文件:
python
def load_model():
global model # 声明使用全局变量model,避免在函数内创建局部变量
# 1. 加载预训练的ResNet18模型
model = models.resnet18(pretrained=False) # 注意:新版torchvision需显式设pretrained=False
# 2. 修改模型的全连接层(fc层),适配自定义分类任务
num_ftrs = model.fc.in_features # 获取fc层的输入特征数(ResNet18默认是512)
# 将fc层替换为输出102类的全连接层(因原代码分类任务为102类)
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
# 3. 加载训练好的权重文件(best.pth)
checkpoint = torch.load('best.pth', map_location=torch.device('cpu')) # 加载权重到CPU
model.load_state_dict(checkpoint['state_dict']) # 将权重加载到模型
# 4. 设置模型为评估模式(至关重要!)
model.eval()
# 5. 若使用GPU,将模型移动到GPU
if use_gpu:
model.cuda()
关键细节解析:
- 为什么修改
fc
层?ResNet18 预训练模型的fc
层默认输出 1000 类(对应 ImageNet 数据集),而该项目是 102 类分类任务,因此需要替换fc
层的输出维度为 102; model.eval()
的作用:将模型从训练模式切换到评估模式,会关闭 Dropout 层、固定 BatchNorm 层的统计参数,确保推理结果的一致性(若不设置,推理结果会异常);torch.load(..., map_location=...)
:若use_gpu=False
,需指定map_location=torch.device('cpu')
,避免因权重文件是在 GPU 上训练而无法在 CPU 加载的问题。
4. 图像预处理函数:将输入图像转为模型可接受格式
PyTorch 模型对输入图像有严格的格式要求(如尺寸、归一化等),prepare_image()
函数负责将用户上传的图像转换为模型能处理的张量:
python
def prepare_image(image, target_size):
# 1. 确保图像为RGB格式(若上传的是灰度图,转换为RGB)
if image.mode != 'RGB':
image = image.convert('RGB')
# 2. 图像尺寸调整(ResNet18默认输入尺寸为224x224)
image = transforms.Resize(target_size)(image) # target_size=(224,224)
# 3. 将PIL图像转为PyTorch张量(形状:[C, H, W],数值范围0-1)
image = transforms.ToTensor()(image)
# 4. 图像归一化(使用ImageNet数据集的均值和标准差,预训练模型的标准操作)
image = transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet数据集的RGB通道均值
std=[0.229, 0.224, 0.225] # ImageNet数据集的RGB通道标准差
)(image)
# 5. 增加批次维度(模型要求输入为[batch_size, C, H, W],单张图像需手动加batch维度)
image = image[None] # 等价于image.unsqueeze(0),形状从[3,224,224]变为[1,3,224,224]
# 6. 若使用GPU,将张量移动到GPU
if use_gpu:
image = image.cuda()
return image
预处理的必要性:
- 格式统一:用户可能上传灰度图、PNG 图(带透明通道)等,需转为 RGB 格式;
- 尺寸统一:模型训练时用的是 224x224 图像,输入图像必须保持相同尺寸;
- 归一化:消除像素值范围差异(如 0-255 转为 0-1),并使用与预训练模型一致的均值 / 标准差,确保模型推理稳定。
5. API 接口定义:处理请求与返回结果
/predict
路由是整个服务的 "接口门面",负责接收客户端的 POST 请求、调用模型推理、组织返回结果:
python
@app.route("/predict", methods=['POST']) # 定义路由:URL为/predict,仅支持POST方法
def predict():
# 初始化返回数据(默认success=False,避免异常时无返回)
data = {'success': False}
# 1. 检查请求方法是否为POST(POST适合传输文件等大数据)
if flask.request.method == 'POST':
# 2. 检查请求中是否包含'image'字段(即图像文件)
if flask.request.files.get('image'):
# 3. 读取图像文件(从Flask请求中获取字节数据)
image_bytes = flask.request.files['image'].read()
# 将字节数据转为PIL图像对象(通过io.BytesIO模拟文件流)
image = Image.open(io.BytesIO(image_bytes))
# 4. 图像预处理(转为模型可接受的张量)
image_tensor = prepare_image(image, target_size=(224, 224))
# 5. 模型推理(禁用梯度计算,提升速度并节省内存)
with torch.no_grad():
# 计算模型输出并通过softmax转为概率(dim=1表示按类别维度计算)
preds = F.softmax(model(image_tensor), dim=1)
# 获取Top3概率和对应的类别索引(k=3)
top3_probs, top3_labels = torch.topk(preds.cpu().data, k=3, dim=1)
# 转为numpy数组(方便后续处理)
top3_probs = top3_probs.numpy()
top3_labels = top3_labels.numpy()
# 6. 组织返回结果(将Top3结果存入data字典)
data['predictions'] = list() # 用于存储多个预测结果
# 遍历Top3结果,构造每个结果的字典
for prob, label in zip(top3_probs[0], top3_labels[0]):
result = {
"label": str(label), # 类别索引(实际项目中可替换为类别名称)
"probability": float(prob) # 对应类别的概率(保留float类型,便于JSON序列化)
}
data['predictions'].append(result)
# 7. 标记请求成功
data['success'] = True
# 8. 以JSON格式返回结果(Flask自动处理序列化)
return flask.jsonify(data)
接口设计关键点:
- 为什么用 POST 方法?GET 方法适合传递少量参数,而图像文件体积较大,POST 方法支持更大的请求体,且更安全(数据不暴露在 URL 中);
torch.no_grad()
:推理阶段不需要计算梯度,该上下文管理器可显著提升速度并减少内存占用,是 PyTorch 推理的标准操作;- 结果组织:返回
success
字段便于客户端判断请求是否成功,predictions
字段包含 Top3 分类结果,符合实际应用中 "给出多个可能结果" 的需求; - 数据类型转换:PyTorch 张量需转为 numpy 数组,再转为 Python 基础类型(如
float(prob)
),避免 JSON 序列化失败(JSON 不支持张量类型)。、
6. 服务启动:加载模型并启动 Flask 服务
最后一段代码负责在脚本直接运行时加载模型并启动 Flask 服务:
python
if __name__ == '__main__':
print("Loading PyTorch model and Flask starting server ...")
print("Please wait until server has fully started")
# 1. 加载模型(服务启动时仅加载一次)
load_model()
# 2. 启动Flask服务(默认运行在localhost:5012)
app.run(port='5012') # port参数指定服务端口,可根据需求修改
服务启动细节:
if __name__ == '__main__'
:确保只有当脚本被直接运行时才执行服务启动逻辑(若作为模块导入则不执行);- 端口选择:
port='5012'
指定服务运行在 5012 端口,若该端口被占用,可修改为其他未占用端口(如 5000、8080 等); - 默认配置:
app.run()
默认仅允许本地(localhost)访问,若需要外部设备访问,可添加host='0.0.0.0'
参数(如app.run(host='0.0.0.0', port='5012')
)。
启动完成如下:

二.客户端构建及测试
首先明确这段客户端代码的目标:作为 "中间桥梁",读取本地图像文件,向服务端的/predict
接口发送 POST 请求,接收并解析返回的 JSON 结果,最终以友好的格式展示给用户。
它解决了两个关键问题:
- 如何将本地图像文件转换为服务端可接收的请求格式;
- 如何解析服务端返回的 JSON 数据,提取有用的分类结果(类别、概率)。
1. 服务端 URL 配置:指定 API 地址
python
import requests # 导入HTTP请求库,用于向服务端发送POST请求
flask_url = 'http://127.0.0.1:5012/predict'
# flask_url = 'http://192.168.2.114:5012/predict' # 注释掉的跨设备调用地址
这是客户端与服务端建立连接的 "关键地址",需要重点理解两个参数:
127.0.0.1
:表示 "本地回环地址",仅当客户端和服务端在同一台设备 (如同一台电脑)时使用,用于本地测试;192.168.2.114
:服务端设备在局域网中的 IP 地址,当客户端和服务端在同一局域网下的不同设备 (如服务端在台式机、客户端在笔记本)时,需将127.0.0.1
替换为服务端的实际局域网 IP;5012
:端口号,必须与服务端app.run(port='5012')
配置的端口一致(端口不匹配会导致请求失败);/predict
:服务端定义的 API 路由,需与服务端@app.route("/predict")
完全一致。
2. 核心函数predict_result
:实现请求与结果解析
python
def predict_result(image_path):
# 1. 读取本地图像文件(以二进制模式读取,避免编码问题)
image = open(image_path, 'rb').read()
# 2. 构造请求参数(关键:用'files'格式传递二进制图像)
payload = {'image': image}
# 3. 发送POST请求到服务端,并解析返回的JSON结果
# requests.post():发送POST请求,files参数用于传递文件类数据
# .json():将服务端返回的JSON字符串自动转为Python字典
r = requests.post(flask_url, files=payload).json()
# 4. 根据服务端返回的'success'字段判断请求是否成功
if r['success']:
# 若成功,遍历Top3预测结果并格式化输出
for (i, result) in enumerate(r['predictions']):
# 输出格式:1.预测类别为XX的概率:0.9876
print('{}.预测类别为{}的概率:{}'.format(i + 1, result['label'], result['probability']))
else:
# 若失败,提示用户请求失败
print('Request Failed')
这部分是客户端的核心,需要重点理解以下 3 个关键点:
(1)图像文件的读取方式:二进制模式'rb'
- 为什么用
'rb'
(read binary)而不是'r'
(文本模式)?图像是二进制文件(包含像素值、格式信息等二进制数据),若用文本模式读取,会导致数据损坏,服务端无法解析;而'rb'
模式能完整保留图像的二进制数据,确保服务端接收后能正确还原为图像。
(2)请求参数的格式:files=payload
- 服务端
flask.request.files.get('image')
接收的是 "文件类型" 的请求数据,因此客户端必须用requests.post()
的files
参数传递,而非data
或json
参数:data
:用于传递文本类型的键值对(如表单数据);json
:用于传递 JSON 格式的文本数据;files
:专门用于传递二进制文件(如图片、文档),格式为{'字段名': 二进制数据}
,其中 "字段名"'image'
必须与服务端flask.request.files['image']
的字段名完全一致。
(3)结果解析逻辑:基于服务端的 JSON 结构
服务端返回的 JSON 数据格式如下(回顾上面内容):
python
{
"success": true,
"predictions": [
{"label": "23", "probability": 0.9234},
{"label": "45", "probability": 0.0512},
{"label": "18", "probability": 0.0257}
]
}
客户端通过r = ... .json()
将其转为 Python 字典后:
- 用
r['success']
判断请求是否成功(true
表示成功,false
表示失败); - 若成功,遍历
r['predictions']
列表(包含 Top3 结果),通过result['label']
和result['probability']
提取类别和概率,再用enumerate
添加序号,让输出更清晰。
3. 主函数调用:指定图像路径并执行
python
if __name__ == '__main__':
# 调用predict_result函数,传入本地图像的路径
predict_result('./flower_data/flower_data/val_filelist/image_00028.jpg')
if __name__ == '__main__'
:确保只有当脚本被直接运行时,才执行图像预测逻辑(若作为模块导入则不执行);- 图像路径:
./flower_data/.../image_00028.jpg
是相对路径,表示 "当前脚本所在目录下的flower_data
文件夹中的目标图像";若图像在其他位置,需传入绝对路径(如'C:/Users/Admin/Desktop/test.jpg'
)。
测试完成如下(确保提前打开服务端):
