目录
[1.1 环境要求](#1.1 环境要求)
[1.2 必备文件](#1.2 必备文件)
[二、服务端搭建:让模型 "听候指令"](#二、服务端搭建:让模型 “听候指令”)
[2.1 服务端完整代码(server.py)](#2.1 服务端完整代码(server.py))
[2.2 服务端启动与验证](#2.2 服务端启动与验证)
[三、客户端开发:向服务端 "发送请求"](#三、客户端开发:向服务端 “发送请求”)
[3.1 客户端完整代码(client.py)](#3.1 客户端完整代码(client.py))
[3.2 客户端运行与结果示例](#3.2 客户端运行与结果示例)
[4.1 服务端启动失败](#4.1 服务端启动失败)
[4.2 客户端连接失败](#4.2 客户端连接失败)
[4.3 模型推理报错](#4.3 模型推理报错)
前言
在机器学习项目中,训练好的模型只有部署到实际环境中才能发挥价值。对于本机测试或小规模应用场景,Flask 框架是实现模型部署的轻量优选 ------ 它能快速搭建 HTTP 服务,让模型以接口形式接收请求、返回预测结果,无需复杂的服务器配置。
本文将以 ResNet18 图像分类模型为例,完整讲解如何用 Flask 实现 "本机模型部署":从服务端代码编写(模型加载、接口定义),到客户端代码开发(图像上传、结果解析),再到常见问题排查,确保新手也能一步到位跑通流程。
一、部署前准备
在开始编写代码前,需先确认环境和依赖是否齐全,避免后续因版本或包缺失导致报错。
1.1 环境要求
- Python 版本:3.7~3.9(PyTorch 对高版本 Python 兼容性可能不稳定)
- 核心依赖包:
python
# 安装Flask(Web服务框架)
pip install flask
# 安装PyTorch+TorchVision(模型加载与图像预处理)
pip install torch torchvision
# 安装PIL(图像读取处理)和requests(客户端请求)
pip install pillow requests
1.2 必备文件
- 训练好的模型权重文件:本文使用
best.pth
(ResNet18 微调后权重,需确保与代码中类别数匹配) - 测试图像:准备 1~2 张用于验证的图像(如 JPG/PNG 格式)
- 代码结构:建议按如下目录组织,避免路径混乱

二、服务端搭建:让模型 "听候指令"
服务端的核心作用是:加载预训练模型、定义预测接口、监听本机请求。当客户端发送图像请求时,服务端会完成图像预处理、模型推理,并返回 JSON 格式的预测结果。
2.1 服务端完整代码(server.py)
python
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
# 1. 初始化Flask应用
app = flask.Flask(__name__) # __name__定位应用根路径,用于查找静态资源
model = None # 全局变量存储模型,避免重复加载
use_gpu = False # 本机部署默认用CPU(若有GPU可设为True,需确保PyTorch支持CUDA)
# 2. 加载预训练模型
def load_model():
"""加载ResNet18模型,替换全连接层适配自定义分类任务"""
global model
# 加载ResNet18基础网络(pretrained=False表示不加载默认预训练权重,用自己的best.pth)
model = models.resnet18(pretrained=False)
# 获取全连接层输入特征数,替换为自定义类别数(本文以102类为例)
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)) # 输出维度=类别数
# 加载训练好的权重文件(需确保best.pth路径正确)
checkpoint = torch.load('best.pth', map_location=torch.device('cpu')) # 强制CPU加载,避免GPU报错
model.load_state_dict(checkpoint['state_dict']) # 加载权重参数
# 设为评估模式(禁用Dropout、BatchNorm等训练特有的层)
model.eval()
# 若启用GPU且设备支持,将模型移至CUDA
if use_gpu and torch.cuda.is_available():
model = model.cuda()
print("模型加载完成,等待请求...")
# 3. 图像预处理函数(需与训练时保持一致)
def prepare_image(image, target_size=(224, 224)):
"""将客户端传入的图像转为模型可接受的Tensor格式"""
# 统一图像为RGB格式(避免灰度图/透明图报错)
if image.mode != 'RGB':
image = image.convert('RGB')
# 预处理 pipeline:Resize→ToTensor→Normalize(ResNet默认预处理参数)
preprocess = transforms.Compose([
transforms.Resize(target_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_tensor = preprocess(image)
# 增加batch维度(模型要求输入为[batch_size, C, H, W],单张图batch_size=1)
image_tensor = image_tensor.unsqueeze(0) # 等价于image_tensor[None]
# 若启用GPU,将Tensor移至CUDA
if use_gpu and torch.cuda.is_available():
image_tensor = image_tensor.cuda()
return image_tensor
# 4. 定义预测接口(POST方法)
@app.route("/predict", methods=["POST"])
def predict():
"""
接收客户端POST请求:
- 请求体包含image字段(二进制图像)
- 返回JSON格式结果:success(请求状态)、predictions(Top3预测结果)
"""
# 初始化返回结果字典
result = {"success": False, "predictions": []}
# 检查请求方法是否为POST,且包含image文件
if flask.request.method == "POST" and flask.request.files.get("image"):
# 步骤1:读取客户端传入的二进制图像
image_bytes = flask.request.files["image"].read()
# 将二进制数据转为PIL图像对象
image = Image.open(io.BytesIO(image_bytes))
# 步骤2:图像预处理
image_tensor = prepare_image(image, target_size=(224, 224))
# 步骤3:模型推理(禁用梯度计算,加速推理)
with torch.no_grad():
# 计算各类别概率(softmax归一化)
preds = F.softmax(model(image_tensor), dim=1)
# 获取概率Top3的类别和概率值(cpu()转为CPU张量,避免GPU与CPU数据冲突)
top3_probs, top3_labels = torch.topk(preds.cpu(), k=3, dim=1)
# 步骤4:处理结果(转为numpy数组→构造JSON格式)
top3_probs = top3_probs.numpy()[0] # 取第1个batch(仅1张图)
top3_labels = top3_labels.numpy()[0]
# 遍历Top3结果,添加到返回字典
for prob, label in zip(top3_probs, top3_labels):
result["predictions"].append({
"label": str(label), # 类别标签(若有类别名映射,可此处替换为中文)
"probability": round(float(prob), 4) # 概率值(保留4位小数)
})
# 标记请求成功
result["success"] = True
# 以JSON格式返回结果(Flask自动设置Content-Type为application/json)
return flask.jsonify(result)
# 5. 启动服务
if __name__ == "__main__":
# 先加载模型(确保模型加载成功后再启动服务)
try:
load_model()
except Exception as e:
print(f"模型加载失败:{str(e)}")
exit(1) # 模型加载失败则退出程序
# 启动Flask服务(本机部署关键参数)
# host='127.0.0.1':仅本机可访问(推荐本机测试用)
# port=5012:端口号(避免与其他服务冲突,如8080、5000)
app.run(host='127.0.0.1', port=5012, debug=False)
2.2 服务端启动与验证
- 运行server.py,若控制台输出以下内容,说明服务启动成功:

2.初步验证:打开浏览器访问http://127.0.0.1:5012/predict
,若返回{"success":false,"predictions":[]}
,证明接口正常监听(无图像请求时返回默认状态)。
三、客户端开发:向服务端 "发送请求"
客户端的作用是:读取本地图像、以 POST 方式向服务端接口发送请求、解析返回的 JSON 结果并打印。
3.1 客户端完整代码(client.py)
python
import requests
# 1. 配置服务端地址(需与服务端host和port一致)
# 本机部署用http://127.0.0.1:5012/predict,跨设备需替换为服务端局域网IP
FLASK_URL = "http://127.0.0.1:5012/predict"
def predict_image(image_path):
"""
向服务端发送图像预测请求:
param image_path: 本地图像路径(如"./test_img/image_06975.jpg")
return: 打印预测结果
"""
try:
# 步骤1:以二进制形式读取本地图像(保持图像原始格式)
with open(image_path, 'rb') as f:
image_bytes = f.read()
# 步骤2:构造请求体(key为"image",与服务端flask.request.files.get("image")对应)
payload = {"image": image_bytes}
# 步骤3:发送POST请求(timeout设为10秒,避免请求超时)
response = requests.post(FLASK_URL, files=payload, timeout=10)
# 步骤4:解析响应结果(JSON→字典)
result = response.json()
# 步骤5:判断请求是否成功并打印结果
if response.status_code == 200 and result["success"]:
print("请求成功(状态码:200)")
print("Top3预测结果:")
for i, pred in enumerate(result["predictions"], 1):
print(f" {i}. 类别:{pred['label']},概率:{pred['probability']}")
else:
print(f"请求失败:{result}")
except Exception as e:
print(f"客户端报错:{str(e)}")
# 6. 运行客户端(测试单张图像)
if __name__ == "__main__":
# 替换为你的测试图像路径(相对路径/绝对路径均可)
test_image_path = "./test_img/image_06975.jpg"
predict_image(test_image_path)
3.2 客户端运行与结果示例
- 确保服务端已启动,运行client.py,若请求成功,控制台输出如下:

状态码说明:
◦ 200:请求成功(服务端正常处理)
◦ 404:接口地址错误(如 URL 写错)
◦ 500:服务端内部错误(如模型加载失败、代码报错)
四、常见问题排查
在部署过程中,新手容易遇到连接失败、模型报错等问题,以下是高频问题的解决方案:
4.1 服务端启动失败
-
报错 1:
FileNotFoundError: [Errno 2] No such file or directory: 'best.pth'
原因:模型权重文件路径错误。解决:确认best.pth
与server.py
在同一目录,或使用绝对路径(如torch.load("C:/model_deployment/best.pth")
)。 -
报错 2:
WinError 10049
请求的地址无效 原因:app.run(host=...)
中 IP 地址错误(非本机 IP)。解决:本机部署用host='127.0.0.1'
,或通过ipconfig
(Windows)/ifconfig
(Linux)查看本机正确 IP。
4.2 客户端连接失败
-
报错 1:
requests.exceptions.ConnectionError: HTTPConnectionPool
原因:服务端未启动,或 IP / 端口不匹配。解决:先启动服务端,确认客户端FLASK_URL
与服务端host:port
完全一致(如均为127.0.0.1:5012
)。 -
**报错 2:
PIL.UnidentifiedImageError: cannot identify image file
**原因:图像路径错误,或文件不是有效图像(如后缀为.jpg 但实际是.txt)。解决:检查图像路径,用画图工具打开图像确认是否正常。
4.3 模型推理报错
- 报错:
RuntimeError: Expected 4-dimensional input for 4-dimensional weight
原因:图像未增加 batch 维度(模型要求输入为[batch_size, C, H, W]
)。解决:确保prepare_image
函数中调用image_tensor.unsqueeze(0)
。
五、扩展与优化建议
-
类别名映射 :当前返回的是类别编号(如 35),可添加字典映射为中文(如
label_map = {35: "玫瑰", 36: "百合"}
),在服务端predict
函数中替换"label": str(label)
为"label": label_map[label]
。 -
多图批量预测:修改客户端代码,支持遍历文件夹下所有图像,批量发送请求。
-
生产环境优化 :Flask 开发服务器不适合生产环境,可改用
Gunicorn
(Linux)或Waitress
(Windows)作为 WSGI 服务器,搭配Nginx
反向代理,提升并发能力。
总结
本文通过 Flask 框架实现了本机 PyTorch 模型的完整部署流程:服务端负责加载模型和提供接口,客户端负责发送请求和解析结果,整个流程轻量、易上手,适合小规模测试或个人项目使用。
核心要点可总结为 3 点:
- 服务端与客户端的
host:port
必须一致; - 图像预处理需与训练时保持一致(如 Resize 尺寸、Normalize 参数);
- 先启动服务端,再运行客户端,避免连接失败。
按照本文步骤操作,即可快速将自己的 PyTorch 模型部署到本机,实现 "训练→部署→调用" 的闭环。