一、整体功能概述
这两段代码组合起来实现了一个 深度学习图像分类推理系统:
-
代码一(服务端) :
使用 Flask 搭建 HTTP 服务器,加载一个 PyTorch 训练好的模型(如 ResNet18),接受图片上传请求,并返回分类预测结果(前 3 名类别与概率)。
-
代码二(客户端) :
使用
requests
库向服务端发送图片(HTTP POST 请求),获取预测结果并打印。
这种结构在工业场景中非常常见,被称为:
模型服务化部署\]+\[客户端调用
二、运行流程详解
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
(1)模型加载阶段
# 初始化Flask app
app = flask.Flask(__name__) # 创建一个新的Flask应用程序实例,
# __name__参数通常被传递给Flask应用程序来定位应用程序中的模块,这样Flask就可以知道在哪里找到模板、静态文件等。
# 总体来说app = flask.Flask(__name__)是Flask应用程序的起点。它初始化了一个新的Flask应用程序实例,为后续添加路由、配置等奠定了基础。
use_gpu = False # 是否使用GPU训练
def load_model():
global model
# 加载resnet18网络
model = models.resnet18()
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)) # 类别数自己根据自己任务来
# print(model)
checkpoint = torch.load("best.pth")
model.load_state_dict(checkpoint['state_dict'])
# 将模型指定为测试格式
model.eval()
# 是否使用gpu
if use_gpu:
model.cuda()
解释:
初始化Flask实例
使用 torchvision 内置的 resnet18
结构;
修改全连接层输出改为 102 类;
从 best.pth
文件加载权重;
state_dict,用于存储模型的所有可学习参数。
torch.load('best.pth')
从磁盘读取
.pth
文件,反序列化为 Python 字典。
checkpoint['state_dict']
提取出模型参数部分(即训练保存的 state_dict)。
model.load_state_dict(...)
把这些参数加载到你当前定义的模型结构中去。
设定为评估模式(model.eval()
)以禁用 dropout、BN 更新。
(2)图像预处理阶段
def prepare_image(image, target_size):
# 针对不同情况,image格式不一样,需要统一到RGB格式
if image.mode != "RGB":
image = image.convert("RGB")
# 按照所使用的模型调整输入图片的尺寸格式,并转为tensor
image = transforms.Resize(target_size)(image) #.forword(image)
image = transforms.ToTensor()(image)
# (RGB三通道)这的参数和数据集中是对应的
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
# 增加一个维度,用于做batch测试 本次这里一次测试一张
image = image[None]
if use_gpu:
image = image.cuda()
return torch.tensor(image)
功能:
将输入图像统一为 RGB;
调整大小至模型输入要求;
转换为 Tensor;
像素分布进行归一化(标准化)按 ImageNet 统计值进行标准化;
增加 batch 维度,使输入符合 (1, 3, 224, 224)
形状。
torch.Size([3, 224, 224]) -> torch.Size([1, 3, 224, 224])
3 → 通道数(R、G、B)
224 → 图像高度
224 → 图像宽度
多了一个 batch 维度,表示"1 张图像"。batch size(一次输入的图片数量)
对于模型推理(inference)是必不可少的。
(3)请求接收与预测阶段
@app.route("/predict", methods=["POST"])
def predict():
# 做一个标志,刚开始无图像传入时为False,传入图像时为True
data = {"success": False}
if flask.request.method == "POST": # 如果收到请求
if flask.request.files.get("image"): # 判断是否为图像
image = flask.request.files["image"].read() # 收到的图像进行读取,内容为二进制
# BytesIO()提供了一个类似文件对象的接口,允许你像操作文件一样在内存中读写二进制。传输过来的数据一般在缓存
image = Image.open(io.BytesIO(image))
# 利用上面的预处理函数对传入的图像进行预处理
image = prepare_image(image, target_size=(224, 224))
preds = F.softmax(model(image), dim=1) # 得到各个类的概率
results = torch.topk(preds.cpu().data, k=3, dim=1) # 概率最大的前3个结果
# torch.topk用于返回输入张量中每行的最大的k个元素及其对应的索引
results = (results[0].cpu().numpy(), results[1].cpu().numpy())
# 向data字典增加一个key,value,其中value为list格式
data["predictions"] = list()
流程:
项目 | 含义 |
---|---|
@app.route() |
Flask 用于定义 路由 的装饰器 |
"/predict" |
指定访问的 URL 路径 |
methods=["POST"] |
指定 HTTP 请求方式(这里只允许 POST) |
这意味着:
当客户端向服务器发送一个 POST 请求 ,访问
/predict
路径时,Flask 就会调用下面定义的
predict()
函数来处理。
-
Flask 接收到客户端上传的图片文件;
-
读取图片二进制流并转为
PIL.Image
; -
调用预处理函数;
-
使用模型进行前向推理;
-
计算 softmax 概率;
-
取出 top-3 预测结果。
-
转回 CPU 将 PyTorch 张量转换为 NumPy 数组
(4)返回结果给客户端
for prob, label in zip(results[0][0], results[1][0]):
r = {"label": str(label), "probability": float(prob)}
# 将预测结果添加到data字典
data["predictions"].append(r)
data["success"] = True
# jsonify非常有利于网络传输,字典可以直接转换为json文件的内容,用作网络传输,一般都用json格式的数据
return flask.jsonify(data)
结果示例:

这段结果将通过 JSON 格式返回,供客户端使用。
(5)服务启动
if __name__ == "__main__": #是Python中的内置变量,是有值,和你当前运行的代码文件有关
print("Loading PyTorch model and Flask starting server ...")
print("Please wait until server has fully started")
load_model() # 先加载模型
# 再开启服务
app.run(port='5012') # 最好都有自己的端口号,有些端口是固定的(Linux),若使用自己的电脑,端口号为5012来跑这个保护、服务器。进入1
Flask 启动后默认监听本机地址:
三、客户端请求
import requests
'''客户端的程序代码'''
flask_url = 'http://127.0.0.1:5012/predict' # url和端口写成自己的本地ip
def predict_result(image_path):
image = open(image_path, 'rb').read()
payload = {'image': image} #
r = requests.post(flask_url, files = payload).json() # 是因为网络传输一般使用json类型的数据 字符类型的数据
# 向flask_url服务发送一个POST请求,并尝试将返回的JSON响应解析为一个Python字典。
if r['success']: # 成功的话再返回。
# 输出结果
for (i, result) in enumerate(r['predictions']):
print('{}.预测类别为{}:的概率: {}'.format(i + 1, result['label'], result['probability']))
# 失败了就打印
else:
print('Request failed')
if __name__ == '__main__':
predict_result('./flower_data/val_filelist/image_00028.jpg')
流程:
-
指定服务器 IP 和端口;
-
读取本地图片为二进制;
-
使用
requests.post()
向 Flask 服务发送请求; -
解析返回的 JSON 结果;
-
打印预测的 top-3 类别与概率。enumerate()用来在 遍历可迭代对象(如列表、元组、字符串等)时,同时获取每个元素的索引和值。
输出示例:

四、运行准备与注意事项
-
模型文件
需在 Flask 目录下存在
best.pth
文件(训练好的权重)。文件结构大致为:
./server.py ./best.pth ./flower_data/
-
IP 与端口
-
服务端(代码一)运行主机的 IP 地址需与客户端一致;
-
可在命令行中使用以下命令查看 IP:
ipconfig # Windows ifconfig # Linux/Mac
-
-
GPU 支持
若要启用 GPU,请修改:
use_gpu = True
并确保 PyTorch 安装了 CUDA 版本。
-
跨机器访问时防火墙
需开放端口 5012:
netsh advfirewall firewall add rule name="Flask5012" dir=in action=allow protocol=TCP localport=5012
🧩 五、可优化与扩展建议
模块 | 优化方向 | 建议 |
---|---|---|
模型加载 | 减少重复加载 | 可在全局定义模型并仅加载一次 |
接口设计 | 更友好输出 | 返回类别名称而非数字索引(通过类别映射表) |
性能优化 | 批处理或异步请求 | 可支持批量图片预测 |
安全性 | 限制上传文件类型 | 仅允许 .jpg , .png 等图片格式 |
部署 | 使用 gunicorn + nginx | 提升并发与稳定性 |
客户端 | GUI 或 Web 页面 | 可做简单上传界面查看预测结果 |