大模型--模型部署

一、模型部署

1.模型部署是什么

模型部署(Model Deployment)是指将训练好的机器学习模型集成到实际的生产环境或应用系统中,使其能够接收输入数据(例如图片、文本、数值等),并实时或批量地输出预测结果,从而为最终用户或其他服务提供智能能力。

简单来说,模型训练阶段就像我们在实验室里培养了一个"专家"------这个专家通过学习大量数据掌握了某种技能(比如识别图像中的花卉种类)。但专家本身只是一堆参数和计算逻辑,无法直接对外服务。模型部署就是把这个"专家"请到工作现场(比如服务器、手机App、嵌入式设备),并给他配备一个"问答窗口"(API接口),让外部程序可以随时向它提问(发送数据),并立即得到答案(预测结果)。

2.为什么需要模型部署?
  • 让模型产生实际价值:模型只有在被使用时才具有意义,比如一个图像识别模型部署后,才能帮助用户自动分类照片。

  • 实时响应:很多场景需要毫秒级的预测反馈,例如自动驾驶、实时推荐系统。

  • 可扩展性:部署后的模型可以同时服务大量请求,支持业务增长。

  • 集成到现有系统:模型可以作为微服务嵌入到更大的软件架构中,与其他组件协同工作。

3.常见部署方式
  1. Web API 服务(如你提供的代码) 将模型封装在一个HTTP服务器中,通过RESTful API对外提供预测接口。任何能发送HTTP请求的客户端(网页、移动App、其他后端服务)都可以调用。这是最灵活、最通用的方式。

  2. 嵌入式部署 将模型压缩并移植到移动设备(手机、IoT设备)上运行,例如在手机本地进行图像识别,无需联网。

  3. 批处理预测 定期处理大量离线数据,例如每天凌晨对用户行为日志进行预测,生成推荐列表。

  4. 边缘计算部署 将模型部署在网络边缘节点(如路由器、基站),减少数据传输延迟,适用于物联网场景。

4.部署时需要关注的问题
  • 性能与延迟:模型推理速度能否满足实时要求?是否需要GPU加速?

  • 并发能力:能同时处理多少个请求?如何水平扩展?

  • 模型版本管理:如何平滑升级模型而不中断服务?

  • 监控与日志:如何监控模型效果、捕捉异常?

  • 安全与隐私:如何防止恶意攻击、保护用户数据?

二、实际运用

任务:

实现一个简单的图像分类服务,采用客户端-服务器架构。服务器端使用 Flask 框架搭建一个 REST API,加载预训练的 ResNet18 模型(在 102 类花卉数据集上微调),接收客户端上传的图片并进行推理,返回分类结果(Top-3 概率及对应类别标签)。客户端则通过 requests 库向服务器发送图片,并打印预测结果

服务器端代码(server.py

1. 导入库
python 复制代码
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
  • io:用于在内存中处理二进制数据流(如图片)。

  • flask:构建 Web 服务器和路由。

  • torchtorch.nn.functional:PyTorch 核心库,用于张量运算和 softmax。

  • PIL.Image:Python 图像库,用于打开和处理图像。

  • torchvision:提供预训练模型、数据变换等。

2. 初始化 Flask 应用
python 复制代码
app = flask.Flask(__name__)
model = None
use_gpu = False
  • app 是 Flask 应用实例,用于注册路由和启动服务。

  • model 为全局变量,存储加载后的模型。

  • use_gpu 控制是否使用 GPU(当前设为 False,即使用 CPU)。

3. 加载模型函数 load_model()
python 复制代码
def load_model():
    global model
    model = models.resnet18()
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
    checkpoint = torch.load('best.pth')
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    if use_gpu:
        model.cuda()
  • 创建一个 ResNet18 模型,并将最后的全连接层替换为输出 102 类的线性层(假设数据集有 102 个类别)。

  • 加载训练好的权重文件 best.pth(该文件应为 PyTorch 的 checkpoint,包含 'state_dict' 键)。

  • 调用 model.eval() 切换到评估模式,关闭 Dropout 和 BatchNorm 的训练行为。

  • use_gpu 为真,则将模型移动到 GPU。

4. 图像预处理函数 prepare_image()
python 复制代码
def prepare_image(image, target_size):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image = transforms.Resize(target_size)(image)
    image = transforms.ToTensor()(image)
    image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
    image = image[None]   # 增加 batch 维度
    if use_gpu:
        image = image.cuda()
    return torch.tensor(image)
  • 确保图像为 RGB 模式(有些图片可能是 RGBA 或 L 灰度)。

  • 调整尺寸到 target_size(如 224×224),这是 ResNet 输入大小。

  • 转换为张量([C, H, W],值范围从 0~255 缩放到 0~1)。

  • 使用 ImageNet 数据集的均值和标准差进行归一化,与训练时的预处理保持一致。

  • 添加 batch 维度(从 [C, H, W] 变为 [1, C, H, W]),因为模型期望输入是一个 batch。

  • 最后返回张量。

5. 预测路由 /predict
python 复制代码
@app.route("/predict", methods=["POST"])
def predict():
    data = {"success": False}
    if flask.request.method == 'POST':
        if flask.request.files.get("image"):
            image = flask.request.files["image"].read()
            image = Image.open(io.BytesIO(image))
            image = prepare_image(image, target_size=(224, 224))
            preds = F.softmax(model(image), dim=1)
            results = torch.topk(preds.cpu().data, k=3, dim=1)
            results = (results[0].cpu().numpy(), results[1].cpu().numpy())
​
            data['predictions'] = list()
            for prob, label in zip(results[0][0], results[1][0]):
                r = {"label": str(label), "probability": float(prob)}
                data['predictions'].append(r)
            data["success"] = True
​
    return flask.jsonify(data)
  • 定义一个只接受 POST 请求的路由。

  • 初始化返回数据字典,默认 "success": False

  • 检查请求中是否包含名为 "image" 的文件(客户端上传时需用该字段名)。

  • 读取文件的二进制内容,用 PIL.Image.openio.BytesIO 将其转换为图像对象。

  • 调用预处理函数得到模型输入张量。

  • 执行模型推理:model(image) 得到原始 logits,然后 F.softmax 转换为概率(dim=1 表示对类别维度)。

  • torch.topk 取概率最高的前 3 个结果,返回值和索引。

  • 将结果从 GPU 移到 CPU(cpu()),转为 NumPy 数组,便于处理。

  • 构建预测列表,每个元素包含标签(索引)和概率(浮点数)。

  • success 设为 True,最后用 flask.jsonify 将字典转为 JSON 返回。

6. 主程序入口
python 复制代码
if __name__ == '__main__':
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")
    load_model()
    app.run(host='0.0.0.0', port=5100)
  • 仅在直接运行该脚本时执行。

  • 先加载模型,然后启动 Flask 开发服务器,监听所有网络接口(0.0.0.0)的 5100 端口。

7.运行结果

客户端代码(client.py

1. 导入库并定义服务器地址
python 复制代码
import requests
​
flask_url = 'http://127.0.0.1:5100/predict'
# flask_url='http://192.168.31.87:5100/predict'   # 另一可选地址,这是我自己的,写代码时换成自己主机的
  • 使用 requests 库发送 HTTP 请求。

  • 定义服务器的完整 URL(注意端口与服务器端一致)。

2. 预测函数 predict_result(image_path)
python 复制代码
def predict_result(image_path):
    image = open(image_path, 'rb').read()
    payload = {'image': image}
    r = requests.post(flask_url, files=payload).json()
​
    if r['success']:
        for i, result in enumerate(r['predictions']):
            print('{}.预测类别为{}:的概率:{}'.format(i+1, result['label'], result['probability']))
    else:
        print('request failed')
  • 以二进制读取模式打开图片文件,获取图片的二进制数据。

  • 构造 payload 字典,键 'image' 对应图片二进制数据。注意这里直接传二进制,requests.postfiles 参数会将其包装为 multipart/form-data 格式。

  • 发送 POST 请求,并解析返回的 JSON 数据(.json())。

  • 如果服务器返回的 success 为真,则遍历 predictions 列表,打印排名、类别标签和概率。

  • 若请求失败(success 为假),打印提示。

3. 主程序调用
python 复制代码
if __name__ == '__main__':
    predict_result('hua.jpg')
  • 调用预测函数,传入本地图片文件名 hua.jpg
4.运行结果
整体工作流程
  1. 启动服务器:运行 server.py,加载模型,监听端口 5100。

  2. 客户端运行 client.py,读取本地图片 hua.jpg,向服务器发送 POST 请求(携带图片)。

  3. 服务器接收请求,预处理图片,进行模型推理,得到 Top-3 预测结果,以 JSON 格式返回。

  4. 客户端接收 JSON,解析并打印结果。

注意事项与潜在问题
  • 模型路径 :服务器假设 best.pth 在当前目录下,且 checkpoint 包含 'state_dict' 键。实际使用时需确保文件存在且格式正确。

  • 类别映射:服务器返回的是类别索引(0~101),而非实际类别名称。若需要输出具体类别名(如"玫瑰"),需在服务器端维护一个索引到标签的映射字典。

  • 错误处理:代码对异常(如图片损坏、模型推理失败)没有捕获,可能返回 500 错误或崩溃。生产环境应添加 try-except 和适当的错误响应。

  • 性能:当前使用 CPU 推理,若图片并发量大可能成为瓶颈。可启用 GPU、使用异步框架或模型优化。

  • 安全性:直接接受用户上传文件,存在安全风险(如超大图片、恶意文件)。应限制文件大小、类型,并对输入进行验证。

  • Flask 开发服务器app.run() 启动的是单进程开发服务器,不适合生产。生产环境建议使用 Gunicorn 等 WSGI 服务器。

相关推荐
佳木逢钺2 小时前
机器人/无人机视觉开发选型指南:RealSense D455 vs D435i 与奥比中光的互补方案
c++·人工智能·计算机视觉·机器人·ros·无人机
实在智能RPA2 小时前
2026年企业级实测:企业部署智能体要什么电脑配置?从硬件门槛到架构选型的深度拆解
人工智能·ai·架构
大力财经2 小时前
阿里发布全球首个企业级Agent平台“悟空”
大数据·人工智能
大刘讲IT2 小时前
AI 革命:生产力范式跃迁与数字文明重构
人工智能·程序人生·重构·制造
2301_821700532 小时前
使用Scikit-learn进行机器学习模型评估
jvm·数据库·python
星爷AG I2 小时前
14-11 双手协调(AGI基础理论)
人工智能·agi
大囚长2 小时前
游戏主机神经纹理压缩与AI重建技术的综合应用方案分析
人工智能·游戏
Chasing__Dreams2 小时前
python--设计模式--13.1--结构性--享元模式
python·设计模式·享元模式
休息一下接着来2 小时前
神经网络与卷积神经网络(CNN)
人工智能·神经网络·cnn