LLM 推理加速工程实战:从 KV Cache 到 Continuous Batching,把吞吐拉满但不把延迟搞崩

这篇文章写给做过线上推理服务的人:你可能已经把模型跑起来了,也知道"开缓存""开 batch""上 vLLM/SGlang"这些词;但真到线上你会发现:

  • 吞吐提升了,P99 延迟炸了;
  • KV cache 开了,显存不够了;
  • batch 大了,排队时间把 decode 省下来的全吃回去;
  • 模型换成更大/更小,系统瓶颈位置完全变了。

我想把这些"工程里真的会遇到的 tradeoff"讲清楚,并给一套可以直接复制的压测+指标口径。

0. 先把问题说清楚:你在优化的到底是什么?

我见过很多团队一上来就说"要提升 QPS",然后开始堆机器、调 batch、换框架。最后上线一周:

  • 平均延迟看起来还行
  • P95/P99偶发尖刺
  • 用户抱怨"有时候特别慢"
  • 成本居高不下

原因是:LLM 推理服务通常要同时满足三个目标:

  1. 吞吐(Throughput):单位时间能处理多少 token / request
  2. 延迟(Latency):尤其是 TTFT(Time To First Token)和 P99
  3. 成本(Cost):显存/显卡利用率、单位 token 成本

它们是一个三角形:动一个角,另两个角经常会变形。

本文会用统一的指标口径来讲:

  • TTFT:从请求到第一个 token 输出
  • TPOT(Time Per Output Token):每个输出 token 的平均耗时(不含排队)
  • Queue Wait:排队等待调度的时间(batching 的副作用通常在这里)
  • Tokens/s:吞吐(每秒输出 token 数)

1. 推理的两个阶段:Prefill vs Decode(别混着优化)

LLM 推理可以粗暴拆成:

  • Prefill(又叫 prompt 处理):把输入 prompt 喂进去,构建 KV cache
  • Decode(逐 token 生成):每一步用 KV cache 继续生成下一个 token

在工程上这非常关键:

  • Prefill 更像"大矩阵乘法",吞吐通常高,延迟与 prompt 长度强相关
  • Decode 更像"小步迭代",每一步计算量小,但要做很多步,且更容易被调度/通信/内存访问拖慢

一个简单但很有效的做法:把压测拆成两类

  • 固定输出长度,扫输入长度(prefill 压测)
  • 固定输入长度,扫输出长度(decode 压测)

你会很快看到瓶颈到底在哪。

2. KV Cache:它不是"开了就快",而是"开了就占内存"

KV cache 的本质:把 attention 里历史 token 的 Key/Value 存下来,避免每一步重算。

2.1 工程上最常见的坑:显存被 KV 吃光

很多人看到"KV cache 加速",就默认开到最大。然后出现两类问题:

  • OOM:并发一上来就爆
  • 吞吐下降:为了不 OOM,把 batch/并发压低,整体吞吐反而下降

你需要一个"显存预算"的概念:

显存 = 模型权重 + 激活/临时 buffer + KV cache

而 KV cache 与以下因素线性相关:

  • 并发请求数(或同时 decode 的序列数)
  • 上下文长度(prompt + 已生成 token)
  • 层数、头数、head dim
  • dtype(fp16/bf16/int8 量化方案)

2.2 一个可以直接用的 KV cache 估算脚本(真实代码)

下面这段 Python 代码可以粗估 KV cache 占用(偏保守),你可以拿去给容量评审用。

python 复制代码
# kv_estimate.py
from dataclasses import dataclass

@dataclass
class ModelCfg:
    num_layers: int
    num_kv_heads: int
    head_dim: int
    dtype_bytes: int  # fp16/bf16 = 2


def estimate_kv_bytes(cfg: ModelCfg, batch: int, seq_len: int) -> int:
    # 每层 KV:K 和 V 各一份
    # shape ~ (batch, num_kv_heads, seq_len, head_dim)
    per_layer = 2 * batch * cfg.num_kv_heads * seq_len * cfg.head_dim * cfg.dtype_bytes
    return per_layer * cfg.num_layers


def human(n: int) -> str:
    for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
        if n < 1024:
            return f"{n:.2f}{unit}"
        n /= 1024
    return f"{n:.2f}PB"


if __name__ == '__main__':
    # 以一个常见 7B-13B 量级的配置举例(请按你的模型改)
    cfg = ModelCfg(num_layers=32, num_kv_heads=32, head_dim=128, dtype_bytes=2)

    for batch in [1, 4, 8, 16]:
        for seq in [512, 2048, 8192]:
            kv = estimate_kv_bytes(cfg, batch=batch, seq_len=seq)
            print(f"batch={batch:>2} seq={seq:>5} -> KV ~ {human(kv)}")

你会发现:上下文长度从 2k 到 8k 是 4 倍,并发从 4 到 16 也是 4 倍,叠加就是 16 倍。

这就是为什么"我只是把 max_tokens 调大一点"可能会让线上直接炸。

3. Batching:吞吐的灵丹妙药,也是 P99 的头号杀手

batching 的逻辑很简单:把多个请求合并,让 GPU 一次算更多。

但线上最常见的问题是:

  • 你 batch 越大,排队时间越长
  • 你为了吞吐把 queue 拉长,TTFT 变差,用户感觉"卡"

3.1 Continuous Batching(连续批处理)为什么是关键

传统 batching 是"凑齐一批再算",会导致等待。

Continuous batching(vLLM、SGLang 等框架里常见)是:

  • GPU 一直在跑
  • 新请求可以插进来
  • 旧请求生成完就退出

它的价值是:在不牺牲太多吞吐的情况下,显著改善 TTFT 和 P99。

3.2 一个最小可用的"队列 + 批处理"模拟(真实代码)

这段代码不是为了精准模拟 GPU,而是让你在白板上解释清楚:为什么 batch 会拖慢 P99。

python 复制代码
# batch_queue_sim.py
import random


def simulate(arrival_rate, service_ms, batch_size, duration_s=10):
    """Rough simulation: Poisson arrival + batch service."""
    now = 0.0
    end = duration_s * 1000
    queue = []
    latencies = []

    while now < end:
        inter_arrival = random.expovariate(arrival_rate / 1000.0)  # ms
        now += inter_arrival
        queue.append(now)

        while len(queue) >= batch_size:
            batch = [queue.pop(0) for _ in range(batch_size)]
            start = max(now, batch[0])
            finish = start + service_ms
            for t in batch:
                latencies.append(finish - t)
            now = finish

    if not latencies:
        return None

    latencies.sort()
    p50 = latencies[int(0.50 * len(latencies))]
    p95 = latencies[int(0.95 * len(latencies))]
    p99 = latencies[int(0.99 * len(latencies))]
    return p50, p95, p99, len(latencies)


if __name__ == '__main__':
    random.seed(42)
    for bs in [1, 2, 4, 8, 16]:
        r = simulate(arrival_rate=50, service_ms=40, batch_size=bs)
        print(f"batch={bs:>2} -> {r}")

你会看到一个趋势:batch 越大,吞吐上去了,但尾延迟会被排队拉长。

工程上真正要做的是:

  • 给 batching 一个最大等待时间(max wait / batching window)
  • 给交互式请求更高优先级(例如 chat vs batch job)

4. 你应该如何压测:别只看 QPS,至少看这 6 个指标

一个可落地的压测方式是:用一个脚本同时输出

  • request/s
  • tokens/s
  • TTFT
  • TPOT
  • P95/P99
  • GPU 利用率(sm%、mem%、显存占用)

4.1 一个可直接跑的压测客户端(Python + httpx)

假设你的服务是一个 OpenAI-compatible 的 /v1/chat/completions,支持 stream=true

python 复制代码
# loadgen.py
import asyncio
import time
import json
import statistics
import httpx

API_URL = "http://127.0.0.1:8000/v1/chat/completions"
MODEL = "your-model"

PROMPT = """你是一个严谨的工程师。请用 3 点总结 continuous batching 的优缺点,并给出一个线上调参建议。"""


def now_ms():
    return time.time() * 1000


async def one(client: httpx.AsyncClient, max_tokens=256):
    t0 = now_ms()
    ttft = None
    out_tokens = 0

    payload = {
        "model": MODEL,
        "stream": True,
        "max_tokens": max_tokens,
        "messages": [
            {"role": "user", "content": PROMPT}
        ],
    }

    async with client.stream("POST", API_URL, json=payload, timeout=120) as r:
        r.raise_for_status()
        async for line in r.aiter_lines():
            if not line:
                continue
            if line.startswith("data: "):
                data = line[len("data: "):]
                if data == "[DONE]":
                    break
                obj = json.loads(data)
                delta = obj["choices"][0]["delta"].get("content")
                if delta is not None:
                    if ttft is None:
                        ttft = now_ms() - t0
                    # rough token estimate by chars; replace with tokenizer in prod
                    out_tokens += max(1, len(delta) // 4)

    t1 = now_ms()
    total = t1 - t0
    return ttft or total, total, out_tokens


async def main(concurrency=10, seconds=30):
    ttfts, totals, toks = [], [], []

    async with httpx.AsyncClient() as client:
        start = time.time()

        async def worker():
            while time.time() - start < seconds:
                ttft, total, out = await one(client)
                ttfts.append(ttft)
                totals.append(total)
                toks.append(out)

        await asyncio.gather(*[worker() for _ in range(concurrency)])

    def p(xs, q):
        xs = sorted(xs)
        return xs[int(q * len(xs))]

    print(f"requests={len(totals)}")
    print(f"avg_total_ms={statistics.mean(totals):.1f} p95={p(totals,0.95):.1f} p99={p(totals,0.99):.1f}")
    print(f"avg_ttft_ms ={statistics.mean(ttfts):.1f} p95={p(ttfts,0.95):.1f} p99={p(ttfts,0.99):.1f}")
    print(f"tokens_total={sum(toks)} tokens/s={sum(toks)/seconds:.1f}")


if __name__ == '__main__':
    asyncio.run(main(concurrency=20, seconds=30))

这份脚本的价值在于:它会把 TTFT 单独拉出来,让你看到 batching/排队的真实代价。

5. vLLM / SGLang / TensorRT-LLM:工程选型时我会看什么

这里不做"文档复述",我只说上线会遇到的点:

5.1 你的瓶颈是算力还是调度?

  • 如果 GPU 算力吃满(SM 利用率高),但 tokens/s 仍不够:考虑量化、算子融合、TensorRT-LLM
  • 如果 SM 利用率不高,但延迟大:多半是调度/queue/IO/CPU 端瓶颈,先把 batching 和服务架构理顺

5.2 KV 管理策略

  • PagedAttention 这类方案能缓解碎片化,但不是免费午餐:会引入额外管理开销
  • 对长上下文,prefix caching / prompt cache(复用系统 prompt / 业务模板)往往比"无脑扩显存"更划算

5.3 多租户/多模型

一个现实问题:线上不是只有一个模型。

  • 多模型共享 GPU:调度更复杂,容易互相干扰
  • 多模型分 GPU:资源更浪费,但稳定

我更倾向的策略是:

  • 交互式主模型独占一组 GPU(保证 P99)
  • 批处理/离线模型用另一组 GPU(吞吐优先)
  • 需要弹性时,再做跨池迁移

6. 线上调优清单(我真正会按这个顺序做)

按优先级:

  1. 先把指标口径打通:TTFT、TPOT、queue wait、tokens/s
  2. 拆 prefill/decode:分别压测,不要用一个平均值糊弄
  3. 给 batching 加上上限:batching window + 最大并发
  4. 做显存预算:权重/kv/buffer,明确最大上下文与最大并发
  5. 把请求分类:交互式 vs 批处理,走不同队列/不同 GPU 池
  6. 再考虑框架/量化升级:否则你可能在错误的瓶颈上花 2 周

7. 结语:优化推理不是"换个框架",是把系统当系统看

推理加速的本质是:

  • 你在做一个有排队、有调度、有资源竞争的在线系统
  • LLM 只是其中最贵、最显眼的那个组件

当你把 TTFT/TPOT/queue wait 拆开看,把 KV cache 当成显存预算的一部分,把 batching 当成排队系统的一部分,很多"玄学"就会变成可解释、可调参、可复现。

如果你愿意进一步做工程化:

  • 把压测脚本接入 CI,做回归
  • 把线上参数变更纳入变更流程
  • 给 P99 配置 SLO + 自动扩缩容

你会发现:推理性能这件事,不再靠"某个同学经验很强",而是靠体系。

相关推荐
虎鲸不是鱼4 小时前
LM Studio使用MTP的qwen3.6-27B-以7840hs的780M为例
大模型·llm·qwen·lm studio·mtp
数据智能老司机5 小时前
领域专用小型语言模型——端到端 Transformer 微调
llm
风雨中的小七6 小时前
和AI一起搞事情#6. 如何实现图片文字元素编辑?
人工智能·llm
Komorebi_99996 小时前
LangChain Day2 课程:提示词模板 + Chain 链精讲
大模型·llm
程序员三明治6 小时前
【AI】Tika:一次文档解析引擎的工程实践
java·人工智能·大模型·llm·后端开发·rag·tika文件解析
冬奇Lab17 小时前
Agent系列(四):工具调用深度解析——Agent 的手和眼
人工智能·llm
冬奇Lab17 小时前
一天一个开源项目(第111篇):Understand Anything - 把代码库变成可探索知识图谱的 AI 引擎
人工智能·开源·llm
养肥胖虎18 小时前
完整学习LLM(四):Token是什么
大模型·llm·token·学习路线
qcx2321 小时前
【系统学AI】03 LLM训练全流程:预训练→SFT→对齐五条路线
人工智能·llm·sft·预训练·奖励模型·对齐·路线