AI 模型部署从入门到生产 —— ONNX 转换、TensorRT 加速、推理服务搭建

前言

模型训练出来只是第一步,让它稳定高效地服务线上请求才是真正的 challenge。这篇文章覆盖模型部署的完整流程------从模型导出到生产级推理服务。

Step 1:模型导出与格式转换

PyTorch → ONNX

ONNX(Open Neural Network Exchange)是模型部署的中间格式,几乎所有的推理框架和硬件都支持它。

python 复制代码
import torch
import torch.onnx

model = YourModel()
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
    opset_version=17,
)

关键参数:

  • dynamic_axes:声明哪些维度是可变的,比如 batch_size 和序列长度
  • opset_version:越高支持的算子越多,但硬件兼容性可能下降。推荐 17-19

ONNX 优化

导出的 ONNX 模型可以用 onnxruntime 做优化:

bash 复制代码
python -m onnxruntime.tools.convert_onnx_models_to_ort model.onnx

或者在代码中优化:

python 复制代码
import onnxruntime as ort
import onnx
from onnxruntime.transformers import optimizer

opt = optimizer.optimize_model(
    "model.onnx",
    model_type="bert",    # 根据模型类型选择
    num_heads=12,
    hidden_size=768,
    opt_level=99,         # 最大优化
)
opt.save_model_to_file("model_optimized.onnx")

优化效果:一般能提升 1.5x-3x 的推理速度。

检查 ONNX 模型的正确性

python 复制代码
# 验证输出一致性
with torch.no_grad():
    torch_output = model(dummy_input).numpy()

ort_session = ort.InferenceSession("model.onnx")
ort_input = {ort_session.get_inputs()[0].name: dummy_input.numpy()}
ort_output = ort_session.run(None, ort_input)

# 检查误差
diff = np.abs(torch_output - ort_output[0]).max()
print(f"Max diff: {diff}")
# 对于 FP32,diff < 1e-5 是正常的
# 对于 FP16/INT8,diff < 1e-2 是可接受的

Step 2:推理加速 ------ TensorRT

TensorRT 是 NVIDIA 的推理优化引擎,可以在 GPU 上达到极致性能。

ONNX → TensorRT

python 复制代码
import tensorrt as trt

logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)

with open("model.onnx", "rb") as f:
    parser.parse(f.read())

config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)  # 1GB

# FP16 精度(精度几乎无损,速度提升 2x)
if builder.platform_has_fast_fp16:
    config.set_flag(trt.BuilderFlag.FP16)

# INT8 量化(需要校准数据集,速度再提升 2x)
# config.set_flag(trt.BuilderFlag.INT8)

engine = builder.build_serialized_network(network, config)
with open("model.trt", "wb") as f:
    f.write(engine)

精度与速度权衡

精度 推理速度 模型大小 质量影响 适用场景
FP32 1x 100% 基准、调试
FP16 1.5-2x 50% 几乎无 生产推荐
INT8 3-4x 25% 轻微 对延迟极度敏感
INT4 5-6x 12.5% 明显 边缘设备

通用建议:FP16 是性价比最高的选择。INT8 需要校准数据集,且对某些模型有精度风险。

TensorRT 推理

python 复制代码
import tensorrt as trt
import pycuda.driver as cuda

class TRTInference:
    def __init__(self, engine_path):
        logger = trt.Logger(trt.Logger.WARNING)
        with open(engine_path, "rb") as f:
            self.engine = trt.Runtime(logger).deserialize_cuda_engine(f.read())
        self.context = self.engine.create_execution_context()

        # 分配 GPU 内存
        self.inputs = []
        self.outputs = []
        for i in range(self.engine.num_io_tensors):
            name = self.engine.get_tensor_name(i)
            shape = self.engine.get_tensor_shape(name)
            size = trt.volume(shape)
            dtype = trt.nptype(self.engine.get_tensor_dtype(name))
            # 分配内存
            ...  # 完整代码见注

    def infer(self, input_numpy):
        self.inputs[0].host = input_numpy.astype(np.float32)
        [cuda.memcpy_htod(inp.device, inp.host) for inp in self.inputs]
        self.context.execute_v2([inp.device for inp in self.inputs + self.outputs])
        [cuda.memcpy_dtoh(out.host, out.device) for out in self.outputs]
        return self.outputs[0].host

Step 3:推理服务框架

三种主流方案对比

框架 适用场景 特点
Triton Inference Server 多模型混合、A/B 测试 NVIDIA 官方,支持多种硬件和框架
vLLM LLM 推理(7B-70B) PagedAttention,连续批处理
SGLang LLM 推理(结构化输出) RadixAttention,JSON Mode

用 vLLM 部署 LLM

python 复制代码
from vllm import LLM, SamplingParams

llm = LLM(
    model="meta-llama/Llama-3.2-3B",
    tensor_parallel_size=1,      # 单卡部署
    gpu_memory_utilization=0.9,  # GPU 显存利用率
    max_num_seqs=256,            # 最大并发序列数
    enable_prefix_caching=True,  # 前缀缓存加速
)

sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=1024,
)

outputs = llm.generate(["请解释什么是 Transformer"], sampling_params)
print(outputs[0].outputs[0].text)

生产部署用 OpenAI 兼容 API:

bash 复制代码
python -m vllm.entrypoints.openai.api_server     --model meta-llama/Llama-3.2-3B     --port 8000     --gpu-memory-utilization 0.9

然后直接当 OpenAI API 用:

python 复制代码
from openai import OpenAI

client = OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="sk-xxx",  # vLLM 不验证 key,随便填
)

response = client.chat.completions.create(
    model="meta-llama/Llama-3.2-3B",
    messages=[{"role": "user", "content": "Hello"}],
)

vLLM 关键配置

python 复制代码
LLM(
    model="...",
    # 显存配置
    gpu_memory_utilization=0.85-0.95,  # 留一些给 CUDA kernel
    max_model_len=8192,                # 最大上下文长度
    # 性能配置
    num_scheduler_steps=8,             # 调度步数,越大吞吐越高
    enable_chunked_prefill=True,       # 分块 prefill,降低 TTFT
    # 量化
    quantization="awq",                # AWQ 量化(需要先量化模型)
    # 分布式
    tensor_parallel_size=1,            # >1 时需多卡
)

Step 4:性能压测与监控

关键指标

指标 英文 含义 好值
首 Token 延迟 TTFT 从请求到第一个 Token 的时间 < 500ms
Token 生成速度 TPOT 每个 Token 的生成时间 < 50ms
吞吐量 Throughput 每秒生成的 Token 数 > 1000 tokens/s
并发数 Concurrency 同时处理的请求数 视场景而定

压测工具

bash 复制代码
# 用 vllm 自带的 benchmark
python -m vllm.benchmarks.benchmark_throughput     --model meta-llama/Llama-3.2-3B     --dataset ShareGPT_V3_unfiltered_cleaned_split.json     --num-prompts 1000

# 用 locust 做 HTTP 压测
pip install locust
locust --host http://localhost:8000

生产级部署架构

markdown 复制代码
                   ┌─────────────┐
                   │   Load      │
                   │  Balancer   │
                   └──────┬──────┘
                          │
          ┌───────────────┼───────────────┐
          │               │               │
    ┌─────▼─────┐   ┌─────▼─────┐   ┌─────▼─────┐
    │ Worker 1  │   │ Worker 2  │   │ Worker 3  │
    │ vLLM/Triton│   │ vLLM/Triton│   │ vLLM/Triton│
    └───────────┘   └───────────┘   └───────────┘
          │               │               │
    ┌─────▼───────────────▼───────────────▼─────┐
    │          共享显存 / 模型分片                 │
    └───────────────────────────────────────────┘
  • 水平扩展:多个 Worker 加 Load Balancer
  • 动态缩扩容:根据队列长度自动调整 Worker 数量
  • 缓存层:Redis 缓存常见问题的回答(适合 LLM 场景,回答可以复用)

总结

模型部署的核心链路:

bash 复制代码
PyTorch → ONNX → TensorRT → Triton/vLLM → Load Balancer → API
  ①        ②        ③          ④               ⑤
  1. PyTorch 导出 ONNX(设置 dynamic_axes)
  2. ONNX 优化(onnxruntime transformers)
  3. TensorRT 转换(FP16 是最优性价比)
  4. 推理框架部署(LLM 用 vLLM,多模态用 Triton)
  5. 负载均衡与缓存

每个环节都有优化空间,但先跑通全链路,再逐步压测优化是最务实的方式。


本文发布于 Zyentor(智元界) ------ AI 开发者社区 原文链接:www.zyentor.com/news/3183

相关推荐
董董灿是个攻城狮1 小时前
AI 会吃了天涯吗?
人工智能
A15362551 小时前
从 AI 零引用到高转化:GEO 落地价值解析
人工智能
Omics Pro1 小时前
P4医学4大支柱需绑定4大数字技术才可落地
人工智能·python·算法·机器学习·plotly
段一凡-华北理工大学1 小时前
工业领域的Hadoop架构学习~系列文章07:Spark内存计算引擎
大数据·人工智能·hadoop·学习·架构·高炉炼铁·高炉炼铁智能化
机器学习是魔鬼1 小时前
在矩池云上开箱即用Energy Forecasting:能源电力电价预测实战指南
人工智能·python·机器学习
AINative软件工程1 小时前
LLM Prompt 版本管理工程实践:像管代码一样管理你的 Prompt,告别“改坏了不知道”
人工智能·架构
阿黎梨梨1 小时前
小白也能懂的 AI 黑话手册:从 Token 到 Agent 的硬核科普
人工智能
艺舟先生1 小时前
开源agent源码架构分析之claude(二)
人工智能·架构