模型部署实战:PyTorch生产化指南

‌**一、为什么要做模型部署?**‌

模型部署是将训练好的模型‌投入实际应用‌的关键步骤,涉及:

  1. 模型格式转换(TorchScript/ONNX)
  2. 性能优化(量化/剪枝)
  3. 构建API服务
  4. 移动端集成

本章使用ResNet18实现图像分类,并演示完整部署流程。

二、模型转换:TorchScript与ONNX

1. 准备预训练模型

python 复制代码
import torch
import torchvision

# 加载预训练模型
model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
model.eval()

# 示例输入
dummy_input = torch.rand(1, 3, 224, 224)

2. 导出为TorchScript

python 复制代码
# 方法一:追踪执行路径(适合无控制流模型)
traced_model = torch.jit.trace(model, dummy_input)
torch.jit.save(traced_model, "resnet18_traced.pt")

# 方法二:直接转换(适合含if/for的模型)
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, "resnet18_scripted.pt")

# 加载测试
loaded_model = torch.jit.load("resnet18_traced.pt")
output = loaded_model(dummy_input)
print("TorchScript输出形状:", output.shape)  # 应输出torch.Size([1, 1000])

3. 导出为ONNX格式

python 复制代码
torch.onnx.export(
    model,
    dummy_input,
    "resnet18.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        'input': {0: 'batch_size'}, 
        'output': {0: 'batch_size'}
    }
)

# 验证ONNX模型
import onnx
onnx_model = onnx.load("resnet18.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX模型输入输出:")
print(onnx_model.graph.input)
print(onnx_model.graph.output)

三、构建API服务

1. 使用FastAPI创建Web服务

python 复制代码
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import io
import numpy as np
import torchvision.transforms as transforms

app = FastAPI()

# 加载TorchScript模型
model = torch.jit.load("resnet18_traced.pt")

# 图像预处理
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

@app.post("/predict")
async def predict(image: UploadFile = File(...)):
    # 读取并预处理图像
    image_data = await image.read()
    img = Image.open(io.BytesIO(image_data)).convert("RGB")
    tensor = preprocess(img).unsqueeze(0)
    
    # 执行推理
    with torch.no_grad():
        output = model(tensor)
    
    # 获取预测结果
    _, pred = torch.max(output, 1)
    return {"class_id": int(pred)}

# 运行命令:uvicorn main:app --reload

2. 测试API服务

python 复制代码
import requests

# 准备测试图片
url = "https://images.unsplash.com/photo-1517849845537-4d257902454a?auto=format&fit=crop&w=224&q=80"
response = requests.get(url)
with open("test_dog.jpg", "wb") as f:
    f.write(response.content)

# 发送预测请求
with open("test_dog.jpg", "rb") as f:
    files = {"image": f}
    response = requests.post("http://localhost:8000/predict", files=files)
    print("预测结果:", response.json())  # 应输出对应类别ID

‌**四、移动端部署(Android/iOS)**‌

1. 转换Core ML格式(iOS)

python 复制代码
import coremltools as ct

# 从PyTorch转换
example_input = torch.rand(1, 3, 224, 224) 
traced_model = torch.jit.trace(model, example_input)

mlmodel = ct.convert(
    traced_model,
    inputs=[ct.TensorType(shape=example_input.shape)]
)
mlmodel.save("ResNet18.mlmodel")

2. 使用PyTorch Mobile(Android)

python 复制代码
// Android示例代码(Java)
Module module = Module.load(assetFilePath(this, "resnet18_traced.pt"));
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
    bitmap,
    TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
    TensorImageUtils.TORCHVISION_NORM_STD_RGB
);
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

五、性能优化技巧

1. 模型量化(减少体积/提升速度)

python 复制代码
# 动态量化
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)
torch.jit.save(torch.jit.script(quantized_model), "resnet18_quantized.pt")

# 测试量化效果
print("原始模型大小:", sum(p.numel() for p in model.parameters()))
print("量化模型大小:", sum(p.numel() for p in quantized_model.parameters()))

2. ONNX Runtime加速推理

python 复制代码
import onnxruntime

ort_session = onnxruntime.InferenceSession("resnet18.onnx")
ort_inputs = {ort_session.get_inputs().name: dummy_input.numpy()}
ort_outputs = ort_session.run(None, ort_inputs)

print("ONNX Runtime输出形状:", ort_outputs.shape)

六、常见问题解答

‌**Q1:如何处理模型版本兼容性问题?**‌

  • 保持PyTorch版本一致(使用requirements.txt固定版本)
  • 转换时指定opset_version:
python 复制代码
torch.onnx.export(..., opset_version=13)

‌**Q2:部署时出现形状不匹配错误?**‌

  • 检查预处理是否与训练时一致
  • 使用Netron可视化模型输入输出:
python 复制代码
pip install netron
netron resnet18.onnx

‌**Q3:如何监控API性能?**‌

  • 添加中间件记录响应时间:
python 复制代码
@app.middleware("http")
async def add_process_time(request, call_next):
    start_time = time.time()
    response = await call_next(request)
    response.headers["X-Process-Time"] = str(time.time() - start_time)
    return response

七、小结与下篇预告

  • 本文重点‌:

    1. 模型格式转换(TorchScript/ONNX)
    2. 构建高并发API服务
    3. 移动端部署与性能优化
  • 下篇预告 ‌:

    第六篇将深入PyTorch生态,介绍分布式训练与多GPU加速策略,实现工业级训练效率!

相关推荐
忍者算法5 分钟前
什么是 LLM(大语言模型)?——从直觉到应用的全面解读
人工智能·语言模型·自然语言处理
钢铁男儿5 分钟前
Python 序列构成的数组(列表推导和生成器表达式)
开发语言·windows·python
aiweker11 分钟前
Celery 全面指南:Python 分布式任务队列详解
开发语言·分布式·python
SoFlu软件机器人12 分钟前
从 Copilot 到垂直工具:AI 编程的 “专精特新“ 进化论
人工智能·copilot
蚝油菜花14 分钟前
Video-T1:视频生成实时手术刀!清华腾讯「帧树算法」终结闪烁抖动
人工智能·开源
JavaEdge在掘金16 分钟前
让 LLM 既能“看”又能“推理”!
python
可乐张20 分钟前
大模型MCP 教程:从原理到实战的全攻略
python·mcp
KARL31 分钟前
cursor、cline很🔥,AI浪潮下作为前端如何构建自己的vscode编程agent
前端·人工智能
无关风雪月32 分钟前
[Python]不要使用可变对象作为函数默认参数或者作为字典的键
后端·python
James. 常德 student38 分钟前
深度学习之自动求导
人工智能·深度学习