一、模型部署
1.模型部署是什么
模型部署(Model Deployment)是指将训练好的机器学习模型集成到实际的生产环境或应用系统中,使其能够接收输入数据(例如图片、文本、数值等),并实时或批量地输出预测结果,从而为最终用户或其他服务提供智能能力。
简单来说,模型训练阶段就像我们在实验室里培养了一个"专家"------这个专家通过学习大量数据掌握了某种技能(比如识别图像中的花卉种类)。但专家本身只是一堆参数和计算逻辑,无法直接对外服务。模型部署就是把这个"专家"请到工作现场(比如服务器、手机App、嵌入式设备),并给他配备一个"问答窗口"(API接口),让外部程序可以随时向它提问(发送数据),并立即得到答案(预测结果)。
2.为什么需要模型部署?
-
让模型产生实际价值:模型只有在被使用时才具有意义,比如一个图像识别模型部署后,才能帮助用户自动分类照片。
-
实时响应:很多场景需要毫秒级的预测反馈,例如自动驾驶、实时推荐系统。
-
可扩展性:部署后的模型可以同时服务大量请求,支持业务增长。
-
集成到现有系统:模型可以作为微服务嵌入到更大的软件架构中,与其他组件协同工作。
3.常见部署方式
-
Web API 服务(如你提供的代码) 将模型封装在一个HTTP服务器中,通过RESTful API对外提供预测接口。任何能发送HTTP请求的客户端(网页、移动App、其他后端服务)都可以调用。这是最灵活、最通用的方式。
-
嵌入式部署 将模型压缩并移植到移动设备(手机、IoT设备)上运行,例如在手机本地进行图像识别,无需联网。
-
批处理预测 定期处理大量离线数据,例如每天凌晨对用户行为日志进行预测,生成推荐列表。
-
边缘计算部署 将模型部署在网络边缘节点(如路由器、基站),减少数据传输延迟,适用于物联网场景。
4.部署时需要关注的问题
-
性能与延迟:模型推理速度能否满足实时要求?是否需要GPU加速?
-
并发能力:能同时处理多少个请求?如何水平扩展?
-
模型版本管理:如何平滑升级模型而不中断服务?
-
监控与日志:如何监控模型效果、捕捉异常?
-
安全与隐私:如何防止恶意攻击、保护用户数据?
二、实际运用
任务:
实现一个简单的图像分类服务,采用客户端-服务器架构。服务器端使用 Flask 框架搭建一个 REST API,加载预训练的 ResNet18 模型(在 102 类花卉数据集上微调),接收客户端上传的图片并进行推理,返回分类结果(Top-3 概率及对应类别标签)。客户端则通过 requests 库向服务器发送图片,并打印预测结果
服务器端代码(server.py)
1. 导入库
python
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
-
io:用于在内存中处理二进制数据流(如图片)。 -
flask:构建 Web 服务器和路由。 -
torch和torch.nn.functional:PyTorch 核心库,用于张量运算和 softmax。 -
PIL.Image:Python 图像库,用于打开和处理图像。 -
torchvision:提供预训练模型、数据变换等。
2. 初始化 Flask 应用
python
app = flask.Flask(__name__)
model = None
use_gpu = False
-
app是 Flask 应用实例,用于注册路由和启动服务。 -
model为全局变量,存储加载后的模型。 -
use_gpu控制是否使用 GPU(当前设为False,即使用 CPU)。
3. 加载模型函数 load_model()
python
def load_model():
global model
model = models.resnet18()
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
checkpoint = torch.load('best.pth')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
if use_gpu:
model.cuda()
-
创建一个 ResNet18 模型,并将最后的全连接层替换为输出 102 类的线性层(假设数据集有 102 个类别)。
-
加载训练好的权重文件
best.pth(该文件应为 PyTorch 的 checkpoint,包含'state_dict'键)。 -
调用
model.eval()切换到评估模式,关闭 Dropout 和 BatchNorm 的训练行为。 -
若
use_gpu为真,则将模型移动到 GPU。
4. 图像预处理函数 prepare_image()
python
def prepare_image(image, target_size):
if image.mode != 'RGB':
image = image.convert('RGB')
image = transforms.Resize(target_size)(image)
image = transforms.ToTensor()(image)
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
image = image[None] # 增加 batch 维度
if use_gpu:
image = image.cuda()
return torch.tensor(image)
-
确保图像为 RGB 模式(有些图片可能是 RGBA 或 L 灰度)。
-
调整尺寸到
target_size(如 224×224),这是 ResNet 输入大小。 -
转换为张量(
[C, H, W],值范围从 0~255 缩放到 0~1)。 -
使用 ImageNet 数据集的均值和标准差进行归一化,与训练时的预处理保持一致。
-
添加 batch 维度(从
[C, H, W]变为[1, C, H, W]),因为模型期望输入是一个 batch。 -
最后返回张量。
5. 预测路由 /predict
python
@app.route("/predict", methods=["POST"])
def predict():
data = {"success": False}
if flask.request.method == 'POST':
if flask.request.files.get("image"):
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image))
image = prepare_image(image, target_size=(224, 224))
preds = F.softmax(model(image), dim=1)
results = torch.topk(preds.cpu().data, k=3, dim=1)
results = (results[0].cpu().numpy(), results[1].cpu().numpy())
data['predictions'] = list()
for prob, label in zip(results[0][0], results[1][0]):
r = {"label": str(label), "probability": float(prob)}
data['predictions'].append(r)
data["success"] = True
return flask.jsonify(data)
-
定义一个只接受 POST 请求的路由。
-
初始化返回数据字典,默认
"success": False。 -
检查请求中是否包含名为
"image"的文件(客户端上传时需用该字段名)。 -
读取文件的二进制内容,用
PIL.Image.open和io.BytesIO将其转换为图像对象。 -
调用预处理函数得到模型输入张量。
-
执行模型推理:
model(image)得到原始 logits,然后F.softmax转换为概率(dim=1表示对类别维度)。 -
torch.topk取概率最高的前 3 个结果,返回值和索引。 -
将结果从 GPU 移到 CPU(
cpu()),转为 NumPy 数组,便于处理。 -
构建预测列表,每个元素包含标签(索引)和概率(浮点数)。
-
将
success设为True,最后用flask.jsonify将字典转为 JSON 返回。
6. 主程序入口
python
if __name__ == '__main__':
print("Loading PyTorch model and Flask starting server ...")
print("Please wait until server has fully started")
load_model()
app.run(host='0.0.0.0', port=5100)
-
仅在直接运行该脚本时执行。
-
先加载模型,然后启动 Flask 开发服务器,监听所有网络接口(
0.0.0.0)的 5100 端口。
7.运行结果

客户端代码(client.py)
1. 导入库并定义服务器地址
python
import requests
flask_url = 'http://127.0.0.1:5100/predict'
# flask_url='http://192.168.31.87:5100/predict' # 另一可选地址,这是我自己的,写代码时换成自己主机的
-
使用
requests库发送 HTTP 请求。 -
定义服务器的完整 URL(注意端口与服务器端一致)。
2. 预测函数 predict_result(image_path)
python
def predict_result(image_path):
image = open(image_path, 'rb').read()
payload = {'image': image}
r = requests.post(flask_url, files=payload).json()
if r['success']:
for i, result in enumerate(r['predictions']):
print('{}.预测类别为{}:的概率:{}'.format(i+1, result['label'], result['probability']))
else:
print('request failed')
-
以二进制读取模式打开图片文件,获取图片的二进制数据。
-
构造
payload字典,键'image'对应图片二进制数据。注意这里直接传二进制,requests.post的files参数会将其包装为 multipart/form-data 格式。 -
发送 POST 请求,并解析返回的 JSON 数据(
.json())。 -
如果服务器返回的
success为真,则遍历predictions列表,打印排名、类别标签和概率。 -
若请求失败(
success为假),打印提示。
3. 主程序调用
python
if __name__ == '__main__':
predict_result('hua.jpg')
- 调用预测函数,传入本地图片文件名
hua.jpg。
4.运行结果

整体工作流程
-
启动服务器:运行
server.py,加载模型,监听端口 5100。 -
客户端运行
client.py,读取本地图片hua.jpg,向服务器发送 POST 请求(携带图片)。 -
服务器接收请求,预处理图片,进行模型推理,得到 Top-3 预测结果,以 JSON 格式返回。
-
客户端接收 JSON,解析并打印结果。
注意事项与潜在问题
-
模型路径 :服务器假设
best.pth在当前目录下,且 checkpoint 包含'state_dict'键。实际使用时需确保文件存在且格式正确。 -
类别映射:服务器返回的是类别索引(0~101),而非实际类别名称。若需要输出具体类别名(如"玫瑰"),需在服务器端维护一个索引到标签的映射字典。
-
错误处理:代码对异常(如图片损坏、模型推理失败)没有捕获,可能返回 500 错误或崩溃。生产环境应添加 try-except 和适当的错误响应。
-
性能:当前使用 CPU 推理,若图片并发量大可能成为瓶颈。可启用 GPU、使用异步框架或模型优化。
-
安全性:直接接受用户上传文件,存在安全风险(如超大图片、恶意文件)。应限制文件大小、类型,并对输入进行验证。
-
Flask 开发服务器 :
app.run()启动的是单进程开发服务器,不适合生产。生产环境建议使用 Gunicorn 等 WSGI 服务器。