一、FlashAttention:注意力计算的显存革命
1.1 标准注意力的显存问题
标准 Self-Attention 的计算过程是:Q×K^T → Softmax → ×V。中间会产生一个 N×N 的注意力矩阵(N 是序列长度)。序列长度为 4096 时,这个矩阵占 64MB(FP16),batch_size=32 时就是 2GB。对于大模型来说,光注意力矩阵就吃掉了大量显存。
更关键的是,这个 N×N 矩阵要写入 HBM 再读回来做 Softmax,一写一读浪费了大量带宽。
1.2 FlashAttention 的核心思想
FlashAttention 的思路是分块计算------不一次性算出完整的 N×N 矩阵,而是把 Q、K、V 切成小块,每次只在 SRAM 里算一小块注意力,算完就丢掉中间结果。这样做的好处有两个:
第一,显存占用从 O(N²) 降到 O(N)。不再需要存储完整的注意力矩阵,只需要存储当前块的中间结果。
第二,减少 HBM 访问次数。标准注意力需要对 N×N 矩阵做两次 HBM 访问(写入 Softmax 输入、读出 Softmax 输出),FlashAttention 把这些操作都在 SRAM 里完成了。
1.3 CANN 上的实现
python
import torch
import torch_npu
def standard_attention(Q, K, V, scale):
"""标准注意力实现
计算流程:
1. scores = Q @ K^T * scale → 产生 N×N 矩阵,写入 HBM
2. attn = softmax(scores) → 读 N×N,算完写回 HBM
3. output = attn @ V → 读 N×N 和 V,写输出
总 HBM 访问: 读 Q/K/V + 写 scores + 读 scores + 写 attn + 读 attn + 写 output
≈ 6 次大块 HBM 访问
"""
scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
attn = torch.softmax(scores, dim=-1)
output = torch.matmul(attn, V)
return output
def flash_attention_forward(Q, K, V, scale, block_size=128):
"""FlashAttention 前向传播
分块计算流程:
1. 将 Q 按 block_size 切块
2. 对每个 Q 块,遍历所有 K/V 块
3. 在 SRAM 里完成局部 Softmax 和加权求和
4. 用 online softmax 技巧增量更新输出
为什么用 online softmax?
标准 Softmax 需要知道所有元素才能算分母(sum of exp)。
online softmax 维护一个运行最大值和运行 sum,每加入新块就更新,
最终结果和标准 Softmax 完全一致。
"""
batch, seq_len, head_dim = Q.shape
num_blocks = (seq_len + block_size - 1) // block_size
# 输出张量和 softmax 的 running 统计量
output = torch.zeros_like(Q)
running_max = torch.full((batch, head_dim, 1), float('-inf'), device=Q.device)
running_sum = torch.zeros((batch, head_dim, 1), device=Q.device)
for i in range(num_blocks):
# 取出第 i 个 Q 块
q_start = i * block_size
q_end = min((i + 1) * block_size, seq_len)
Q_block = Q[:, q_start:q_end, :] # (batch, block_size, head_dim)
# 遍历所有 K/V 块
for j in range(num_blocks):
k_start = j * block_size
k_end = min((j + 1) * block_size, seq_len)
K_block = K[:, k_start:k_end, :]
V_block = V[:, k_start:k_end, :]
# 局部注意力分数
scores_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) * scale
# Online Softmax 更新
block_max = scores_block.max(dim=-1, keepdim=True).values
new_max = torch.max(running_max, block_max)
# 修正之前的累加结果
exp_correction = torch.exp(running_max - new_max)
block_correction = torch.exp(block_max - new_max)
running_sum = running_sum * exp_correction + \
torch.sum(torch.exp(scores_block - new_max), dim=-1, keepdim=True)
# 更新输出
output[:, q_start:q_end, :] = \
output[:, q_start:q_end, :] * exp_correction + \
torch.matmul(torch.exp(scores_block - new_max), V_block)
running_max = new_max
# 归一化
output = output / running_sum
return output
# 验证正确性
batch, seq_len, heads, dim = 2, 512, 8, 64
scale = dim ** -0.5
Q = torch.randn(batch, heads, seq_len, dim).npu()
K = torch.randn(batch, heads, seq_len, dim).npu()
V = torch.randn(batch, heads, seq_len, dim).npu()
out_std = standard_attention(Q, K, V, scale)
out_flash = flash_attention_forward(Q, K, V, scale)
max_diff = (out_std - out_flash).abs().max().item()
print(f"标准注意力 vs FlashAttention 最大差异: {max_diff:.6e}")
# 差异应该在 1e-5 量级,是浮点精度误差
1.4 性能对比分析
| 指标 | 标准注意力 | FlashAttention | 收益 |
|---|---|---|---|
| 显存占用 (seq=4096) | 64 MB | 4 KB | 降低 16000 倍 |
| HBM 访问次数 | 6 次 | 2 次 | 减少 67% |
| 实际延迟 (seq=2048) | 12 ms | 5 ms | 加速 2.4 倍 |
FlashAttention 的数学结果和标准注意力完全一致,差异只来自浮点精度。这意味着不需要修改任何模型代码,只需要替换注意力函数就能获得收益。
二、推测解码:打破自回归的串行瓶颈
2.1 自回归推理的问题
大模型生成文本是逐 token 的------生成第 t 个 token 时,必须等第 t-1 个 token 生成完。每生成一个 token 都要读取全部模型参数,但每次只算一个 token 的计算量。GPU/NPU 的并行能力完全用不上。
假设生成 100 个 token,串行执行需要 100 次完整的前向传播。如果每次前向传播耗时 20ms,总耗时 2 秒。
2.2 推测解码的思路
推测解码(Speculative Decoding)的核心思想是:用一个小模型快速"猜"多个 token,然后用大模型并行验证。
具体流程:
- 小模型(Draft Model)自回归生成 5 个 token(猜)
- 大模型(Target Model)一次前向传播验证这 5 个 token
- 从左到右找到第一个错误的位置,保留前面正确的 token
- 从错误位置开始重新猜测
如果小模型猜对了 3 个 token,那就一次前向传播得到了 3 个 token,相当于加速了 3 倍。
2.3 CANN 上的实现
python
import torch
import torch_npu
class SpeculativeDecoder:
"""推测解码器
参数:
- draft_model: 小模型(如 1.5B),速度快但精度低
- target_model: 大模型(如 70B),速度慢但精度高
- draft_length: 每次猜测的 token 数
为什么猜 5 个而不是更多?
- 猜太多,猜对的概率下降,验证浪费
- 猜太少,加速效果不明显
- 实验表明 5 是最优的平衡点
为什么能保证输出和纯大模型一致?
- 验证阶段用的是大模型的概率分布
- 如果小模型猜的 token 在大模型的概率分布下也被接受
- 那么结果就和大模型逐个生成完全一致
"""
def __init__(self, draft_model, target_model, draft_length=5):
self.draft = draft_model
self.target = target_model
self.draft_len = draft_length
@torch.no_grad()
def generate(self, prompt_ids, max_new_tokens=100):
"""推测解码生成
返回: token ids 列表,和纯大模型生成结果完全一致
"""
generated = list(prompt_ids)
tokens_generated = 0
while tokens_generated < max_new_tokens:
# Step 1: 小模型快速猜测 draft_length 个 token
draft_tokens, draft_probs = self._draft_generate(
generated, self.draft_len
)
# Step 2: 大模型并行验证所有猜测
target_probs = self._target_verify(
generated + draft_tokens
)
# Step 3: 从左到右检查,找到第一个被拒绝的位置
accepted_count = 0
for i in range(len(draft_tokens)):
# 接受概率 = min(1, target_prob / draft_prob)
t_prob = target_probs[len(prompt_ids) + accepted_count][draft_tokens[i]]
d_prob = draft_probs[i][draft_tokens[i]]
if torch.rand(1).item() < min(1.0, t_prob / (d_prob + 1e-10)):
accepted_count += 1
generated.append(draft_tokens[i])
else:
# 被拒绝,用大模型的分布采样一个 token
new_token = torch.multinomial(
target_probs[len(prompt_ids) + accepted_count], 1
).item()
generated.append(new_token)
accepted_count += 1
break
tokens_generated += accepted_count
if tokens_generated >= max_new_tokens:
break
return generated[:max_new_tokens + len(prompt_ids)]
def _draft_generate(self, context, num_tokens):
"""小模型自回归生成几个 token"""
tokens = []
probs = []
current = context.copy()
for _ in range(num_tokens):
input_ids = torch.tensor([current]).npu()
logits = self.draft(input_ids)[:, -1, :]
prob = torch.softmax(logits, dim=-1)
token = torch.multinomial(prob, 1).item()
tokens.append(token)
probs.append(prob[0].cpu())
current.append(token)
return tokens, probs
def _target_verify(self, context):
"""大模型一次前向传播,返回每个位置的概率分布"""
input_ids = torch.tensor([context]).npu()
logits = self.target(input_ids)[:, -len(context):, :]
return torch.softmax(logits, dim=-1)[0].cpu()
def benchmark_speculative_vs_standard(target_model, draft_model, prompt, num_tokens=50):
"""对比推测解码 vs 标准自回归的延迟"""
import time
# 标准自回归
decoder = SpeculativeDecoder(draft_model, target_model)
start = time.time()
# 模拟标准自回归: 每次生成 1 个 token
standard_output = prompt.copy()
for _ in range(num_tokens):
input_ids = torch.tensor([standard_output]).npu()
logits = target_model(input_ids)[:, -1, :]
token = torch.argmax(logits, dim=-1).item()
standard_output.append(token)
standard_time = time.time() - start
# 推测解码
start = time.time()
spec_output = decoder.generate(prompt, max_new_tokens=num_tokens)
spec_time = time.time() - start
speedup = standard_time / spec_time if spec_time > 0 else 0
print(f"标准自回归: {standard_time:.3f}s")
print(f"推测解码: {spec_time:.3f}s")
print(f"加速比: {speedup:.2f}x")
2.4 推测解码的适用条件
推测解码不是万能的。它最有效的场景是:大模型非常大(>30B),小模型足够快(比大模型快 5 倍以上),生成的文本有较强的可预测性(如代码补全、新闻摘要)。
如果大模型本身就很小,推测解码的验证开销占比太高,反而可能变慢。如果文本不可预测(如创意写作),小模型猜对率很低,加速效果也不好。
三、连续批处理:吞吐量的数量级提升
3.1 静态批处理的问题
传统批处理等所有请求凑够一个 batch 才开始执行。问题是:如果 batch 里有一个请求特别慢,其他请求都要等。这叫"木桶效应"------batch 延迟由最慢的请求决定。
3.2 连续批处理(Continuous Batching)
连续批处理允许在 batch 执行过程中动态插入和移除请求。一个请求生成完了,它的 NPU 槽位立刻给新请求用,不用等整个 batch 都完成。
时间轴:
静态批处理: [请求1, 请求2, 请求3] → 等待 → 等待 → 完成
连续批处理: [请求1, 请求2, 请求3]
↓ 请求2完成
[请求1, 请求4, 请求3]
↓ 请求1完成
[请求5, 请求4, 请求3]
3.3 CANN 上的实现
python
import time
import threading
from collections import deque
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class InferenceRequest:
"""推理请求"""
request_id: str
input_ids: list
max_new_tokens: int
created_at: float = field(default_factory=time.time)
generated_tokens: int = 0
output_ids: list = field(default_factory=list)
is_finished: bool = False
class ContinuousBatchScheduler:
"""连续批处理调度器
核心机制:
1. 请求槽位管理: 每个槽位独立运行一个请求
2. 动态替换: 请求完成后,槽位立即接收新请求
3. KV Cache 复用: 每个槽位的 KV Cache 独立管理
为什么连续批处理能提升吞吐?
- 静态批处理: 10 个请求,最慢的要 5 秒,总耗时 5 秒,吞吐 2 req/s
- 连续批处理: 10 个请求,平均 1 秒完成一个,总耗时 ~2 秒,吞吐 5 req/s
- 吞吐提升来自消除了"等最慢请求"的浪费
为什么连续批处理不增加延迟?
- 对于单个请求来说,它的执行路径和静态批处理完全一样
- 连续批处理只是让 NPU 不闲着,不改变单个请求的处理速度
"""
def __init__(self, max_batch_size=32):
self.max_batch_size = max_batch_size
self.active_slots = {} # slot_id → InferenceRequest
self.waiting_queue = deque()
self.lock = threading.Lock()
def submit(self, request: InferenceRequest):
"""提交请求"""
with self.lock:
if len(self.active_slots) < self.max_batch_size:
# 有空闲槽位,直接执行
slot_id = len(self.active_slots)
self.active_slots[slot_id] = request
print(f"请求 {request.request_id} 进入槽位 {slot_id}")
else:
# 没有空闲槽位,进入等待队列
self.waiting_queue.append(request)
print(f"请求 {request.request_id} 进入等待队列 (队列长度: {len(self.waiting_queue)})")
def on_request_complete(self, slot_id: int):
"""请求完成,槽位空出"""
with self.lock:
completed = self.active_slots.pop(slot_id)
completed.is_finished = True
print(f"请求 {completed.request_id} 完成 (生成 {completed.generated_tokens} tokens)")
# 从等待队列取新请求填充槽位
if self.waiting_queue:
new_request = self.waiting_queue.popleft()
self.active_slots[slot_id] = new_request
print(f"请求 {new_request.request_id} 进入槽位 {slot_id}")
def get_batch(self) -> list:
"""获取当前 batch 的所有请求"""
with self.lock:
return list(self.active_slots.values())
def get_stats(self) -> dict:
"""获取调度统计"""
with self.lock:
return {
'active_slots': len(self.active_slots),
'waiting_queue': len(self.waiting_queue),
'total_processed': sum(
1 for r in self.active_slots.values() if r.is_finished
),
}
def simulate_continuous_batching():
"""模拟连续批处理"""
scheduler = ContinuousBatchScheduler(max_batch_size=4)
# 提交 8 个请求
for i in range(8):
req = InferenceRequest(
request_id=f"req-{i:03d}",
input_ids=[100 + i] * 10,
max_new_tokens=20 + i * 5, # 不同长度,模拟真实场景
)
scheduler.submit(req)
# 模拟执行
print(f"\n初始状态: {scheduler.get_stats()}")
# 模拟请求完成(不同时间)
completion_order = [0, 2, 1, 3, 4, 5, 6, 7]
for slot in completion_order:
if slot in scheduler.active_slots:
scheduler.on_request_complete(slot)
print(f" 状态: {scheduler.get_stats()}")
print(f"\n最终状态: {scheduler.get_stats()}")
四、三个技术的组合收益
| 技术 | 优化维度 | 单独收益 | 组合收益 |
|---|---|---|---|
| FlashAttention | 显存 + 延迟 | 显存降 90%+,延迟降 50%+ | 与推测解码组合:更大 batch |
| 推测解码 | 单请求延迟 | 2-3 倍加速 | 与连续批处理组合:吞吐不降 |
| 连续批处理 | 整体吞吐 | 吞吐提升 2-5 倍 | 与 FlashAttention 组合:更大 batch |
实际生产中,三个技术通常同时使用。FlashAttention 腾出的显存让 batch 更大,连续批处理让 NPU 不空闲,推测解码让单个请求更快完成。
五、常见问题
| 问题 | 原因 | 解决方案 |
|---|---|---|
| FlashAttention 精度下降 | 不应该,可能是实现 bug | 检查 online softmax 的数值稳定性 |
| 推测解码变慢了 | 小模型太慢或猜对率太低 | 换更小的 draft model 或调整 draft_length |
| 连续批处理延迟不稳 | 等待队列太长 | 增加 NPU 数量或降低 batch size |
相关仓库
- CANN - 昇腾计算架构 https://gitee.com/ascend/cann
- FlashAttention - 高效注意力实现 https://github.com/Dao-AILab/flash-attention
- vLLM - 连续批处理推理 https://github.com/vllm-project/vllm
- Speculative Decoding - 推测解码论文 https://arxiv.org/abs/2211.17192