目录
[1. 服务端功能概述](#1. 服务端功能概述)
[4. 模型加载函数](#4. 模型加载函数)
[5. 图像预处理函数](#5. 图像预处理函数)
[6. API路由定义](#6. API路由定义)
[7. 服务启动](#7. 服务启动)
[1. 客户端功能概述](#1. 客户端功能概述)
[2. 客户端代码](#2. 客户端代码)
[1. 环境配置](#1. 环境配置)
[2. 模型准备](#2. 模型准备)
[3. 启动服务](#3. 启动服务)
[4. 修改IP地址](#4. 修改IP地址)
[5. 客户端调用](#5. 客户端调用)
在深度学习领域,模型训练只是整个项目生命周期的一部分。如何将训练好的模型部署到生产环境中,让其他应用程序能够方便地调用,是许多开发者面临的挑战。本文将详细介绍如何使用Flask框架和PyTorch,构建一个完整的图像分类API服务,实现模型的在线部署与调用。
一、模型部署的目的和用到的技术
- 提供实时预测服务:让模型能够处理实时请求
- 实现模型复用:避免重复训练,提高开发效率
- 支持多端调用:Web、移动端、桌面应用都可以通过API访问
- 便于模型更新:只需更新服务端,客户端无需改动
需要用到的:
- Flask:轻量级Web框架,适合快速构建API服务
- PyTorch:深度学习框架,用于加载和运行模型
- Requests:Python HTTP库,用于客户端请求
- PIL:图像处理库
- Torchvision:提供预训练模型和图像变换工具
二、系统架构设计
1.整体架构(c/s)
客户端 (Client) → HTTP请求 → Flask服务器(server) → PyTorch模型 → 预测结果
客户端 (Client) ← HTTP响应 ← Flask服务器 ← 返回结果
2.具体流程
- 服务端启动:加载预训练模型,监听指定端口
- 客户端请求:发送图像文件到服务器
- 服务器处理:接收图像,预处理,模型推理
- 结果返回:将预测结果封装成JSON格式返回
- 客户端展示:解析结果并显示
三、服务端实现
1. 服务端功能概述
根据代码注释,服务端需要实现以下核心功能:
- 接收来自客户端的信息,24小时运行
- 将模型部署起来
- 对图片进行识别
- 将识别结果返回给客户端
2.导入所需库
python
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
io:处理字节流数据,用于图像读取
flask:Web框架,构建API服务
torch 和 torch.nn.functional:PyTorch核心库和函数式接口
PIL:Python图像处理库
torchvision:提供预训练模型和数据预处理工具
3.Flask应用初始化
python
app = flask.Flask(__name__)
model = None
use_gpu = False
app = flask.Flask(name):初始化Flask应用程序实例。__name__参数用于定位应用程序的根路径,这样Flask就可以知道在哪里找到模板、静态文件等,是Flask应用程序的起点,为后续添加路由、配置等奠定了基础。
model = None:全局变量,用于存储加载的模型
use_gpu = False:是否使用GPU进行推理
4. 模型加载函数
python
def load_model():
"""Load the pre-trained model, you can use your model just as easily."""
global model
# 加载resnet18网络
model = models.resnet18()
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)) # 修改最后的类别数
checkpoint = torch.load('best.pth')
model.load_state_dict(checkpoint['state_dict'])
# 将模型指定为测试模式
model.eval()
# 是否使用gpu
if use_gpu:
model.cuda()
global model:声明使用全局变量model,在函数内部修改全局变量
models.resnet18():加载ResNet18预训练网络结构
model.fc.in_features:获取原全连接层的输入特征数
nn.Sequential(nn.Linear(num_ftrs, 102)):修改最后的全连接层,适配102类分类任务
torch.load('best.pth'):加载训练好的模型权重文件
model.load_state_dict(checkpoint['state_dict']):将权重加载到模型中
model.eval():将模型设置为评估模式,关闭Dropout和Batch Normalization的训练行为
model.cuda():如果use_gpu为True,将模型移动到GPU
5. 图像预处理函数
python
def prepare_image(image, target_size):
if image.mode != 'RGB':
image = image.convert('RGB')
image = transforms.Resize(target_size)(image)
image = transforms.ToTensor()(image)
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
image = image[None]
if use_gpu:
image = image.cuda()
return torch.tensor(image)
if image.mode != 'RGB':检查图像颜色模式,如果不是RGB则转换
transforms.Resize(target_size)(image):调整图像尺寸到目标大小(224×224)
transforms.ToTensor()(image):将PIL图像转换为PyTorch张量,并将像素值从[0,255]缩放到[0.0,1.0]
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])(image):使用ImageNet数据集的均值和标准差进行标准化
image = image[None]:在第0维添加一个维度,将形状从(C,H,W)变为(1,C,H,W),满足模型输入的batch要求
return torch.tensor(image):返回处理后的张量
6. API路由定义
python
@app.route("/predict", methods=["POST"])
def predict():
data = {"success": False}
if flask.request.method == "POST":
if flask.request.files.get("image"):
# 读取图像文件
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image))
# 预处理图像
image = prepare_image(image, target_size=(224, 224))
preds = F.softmax(model(image), dim=1)
results = torch.topk(preds.cpu().data, k=3, dim=1)
results = (results[0].cpu().numpy(), results[1].cpu().numpy())
data['predictions'] = list()
for prob, label in zip(results[0][0], results[1][0]):
r = {'label': str(label), 'probability': float(prob)}
data['predictions'].append(r)
data['success'] = True
return flask.jsonify(data)
@app.route("/predict", methods=["POST"]):装饰器,将URL路径"/predict"与predict函数关联,并指定只处理POST请求
data = {"success": False}:初始化返回数据,success默认为False
flask.request.files.get("image"):获取上传的文件,key为"image"
image = flask.request.files["image"].read():读取文件的二进制数据
Image.open(io.BytesIO(image)):将二进制数据转换为PIL图像
F.softmax(model(image), dim=1):模型推理后应用softmax函数,将输出转换为概率分布
torch.topk(preds.cpu().data, k=3, dim=1):获取概率最高的前3个类别及其概率
results[0].cpu().numpy():将概率值转换为numpy数组
results[1].cpu().numpy():将类别索引转换为numpy数组
循环遍历,将结果封装成字典列表
flask.jsonify(data):将字典转换为JSON格式的HTTP响应
7. 服务启动
python
if __name__ == '__main__':
print("Loading PyTorch model and Flask starting server ...")
print("Please wait until server has fully started")
load_model() # 加载模型
# 开启服务
app.run(host='0.0.0.0', port=5010)
if name == 'main':判断是否直接运行此脚本
load_model():启动时加载模型
app.run(host='0.0.0.0', port=5010):启动Flask开发服务器,host='0.0.0.0'表示监听所有网络接口,允许局域网内的其他设备访问,port=5010指定端口号
四、客户端
1. 客户端功能概述
根据代码注释,客户端需要实现以下核心功能:
-
负责发送图片到指定的服务器
-
接收服务器端返回的结果信息
2. 客户端代码
python
import requests
flask_url = 'http://192.168.31.83:5010/predict' # url和端口携程自己的本地ip
def predict_result(image_path):
image = open(image_path, 'rb').read()
payload = {'image': image}
r = requests.post(flask_url, files=payload).json()
# 向fLask_url服务发送一个P0ST请求,并尝试将返回的JSON响应解析为一个Python字典
if r['success']:
for (i, result) in enumerate(r['predictions']):
print('{}.预测类别为{}:的概率{}'.format(i + 1, result['label'], result['probability']))
else:
print('Request failed')
if __name__ == '__main__':
predict_result('flower.jpg')
import requests:导入requests库,用于发送HTTP请求
flask_url = 'http://192.168.31.83:5010/predict':服务器端API地址,需要根据实际情况修改IP和端口
open(image_path, 'rb').read():以二进制模式打开并读取图像文件
payload = {'image': image}:构建请求数据,键名"image"必须与服务器端flask.request.files.get("image")一致
requests.post(flask_url, files=payload).json():发送POST请求,files参数用于上传文件,.json()将响应解析为Python字典
if r['success']:检查服务器返回的success字段
enumerate(r['predictions']):遍历预测结果列表,同时获取索引和值
打印每个预测结果的类别和概率
五、部署
1. 环境配置
python
# 创建虚拟环境(可选)
python -m venv venv
# 激活虚拟环境(Windows)
venv\Scripts\activate
# 安装依赖
pip install flask torch torchvision pillow requests
2. 模型准备
确保在服务端目录下有训练好的模型文件 best.pth。根据代码,这个文件应该包含:
模型权重(通过checkpoint['state_dict']访问)
3. 启动服务
先执行server代码

4. 修改IP地址
客户端代码中的IP地址需要根据实际情况修改:
python
flask_url = 'http://你的实际IP:5010/predict'
5. 客户端调用
