【LLM技术全景】长上下文技术全景:突破窗口限制的方法论

摘要:本文是《LLM技术全景:从Token到部署》系列第十二篇,技术原理篇第四讲。大模型的上下文窗口,曾是制约其应用的最大"天花板":GPT-2只有1024个Token,想处理一本书?不行。本文从根本原因出发,系统梳理突破上下文限制的四条技术路线------位置编码外推、稀疏注意力、FlashAttention高效计算、StreamingLLM无限推理,帮你彻底理解从4K到1M上下文窗口背后的技术真相。

阅读收获:① 理解上下文长度受限的根本原因(计算复杂度 + 位置编码泛化);② 掌握RoPE位置插值(PI)、YaRN、ALiBi三种外推方案;③ 理解FlashAttention的IO感知设计与加速原理;④ 了解稀疏注意力机制(Longformer、BigBird);⑤ 掌握StreamingLLM无限流式推理的核心思想。


一、引言:上下文窗口为什么是"天花板"?

复制代码
用户问题:
  "帮我分析这本300页的产品手册,找出所有关于安全规范的内容"
  → 300页 ≈ 150,000 个Token

现实:
  GPT-2:4K Token上限,直接拒绝
  GPT-3:4K → 失败
  GPT-3.5:16K → 仍然不够
  GPT-4:128K → 勉强能处理
  Gemini 1.5 Pro:1M Token → 终于!

这个"天花板"从哪里来?有两个独立的根源

1.1 根源一:二次方计算复杂度

标准Transformer的注意力机制,计算复杂度是 O(n²)

复制代码
Attention(Q, K, V) = softmax(QK^T / √d_k) · V

QK^T 的形状:[n, n]
→ n个Token,需要计算 n×n 个注意力权重
→ 内存占用:O(n²·d),n=128K时 ≈ 数百GB!

具体数字(32K Token,d=128):

复制代码
注意力矩阵大小:32768 × 32768 × 4 bytes ≈ 4GB(仅单层!)
GPT-4有96层,则每次前向传播 ≈ 384GB
→ 显然不可能在单卡甚至多卡上直接计算

1.2 根源二:位置编码泛化失败

不论使用绝对位置编码(APE)还是相对位置编码(RoPE),当测试序列长度超过训练时的最大长度时,模型性能会急剧下降:

复制代码
训练时最大长度:4096
测试时序列长度:8192
→ 模型遇到了"没见过"的位置,性能崩溃

直觉理解:
  就像一个只走过4公里路的人,
  突然要走8公里------前4公里还行,
  后4公里方向感全无。

这两个根源,需要用不同的技术手段分别解决:

复制代码
上下文扩展技术路线图:
                  ┌─ 位置插值(PI)
                  ├─ YaRN(非均匀插值)
 位置编码外推 ──── ┤
                  ├─ ALiBi(线性偏置)
                  └─ NTK-aware 插值
                  
                  ┌─ Sliding Window(滑动窗口)
                  ├─ Longformer(局部+全局)
 稀疏注意力 ────── ┤
                  └─ BigBird(随机+局部+全局)
                  
                  ┌─ FlashAttention v1/v2/v3
 高效注意力计算 ── ┤
                  └─ Ring Attention(分布式长序列)
                  
                  └─ StreamingLLM(无限流式)
 无限外推 ──────── ┘

二、位置编码外推:从4K到128K的核心技术

2.1 为什么RoPE需要外推?

LLaMA、GPT-NeoX等现代大模型普遍使用旋转位置编码(RoPE)。让我们先快速回顾其原理:

python 复制代码
# RoPE 核心:对Query和Key施加位置相关的旋转矩阵
# 位置m处的编码:
# q_m = R_m · q   (对q在m位置进行旋转)
# k_n = R_n · k   (对k在n位置进行旋转)
# 注意力权重 ∝ q_m^T · k_n = q^T · R_{m-n} · k
# → 只依赖相对位置 m-n,天然适合相对位置编码!

import torch
import math

def apply_rope(q, k, positions):
    """
    q, k: [batch, heads, seq_len, head_dim]
    positions: [seq_len]
    """
    head_dim = q.shape[-1]
    
    # 计算旋转角度:θ_i = 1 / (10000^(2i/d))
    # i = 0, 1, ..., d/2-1
    theta = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, dtype=torch.float) / head_dim))
    
    # 位置 × 频率
    freqs = positions.unsqueeze(1) * theta.unsqueeze(0)  # [seq_len, d/2]
    
    # 构建旋转矩阵的 cos 和 sin 分量
    cos = freqs.cos()  # [seq_len, d/2]
    sin = freqs.sin()  # [seq_len, d/2]
    
    # 旋转操作(交叉乘法)
    def rotate(x, cos, sin):
        x1, x2 = x[..., ::2], x[..., 1::2]
        return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2)
    
    return rotate(q, cos, sin), rotate(k, cos, sin)

问题所在

复制代码
训练时:positions = [0, 1, 2, ..., 4095]
           频率 θ_i 在此范围内良好工作

测试时:positions = [0, 1, 2, ..., 8191]
           超出4096的部分:θ_i 旋转到了训练时没见过的角度
           → 注意力计算结果失真,性能崩溃

2.2 方案一:线性位置插值(PI)

论文:Extending Context Window of Large Language Models via Positional Interpolation(Chen et al., 2023)

核心思想:与其让位置超出范围,不如把位置"压缩"回训练范围内。

复制代码
原来(直接外推):
  训练范围 [0, 4096],测试需要 [0, 8192]
  → 4097...8192 越界,模型没见过

线性插值(PI):
  将测试位置乘以缩放因子 s = L_train / L_test
  s = 4096 / 8192 = 0.5
  → 测试位置 [0, 8192] 压缩为 [0, 4096]
  → 所有位置都在训练范围内!
python 复制代码
class LinearPositionInterpolation:
    def __init__(self, original_max_len=4096, target_max_len=8192):
        self.scale = original_max_len / target_max_len  # 0.5
    
    def get_positions(self, seq_len):
        """返回缩放后的位置序列"""
        original_positions = torch.arange(seq_len, dtype=torch.float)
        scaled_positions = original_positions * self.scale  # 压缩到训练范围
        return scaled_positions

# 实际使用:
pi = LinearPositionInterpolation(original_max_len=4096, target_max_len=32768)
# 对于位置 i,实际使用 i × (4096/32768) = i × 0.125

代价 :相邻Token的位置差从1缩小到了0.5,模型需要在更密集 的位置空间中区分Token顺序。

好消息:通过少量微调(约1000步),模型可以适应压缩后的位置分布。

实测效果(LLaMA-7B):

复制代码
扩展方式             | 4K长度性能 | 8K长度性能 | 16K长度性能
--------------------|-----------|-----------|------------
原始RoPE(直接推断) |   100%    |    ~15%   |    ~5%
线性插值 + 1K步微调  |    98%    |    91%    |   (未测)
线性插值 + 2K步微调  |    99%    |    95%    |    88%

2.3 方案二:YaRN(Yet another RoPE extensioN)

论文:YaRN: Efficient Context Window Extension of Large Language Models(Peng et al., 2023)

问题:线性插值是"一刀切"------所有频率维度等比例缩放。但RoPE不同频率的维度有不同的特性:

复制代码
低频维度(大i,θ_i 小):捕捉长程依赖
高频维度(小i,θ_i 大):捕捉短程依赖

线性插值的问题:
  对高频维度(已经充分利用)进行压缩 → 损伤短程建模能力
  对低频维度(还有余量)进行压缩 → 还好
  
YaRN的做法:区分对待!

YaRN核心创新:非均匀插值

python 复制代码
def yarn_get_scale(dim_idx, head_dim, original_max_len, target_max_len):
    """
    根据维度索引,计算不同的缩放因子
    """
    # 计算该维度的波长
    theta = 1.0 / (10000 ** (2 * dim_idx / head_dim))
    wavelength = 2 * math.pi / theta  # 该维度的周期
    
    # 三个区间:
    # 1. 高频(波长短):不插值,直接用
    # 2. 中频:线性插值
    # 3. 低频(波长长):NTK-aware(基数外推)
    
    low_freq_factor = 1  # 低频阈值(单位:原始上下文长度的倍数)
    high_freq_factor = 4  # 高频阈值
    
    if wavelength < original_max_len / high_freq_factor:
        # 高频维度:不缩放(保留短程信息)
        return 1.0
    elif wavelength > original_max_len / low_freq_factor:
        # 低频维度:用 NTK 缩放(保留长程泛化)
        scale = target_max_len / original_max_len
        return scale
    else:
        # 中间频率:线性混合
        smooth = (original_max_len / wavelength - low_freq_factor) / \
                 (high_freq_factor - low_freq_factor)
        scale_linear = target_max_len / original_max_len
        return (1 - smooth) * scale_linear / (target_max_len / original_max_len) + smooth

实测对比(以 Mistral-7B 扩展到 128K 为例):

复制代码
方法         | 需要微调步数 | 128K下困惑度 | 4K原始性能保留
------------|------------|-------------|---------------
线性 PI     |   ~2000步   |   较高(差) |      98%
YaRN        |   ~400步    |   较低(好) |      99.5%
不做任何处理 |    无       |   极高(差) |      100%

YaRN以更少的微调步数更低的困惑度在扩展到超长上下文方面显著优于线性插值。

2.4 方案三:ALiBi(注意力线性偏置)

论文:Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation(Press et al., 2021)

根本不同的思路 :与其修改位置编码,不如彻底去掉位置编码,改为在注意力分数上添加线性偏置。

复制代码
标准注意力:
  Score(q_m, k_n) = q_m · k_n / √d_k

ALiBi:
  Score(q_m, k_n) = q_m · k_n / √d_k - slope × |m - n|

其中 slope 是每个注意力头不同的超参数(固定值,不可学习):
  head 1: slope = 2^(-1/4)
  head 2: slope = 2^(-2/4)
  ...
  head 8: slope = 2^(-8/4)  (以8头为例)

直觉理解:
  |m - n| 越大 → 相距越远 → 减去更大的数 → 注意力分数越低
  → 模型自然偏向关注近距离Token
  → 即使遇到超长序列,这个偏置机制天然适用!

可视化:
  位置差 0   1   2   3   ... 1000 ... 10000
  惩罚项 0  -s  -2s  -3s ... -1000s ... -10000s
  (s 是该头的 slope)

ALiBi的优劣势

对比维度 ALiBi RoPE + 插值
外推能力 天然线性外推,无需微调 需要插值 + 微调
短序列性能 与RoPE相当 与RoPE相当
工程复杂度 极简(加一行偏置) 需要修改位置计算逻辑
适用性 训练时就必须使用 可以事后添加
代表模型 MPT、BLOOM LLaMA 3、Mistral

三、FlashAttention:解决二次方内存的工程杰作

3.1 标准注意力的内存瓶颈

前面说了,标准注意力的内存是 O(n²),但为什么这么致命?关键在于 GPU内存层次

复制代码
GPU内存层次(从快到慢,从小到大):
  ┌─────────────────────────────────────────┐
  │  寄存器 (Register)                       │  ~MB级,极快
  ├─────────────────────────────────────────┤
  │  共享内存/L1 Cache (SRAM)                │  ~MB级,很快
  ├─────────────────────────────────────────┤
  │  L2 Cache                               │  ~几十MB,较快
  ├─────────────────────────────────────────┤
  │  显存 (HBM - High Bandwidth Memory)     │  ~几十GB,慢得多
  └─────────────────────────────────────────┘

标准注意力的问题:
  步骤1:计算 S = QK^T               → 写入HBM(O(n²)大!)
  步骤2:计算 P = softmax(S)         → 从HBM读,再写回(O(n²))
  步骤3:计算 O = P·V               → 从HBM读,写入输出
  
  → 大量的HBM读写(I/O bound),GPU计算单元大量等待
  → 这是瓶颈所在,不是"算不过来",而是"传不过来"

3.2 FlashAttention v1:IO感知的分块计算

论文:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness(Dao et al., 2022)

核心思想:避免将 n×n 的注意力矩阵写入显存。利用 SRAM(片上高速内存)分块计算。

复制代码
标准注意力(朴素实现):
  全部数据 → HBM → SRAM(计算)→ HBM → SRAM(计算)→ HBM
  
FlashAttention:
  分块加载 Q、K、V → SRAM
  在SRAM内完成分块注意力计算(从不写入完整 n×n 矩阵)
  直接输出最终结果 O → HBM
  
  关键技巧:在线Softmax(Online Softmax)
  → 不需要先看完全部 n 个Key,就能计算出精确的Softmax
  → 这使得分块计算成为可能

在线Softmax推导(理解FlashAttention的关键):

python 复制代码
# 普通Softmax(需要先看完所有Key)
def naive_softmax(scores):  # scores: [n]
    exp_s = torch.exp(scores - scores.max())  # 数值稳定
    return exp_s / exp_s.sum()

# 问题:需要两遍扫描(第一遍找max,第二遍计算)
# 无法分块!

# 在线Softmax(可以分块计算!)
def online_softmax_update(m_prev, l_prev, o_prev, new_block_scores, new_v):
    """
    增量更新:已处理前面的块,现在处理新的一块
    m_prev: 之前所有块的最大值
    l_prev: 之前所有块的 exp 和
    o_prev: 之前所有块的加权输出
    """
    m_new = max(m_prev, new_block_scores.max())  # 更新全局最大值
    
    # 重新缩放旧的结果(因为全局最大值变了)
    l_new = torch.exp(m_prev - m_new) * l_prev + \
            torch.exp(new_block_scores - m_new).sum()
    
    # 更新加权输出
    o_new = (torch.exp(m_prev - m_new) * l_prev * o_prev + 
             torch.exp(new_block_scores - m_new) * new_v) / l_new
    
    return m_new, l_new, o_new

Flash Attention内存对比

复制代码
方法               | 峰值内存(n=4K) | 峰值内存(n=16K) | 速度
------------------|--------------|----------------|-------
标准注意力          |    ~800MB    |     ~12GB      |  基准
FlashAttention v1  |    ~2MB      |     ~2MB       |  快2-4倍
FlashAttention v2  |    ~2MB      |     ~2MB       |  快4-8倍
FlashAttention v3  |    ~2MB      |     ~2MB       |  快6-12倍(H100上)

→ 内存从O(n²)降到O(n)!!

3.3 FlashAttention v2、v3的进化

v2(2023):改进了工作分配方式

  • 减少非矩阵乘法(non-matmul)操作
  • 并行化序列维度,更好利用多SM(流多处理器)
  • A100上速度提升:~2倍(相比v1)

v3(2024,专为H100设计)

  • 利用H100的异步执行(Async WGMMA指令)
  • 交叉执行softmax和矩阵乘法
  • Tensor Core利用率从35%提升至75%
  • H100上比v2再快50-80%
python 复制代码
# 使用FlashAttention(实际调用很简单)
from flash_attn import flash_attn_func

# 替代 torch.nn.functional.scaled_dot_product_attention
output = flash_attn_func(
    q,         # [batch, seq_len, num_heads, head_dim]
    k,
    v,
    dropout_p=0.0,
    causal=True,  # 自回归模型必须为True
    softmax_scale=1.0 / math.sqrt(head_dim),
)

四、稀疏注意力:选择性地"看"

FlashAttention解决了内存问题,但计算量仍是O(n²)。如果想降低计算量 本身,需要稀疏注意力

4.1 滑动窗口注意力(Sliding Window)

核心思想:每个Token只关注距离自己最近的 w 个Token(窗口大小 w << n)。

复制代码
标准注意力:每个Token关注所有n个Token
  Token_i 关注 Token_0, 1, 2, ..., n-1
  复杂度:O(n²)

滑动窗口:每个Token只关注最近w个Token
  Token_i 关注 Token_{i-w}, ..., Token_{i-1}, Token_i
  复杂度:O(n · w)(线性!)

问题 :全局信息如何传播?通过层叠效应

  • 第1层:窗口大小 w,每个Token感知范围 w
  • 第2层:每个Token的感知已经融合了其前后 w 个Token的信息,等效感知范围 2w
  • 第L层:等效感知范围 L×w

对于L=32层、w=512的模型,理论感知范围 32×512=16384个Token------足够!

代表实现:Mistral-7B使用了 w=4096 的滑动窗口注意力(配合RoPE),在长文本上保持线性内存。

4.2 Longformer:局部 + 全局注意力

论文:Longformer: The Long-Document Transformer(Beltagy et al., 2020)

三种注意力模式混合

复制代码
局部滑动窗口(所有Token都有):
  Token_i ←→ Token_{i-w/2}, ..., Token_{i+w/2}
  捕捉局部上下文(句子内、段落内)
  复杂度:O(n·w)

全局注意力(少数"全局Token"有):
  [CLS], [SEP], 特定问题Token 等特殊标记
  全局Token ←→ 所有其他Token(双向)
  捕捉全局信息(答案通过全局Token向全文传播)
  
稀疏随机注意力(可选):
  随机选取 r 个额外Key进行关注
  补充全局信息传播
python 复制代码
from transformers import LongformerModel, LongformerTokenizer

tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
model = LongformerModel.from_pretrained('allenai/longformer-base-4096')

# 关键:指定全局注意力位置
inputs = tokenizer(long_document, return_tensors='pt')
inputs['global_attention_mask'] = torch.zeros_like(inputs['input_ids'])
inputs['global_attention_mask'][:, 0] = 1  # CLS Token使用全局注意力

outputs = model(**inputs)

4.3 注意力模式的演进

复制代码
年份  方法                  上下文长度  复杂度      代表模型
2017  Full Attention        512-4K     O(n²)      GPT-3, BERT
2020  Sparse Transformer    16K        O(n√n)     (OpenAI内部)
2020  Longformer            4K-16K     O(n·w)     allenai/longformer
2020  BigBird               4K-16K     O(n)       Google BigBird
2022  FlashAttention v1     超长        O(n²)计算量  广泛集成
2023  FlashAttention v2     超长        O(n²)计算量  Llama3, Mistral
2023  RoPE + YaRN           128K       O(n²)计算量  Mistral-7B 128K
2024  FlashAttention v3     超长        O(n²)计算量  H100专用
2024  Ring Attention        百万级      O(n²)       分布式长序列
2025  Gemini 1.5 Pro        1M         专用架构    Google Gemini

五、StreamingLLM:真正的无限上下文

5.1 问题背景:实际部署的困境

即使上下文窗口扩展到了128K,实际部署仍面临一个棘手问题:

复制代码
场景:长对话服务(如客服AI、持续对话助手)

挑战:
  轮次 1-10:对话积累,上下文 10K Token,OK
  轮次 11-50:上下文 50K Token,勉强
  轮次 51-???:超出128K上下文上限......崩溃!
  
传统解决方案:
  方案A:截断(丢弃早期对话)→ 失忆,体验差
  方案B:滑动窗口(只保留最近K个)→ 仍然"失忆"
  方案C:摘要压缩(用LLM摘要历史)→ 高延迟,信息损失

有没有方法无限处理流式输入,同时内存固定?

5.2 StreamingLLM的核心发现

论文:Efficient Streaming Language Models with Attention Sinks(Xiao et al., 2023,MIT + Meta)

关键发现:注意力汇聚现象(Attention Sink)

复制代码
研究者发现:几乎所有大模型中,
前几个Token(尤其是第1个Token,即 BOS)
会获得极高的注意力分数------无论它们本身的语义如何。

这是为什么?
  LLM "必须" 将注意力分散到某些位置,
  即使当前任务不需要早期Token的信息,
  Softmax也无法输出 [0, 0, 0, ..., 0]。
  
  → 前几个Token充当了"注意力垃圾桶"
  → 删掉它们会导致注意力分布剧烈变化,性能崩溃!

StreamingLLM 的方案

复制代码
保留两类 Token:
  1. Attention Sinks(注意力汇):最开始的 4 个 Token(固定保留)
  2. Recent Tokens(近期记忆):最近的 L 个 Token(滑动窗口)

KV Cache 结构:
  [Sink₁, Sink₂, Sink₃, Sink₄ | Token_{t-L+1}, ..., Token_t]
   ←────── 固定 Sinks ──────→│←────── 滑动窗口 ──────────→
   
总 KV Cache 大小:4 + L(固定!不随对话增长)
python 复制代码
class StreamingLLM:
    def __init__(self, model, sink_size=4, window_size=1020):
        self.model = model
        self.sink_size = sink_size       # 保留最初4个Token
        self.window_size = window_size   # 保留最近1020个Token
        self.total_kv_size = sink_size + window_size  # 总共1024
        
        self.kv_cache = None
    
    def decode_token(self, new_token):
        """流式处理每个新Token"""
        # 1. 如果KV Cache未满,直接添加
        if self.kv_cache is None or len(self.kv_cache) < self.total_kv_size:
            new_kv = self.model.forward_single(new_token, self.kv_cache)
            self.kv_cache = append(self.kv_cache, new_kv)
        else:
            # 2. KV Cache已满:保留Sinks + 滑动窗口
            sinks = self.kv_cache[:self.sink_size]      # 固定保留
            recent = self.kv_cache[-(self.window_size-1):]  # 只保留最近的
            self.kv_cache = concat(sinks, recent)       # 拼接
            
            # 3. 前向传播
            new_kv = self.model.forward_single(new_token, self.kv_cache)
            self.kv_cache = append(self.kv_cache, new_kv)
        
        return self.model.predict_next(self.kv_cache)
    
    def generate_stream(self, token_stream):
        """无限流式生成"""
        for token in token_stream:
            yield self.decode_token(token)
        # 内存:始终是 O(sink_size + window_size) = 常数!

实验结果(LLaMA-7B,4M Token 流式输入):

复制代码
方法           | 4M Token处理 | 性能衰退 | 内存增长
--------------|-------------|---------|--------
标准滑动窗口   |     ❌崩溃   |  完全崩  | 线性增长
StreamingLLM  |     ✅正常   |  几乎无  | 固定不变
(比较基线:完整注意力 Oracle,内存 O(n²),无法部署)

局限性 :StreamingLLM无法"回溯"到早期信息(除了4个Sink以外的历史完全丢弃)。适合流式对话 ,不适合需要引用全文历史的任务(如长文档QA)。


六、综合对比:如何选择长上下文方案?

6.1 技术方案横向对比

方案 内存复杂度 计算复杂度 上下文上限 适用场景
标准注意力 O(n²) O(n²) 受显存限制(~128K) 一般任务
FlashAttention O(n) O(n²) 显存许可范围 替代标准注意力
线性插值 PI O(n) O(n²) 4-8×原始窗口 微调时扩展
YaRN O(n) O(n²) 4-32×原始窗口 高效微调扩展
ALiBi O(n) O(n²) 无限(性能衰减) 训练时直接用
滑动窗口 O(n·w) O(n·w) 理论无限 文档处理
Longformer O(n·w) O(n·w) ~4-16K 长文档理解
StreamingLLM O(常数) O(n·L) 真正无限 流式对话
Ring Attention O(n/节点) O(n²) 百万级 分布式集群

6.2 实践决策树

复制代码
你的任务是什么?
│
├─ 流式实时对话(无限轮次)
│    └─ StreamingLLM(内存固定,适合生产部署)
│
├─ 一次处理超长文档(100K+)
│    ├─ 有微调条件 → YaRN 微调扩展
│    └─ 没有微调条件 → 支持长上下文的模型API(Gemini/GPT-4/Claude)
│
├─ 中等长度文档(16K-128K)
│    ├─ 已有模型需要扩展 → 线性插值 PI + 少量微调
│    └─ 从头训练/选型 → LLaMA-3.1(128K,内置YaRN)
│
├─ 加速计算(不改变上下文长度)
│    └─ FlashAttention(插入即用,几乎无副作用)
│
└─ 超长序列(>1M Token)分布式处理
     └─ Ring Attention(多GPU/节点并行处理长序列)

七、实战:用 Transformers 处理超长文档

7.1 快速上手:加载支持长上下文的模型

python 复制代码
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# LLaMA-3.1-8B-Instruct:官方支持128K上下文
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,    # 推荐 BF16
    device_map="auto",             # 自动分配到可用GPU
    attn_implementation="flash_attention_2",  # 强烈推荐!
)

# 处理长文档
def analyze_long_doc(doc_text: str, question: str, max_new_tokens: int = 500):
    prompt = f"""请仔细阅读以下文档,然后回答问题。
    
文档:
{doc_text}

问题:{question}

回答:"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    token_count = inputs['input_ids'].shape[1]
    print(f"文档长度:{token_count} tokens")
    
    if token_count > 128000:
        print("警告:超出128K上下文限制,请考虑分块处理")
        return None
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # 文档分析用贪心解码即可
            pad_token_id=tokenizer.eos_token_id,
        )
    
    answer = tokenizer.decode(outputs[0][token_count:], skip_special_tokens=True)
    return answer

7.2 RAG vs 长上下文:该用哪个?

复制代码
RAG(检索增强生成):
  原理:先检索相关片段,再用短上下文推理
  优点:成本低、速度快、可处理海量知识库
  缺点:检索质量依赖embedding,可能遗漏关键信息

长上下文:
  原理:把全部内容塞入上下文窗口
  优点:不会遗漏信息,理解更完整
  缺点:成本高(推理时间/费用 ∝ n²)、大海捞针问题

选择建议:
  知识库 > 1M Token → RAG(长上下文装不下)
  文档 ≤ 100K Token + 精度要求高 → 长上下文
  折中方案:RAG 粗筛 + 长上下文精读(Coarse-to-Fine)
python 复制代码
class CoarseToFineRetrieval:
    """RAG粗筛 + 长上下文精读的混合方案"""
    
    def __init__(self, llm, retriever, long_ctx_model):
        self.llm = llm              # 短上下文模型(用于RAG)
        self.retriever = retriever  # 向量检索器
        self.long_model = long_ctx_model  # 长上下文模型(用于精读)
    
    def query(self, question, top_k_coarse=20, top_k_fine=5):
        # 第一阶段:RAG粗筛(快速)
        # 从百万Token语料中检索出top-20段落
        coarse_results = self.retriever.search(question, k=top_k_coarse)
        
        # 第二阶段:长上下文精读(精确)
        # 将top-20段落合并(通常几万Token),用长上下文模型精读
        combined_context = "\n\n".join([r.text for r in coarse_results])
        answer = self.long_model.generate(
            context=combined_context,
            question=question
        )
        
        return answer

八、总结与展望

本文核心要点

复制代码
1. 上下文限制的根源:
   ✓ 计算复杂度 O(n²):n=128K时,注意力矩阵4GB以上
   ✓ 位置编码泛化失败:超出训练长度,性能急剧下降

2. 位置编码外推方案:
   ✓ 线性插值 PI:简单有效,需少量微调(2000步内)
   ✓ YaRN:非均匀插值,更好保留短程信息,只需~400步微调
   ✓ ALiBi:根本解决,训练时使用,天然支持外推

3. 高效注意力:
   ✓ FlashAttention:IO感知分块计算,内存从O(n²)降到O(n)
   ✓ v1/v2/v3逐代进化,H100上比v1快6-12倍
   ✓ 集成到几乎所有主流框架,几乎是必装优化

4. 稀疏注意力:
   ✓ 滑动窗口:O(n·w)线性复杂度,通过层叠传播全局信息
   ✓ Longformer:局部+全局混合注意力

5. StreamingLLM:
   ✓ 关键发现:Attention Sink现象(前4个Token必须保留)
   ✓ Sink + 滑动窗口 = 真正无限上下文(固定内存)
   ✓ 适合流式对话,不适合需要回溯历史的任务

6. 实践建议:
   ✓ 所有场景:先装 FlashAttention,零成本收益
   ✓ 扩展现有模型:YaRN > 线性插值
   ✓ 流式部署:StreamingLLM
   ✓ 超长但精度要求高:长上下文 > RAG(但成本高)

技术演进方向

复制代码
2022:FlashAttention v1 → GPU内存不再是瓶颈
2023:YaRN / PI → 长上下文微调成本从天文数字降至可行
2023:StreamingLLM → 流式无限对话成为可能
2024:FlashAttention v3 / Ring Attention → H100最大化利用
2024:Gemini 1.5 Pro 1M上下文 → 长上下文成为产品特性
2025:1M成为标配,10M开始探索
2026:方向:超长上下文压缩(不是全文塞入,而是智能压缩)
      → MemoryBank、Compressive Transformer、RMT(循环记忆Transformer)

下期预告(7月6日,周日):《RLHF与DPO:大模型对齐技术的两条路径》------PPO/RLHF完整三步流程、DPO数学推导、对比实验与选择建议。


参考资料

  1. Chen et al. (2023) --- Extending Context Window of Large Language Models via Positional Interpolation
  2. Peng et al. (2023) --- YaRN: Efficient Context Window Extension of Large Language Models
  3. Press et al. (2021) --- Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation
  4. Dao et al. (2022) --- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
  5. Dao et al. (2023) --- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
  6. Shah et al. (2024) --- FlashAttention-3: Fast and Accurate Attention for Hopper GPUs
  7. Beltagy et al. (2020) --- Longformer: The Long-Document Transformer
  8. Xiao et al. (2023) --- Efficient Streaming Language Models with Attention Sinks
  9. Liu et al. (2023) --- Ring Attention with Blockwise Transformers for Near-Infinite Context

延伸讨论

思考题

  1. FlashAttention降低了内存,但计算量仍是O(n²)。能否设计一种方案,同时降低内存和计算量?(提示:稀疏注意力 + FlashAttention结合?)
  2. StreamingLLM的Attention Sink现象是真实的"语义锚点",还是模型的"习惯性分心"?如何设计实验验证?
  3. 当上下文窗口足够大(>1M),RAG是否会被彻底取代?

实践作业

  • 安装 flash-attn 包,对比标准注意力和FlashAttention在不同序列长度下的速度和内存差异
  • 用 StreamingLLM 实现一个支持无限轮次的聊天机器人原型(固定KV Cache大小为1024)
  • 分别测试:截断策略 vs StreamingLLM vs 完整上下文,在长对话中的答案一致性