在深度学习项目落地时,将训练好的模型封装成可远程调用的 API 接口是核心环节。本文将完整讲解如何基于 Flask(轻量级 Web 框架)和 PyTorch(深度学习框架),实现图像分类服务端 API 开发 + 客户端调用 的全流程,让训练好的模型通过 HTTP 接口对外提供预测服务。
一、服务端开发:搭建图像分类 API 接口
服务端的核心职责是:加载预训练模型、接收客户端上传的图片、预处理图片、执行预测、返回 JSON 格式的预测结果。
1.完整服务端代码
python
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
# 初始化Flask应用
app = flask.Flask(__name__)
# 全局变量:模型实例、是否使用GPU
model = None
use_gpu = False # 测试阶段建议关闭GPU,避免环境问题
def load_model():
"""加载预训练的ResNet18模型(可替换为自定义模型)"""
global model
# 1. 加载ResNet18主干网络
model = models.resnet18(pretrained=False) # 不加载默认预训练权重
# 2. 修改全连接层,适配自定义分类任务(示例为102分类)
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
# 3. 加载训练好的权重文件
checkpoint = torch.load('best.pth', map_location='cpu') # 强制使用CPU,避免GPU/CPU不匹配
model.load_state_dict(checkpoint['state_dict'])
# 4. 设置模型为评估模式(禁用Dropout/BatchNorm更新)
model.eval()
# 5. 可选:使用GPU(需确保环境有CUDA)
if use_gpu and torch.cuda.is_available():
model.cuda()
def prepare_image(image, target_size=(224, 224)):
"""预处理图片:转为RGB、resize、归一化、增加batch维度"""
# 1. 统一转为RGB格式(避免灰度图等格式异常)
if image.mode != 'RGB':
image = image.convert('RGB')
# 2. 图像变换:resize -> 转Tensor -> 归一化
transform = transforms.Compose([
transforms.Resize(target_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet均值/方差
])
image = transform(image)
# 3. 增加batch维度(模型要求输入为[batch, channel, h, w])
image = image.unsqueeze(0)
# 4. 可选:转到GPU
if use_gpu and torch.cuda.is_available():
image = image.cuda()
return image
@app.route("/predict", methods=["POST"])
def predict():
"""API核心接口:接收图片,返回预测结果"""
# 初始化返回数据
data = {"success": False}
# 仅处理POST请求
if flask.request.method == "POST":
# 检查是否有图片上传
if flask.request.files.get("image"):
try:
# 1. 读取并解析图片
image_bytes = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image_bytes))
# 2. 预处理图片
image = prepare_image(image)
# 3. 模型预测(禁用梯度计算,提升速度)
with torch.no_grad():
preds = F.softmax(model(image), dim=1) # 计算类别概率
# 4. 获取概率最高的前3个结果
top3_probs, top3_labels = torch.topk(preds.cpu(), k=3, dim=1)
# 5. 格式化结果
data['predictions'] = []
for prob, label in zip(top3_probs.numpy()[0], top3_labels.numpy()[0]):
data['predictions'].append({
"label": str(label),
"probability": float(prob)
})
# 6. 标记请求成功
data["success"] = True
except Exception as e:
print(f"预测出错:{str(e)}")
# 返回JSON格式结果(HTTP响应)
return flask.jsonify(data)
if __name__ == '__main__':
print("加载PyTorch模型中...")
load_model() # 启动前先加载模型
print("模型加载完成,启动Flask服务...")
# 启动服务:host=0.0.0.0允许局域网访问,port为自定义端口
app.run(host='0.0.0.0', port=5012, debug=False)
2.服务端核心模块解析
(1)模型加载模块 load_model()
-
加载 ResNet18 主干网络,替换全连接层适配自定义分类数(示例为 102 类)
-
加载训练好的权重文件
best.pth,设置map_location='cpu'避免 GPU/CPU 环境不匹配 -
调用
model.eval()将模型设为评估模式(禁用 Dropout、BatchNorm 等训练层)。
(2)图片预处理模块 prepare_image()
-
统一转为 RGB 格式,避免灰度图、RGBA 图等格式异常
-
使用 ImageNet 均值 / 方差归一化(与 ResNet 预训练时的预处理一致)
-
增加 batch 维度(模型要求输入为 4 维张量:[batch, c, h, w])。
(3)API 接口 predict()
-
仅接收 POST 请求,确保数据传输安全
-
通过
flask.request.files读取客户端上传的二进制图片 -
使用
torch.no_grad()禁用梯度计算,大幅提升预测速度 -
用
torch.topk获取概率最高的前 3 个类别,格式化后以 JSON 返回。
二、客户端开发:调用 API 获取预测结果
客户端的核心职责是:读取本地图片、以二进制形式上传到服务端 API、解析返回的 JSON 结果并展示。
1.客户端完整代码
python
import requests#第三方库,爬虫中的一个库。
'''客户端的程序代码:功能;负责按照指定格式上传图片和接受结果'''
flask_url='http://127.0.0.1:5012/predict'#_url和端口写成自己的本地ip
def predict_result(image_path):
image = open(image_path, 'rb').read()
payload= {'image':image}
r=requests.post(flask_url,files =payload).json()#是因为网络传辑
if r['success']:
# 成功的话再返回。
# 输出结果
for(i,result) in enumerate(r['predictions']):
print('{}.预测类别为{}:的概率:{}'.format(i + 1,result['label'],result['probability']))
# 失败了就打印
else:
print('Request failed')
if __name__ == "__main__":
predict_result('./flower_data/2.jpg')
2.代码解析
2.1 配置信息
import requests#第三方库,爬虫中的一个库。
-
引入
requests库:这是 Python 中处理 HTTP 请求的主流第三方库,不仅用于爬虫,更是接口调用的核心工具,能轻松实现 POST/GET 等请求方式。flask_url='http://127.0.0.1:5012/predict'#_url和端口写成自己的本地ip
-
配置服务端接口地址:
127.0.0.1是本地回环地址,对应运行 Flask 服务端的本机;5012是服务端设置的端口号;/predict是服务端开放的预测接口路径,必须与服务端的路由地址完全一致。
2.2 核心预测函数predict_result
def predict_result(image_path):
image = open(image_path, 'rb').read()
-
读取图片文件:
open(image_path, 'rb')以二进制只读模式打开图片,read()将文件内容读取为二进制字节流 ------ 网络传输图片必须用二进制格式,不能直接传文件路径或文本格式。payload= {'image':image}
-
构造请求参数:创建字典
payload,键名image是服务端约定的接收图片的参数名(需与服务端flask.request.files.get("image")中的image一致),值为读取的二进制图片数据。r=requests.post(flask_url,files =payload).json()
-
发送 POST 请求并解析结果:
(1)requests.post():向指定的flask_url发送 POST 请求,files参数专门用于上传文件 / 二进制数据
(2).json():将服务端返回的 JSON 格式字符串自动解析为 Python 字典,方便后续取值。
if r['success']:
for(i,result) in enumerate(r['predictions']):
print('{}.预测类别为{}:的概率:{}'.format(i + 1,result['label'],result['probability']))
else:
print('Request failed')
- 解析并输出结果:
(1)服务端返回的字典中,success为True表示预测成功,predictions是包含前 3 个高概率类别结果的列表
(2)enumerate为结果添加序号,通过result['label']和result['probability']分别获取类别标签和对应概率,格式化输出
(3)若success为False,则打印请求失败提示。
三、运行效果
服务端正常响应时,客户端控制台会输出:
1.预测类别为5:的概率:0.9875643849372864
2.预测类别为8:的概率:0.0089123456789012
3.预测类别为12:的概率:0.0021098765432109