Deepseek Natively Sparse Attention

NSA(Natively Sparse Attention)论文原理解析

论文标题: Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention
作者团队: DeepSeek-AI, Peking University, University of Washington
核心目标: 提出一种高效、可训练的稀疏注意力机制,以提高长文本处理的计算效率,同时保持模型性能。


1. NSA 研究背景

1.1 长文本建模的挑战

  • 现代大模型(如 GPT-4, Gemini 1.5)需要处理超长文本(64k 甚至更长)。
  • 传统 全注意力(Full Attention) 计算复杂度为 O(N²) ,在长文本上计算开销巨大,导致 训练和推理效率低下
  • 现有的稀疏注意力(Sparse Attention)方法在 训练阶段支持较弱,通常只优化推理阶段。

1.2 现有稀疏注意力方法的局限

  • 理论计算减少 ≠ 真实速度提升
    • 许多方法仅在 推理阶段(Inference Stage) 优化,忽略了 训练时间的计算成本
    • 例如 H2O、Quest 关注 KV 缓存剪枝(KV-cache pruning) ,但 在真实硬件上加速有限
  • 训练阶段支持不足
    • 许多方法采用 离散选择(Discrete Selection) ,导致 梯度无法回传,难以进行端到端训练(End-to-End Training)。

2. NSA 方法:基于层次化的稀疏注意力

NSA 提出的创新点:

  1. 层次化稀疏策略(Hierarchical Sparse Strategy)
    • 结合 粗粒度 token 压缩(Compression)细粒度 token 选择(Selection) ,同时保留 全局信息局部精度
  2. 硬件优化(Hardware-Aligned System)
    • 设计 适用于现代 GPU(如 A100, H100)的优化算子,提升推理效率。
  3. 可训练性增强(Natively Trainable Design)
    • 允许在 训练阶段 进行稀疏优化,而不仅仅是在推理阶段加速。

2.1 NSA 关键机制

NSA 通过 三种注意力路径 进行计算:

  1. 压缩注意力(Compressed Attention)
    • 通过块级 Token 压缩(Blockwise Token Compression),减少计算开销。
  2. 选择性注意力(Selected Attention)
    • 仅保留 Top-k 重要 token,忽略不重要的 Token,提高计算效率。
  3. 滑动窗口注意力(Sliding Attention)
    • 确保局部上下文不会丢失,提高信息完整性。
NSA 计算过程
  1. 查询(Query) 经过 三种注意力路径 计算 注意力得分(Attention Score)
  2. 不同路径的注意力结果通过门控机制(Gating Mechanism)进行加权融合
  3. 最终得到优化后的注意力输出(Sparse Attention Output)

3. NSA 在硬件上的优化

3.1 计算强度均衡(Arithmetic Intensity Balance)

  • 在现代 GPU 上,计算强度(Arithmetic Intensity)决定了性能瓶颈:
    • 高计算强度(Compute-Bound):计算单元占用率高,计算能力未完全发挥。
    • 低计算强度(Memory-Bound):计算单元空闲,受限于显存访问速度。
  • NSA 通过 块级计算(Blockwise Computation) 提高 计算密度(Compute Density),减少显存访问瓶颈。

3.2 Triton 自定义内核(Triton Kernel Optimization)

  • 传统注意力计算 内存访问不连续,GPU 计算利用率低
  • NSA 通过 基于 Triton 的自定义 GPU 内核(Custom GPU Kernel for Sparse Selection)
    • 组级数据加载(Group-Centric Data Loading):避免多次访问 KV 缓存,减少内存带宽压力。
    • 共享 KV 读取(Shared KV Fetching):减少重复数据加载,提高计算效率。

4. NSA 在实验中的表现

4.1 计算加速

  • 相比全注意力(Full Attention),NSA 在 64k 序列上的速度提升最高可达 11.6×
  • 在训练阶段,NSA 前向传播(Forward)速度提高 9.0×,反向传播(Backward)速度提高 6.0×

4.2 模型性能

  • 在多个 自然语言任务(NLP Benchmarks) 上,NSA 在 保持甚至超过全注意力性能 的同时,大幅提高计算效率。
  • 64k 长文本任务 (LongBench Benchmark)中,NSA 超过所有现有稀疏注意力方法

4.3 复杂推理能力

  • NSA 在 数学推理任务(AIME 24 Benchmark) 中表现出色:
    • 8k 和 16k 上下文长度下,NSA 比全注意力基线提高 2.5× 和 1.6×

5. NSA 的关键优势

特点 NSA 贡献
计算复杂度降低 通过 层次化稀疏选择 ,将 O(N²) 降至 O(N log K)
硬件优化 适配 GPU Tensor Cores,优化内存访问,提高计算效率。
训练支持 NSA 可训练(Natively Trainable),不同于只优化推理的稀疏方法。
长文本处理能力 64k 长文本任务上超越全注意力 ,同时加速 推理和训练

6. 论文总结

NSA 通过 层次化稀疏注意力、硬件优化、训练可行性 ,在 计算加速和性能保持之间取得了平衡

相较于现有方法,NSA 不仅优化了推理(Inference),还显著降低了训练(Training)计算成本,为长文本建模提供了新的解决方案。


压缩注意力(Compressed Attention)机制解析

目标

  • 在保持全局信息的同时 降低计算复杂度,减少 Query-Key 计算量。
  • 通过 块级(blockwise)token 聚合,减少注意力计算中需要处理的 Key-Value 数量。

1. 为什么需要压缩注意力?

  • 标准注意力机制 :每个 Query q q q 需要计算所有 Key K K K 的注意力分数,计算复杂度为 O ( N 2 ) O(N^2) O(N2)。
  • 稀疏注意力(Sparse Attention):可以减少部分 Query-Key 计算,但仍然面临计算量和显存占用的问题。
  • 压缩注意力(Compressed Attention) 通过 对 Key-Value 进行块级聚合,减少 Key-Value 数量,降低计算复杂度。

2. 压缩注意力的具体方法

NSA 采用 块级 token 聚合 的方式,将 Key-Value 压缩成更少的代表性 token

这一过程可以分为 四步

2.1. 按块划分 Key-Value

  • 输入序列长度为 T T T ,Key-Value 维度为 d k d_k dk(Key 维度)和 d v d_v dv(Value 维度)。
  • 选择 块大小(block size) l l l ,把 Key-Value 分成多个块:
    • 第 i i i 块的 Key 表示为:

K i = { k i ⋅ l , k i ⋅ l + 1 , ... , k ( i + 1 ) ⋅ l − 1 } K_i = \{ k_{i \cdot l}, k_{i \cdot l+1}, \dots, k_{(i+1) \cdot l - 1} \} Ki={ki⋅l,ki⋅l+1,...,k(i+1)⋅l−1}

  • 第 i i i 块的 Value 表示为:

V i = { v i ⋅ l , v i ⋅ l + 1 , ... , v ( i + 1 ) ⋅ l − 1 } V_i = \{ v_{i \cdot l}, v_{i \cdot l+1}, \dots, v_{(i+1) \cdot l - 1} \} Vi={vi⋅l,vi⋅l+1,...,v(i+1)⋅l−1}

  • 这样,原始 Key-Value 变成了 块级 Key-Value,大幅减少了 Key 的数量。

2.2. 计算块级 Key 的代表性

  • 块级 Key K cmp K_{\text{cmp}} Kcmp 需要能够代表整个块的信息,可以用 平均池化(Mean Pooling)可训练 MLP
    • 平均池化(Mean Pooling):

K cmp , i = 1 l ∑ j = 0 l − 1 K i ⋅ l + j K_{\text{cmp}, i} = \frac{1}{l} \sum_{j=0}^{l-1} K_{i \cdot l + j} Kcmp,i=l1j=0∑l−1Ki⋅l+j

  • 可训练 MLP(Multi-Layer Perceptron):

K cmp , i = MLP ( K i ⋅ l : ( i + 1 ) ⋅ l ) K_{\text{cmp}, i} = \text{MLP}(K_{i \cdot l : (i+1) \cdot l}) Kcmp,i=MLP(Ki⋅l:(i+1)⋅l)

  • 其中 MLP 可以学习更丰富的特征,而平均池化计算量更低。

2.3. 计算块级 Value

  • 块级 Value V cmp V_{\text{cmp}} Vcmp 也可以采用类似方法:
    • 平均池化:

V cmp , i = 1 l ∑ j = 0 l − 1 V i ⋅ l + j V_{\text{cmp}, i} = \frac{1}{l} \sum_{j=0}^{l-1} V_{i \cdot l + j} Vcmp,i=l1j=0∑l−1Vi⋅l+j

  • 或使用 MLP:

V cmp , i = MLP ( V i ⋅ l : ( i + 1 ) ⋅ l ) V_{\text{cmp}, i} = \text{MLP}(V_{i \cdot l : (i+1) \cdot l}) Vcmp,i=MLP(Vi⋅l:(i+1)⋅l)

  • 这样可以降低计算量,同时保留重要信息。

2.4. 使用压缩 Key-Value 计算注意力

  • 计算 Query Q Q Q 和压缩后的 Key K cmp K_{\text{cmp}} Kcmp 之间的注意力

A cmp = Q K cmp T d k A_{\text{cmp}} = \frac{Q K_{\text{cmp}}^T}{\sqrt{d_k}} Acmp=dk QKcmpT

  • 计算 Softmax:

A cmp ′ = Softmax ( A cmp ) A'{\text{cmp}} = \text{Softmax}(A{\text{cmp}}) Acmp′=Softmax(Acmp)

  • 计算最终的注意力输出:

O cmp = A cmp ′ V cmp O_{\text{cmp}} = A'{\text{cmp}} V{\text{cmp}} Ocmp=Acmp′Vcmp


3. 压缩注意力的优势

对比项 普通注意力 稀疏注意力(Sparse Attention) 压缩注意力(Compressed Attention)
计算复杂度 O ( N 2 ) O(N^2) O(N2) O ( N log ⁡ k ) O(N \log k) O(Nlogk) O ( N ⋅ M ) O(N \cdot M) O(N⋅M)(( M \ll N \))
信息保留 完整信息 仅保留 Top-k 信息 保留全局信息,同时减少计算量
适用场景 短文本 长文本,但计算仍然较大 适合超长文本(64k+),计算高效
  • 相比全注意力(Full Attention),压缩注意力减少了计算量。
  • 相比其他稀疏注意力方法,压缩注意力能保留更多全局信息,同时具有更好的计算效率。

4. 代码示例

这里是一个 PyTorch 实现的 压缩注意力

python 复制代码
import torch
import torch.nn as nn

class CompressedAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, block_size=32):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.block_size = block_size
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        B, T, C = query.size()  # Batch, Sequence Length, Embedding Dimension

        # Projection
        Q = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # Blockwise compression (mean pooling)
        num_blocks = T // self.block_size
        K_cmp = K.view(B, self.num_heads, num_blocks, self.block_size, self.head_dim).mean(dim=3)
        V_cmp = V.view(B, self.num_heads, num_blocks, self.block_size, self.head_dim).mean(dim=3)

        # Compute attention with compressed keys
        attn_weights = torch.matmul(Q, K_cmp.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, V_cmp)

        # Reshape and output
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(attn_output)

# 示例调用
B, T, C = 2, 64, 128  # Batch size, Sequence length, Embedding dimension
num_heads = 8
block_size = 16

attention = CompressedAttention(C, num_heads, block_size)
query = torch.randn(B, T, C)
key = torch.randn(B, T, C)
value = torch.randn(B, T, C)

output = attention(query, key, value)
print(output.shape)  # (B, T, C)

5. 总结

  • 压缩注意力(Compressed Attention) 通过 块级聚合 Key-Value,大幅降低计算量,同时保留全局信息。
  • 计算复杂度降低 : O ( N 2 ) → O ( N ⋅ M ) O(N^2) \to O(N \cdot M) O(N2)→O(N⋅M),其中 M ≪ N M \ll N M≪N(压缩后的块数)。
  • 适用于超长文本建模 ,在 64k 甚至更长的序列 上能够高效工作。
  • 硬件友好 ,支持 GPU Tensor Core 优化,减少显存占用。

选择性注意力(Selected Attention)机制解析

目标

  • 选择 最重要的 Key-Value 进行计算,而不是对所有 Key 计算注意力,从而降低计算复杂度。
  • 通过 Top-K 选择策略,保留最关键的信息,减少冗余计算,提高长序列建模能力。

1. 为什么需要选择性注意力?

  • 普通注意力(Full Attention) 计算复杂度 O(N²),当序列长度很长(如 64k+),计算量巨大。
  • 压缩注意力(Compressed Attention) 通过 块级聚合 降低计算量,但可能损失部分细节信息。
  • 选择性注意力(Selected Attention) 进一步优化,只保留最重要的 Token 参与计算,避免处理不重要的信息,减少计算开销,同时保持全局和局部信息。

2. 选择性注意力的核心步骤

NSA 采用 基于注意力得分的动态 Top-K 选择(Top-K Token Selection) 方法来筛选关键 Token:

2.1. 计算 Query-Key 相关性

首先,计算 查询(Query)所有键(Key) 的相似性(即注意力分数):

A = Q K T d k A = \frac{Q K^T}{\sqrt{d_k}} A=dk QKT

其中:

  • A A A 是注意力分数矩阵,形状为 ( B , H , T , T ) (B, H, T, T) (B,H,T,T),表示每个 Query 对应 Key 的注意力得分。

2.2. 选择 Top-K 重要 Token

  • 对于每个 Query,选择 Top-K 重要的 Key ,其余的 Key 设为 − ∞ -\infty −∞(即被 Mask)。
  • 具体实现:
    • 计算每个 Query 对所有 Key 的注意力分数。
    • 使用 Top-K 算法 找出最大的 K K K 个值,索引存入 I top-k I_{\text{top-k}} Itop-k:

I top-k = argtopk ( A , K ) I_{\text{top-k}} = \text{argtopk}(A, K) Itop-k=argtopk(A,K)

  • 构造稀疏化的注意力分数矩阵:

A i j ′ = { A i j , j ∈ I top-k ( i ) − ∞ , 否则 A'{ij} = \begin{cases} A{ij}, & j \in I_{\text{top-k}}(i) \\ -\infty, & \text{否则} \end{cases} Aij′={Aij,−∞,j∈Itop-k(i)否则

  • 这样,我们 只在最重要的 Top-K Token 上计算 Softmax

A ~ = Softmax ( A ′ ) \tilde{A} = \text{Softmax}(A') A~=Softmax(A′)

2.3. 计算注意力输出

最终,用选择的 Top-K 注意力分数 计算新的 Value 权重求和

O = A ~ V O = \tilde{A} V O=A~V

这样,Query 只会与 最相关的 Key-Value 交互,提高计算效率,同时保留重要信息。


3. 选择性注意力的优势

方法 计算复杂度 信息保留能力 适用场景
全注意力(Full Attention) O ( N 2 ) O(N^2) O(N2) 完整 适用于短文本
压缩注意力(Compressed Attention) O ( N ⋅ M ) O(N \cdot M) O(N⋅M) 保留全局信息 适用于长文本
选择性注意力(Selected Attention) O ( N ⋅ K ) O(N \cdot K) O(N⋅K) 只保留最重要信息 适用于超长文本(64k+)
  • 相比全注意力(Full Attention) ,选择性注意力只计算 Top-K 重要信息,大幅降低计算量。
  • 相比压缩注意力(Compressed Attention) ,选择性注意力能保留 更精确的局部信息,保证高精度。

4. PyTorch 实现

以下是 选择性注意力 的 PyTorch 代码:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelectedAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, top_k):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.top_k = top_k
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        B, T, C = query.shape

        # Projection
        Q = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, H, T, d_k)
        K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (B, H, T, T)

        # Select Top-K tokens
        topk_values, topk_indices = torch.topk(attn_scores, self.top_k, dim=-1)  # (B, H, T, K)

        # Create a mask for non-Top-K elements
        mask = torch.full_like(attn_scores, float('-inf'))  # Default mask
        mask.scatter_(-1, topk_indices, topk_values)  # Retain Top-K values

        # Apply softmax on selected tokens
        attn_weights = F.softmax(mask, dim=-1)

        # Compute attention output
        attn_output = torch.matmul(attn_weights, V)  # (B, H, T, d_k)

        # Reshape and output
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(attn_output)

# 示例调用
B, T, C = 2, 64, 128  # Batch size, Sequence length, Embedding dimension
num_heads = 8
top_k = 16  # 选择 Top-K 重要 Token

attention = SelectedAttention(C, num_heads, top_k)
query = torch.randn(B, T, C)
key = torch.randn(B, T, C)
value = torch.randn(B, T, C)

output = attention(query, key, value)
print(output.shape)  # (B, T, C)

5. 选择性注意力的优化方向

  1. Top-K 选择的优化
    • 目前使用 torch.topk() 进行选择,时间复杂度为 O ( N log ⁡ K ) O(N \log K) O(NlogK)。
    • 可以优化为 Heap Sort + 近似选择算法,进一步提高效率。
  2. 自适应 K 值选择
    • 目前的 K 值是固定的 ,可以使用 Learnable Gate 机制 ,让模型 动态决定 K 的大小
  3. 结合其他稀疏注意力
    • 压缩注意力 + 选择性注意力 可以同时 减少计算量保留最关键信息,适合超长序列任务(64k+)。

6. 总结

  • 选择性注意力(Selected Attention) 通过 Top-K 选择 只保留最重要的 Key-Value,降低计算量。
  • 计算复杂度从 O ( N 2 ) O(N^2) O(N2) 降低到 O ( N ⋅ K ) O(N \cdot K) O(N⋅K) ,适用于 超长文本(64k+)
  • 相较于全注意力和压缩注意力,选择性注意力能更精准地保留信息,同时减少计算成本。
  • 可进一步优化 通过 更快的 Top-K 选择算法自适应 K 值选择 提升性能。

滑动窗口注意力(Sliding Attention)机制解析

目标

  • 在减少计算量的同时,保留局部上下文信息 ,确保模型能够感知短期依赖关系
  • 结合 压缩注意力(Compressed Attention)选择性注意力(Selected Attention) ,在局部窗口范围内保留完整的注意力计算,避免远程信息丢失。

1. 为什么需要滑动窗口注意力?

  • 全注意力(Full Attention) 计算复杂度 O(N²),长序列(64k+)下计算成本极高。
  • 压缩注意力(Compressed Attention) 关注全局信息,但可能会丢失局部细节。
  • 选择性注意力(Selected Attention) 关注最关键的信息,但可能无法保留局部语境。
  • 滑动窗口注意力(Sliding Attention) 通过局部窗口机制 ,确保模型可以关注最近的信息,同时减少计算量。

2. 滑动窗口注意力的核心步骤

NSA 采用 基于局部窗口的注意力计算(Local Context Attention) ,主要分为 四步

2.1. 定义窗口范围

  • 设序列长度为 T T T,窗口大小设定为 W W W(window size) ,则对于每个 Query Q i Q_i Qi,它只会计算:

K win , i = { k i − W , k i − W + 1 , ... , k i } K_{\text{win}, i} = \{ k_{i-W}, k_{i-W+1}, \dots, k_i \} Kwin,i={ki−W,ki−W+1,...,ki}

V win , i = { v i − W , v i − W + 1 , ... , v i } V_{\text{win}, i} = \{ v_{i-W}, v_{i-W+1}, \dots, v_i \} Vwin,i={vi−W,vi−W+1,...,vi}

  • 窗口只包含最近 W W W 个 Token ,降低计算复杂度。

  • 可变窗口机制:可根据任务需求设定不同的窗口大小(例如代码生成任务可能需要更大的窗口)。

2.2. 计算窗口内的 Query-Key 注意力

  • 在窗口范围 W W W 内计算标准注意力:

A win , i = Q i K win , i T d k A_{\text{win}, i} = \frac{Q_i K_{\text{win}, i}^T}{\sqrt{d_k}} Awin,i=dk QiKwin,iT

  • 相比于全局注意力(O(N²)) ,窗口内计算量为 O(N × W) ,显著降低复杂度。

  • 仅关注 最近 W W W 个 Token,保证短期依赖关系。

2.3. 计算 Softmax 并加权求和

  • 计算窗口内的注意力分布:

A win , i ′ = Softmax ( A win , i ) A'{\text{win}, i} = \text{Softmax}(A{\text{win}, i}) Awin,i′=Softmax(Awin,i)

  • 计算最终的注意力输出:

O win , i = A win , i ′ V win , i O_{\text{win}, i} = A'{\text{win}, i} V{\text{win}, i} Owin,i=Awin,i′Vwin,i

2.4. 结合其他注意力机制

  • 最终输出:

O = g cmp O cmp + g sel O sel + g win O win O = g_{\text{cmp}} O_{\text{cmp}} + g_{\text{sel}} O_{\text{sel}} + g_{\text{win}} O_{\text{win}} O=gcmpOcmp+gselOsel+gwinOwin

  • g cmp , g sel , g win g_{\text{cmp}}, g_{\text{sel}}, g_{\text{win}} gcmp,gsel,gwin 是可学习的门控参数(Gating Mechanism)。

  • 这样可以在训练过程中,让模型学习最佳的注意力组合方式


3. 滑动窗口注意力的优势

方法 计算复杂度 局部信息保留 适用场景
全注意力(Full Attention) O ( N 2 ) O(N^2) O(N2) ✅ 完整 适用于短文本
压缩注意力(Compressed Attention) O ( N ⋅ M ) O(N \cdot M) O(N⋅M) ⚠️ 可能丢失局部信息 适用于长文本
选择性注意力(Selected Attention) O ( N ⋅ K ) O(N \cdot K) O(N⋅K) ⚠️ 仅保留关键 Token 适用于超长文本
滑动窗口注意力(Sliding Attention) O ( N ⋅ W ) O(N \cdot W) O(N⋅W) ✅ 重点保留局部信息 适用于超长文本(64k+)
  • 相比全注意力(Full Attention) ,滑动窗口注意力 显著减少计算量
  • 相比压缩注意力(Compressed Attention) ,滑动窗口注意力确保局部信息不会丢失
  • 相比选择性注意力(Selected Attention) ,滑动窗口注意力不会忽略短期依赖

4. PyTorch 实现

以下是 滑动窗口注意力(Sliding Attention) 的 PyTorch 代码:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class SlidingWindowAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        B, T, C = query.shape

        # Projection
        Q = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, H, T, d_k)
        K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # Initialize attention scores (masked)
        attn_scores = torch.full((B, self.num_heads, T, T), float('-inf'), device=query.device)

        # Apply sliding window mask
        for i in range(T):
            start_idx = max(0, i - self.window_size)
            attn_scores[:, :, i, start_idx:i+1] = torch.matmul(
                Q[:, :, i:i+1, :], K[:, :, start_idx:i+1, :].transpose(-2, -1)
            ) / (self.head_dim ** 0.5)

        # Compute attention with masked softmax
        attn_weights = F.softmax(attn_scores, dim=-1)

        # Compute attention output
        attn_output = torch.matmul(attn_weights, V)  # (B, H, T, d_k)

        # Reshape and output
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(attn_output)

# 示例调用
B, T, C = 2, 64, 128  # Batch size, Sequence length, Embedding dimension
num_heads = 8
window_size = 16  # 滑动窗口大小

attention = SlidingWindowAttention(C, num_heads, window_size)
query = torch.randn(B, T, C)
key = torch.randn(B, T, C)
value = torch.randn(B, T, C)

output = attention(query, key, value)
print(output.shape)  # (B, T, C)

5. 进一步优化方向

  1. 动态窗口大小
    • 当前的 窗口大小 W W W 是固定的 ,可以使用 自适应机制(Adaptive Window Size) 让模型学习最佳的窗口大小。
  2. 结合 FlashAttention 提高计算效率
    • 目前的 滑动窗口计算仍然需要遍历 Query ,可以优化成 块级计算(Blockwise Computation),提升 GPU 利用率。

6. 总结

  • 滑动窗口注意力(Sliding Attention) 通过 局部窗口计算,减少计算量,同时保留最近的上下文信息。
  • 计算复杂度从 O ( N 2 ) O(N^2) O(N2) 降低到 O ( N ⋅ W ) O(N \cdot W) O(N⋅W) ,适用于 超长文本(64k+)
  • 结合其他注意力机制 (压缩 + 选择性 + 滑动窗口)可以 提高计算效率,同时保留全局 + 局部信息

NSA 论文中如何结合三种注意力机制?

Natively Sparse Attention(NSA) 机制中,作者采用了一种 层次化稀疏注意力策略(Hierarchical Sparse Strategy) ,将 压缩注意力(Compressed Attention)、选择性注意力(Selected Attention)和滑动窗口注意力(Sliding Attention) 结合,以 同时保留全局信息、关键 Token 以及局部信息,提高计算效率并优化长序列建模。


1. NSA 采用的三条注意力路径

NSA 通过以下三种不同的注意力计算路径,让 Transformer 既能高效处理长序列,又不会丢失关键信息

  1. 压缩注意力(Compressed Attention)
    • 作用:全局信息提取
    • 方式:将 Key-Value 进行 块级压缩,生成粗粒度的全局 Token 表示。
    • 计算复杂度: O ( N ⋅ M ) O(N \cdot M) O(N⋅M)(其中 M ≪ N M \ll N M≪N)。
  2. 选择性注意力(Selected Attention)
    • 作用:筛选最关键的 Token 进行计算
    • 方式:对所有 Query 计算注意力分数,并选择 Top-K 重要 Token,仅对这些 Key 计算注意力。
    • 计算复杂度: O ( N ⋅ K ) O(N \cdot K) O(N⋅K)(其中 K ≪ N K \ll N K≪N)。
  3. 滑动窗口注意力(Sliding Attention)
    • 作用:局部上下文信息保留
    • 方式:每个 Query 仅在其 最近的 W W W 个 Token 内 计算注意力,保留短期依赖信息。
    • 计算复杂度: O ( N ⋅ W ) O(N \cdot W) O(N⋅W)(其中 W ≪ N W \ll N W≪N)。

最终的注意力输出是三种机制的加权和

O = g cmp O cmp + g sel O sel + g win O win O = g_{\text{cmp}} O_{\text{cmp}} + g_{\text{sel}} O_{\text{sel}} + g_{\text{win}} O_{\text{win}} O=gcmpOcmp+gselOsel+gwinOwin

其中:

  • g cmp , g sel , g win g_{\text{cmp}}, g_{\text{sel}}, g_{\text{win}} gcmp,gsel,gwin 是 可学习的门控参数(Gating Mechanism),用于控制不同注意力机制的重要性。

2. NSA 具体如何组合这三种注意力?

(1) 计算 Query-Key 相关性

首先,对 Query 计算三种不同 Key 形式的注意力分数:

  1. 压缩 Key( K cmp K_{\text{cmp}} Kcmp) :计算 Query 和 压缩后的 Key 的相关性:

A cmp = Q K cmp T d k A_{\text{cmp}} = \frac{Q K_{\text{cmp}}^T}{\sqrt{d_k}} Acmp=dk QKcmpT

  1. 选择性 Key( K sel K_{\text{sel}} Ksel) :计算 Query 和 Top-K 选择的 Key 的相关性:

A sel = Q K sel T d k A_{\text{sel}} = \frac{Q K_{\text{sel}}^T}{\sqrt{d_k}} Asel=dk QKselT

  1. 滑动窗口 Key( K win K_{\text{win}} Kwin) :计算 Query 在 局部窗口范围内 的注意力:

A win = Q K win T d k A_{\text{win}} = \frac{Q K_{\text{win}}^T}{\sqrt{d_k}} Awin=dk QKwinT

(2) 计算 Softmax 归一化

对每个注意力分数进行 Softmax 计算:

A ~ cmp = Softmax ( A cmp ) \tilde{A}{\text{cmp}} = \text{Softmax}(A{\text{cmp}}) A~cmp=Softmax(Acmp)

A ~ sel = Softmax ( A sel ) \tilde{A}{\text{sel}} = \text{Softmax}(A{\text{sel}}) A~sel=Softmax(Asel)

A ~ win = Softmax ( A win ) \tilde{A}{\text{win}} = \text{Softmax}(A{\text{win}}) A~win=Softmax(Awin)

(3) 计算注意力输出

计算不同注意力的加权求和:

O cmp = A ~ cmp V cmp O_{\text{cmp}} = \tilde{A}{\text{cmp}} V{\text{cmp}} Ocmp=A~cmpVcmp

O sel = A ~ sel V sel O_{\text{sel}} = \tilde{A}{\text{sel}} V{\text{sel}} Osel=A~selVsel

O win = A ~ win V win O_{\text{win}} = \tilde{A}{\text{win}} V{\text{win}} Owin=A~winVwin

(4) 加权融合不同注意力结果

最终的输出由三种注意力结果加权融合

O = g cmp O cmp + g sel O sel + g win O win O = g_{\text{cmp}} O_{\text{cmp}} + g_{\text{sel}} O_{\text{sel}} + g_{\text{win}} O_{\text{win}} O=gcmpOcmp+gselOsel+gwinOwin

其中:

  • g cmp , g sel , g win g_{\text{cmp}}, g_{\text{sel}}, g_{\text{win}} gcmp,gsel,gwin 是 可学习的门控参数,通过 MLP 计算:

g = σ ( MLP ( X ) ) g = \sigma(\text{MLP}(X)) g=σ(MLP(X))

其中 σ \sigma σ 是 Sigmoid 激活函数,确保 g g g 取值在 (0,1) 之间。


3. PyTorch 实现

以下是 结合三种注意力的 NSA 模型

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class NSA(nn.Module):
    def __init__(self, embed_dim, num_heads, top_k, window_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.top_k = top_k
        self.window_size = window_size
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.gate_mlp = nn.Sequential(nn.Linear(embed_dim, 3), nn.Sigmoid())  # 生成3个门控权重

    def forward(self, query, key, value):
        B, T, C = query.shape

        # Projection
        Q = self.q_proj(query).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算门控权重
        gate_weights = self.gate_mlp(query).unsqueeze(-1).unsqueeze(-1)  # (B, T, 3) -> (B, T, 3, 1, 1)

        # 压缩注意力
        K_cmp = K.mean(dim=-2, keepdim=True)
        V_cmp = V.mean(dim=-2, keepdim=True)
        attn_cmp = torch.matmul(Q, K_cmp.transpose(-2, -1)) / (self.head_dim ** 0.5)
        O_cmp = torch.matmul(F.softmax(attn_cmp, dim=-1), V_cmp)

        # 选择性注意力
        topk_values, topk_indices = torch.topk(attn_cmp, self.top_k, dim=-1)
        attn_sel = torch.zeros_like(attn_cmp).scatter_(-1, topk_indices, topk_values)
        O_sel = torch.matmul(F.softmax(attn_sel, dim=-1), V)

        # 滑动窗口注意力
        attn_win = attn_cmp.masked_fill(torch.arange(T)[:, None] < (torch.arange(T) - self.window_size), float('-inf'))
        O_win = torch.matmul(F.softmax(attn_win, dim=-1), V)

        # 加权求和
        O = gate_weights[..., 0] * O_cmp + gate_weights[..., 1] * O_sel + gate_weights[..., 2] * O_win
        O = O.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(O)

# 测试
B, T, C = 2, 64, 128
attention = NSA(C, num_heads=8, top_k=16, window_size=16)
query = torch.randn(B, T, C)
key = torch.randn(B, T, C)
value = torch.randn(B, T, C)
output = attention(query, key, value)
print(output.shape)  # (B, T, C)

总结

  • NSA 通过三种注意力机制的组合,既保证全局信息,又保留关键 Token 和局部上下文信息。
  • 最终的注意力结果通过可学习的门控机制(Gating Mechanism)进行融合,实现动态调整。
  • 计算复杂度降低到 O ( N log ⁡ K ) O(N \log K) O(NlogK),适用于超长文本(64k+)。

代码是AI生成的 还在调试中

相关推荐
workflower2 分钟前
Prompt Engineering的重要性
大数据·人工智能·设计模式·prompt·软件工程·需求分析·ai编程
curemoon21 分钟前
理解都远正态分布中指数项的精度矩阵(协方差逆矩阵)
人工智能·算法·矩阵
胡桃不是夹子1 小时前
CPU安装pytorch(别点进来)
人工智能·pytorch·python
Fansv5871 小时前
深度学习-6.用于计算机视觉的深度学习
人工智能·深度学习·计算机视觉
xjxijd2 小时前
AI 为金融领域带来了什么突破?
人工智能·其他
无奈何杨2 小时前
免费使用满血版DeepSeek-R1的多种方案
openai·deepseek
SKYDROID云卓小助手2 小时前
无人设备遥控器之如何分享数传篇
网络·人工智能·算法·计算机视觉·电脑
deephub2 小时前
LLM高效推理:KV缓存与分页注意力机制深度解析
人工智能·深度学习·语言模型
奋斗的袍子0072 小时前
Spring AI + Ollama 实现调用DeepSeek-R1模型API
人工智能·spring boot·深度学习·spring·springai·deepseek
青衫弦语2 小时前
【论文精读】VLM-AD:通过视觉-语言模型监督实现端到端自动驾驶
人工智能·深度学习·语言模型·自然语言处理·自动驾驶