本文深入讲解现代大语言模型的核心优化技术,包括KV Cache自回归加速、Multi-Query Attention(MQA)、Grouped-Query Attention(GQA)、稀疏注意力(Sparse Attention)和混合专家模型(Mixture of Experts, MoE)。通过数学原理、架构对比和PyTorch代码实现,帮助读者理解GPT-4、LLaMA、Mixtral等顶级模型的技术细节,掌握LLM推理加速与显存优化的工程实践。
一、为什么需要优化Transformer?
1.1 原始Transformer的性能瓶颈
自回归逐词生成
大量重复计算"] P2["💾 显存占用高
KV矩阵随序列长度增长
多头存储冗余"] P3["📏 序列长度受限
O(n²)复杂度
长文本处理困难"] end subgraph 解决方案 S1["✅ KV Cache
缓存已计算的KV"] S2["✅ MQA/GQA
共享KV降低显存"] S3["✅ Sparse Attention
稀疏注意力模式"] end P1 --> S1 P2 --> S2 P3 --> S3 style P1 fill:#ffcdd2 style P2 fill:#ffccbc style P3 fill:#ffab91 style S1 fill:#a5d6a7 style S2 fill:#81c784 style S3 fill:#66bb6a
1.2 现代LLM采用的优化技术
| 模型 | KV Cache | MQA/GQA | Sparse Attn | MoE | 上下文长度 |
|---|---|---|---|---|---|
| GPT-3 | ✅ | ❌ | ❌ | ❌ | 2K |
| LLaMA | ✅ | ❌ | ❌ | ❌ | 4K |
| LLaMA2 | ✅ | ✅ GQA | ❌ | ❌ | 4K |
| GPT-4 | ✅ | ✅ | 部分 | 推测✅ | 32K/128K |
| Mixtral 8x7B | ✅ | ✅ GQA | ❌ | ✅ | 32K |
| Claude 3 | ✅ | ✅ | ✅ | ? | 200K |
二、KV Cache:自回归加速的核心技术
2.1 自回归生成的重复计算问题
场景:GPT模型生成"我爱学习AI"
问题分析:
- 生成第1个词:计算1次KV
- 生成第2个词:计算2次KV(1次重复)
- 生成第3个词:计算3次KV(2次重复)
- 生成第n个词:计算n次KV(n-1次重复)
总计算量 : <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 + 2 + 3 + . . . + n = n ( n + 1 ) 2 = O ( n 2 ) 1 + 2 + 3 + ... + n = \frac{n(n+1)}{2} = O(n^2) </math>1+2+3+...+n=2n(n+1)=O(n2)
2.2 KV Cache的工作原理
核心思想:缓存已经计算过的Key和Value矩阵,新token只需计算自己的KV。
计算: [START]"] S2["Step 2
计算: [START, 我]
❌ 重复计算START"] S3["Step 3
计算: [START, 我, 爱]
❌ 重复计算START,我"] end subgraph 有Cache[With KV Cache] C1["Step 1
计算&缓存: [START]"] C2["Step 2
✅ 读取: [START]
计算&缓存: [我]"] C3["Step 3
✅ 读取: [START, 我]
计算&缓存: [爱]"] end S1 --> S2 --> S3 C1 --> C2 --> C3 style S2 fill:#ffcdd2 style S3 fill:#ffcdd2 style C2 fill:#a5d6a7 style C3 fill:#a5d6a7
加速效果:
- 无Cache : <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n 2 ) O(n^2) </math>O(n2) 计算
- 有Cache : <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n ) O(n) </math>O(n) 计算
- 加速比: 生成100个token,加速约50倍!
2.3 KV Cache数学原理
标准Attention:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Attention ( Q t , K 1 : t , V 1 : t ) = softmax ( Q t K 1 : t T d k ) V 1 : t \text{Attention}(Q_t, K_{1:t}, V_{1:t}) = \text{softmax}\left(\frac{Q_t K_{1:t}^T}{\sqrt{d_k}}\right)V_{1:t} </math>Attention(Qt,K1:t,V1:t)=softmax(dk QtK1:tT)V1:t
在第 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t步:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> Q t Q_t </math>Qt: 当前token的Query (新计算)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> K 1 : t K_{1:t} </math>K1:t: 所有历史token的Key (1到t-1从缓存读取,t新计算)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> V 1 : t V_{1:t} </math>V1:t: 所有历史token的Value (同上)
缓存更新:
python
# Pseudo-code
cache_K = [] # 初始化KV缓存
cache_V = []
for t in range(max_len):
# 1. 计算当前token的KV
k_t = compute_key(x_t)
v_t = compute_value(x_t)
# 2. 追加到缓存
cache_K.append(k_t)
cache_V.append(v_t)
# 3. 使用全部缓存计算注意力
q_t = compute_query(x_t)
attention = softmax(q_t @ cache_K.T) @ cache_V
2.4 PyTorch实现
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttentionWithCache(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x, cache=None, use_cache=False):
"""
参数:
x: [batch_size, seq_len, d_model]
cache: {'key': [batch, n_heads, past_len, d_k],
'value': [batch, n_heads, past_len, d_k]}
use_cache: 是否返回更新后的cache
"""
batch_size, seq_len, _ = x.size()
# 1. 计算当前输入的QKV
Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_V(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# 2. 如果有cache,拼接历史KV
if cache is not None:
K = torch.cat([cache['key'], K], dim=2) # 拼接到seq_len维度
V = torch.cat([cache['value'], V], dim=2)
# 3. 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 4. 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
output = self.W_O(attn_output)
# 5. 更新cache
if use_cache:
new_cache = {'key': K, 'value': V}
return output, new_cache
return output
# 使用示例:模拟自回归生成
d_model = 512
n_heads = 8
max_len = 10
mha = MultiHeadAttentionWithCache(d_model, n_heads)
# 初始化
cache = None
all_outputs = []
for t in range(max_len):
# 当前token (实际中是上一步的输出)
current_token = torch.randn(1, 1, d_model) # [batch=1, seq_len=1, d_model]
# 前向传播 with cache
output, cache = mha(current_token, cache=cache, use_cache=True)
all_outputs.append(output)
print(f"Step {t+1}:")
print(f" Cache K shape: {cache['key'].shape}")
print(f" Cache V shape: {cache['value'].shape}")
# 输出示例:
# Step 1:
# Cache K shape: torch.Size([1, 8, 1, 64])
# Cache V shape: torch.Size([1, 8, 1, 64])
# Step 2:
# Cache K shape: torch.Size([1, 8, 2, 64]) ← 长度递增
# Cache V shape: torch.Size([1, 8, 2, 64])
# ...
2.5 KV Cache的显存成本
分析:对于单个样本
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> KV Cache Size = 2 × n_layers × n_heads × seq_len × d_k × sizeof(dtype) \text{KV Cache Size} = 2 \times \text{n\_layers} \times \text{n\_heads} \times \text{seq\_len} \times \text{d\_k} \times \text{sizeof(dtype)} </math>KV Cache Size=2×n_layers×n_heads×seq_len×d_k×sizeof(dtype)
示例:LLaMA2-7B
- n_layers = 32
- n_heads = 32
- seq_len = 4096
- d_k = 128
- dtype = float16 (2 bytes)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> KV Cache = 2 × 32 × 32 × 4096 × 128 × 2 = 2.1 GB \text{KV Cache} = 2 \times 32 \times 32 \times 4096 \times 128 \times 2 = 2.1 \text{GB} </math>KV Cache=2×32×32×4096×128×2=2.1GB
单个序列就需要2GB显存! 这就是为什么需要MQA/GQA优化。
三、Multi-Query Attention(MQA):共享KV的激进方案
3.1 MQA的动机
问题:在多头注意力中,每个头都有独立的KV矩阵,造成显存冗余。
核心思想:所有注意力头共享同一组Key和Value,只有Query独立。
3.2 MQA数学公式
标准MHA:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) </math>headi=Attention(QWiQ,KWiK,VWiV)
MQA:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> head i = Attention ( Q W i Q , K W K , V W V ) \text{head}_i = \text{Attention}(QW_i^Q, KW^K, VW^V) </math>headi=Attention(QWiQ,KWK,VWV)
注意: <math xmlns="http://www.w3.org/1998/Math/MathML"> W K , W V W^K, W^V </math>WK,WV 在所有头之间共享。
3.3 显存节省计算
参数量对比:
| 配置 | MHA | MQA | 节省 |
|---|---|---|---|
| Q权重 | <math xmlns="http://www.w3.org/1998/Math/MathML"> h × d m o d e l × d k h \times d_{model} \times d_k </math>h×dmodel×dk | <math xmlns="http://www.w3.org/1998/Math/MathML"> h × d m o d e l × d k h \times d_{model} \times d_k </math>h×dmodel×dk | 0 |
| K权重 | <math xmlns="http://www.w3.org/1998/Math/MathML"> h × d m o d e l × d k h \times d_{model} \times d_k </math>h×dmodel×dk | <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l × d k d_{model} \times d_k </math>dmodel×dk | <math xmlns="http://www.w3.org/1998/Math/MathML"> ( h − 1 ) / h × 100 % (h-1)/h \times 100\% </math>(h−1)/h×100% |
| V权重 | <math xmlns="http://www.w3.org/1998/Math/MathML"> h × d m o d e l × d k h \times d_{model} \times d_k </math>h×dmodel×dk | <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l × d k d_{model} \times d_k </math>dmodel×dk | <math xmlns="http://www.w3.org/1998/Math/MathML"> ( h − 1 ) / h × 100 % (h-1)/h \times 100\% </math>(h−1)/h×100% |
示例(h=32):
- MHA KV缓存: 2.1 GB
- MQA KV缓存: 2.1/32 = 66 MB (节省96.9%!)
3.4 PyTorch实现
python
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 每个头独立的Query
self.W_Q = nn.Linear(d_model, d_model)
# 共享的Key和Value
self.W_K = nn.Linear(d_model, self.d_k) # 注意维度!
self.W_V = nn.Linear(d_model, self.d_k)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# 1. 计算多头Query
Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
Q = Q.transpose(1, 2) # [batch, n_heads, seq_len, d_k]
# 2. 计算共享的K和V
K = self.W_K(x) # [batch, seq_len, d_k]
V = self.W_V(x) # [batch, seq_len, d_k]
# 扩展到所有头(通过broadcast)
K = K.unsqueeze(1) # [batch, 1, seq_len, d_k]
V = V.unsqueeze(1)
# 3. 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# [batch, n_heads, seq_len, d_k]
# 4. 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
output = self.W_O(attn_output)
return output
# 对比参数量
d_model = 512
n_heads = 8
mha = MultiHeadAttention(d_model, n_heads)
mqa = MultiQueryAttention(d_model, n_heads)
print(f"MHA 参数量: {sum(p.numel() for p in mha.parameters())}")
print(f"MQA 参数量: {sum(p.numel() for p in mqa.parameters())}")
# MHA 参数量: 1,050,624
# MQA 参数量: 820,224 (节省22%)
3.5 MQA的缺点
实验数据(PaLM论文):
- 推理速度: 提升1.5-2x
- 模型质量: 下降约3-5%
四、Grouped-Query Attention(GQA):MHA与MQA的平衡
4.1 GQA的设计哲学
核心思想:将多个Query头分组,每组共享一对KV。
K1,V1 K2,V2 ... Kh,Vh"] end subgraph GQA[Grouped-Query: g组共享KV] GQA_Group1["组1: Head1,2,3,4
共享 K1,V1"] GQA_Group2["组2: Head5,6,7,8
共享 K2,V2"] end subgraph MQA[Multi-Query: 1组共享KV] MQA_All["所有Head
共享 K,V"] end MHA -.折中方案.-> GQA GQA -.极端情况.-> MQA style MHA fill:#ffccbc style GQA fill:#fff9c4 style MQA fill:#a5d6a7
4.2 GQA配置
数学关系:
- Query头数: <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h (如32)
- KV组数: <math xmlns="http://www.w3.org/1998/Math/MathML"> g g </math>g (如4或8)
- 每组Query数: <math xmlns="http://www.w3.org/1998/Math/MathML"> h / g h/g </math>h/g
常见配置:
| 模型 | Query头数 | KV组数 | 每组头数 | 显存节省 |
|---|---|---|---|---|
| LLaMA2-7B | 32 | 8 | 4 | 75% |
| LLaMA2-13B | 40 | 5 | 8 | 87.5% |
| LLaMA2-70B | 64 | 8 | 8 | 87.5% |
| Mixtral 8x7B | 32 | 8 | 4 | 75% |
4.3 GQA架构图
[h个头]"] Linear --> K["Key
[g组]"] Linear --> V["Value
[g组]"] subgraph 组1 Q1["Q头1-4"] --> Attn1["注意力计算"] K1["K1"] --> Attn1 V1["V1"] --> Attn1 end subgraph 组2 Q2["Q头5-8"] --> Attn2["注意力计算"] K2["K2"] --> Attn2 V2["V2"] --> Attn2 end Attn1 --> Concat["拼接"] Attn2 --> Concat Concat --> Output["输出"] style Q fill:#fff9c4 style K fill:#a5d6a7 style V fill:#81c784 style Output fill:#c5e1a5
4.4 PyTorch实现
python
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, n_heads, n_kv_groups):
"""
参数:
d_model: 模型维度(如4096)
n_heads: Query头数(如32)
n_kv_groups: KV组数(如8)
"""
super().__init__()
assert n_heads % n_kv_groups == 0, "n_heads必须能被n_kv_groups整除"
self.d_model = d_model
self.n_heads = n_heads
self.n_kv_groups = n_kv_groups
self.n_heads_per_group = n_heads // n_kv_groups
self.d_k = d_model // n_heads
# Query: 每个头独立
self.W_Q = nn.Linear(d_model, d_model)
# Key & Value: 每组一个
self.W_K = nn.Linear(d_model, n_kv_groups * self.d_k)
self.W_V = nn.Linear(d_model, n_kv_groups * self.d_k)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# 1. 计算Q (所有头)
Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
Q = Q.transpose(1, 2) # [batch, n_heads, seq_len, d_k]
# 2. 计算K, V (每组一个)
K = self.W_K(x).view(batch_size, seq_len, self.n_kv_groups, self.d_k)
V = self.W_V(x).view(batch_size, seq_len, self.n_kv_groups, self.d_k)
K = K.transpose(1, 2) # [batch, n_kv_groups, seq_len, d_k]
V = V.transpose(1, 2)
# 3. 将KV复制到每组内的所有头
K = K.repeat_interleave(self.n_heads_per_group, dim=1)
V = V.repeat_interleave(self.n_heads_per_group, dim=1)
# 现在 K, V: [batch, n_heads, seq_len, d_k]
# 4. 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 5. 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
output = self.W_O(attn_output)
return output
# 使用示例
d_model = 4096
n_heads = 32
n_kv_groups = 8 # LLaMA2-7B配置
gqa = GroupedQueryAttention(d_model, n_heads, n_kv_groups)
x = torch.randn(2, 10, d_model)
output = gqa(x)
print(f"输入: {x.shape}") # torch.Size([2, 10, 4096])
print(f"输出: {output.shape}") # torch.Size([2, 10, 4096])
4.5 MHA vs MQA vs GQA对比
(困惑度 Perplexity)"] Speed["推理速度
(tokens/sec)"] Memory["显存占用
(GB)"] end subgraph MHA评分 Q_MHA["最好 ⭐⭐⭐⭐⭐"] S_MHA["最慢 ⭐⭐"] M_MHA["最高 ⭐"] end subgraph GQA评分 Q_GQA["接近MHA ⭐⭐⭐⭐"] S_GQA["较快 ⭐⭐⭐⭐"] M_GQA["适中 ⭐⭐⭐"] end subgraph MQA评分 Q_MQA["略低 ⭐⭐⭐"] S_MQA["最快 ⭐⭐⭐⭐⭐"] M_MQA["最低 ⭐⭐⭐⭐⭐"] end Quality --> Q_MHA Quality --> Q_GQA Quality --> Q_MQA Speed --> S_MHA Speed --> S_GQA Speed --> S_MQA Memory --> M_MHA Memory --> M_GQA Memory --> M_MQA style Q_GQA fill:#fff59d style S_GQA fill:#fff59d style M_GQA fill:#fff59d
实验数据(LLaMA2论文):
- 质量: GQA-8 几乎等同于 MHA
- 速度 : GQA-8 比 MHA 快 1.3x
- 显存 : GQA-8 节省 75% KV缓存
五、稀疏注意力(Sparse Attention)
5.1 长序列的注意力复杂度问题
标准Attention的瓶颈:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 复杂度 = O ( n 2 d ) \text{复杂度} = O(n^2 d) </math>复杂度=O(n2d)
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 是序列长度, <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 是维度。
Claude 3处理200K上下文需要什么?
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 200 K 2 = 40 billion operations per layer! 200K^2 = 40 \text{billion operations per layer!} </math>200K2=40billion operations per layer!
5.2 稀疏注意力模式
核心思想:不是所有token都需要关注所有其他token。
关注所有token"] end subgraph Sparse[稀疏注意力] S1["局部注意力
Sliding Window"] S2["全局注意力
Global Tokens"] S3["随机注意力
Random Sampling"] S4["分块注意力
Blocked"] end Full -.优化.-> Sparse style Full fill:#ffcdd2 style S1 fill:#a5d6a7 style S2 fill:#81c784 style S3 fill:#66bb6a style S4 fill:#4caf50
5.3 常见稀疏注意力模式
(1) Sliding Window Attention
思想:每个token只关注前后固定窗口内的token。
复杂度 : <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n × w ) O(n \times w) </math>O(n×w),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> w w </math>w 是窗口大小(如512)
实现:
python
def sliding_window_mask(seq_len, window_size):
"""
生成滑动窗口mask
"""
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
start = max(0, i - window_size)
end = min(seq_len, i + window_size + 1)
mask[i, start:end] = 1
return mask
# 示例
mask = sliding_window_mask(10, window_size=2)
print(mask)
# tensor([[1., 1., 1., 0., 0., ...],
# [1., 1., 1., 1., 0., ...],
# [1., 1., 1., 1., 1., ...],
# ...])
(2) Global + Local Attention(Longformer模式)
思想:少数全局token关注所有,大部分token只做局部关注。
关注所有token"] end subgraph 局部Token L["普通token
只关注窗口内"] end G -.全注意力.-> All["全部序列"] L -.局部.-> Window["小窗口"] style G fill:#ffeb3b style L fill:#90caf9
实现:
python
def longformer_mask(seq_len, window_size, global_indices):
"""
Longformer注意力mask
global_indices: 全局token的位置(如[0, 1])
"""
# 基础:滑动窗口
mask = sliding_window_mask(seq_len, window_size)
# 全局token可以关注所有
for idx in global_indices:
mask[idx, :] = 1 # 该行全1
mask[:, idx] = 1 # 该列全1
return mask
(3) Sparse Transformer (分块注意力)
思想:将序列分块,块内全注意力,块间稀疏连接。
5.4 FlashAttention: IO优化而非稀疏化
特殊说明:FlashAttention不改变注意力模式,而是优化GPU内存访问。
写入HBM"] Step2["2. 读取,Softmax
写回HBM"] Step3["3. 读取,乘V
写回HBM"] end subgraph FlashAttn[FlashAttention] Fused["分块计算
全程在SRAM
减少HBM访问"] end Step1 --> Step2 --> Step3 style Step1 fill:#ffccbc style Step2 fill:#ffccbc style Step3 fill:#ffccbc style Fused fill:#a5d6a7
加速效果:
- 训练: 快2-4x
- 长序列: 支持64K+上下文
六、混合专家模型(Mixture of Experts, MoE)
6.1 MoE的核心思想
问题:大模型参数多,但每次前向传播只需要激活部分参数。
决策选择专家"] Router -->|20%概率| E1["专家1
数学推理"] Router -->|5%概率| E2["专家2
代码生成"] Router -->|60%概率| E3["专家3
通用知识"] Router -->|10%概率| E4["专家4
创意写作"] Router -->|5%概率| En["专家N
..."] E1 --> Combine["加权组合"] E2 --> Combine E3 --> Combine E4 --> Combine En --> Combine Combine --> Output["输出"] style Router fill:#fff59d style E3 fill:#a5d6a7 style Combine fill:#90caf9
关键特点:
- 稀疏激活:每个token只激活Top-K个专家(如K=2)
- 参数共享:总参数量大,但实际计算量接近小模型
- 专业化:不同专家学习不同领域知识
6.2 MoE架构
Gating"] subgraph MoE层 Router -->|权重w1| Expert1["FFN 专家1"] Router -->|权重w2| Expert2["FFN 专家2"] Router -->|权重0| Expert3["FFN 专家3
未激活"] Router -->|权重0| ExpertN["FFN 专家N
未激活"] end Expert1 --> Sum["加权求和
w1·E1 + w2·E2"] Expert2 --> Sum Sum --> Norm2["LayerNorm"] Norm2 --> Output["输出"] style Router fill:#fff59d style Expert1 fill:#a5d6a7 style Expert2 fill:#81c784 style Expert3 fill:#e0e0e0 style ExpertN fill:#e0e0e0
6.3 路由机制
Softmax路由:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> G ( x ) = Softmax ( x ⋅ W g ) G(x) = \text{Softmax}(x \cdot W_g) </math>G(x)=Softmax(x⋅Wg)
Top-K选择:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Output = ∑ i ∈ TopK ( G ( x ) ) G ( x ) i ⋅ E i ( x ) \text{Output} = \sum_{i \in \text{TopK}(G(x))} G(x)_i \cdot E_i(x) </math>Output=i∈TopK(G(x))∑G(x)i⋅Ei(x)
PyTorch实现:
python
class MoELayer(nn.Module):
def __init__(self, d_model, d_ff, num_experts, top_k=2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# 路由网络
self.gate = nn.Linear(d_model, num_experts)
# 专家网络(FFN)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
for _ in range(num_experts)
])
def forward(self, x):
"""
x: [batch_size, seq_len, d_model]
"""
batch_size, seq_len, d_model = x.size()
# 1. 路由打分
gate_logits = self.gate(x) # [batch, seq_len, num_experts]
# 2. 选择Top-K专家
top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
# top_k_indices: [batch, seq_len, top_k]
# 3. Softmax归一化(只在Top-K上)
top_k_gates = F.softmax(top_k_logits, dim=-1)
# [batch, seq_len, top_k]
# 4. 计算专家输出并加权求和
output = torch.zeros_like(x)
for k in range(self.top_k):
# 获取当前专家索引
expert_idx = top_k_indices[:, :, k] # [batch, seq_len]
gate_weight = top_k_gates[:, :, k] # [batch, seq_len]
# 批量处理(简化版,实际中需要更高效的实现)
for i in range(self.num_experts):
mask = (expert_idx == i) # [batch, seq_len]
if mask.any():
expert_output = self.experts[i](x)
output += expert_output * gate_weight.unsqueeze(-1) * mask.unsqueeze(-1)
return output
# 使用示例
d_model = 512
d_ff = 2048
num_experts = 8
top_k = 2
moe = MoELayer(d_model, d_ff, num_experts, top_k)
x = torch.randn(2, 10, d_model)
output = moe(x)
print(f"输入: {x.shape}") # torch.Size([2, 10, 512])
print(f"输出: {output.shape}") # torch.Size([2, 10, 512])
6.4 实际案例:Mixtral 8x7B
架构特点:
- 8个专家,每个7B参数
- Top-2路由:每个token激活2个专家
- 总参数: 47B (8×7B,但共享attention)
- 激活参数: 13B (相当于13B模型的计算量)
性能数据:
- 数学推理: 优于LLaMA2-70B
- 代码生成: 接近GPT-3.5
- 推理速度: 比70B快5x+
6.5 MoE的挑战
| 挑战 | 说明 | 解决方案 |
|---|---|---|
| 负载均衡 | 某些专家被过度使用 | 添加辅助损失函数 |
| 通信开销 | 分布式训练时专家在不同GPU | 专家并行策略 |
| 泛化性 | 专家过度专业化 | 正则化技术 |
负载均衡损失:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L b a l a n c e = α ⋅ CV ( expert_usage ) L_{balance} = \alpha \cdot \text{CV}(\text{expert\_usage}) </math>Lbalance=α⋅CV(expert_usage)
其中 CV 是变异系数,鼓励专家使用均匀。
七、技术对比与选择指南
7.1 综合对比表
| 技术 | 加速比 | 显存节省 | 质量损失 | 实现难度 | 适用场景 |
|---|---|---|---|---|---|
| KV Cache | 50x+ | 0% | 0% | ⭐ | 所有自回归模型(必备) |
| MQA | 2x | 96% | 3-5% | ⭐⭐ | 极致推理速度场景 |
| GQA | 1.3x | 75% | <1% | ⭐⭐ | 推荐,平衡方案 |
| Sparse Attn | 10x+ | 50%+ | 0-5% | ⭐⭐⭐⭐ | 超长文本(100K+) |
| MoE | 5x | 70% | 0% | ⭐⭐⭐⭐⭐ | 超大模型,计算受限 |
7.2 选择决策树
+ KV Cache] Q2 -->|紧张| Use_GQA[使用GQA
+ KV Cache] Medium --> Q3{质量要求?} Q3 -->|最高| MHA_Long[MHA + KV Cache] Q3 -->|平衡| GQA_Long[GQA + Sliding Window] Long --> Sparse[Sparse Attention
必选方案] Start --> Q4{是否超大模型?} Q4 -->|>100B| Consider_MoE[考虑MoE架构] style Use_GQA fill:#fff59d style GQA_Long fill:#fff59d style Sparse fill:#a5d6a7 style Consider_MoE fill:#81c784
7.3 工业界实践
OpenAI GPT系列:
- GPT-3: MHA + KV Cache
- GPT-3.5/4: 推测 MQA/GQA + Sparse + MoE
Meta LLaMA系列:
- LLaMA: MHA + KV Cache
- LLaMA2: GQA-8 + KV Cache (黄金组合)
- LLaMA3: GQA + 更长上下文
Google PaLM/Gemini:
- PaLM: MQA + KV Cache
- PaLM2: MQA改进版
Anthropic Claude:
- Claude 1/2: 推测 GQA + Sparse
- Claude 3: Sparse Attention (200K上下文)
八、实战:构建一个优化的Transformer
完整代码
python
class OptimizedTransformerBlock(nn.Module):
"""
集成GQA + KV Cache的优化Transformer Block
"""
def __init__(self, d_model, n_heads, n_kv_groups, d_ff, dropout=0.1):
super().__init__()
# GQA
self.gqa = GroupedQueryAttention(d_model, n_heads, n_kv_groups)
# FFN
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, cache=None, use_cache=False):
# Self-attention with cache
attn_out, new_cache = self.gqa(x, cache=cache, use_cache=use_cache)
x = self.norm1(x + self.dropout(attn_out))
# FFN
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
if use_cache:
return x, new_cache
return x
# LLaMA2-7B配置
d_model = 4096
n_heads = 32
n_kv_groups = 8 # GQA-8
d_ff = 11008
n_layers = 32
# 构建完整模型
class OptimizedLLM(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_kv_groups, d_ff, n_layers):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
OptimizedTransformerBlock(d_model, n_heads, n_kv_groups, d_ff)
for _ in range(n_layers)
])
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, input_ids, caches=None, use_cache=False):
x = self.embedding(input_ids)
new_caches = []
for i, layer in enumerate(self.layers):
cache = caches[i] if caches else None
if use_cache:
x, new_cache = layer(x, cache=cache, use_cache=True)
new_caches.append(new_cache)
else:
x = layer(x)
logits = self.lm_head(x)
if use_cache:
return logits, new_caches
return logits
# 使用示例
vocab_size = 32000
model = OptimizedLLM(vocab_size, d_model, n_heads, n_kv_groups, d_ff, n_layers)
print(f"模型参数量: {sum(p.numel() for p in model.parameters())/1e9:.2f}B")
# 输出: 模型参数量: 6.74B (接近LLaMA2-7B)
九、总结与展望
9.1 核心技术总结
9.2 未来趋势
1. 更长的上下文
- 目标: 100万token上下文
- 技术: 混合注意力模式、分层记忆
2. 更高效的架构
- 线性Attention (RWKV, RetNet)
- 状态空间模型 (Mamba)
3. 动态计算
- 早停机制 (Early Exit)
- 自适应计算 (Adaptive Computation)
4. 硬件协同优化
- 定制芯片(TPU, Groq)
- 混合精度(FP8, INT4)
十、练习与资源
练习题
1. 计算KV Cache节省
python
# 给定LLaMA2-13B配置,计算生成1000个token的KV Cache大小
# n_layers=40, n_heads=40, d_k=128, seq_len=1000
2. 实现Sliding Window Mask
python
def create_sliding_window_mask(seq_len, window_size):
# TODO: 实现并可视化
pass
3. 对比GQA不同配置
python
# 实验GQA-4 vs GQA-8 vs MHA的性能和显存