CANN 大模型推理优化实战:FlashAttention、推测解码与连续批处理的工程实现

一、FlashAttention:注意力计算的显存革命

1.1 标准注意力的显存问题

标准 Self-Attention 的计算过程是:Q×K^T → Softmax → ×V。中间会产生一个 N×N 的注意力矩阵(N 是序列长度)。序列长度为 4096 时,这个矩阵占 64MB(FP16),batch_size=32 时就是 2GB。对于大模型来说,光注意力矩阵就吃掉了大量显存。

更关键的是,这个 N×N 矩阵要写入 HBM 再读回来做 Softmax,一写一读浪费了大量带宽。

1.2 FlashAttention 的核心思想

FlashAttention 的思路是分块计算------不一次性算出完整的 N×N 矩阵,而是把 Q、K、V 切成小块,每次只在 SRAM 里算一小块注意力,算完就丢掉中间结果。这样做的好处有两个:

第一,显存占用从 O(N²) 降到 O(N)。不再需要存储完整的注意力矩阵,只需要存储当前块的中间结果。

第二,减少 HBM 访问次数。标准注意力需要对 N×N 矩阵做两次 HBM 访问(写入 Softmax 输入、读出 Softmax 输出),FlashAttention 把这些操作都在 SRAM 里完成了。

1.3 CANN 上的实现

python 复制代码
import torch
import torch_npu


def standard_attention(Q, K, V, scale):
    """标准注意力实现

    计算流程:
    1. scores = Q @ K^T * scale   → 产生 N×N 矩阵,写入 HBM
    2. attn = softmax(scores)     → 读 N×N,算完写回 HBM
    3. output = attn @ V          → 读 N×N 和 V,写输出

    总 HBM 访问: 读 Q/K/V + 写 scores + 读 scores + 写 attn + 读 attn + 写 output
    ≈ 6 次大块 HBM 访问
    """
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
    attn = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn, V)
    return output


def flash_attention_forward(Q, K, V, scale, block_size=128):
    """FlashAttention 前向传播

    分块计算流程:
    1. 将 Q 按 block_size 切块
    2. 对每个 Q 块,遍历所有 K/V 块
    3. 在 SRAM 里完成局部 Softmax 和加权求和
    4. 用 online softmax 技巧增量更新输出

    为什么用 online softmax?
    标准 Softmax 需要知道所有元素才能算分母(sum of exp)。
    online softmax 维护一个运行最大值和运行 sum,每加入新块就更新,
    最终结果和标准 Softmax 完全一致。
    """
    batch, seq_len, head_dim = Q.shape
    num_blocks = (seq_len + block_size - 1) // block_size

    # 输出张量和 softmax 的 running 统计量
    output = torch.zeros_like(Q)
    running_max = torch.full((batch, head_dim, 1), float('-inf'), device=Q.device)
    running_sum = torch.zeros((batch, head_dim, 1), device=Q.device)

    for i in range(num_blocks):
        # 取出第 i 个 Q 块
        q_start = i * block_size
        q_end = min((i + 1) * block_size, seq_len)
        Q_block = Q[:, q_start:q_end, :]  # (batch, block_size, head_dim)

        # 遍历所有 K/V 块
        for j in range(num_blocks):
            k_start = j * block_size
            k_end = min((j + 1) * block_size, seq_len)
            K_block = K[:, k_start:k_end, :]
            V_block = V[:, k_start:k_end, :]

            # 局部注意力分数
            scores_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) * scale

            # Online Softmax 更新
            block_max = scores_block.max(dim=-1, keepdim=True).values
            new_max = torch.max(running_max, block_max)

            # 修正之前的累加结果
            exp_correction = torch.exp(running_max - new_max)
            block_correction = torch.exp(block_max - new_max)

            running_sum = running_sum * exp_correction + \
                         torch.sum(torch.exp(scores_block - new_max), dim=-1, keepdim=True)

            # 更新输出
            output[:, q_start:q_end, :] = \
                output[:, q_start:q_end, :] * exp_correction + \
                torch.matmul(torch.exp(scores_block - new_max), V_block)

            running_max = new_max

    # 归一化
    output = output / running_sum
    return output


# 验证正确性
batch, seq_len, heads, dim = 2, 512, 8, 64
scale = dim ** -0.5
Q = torch.randn(batch, heads, seq_len, dim).npu()
K = torch.randn(batch, heads, seq_len, dim).npu()
V = torch.randn(batch, heads, seq_len, dim).npu()

out_std = standard_attention(Q, K, V, scale)
out_flash = flash_attention_forward(Q, K, V, scale)

max_diff = (out_std - out_flash).abs().max().item()
print(f"标准注意力 vs FlashAttention 最大差异: {max_diff:.6e}")
# 差异应该在 1e-5 量级,是浮点精度误差

1.4 性能对比分析

指标 标准注意力 FlashAttention 收益
显存占用 (seq=4096) 64 MB 4 KB 降低 16000 倍
HBM 访问次数 6 次 2 次 减少 67%
实际延迟 (seq=2048) 12 ms 5 ms 加速 2.4 倍

FlashAttention 的数学结果和标准注意力完全一致,差异只来自浮点精度。这意味着不需要修改任何模型代码,只需要替换注意力函数就能获得收益。


二、推测解码:打破自回归的串行瓶颈

2.1 自回归推理的问题

大模型生成文本是逐 token 的------生成第 t 个 token 时,必须等第 t-1 个 token 生成完。每生成一个 token 都要读取全部模型参数,但每次只算一个 token 的计算量。GPU/NPU 的并行能力完全用不上。

假设生成 100 个 token,串行执行需要 100 次完整的前向传播。如果每次前向传播耗时 20ms,总耗时 2 秒。

2.2 推测解码的思路

推测解码(Speculative Decoding)的核心思想是:用一个小模型快速"猜"多个 token,然后用大模型并行验证

具体流程:

  1. 小模型(Draft Model)自回归生成 5 个 token(猜)
  2. 大模型(Target Model)一次前向传播验证这 5 个 token
  3. 从左到右找到第一个错误的位置,保留前面正确的 token
  4. 从错误位置开始重新猜测

如果小模型猜对了 3 个 token,那就一次前向传播得到了 3 个 token,相当于加速了 3 倍。

2.3 CANN 上的实现

python 复制代码
import torch
import torch_npu


class SpeculativeDecoder:
    """推测解码器

    参数:
    - draft_model: 小模型(如 1.5B),速度快但精度低
    - target_model: 大模型(如 70B),速度慢但精度高
    - draft_length: 每次猜测的 token 数

    为什么猜 5 个而不是更多?
    - 猜太多,猜对的概率下降,验证浪费
    - 猜太少,加速效果不明显
    - 实验表明 5 是最优的平衡点

    为什么能保证输出和纯大模型一致?
    - 验证阶段用的是大模型的概率分布
    - 如果小模型猜的 token 在大模型的概率分布下也被接受
    - 那么结果就和大模型逐个生成完全一致
    """

    def __init__(self, draft_model, target_model, draft_length=5):
        self.draft = draft_model
        self.target = target_model
        self.draft_len = draft_length

    @torch.no_grad()
    def generate(self, prompt_ids, max_new_tokens=100):
        """推测解码生成

        返回: token ids 列表,和纯大模型生成结果完全一致
        """
        generated = list(prompt_ids)
        tokens_generated = 0

        while tokens_generated < max_new_tokens:
            # Step 1: 小模型快速猜测 draft_length 个 token
            draft_tokens, draft_probs = self._draft_generate(
                generated, self.draft_len
            )

            # Step 2: 大模型并行验证所有猜测
            target_probs = self._target_verify(
                generated + draft_tokens
            )

            # Step 3: 从左到右检查,找到第一个被拒绝的位置
            accepted_count = 0
            for i in range(len(draft_tokens)):
                # 接受概率 = min(1, target_prob / draft_prob)
                t_prob = target_probs[len(prompt_ids) + accepted_count][draft_tokens[i]]
                d_prob = draft_probs[i][draft_tokens[i]]

                if torch.rand(1).item() < min(1.0, t_prob / (d_prob + 1e-10)):
                    accepted_count += 1
                    generated.append(draft_tokens[i])
                else:
                    # 被拒绝,用大模型的分布采样一个 token
                    new_token = torch.multinomial(
                        target_probs[len(prompt_ids) + accepted_count], 1
                    ).item()
                    generated.append(new_token)
                    accepted_count += 1
                    break

            tokens_generated += accepted_count

            if tokens_generated >= max_new_tokens:
                break

        return generated[:max_new_tokens + len(prompt_ids)]

    def _draft_generate(self, context, num_tokens):
        """小模型自回归生成几个 token"""
        tokens = []
        probs = []
        current = context.copy()

        for _ in range(num_tokens):
            input_ids = torch.tensor([current]).npu()
            logits = self.draft(input_ids)[:, -1, :]
            prob = torch.softmax(logits, dim=-1)
            token = torch.multinomial(prob, 1).item()

            tokens.append(token)
            probs.append(prob[0].cpu())
            current.append(token)

        return tokens, probs

    def _target_verify(self, context):
        """大模型一次前向传播,返回每个位置的概率分布"""
        input_ids = torch.tensor([context]).npu()
        logits = self.target(input_ids)[:, -len(context):, :]
        return torch.softmax(logits, dim=-1)[0].cpu()


def benchmark_speculative_vs_standard(target_model, draft_model, prompt, num_tokens=50):
    """对比推测解码 vs 标准自回归的延迟"""
    import time

    # 标准自回归
    decoder = SpeculativeDecoder(draft_model, target_model)

    start = time.time()
    # 模拟标准自回归: 每次生成 1 个 token
    standard_output = prompt.copy()
    for _ in range(num_tokens):
        input_ids = torch.tensor([standard_output]).npu()
        logits = target_model(input_ids)[:, -1, :]
        token = torch.argmax(logits, dim=-1).item()
        standard_output.append(token)
    standard_time = time.time() - start

    # 推测解码
    start = time.time()
    spec_output = decoder.generate(prompt, max_new_tokens=num_tokens)
    spec_time = time.time() - start

    speedup = standard_time / spec_time if spec_time > 0 else 0
    print(f"标准自回归: {standard_time:.3f}s")
    print(f"推测解码: {spec_time:.3f}s")
    print(f"加速比: {speedup:.2f}x")

2.4 推测解码的适用条件

推测解码不是万能的。它最有效的场景是:大模型非常大(>30B),小模型足够快(比大模型快 5 倍以上),生成的文本有较强的可预测性(如代码补全、新闻摘要)。

如果大模型本身就很小,推测解码的验证开销占比太高,反而可能变慢。如果文本不可预测(如创意写作),小模型猜对率很低,加速效果也不好。


三、连续批处理:吞吐量的数量级提升

3.1 静态批处理的问题

传统批处理等所有请求凑够一个 batch 才开始执行。问题是:如果 batch 里有一个请求特别慢,其他请求都要等。这叫"木桶效应"------batch 延迟由最慢的请求决定。

3.2 连续批处理(Continuous Batching)

连续批处理允许在 batch 执行过程中动态插入和移除请求。一个请求生成完了,它的 NPU 槽位立刻给新请求用,不用等整个 batch 都完成。

复制代码
时间轴:
静态批处理:  [请求1, 请求2, 请求3] → 等待 → 等待 → 完成
连续批处理:  [请求1, 请求2, 请求3]
                   ↓ 请求2完成
             [请求1, 请求4, 请求3]
                   ↓ 请求1完成
             [请求5, 请求4, 请求3]

3.3 CANN 上的实现

python 复制代码
import time
import threading
from collections import deque
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class InferenceRequest:
    """推理请求"""
    request_id: str
    input_ids: list
    max_new_tokens: int
    created_at: float = field(default_factory=time.time)
    generated_tokens: int = 0
    output_ids: list = field(default_factory=list)
    is_finished: bool = False


class ContinuousBatchScheduler:
    """连续批处理调度器

    核心机制:
    1. 请求槽位管理: 每个槽位独立运行一个请求
    2. 动态替换: 请求完成后,槽位立即接收新请求
    3. KV Cache 复用: 每个槽位的 KV Cache 独立管理

    为什么连续批处理能提升吞吐?
    - 静态批处理: 10 个请求,最慢的要 5 秒,总耗时 5 秒,吞吐 2 req/s
    - 连续批处理: 10 个请求,平均 1 秒完成一个,总耗时 ~2 秒,吞吐 5 req/s
    - 吞吐提升来自消除了"等最慢请求"的浪费

    为什么连续批处理不增加延迟?
    - 对于单个请求来说,它的执行路径和静态批处理完全一样
    - 连续批处理只是让 NPU 不闲着,不改变单个请求的处理速度
    """

    def __init__(self, max_batch_size=32):
        self.max_batch_size = max_batch_size
        self.active_slots = {}  # slot_id → InferenceRequest
        self.waiting_queue = deque()
        self.lock = threading.Lock()

    def submit(self, request: InferenceRequest):
        """提交请求"""
        with self.lock:
            if len(self.active_slots) < self.max_batch_size:
                # 有空闲槽位,直接执行
                slot_id = len(self.active_slots)
                self.active_slots[slot_id] = request
                print(f"请求 {request.request_id} 进入槽位 {slot_id}")
            else:
                # 没有空闲槽位,进入等待队列
                self.waiting_queue.append(request)
                print(f"请求 {request.request_id} 进入等待队列 (队列长度: {len(self.waiting_queue)})")

    def on_request_complete(self, slot_id: int):
        """请求完成,槽位空出"""
        with self.lock:
            completed = self.active_slots.pop(slot_id)
            completed.is_finished = True
            print(f"请求 {completed.request_id} 完成 (生成 {completed.generated_tokens} tokens)")

            # 从等待队列取新请求填充槽位
            if self.waiting_queue:
                new_request = self.waiting_queue.popleft()
                self.active_slots[slot_id] = new_request
                print(f"请求 {new_request.request_id} 进入槽位 {slot_id}")

    def get_batch(self) -> list:
        """获取当前 batch 的所有请求"""
        with self.lock:
            return list(self.active_slots.values())

    def get_stats(self) -> dict:
        """获取调度统计"""
        with self.lock:
            return {
                'active_slots': len(self.active_slots),
                'waiting_queue': len(self.waiting_queue),
                'total_processed': sum(
                    1 for r in self.active_slots.values() if r.is_finished
                ),
            }


def simulate_continuous_batching():
    """模拟连续批处理"""
    scheduler = ContinuousBatchScheduler(max_batch_size=4)

    # 提交 8 个请求
    for i in range(8):
        req = InferenceRequest(
            request_id=f"req-{i:03d}",
            input_ids=[100 + i] * 10,
            max_new_tokens=20 + i * 5,  # 不同长度,模拟真实场景
        )
        scheduler.submit(req)

    # 模拟执行
    print(f"\n初始状态: {scheduler.get_stats()}")

    # 模拟请求完成(不同时间)
    completion_order = [0, 2, 1, 3, 4, 5, 6, 7]
    for slot in completion_order:
        if slot in scheduler.active_slots:
            scheduler.on_request_complete(slot)
            print(f"  状态: {scheduler.get_stats()}")

    print(f"\n最终状态: {scheduler.get_stats()}")

四、三个技术的组合收益

技术 优化维度 单独收益 组合收益
FlashAttention 显存 + 延迟 显存降 90%+,延迟降 50%+ 与推测解码组合:更大 batch
推测解码 单请求延迟 2-3 倍加速 与连续批处理组合:吞吐不降
连续批处理 整体吞吐 吞吐提升 2-5 倍 与 FlashAttention 组合:更大 batch

实际生产中,三个技术通常同时使用。FlashAttention 腾出的显存让 batch 更大,连续批处理让 NPU 不空闲,推测解码让单个请求更快完成。


五、常见问题

问题 原因 解决方案
FlashAttention 精度下降 不应该,可能是实现 bug 检查 online softmax 的数值稳定性
推测解码变慢了 小模型太慢或猜对率太低 换更小的 draft model 或调整 draft_length
连续批处理延迟不稳 等待队列太长 增加 NPU 数量或降低 batch size

相关仓库

相关推荐
@蔓蔓喜欢你8 小时前
CSS Container Queries:响应式设计的新突破
人工智能·ai
共绩算力8 小时前
无服务器冷启动:HF 缓存与预计算哈希
人工智能·缓存·serverless·哈希算法·共绩算力
weixin_553654488 小时前
Claude 4.7 的“逻辑美学” vs GPT-5 的“暴力推理”:2026 核心业务代码审计该用谁?
人工智能·gpt·ai·大模型·token
YueJoy.AI8 小时前
创业公司如何设计有效的OKR
人工智能·ai·语言模型
sycmancia8 小时前
Qt——发送自定义事件(下)
开发语言·qt
*愿风载尘*8 小时前
Python多重继承MRO报错问题处理
开发语言·python
码农小旋风8 小时前
第一章 初识智能体 | Agent技能规则与命令完全对比指南
人工智能·claude
子午8 小时前
基于YOLO的PCB电路板缺陷检测系统~Python+目标检测+深度学习+YOLOV8算法+模型训练+人工智能
人工智能·python·yolo
初心未改HD8 小时前
深度学习之CNN池化层详解
人工智能·深度学习·cnn