Sliding Window Attention(滑动窗口注意力)解析
Sliding Window Attention(滑动窗口注意力) 是 Longformer (来源:https://arxiv.org/pdf/2004.05150)提出的 稀疏注意力机制 ,旨在解决 标准 Transformer 计算复杂度随序列长度增加呈二次增长 的问题。它的核心思想是:
- 每个 token 仅关注局部窗口内的其他 token,而不是整个序列。
- 计算复杂度从 ( O ( n 2 ) O(n^2) O(n2)) 降至 ( O ( n ⋅ w ) O(n \cdot w) O(n⋅w)) ,其中 ( w w w) 是窗口大小。
- 支持更长的文本处理,避免传统 Transformer 处理长序列时的显存和计算瓶颈。
该方法使 Transformer 能够高效处理上千到上万个 token 的长文本 ,特别适用于 文档级任务,如长文摘要、法律文本分析、医疗文档理解等。
1. 为什么需要 Sliding Window Attention?
1.1 传统 Transformer 的问题
Transformer 的 自注意力(Self-Attention) 机制需要计算所有 token 之间的交互:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax} \left(\frac{QK^T}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dk QKT)V
其中:
- ( Q, K, V ) 分别是 查询(Query)、键(Key)、值(Value) 矩阵,形状为 ( n × d k n \times d_k n×dk )。
- 计算量为 ( O ( n 2 d k ) O(n^2 d_k) O(n2dk) ),随着序列长度 ( n n n ) 增加,计算量急剧上升。
- 这导致 Transformer 无法处理长文本 ,因为显存需求和计算复杂度都 随 ( n 2 n^2 n2 ) 增长。
1.2 Sliding Window Attention 解决了什么问题?
- 局部注意力(Local Attention) :每个 token 仅与附近窗口内的 token 交互,而不是整个序列。
- 计算复杂度降低 :从 ( O ( n 2 ) O(n^2) O(n2) ) 降为 ( O ( n ⋅ w ) O(n \cdot w) O(n⋅w) ) ,其中 ( w w w ) 是窗口大小。
- 显存占用减少 :只需要存储窗口内的注意力权重,而非完整的 ( n × n n \times n n×n ) 矩阵。
这意味着,Sliding Window Attention 允许 Transformer 处理更长的序列 ,从传统的 512 tokens 提高到 8K-16K tokens 甚至更长。
2. Sliding Window Attention 计算原理
2.1 标准 Transformer Attention
在标准 Transformer 结构中:
- 每个 token 计算所有其他 token 的注意力权重。
- 形成一个 ( n × n n \times n n×n ) 的注意力矩阵。
- 计算复杂度:( O ( n 2 d ) O(n^2 d) O(n2d) )。
2.2 Sliding Window Attention
在 Sliding Window Attention 结构中:
- 每个 token 仅与窗口内的其他 token 交互。
- 注意力矩阵变为稀疏矩阵 ,只有窗口大小 ( w w w ) 内的注意力权重被计算。
- 计算复杂度变为:( O ( n ⋅ w ⋅ d ) O(n \cdot w \cdot d) O(n⋅w⋅d) )。
示例:
- 设 ( w = 5 w = 5 w=5 ),则:
- 第 10 个 token 仅关注
[8, 9, 10, 11, 12]
。 - 第 20 个 token 仅关注
[18, 19, 20, 21, 22]
。 - 这样每个 token 只计算 5 个注意力权重,而不是所有 n 个。
- 第 10 个 token 仅关注
3. Sliding Window Attention 的 PyTorch 实现
以下是 Longformer 的 Sliding Window Attention 计算的 PyTorch 实现:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SlidingWindowAttention(nn.Module):
def __init__(self, embed_dim, num_heads, window_size):
"""
滑动窗口注意力机制
Args:
embed_dim: 词嵌入维度 d
num_heads: 注意力头的数量 h
window_size: 滑动窗口大小 w
"""
super(SlidingWindowAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.window_size = window_size
self.head_dim = embed_dim // num_heads # 每个头的维度
assert self.head_dim * num_heads == embed_dim, "embed_dim 必须是 num_heads 的整数倍"
# 线性投影层
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.o_proj = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(self, x):
"""
前向传播
Args:
x: 输入张量 [batch, seq_len, embed_dim]
Returns:
输出张量 [batch, seq_len, embed_dim]
"""
batch_size, seq_len, _ = x.shape
# 计算 Q, K, V
Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, seq, d]
K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 初始化注意力矩阵
attn_scores = torch.full((batch_size, self.num_heads, seq_len, seq_len), float("-inf"), device=x.device)
# 计算滑动窗口注意力
for i in range(seq_len):
start = max(0, i - self.window_size)
end = min(seq_len, i + self.window_size + 1)
attn_scores[:, :, i, start:end] = torch.matmul(Q[:, :, i, :], K[:, :, start:end, :].transpose(-2, -1))
# 归一化
attn_scores /= self.head_dim ** 0.5
attn_weights = F.softmax(attn_scores, dim=-1)
# 计算注意力加权的 Value
output = torch.matmul(attn_weights, V)
# 重新排列形状
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
output = self.o_proj(output) # 线性变换回原始维度
return output
4. Sliding Window Attention 的优缺点
✅ 优点
- 计算复杂度降低 :从 ( O ( n 2 ) O(n^2) O(n2) ) 降至 ( O ( n ⋅ w ) O(n \cdot w) O(n⋅w) )。
- 可扩展到长序列 :支持 8K-16K tokens 甚至更长。
- 适用于文档级任务:如长文摘要、法律分析、医疗 NLP 任务。
❌ 缺点
- 不能捕捉远距离依赖:只能处理窗口范围内的 token 交互。
- 需要全局注意力补充 :必须结合 Global Attention 来补充远程信息(例如
CLS
位置)。
5. 结论
- Sliding Window Attention 解决了 Transformer 计算复杂度随序列长度二次增长的问题。
- 通过限制每个 token 只关注局部窗口内的 token,使得计算复杂度降低为 ( O ( n ⋅ w ) O(n \cdot w) O(n⋅w) )。
- 适用于长文本处理,并可结合 Global Attention 进一步提升模型性能。
🚀 这是 Transformer 在长文本任务上的关键优化方案之一!
Sliding Window Attention 如何支持长序列,并处理远距离依赖?
你提到的 "Sliding Window Attention 只能处理局部信息,那为什么还能支持 8K-16K 甚至更长的序列?" 这个问题很好。我们来详细拆解这个机制,看看它如何 既能高效处理长序列,又能解决远距离依赖问题。
1. Sliding Window Attention 适用于长序列的原因
(1) 计算复杂度降低
传统 全自注意力(Full Self-Attention) 计算复杂度为:
O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d)
即随着序列长度 ( n n n ) 增加,计算量呈二次增长。例如:
- ( n = 1024 n = 1024 n=1024 ) 时,需要计算 百万级 注意力分数。
- ( n = 8192 n = 8192 n=8192 ) 时,需要计算 千万级 注意力分数,显存消耗极大。
而 Sliding Window Attention 仅在局部窗口 ( w w w ) 内计算注意力:
O ( n ⋅ w ⋅ d ) O(n \cdot w \cdot d) O(n⋅w⋅d)
通常 ( w ≪ n w \ll n w≪n )(如 ( w = 512 , n = 8192 w=512, n=8192 w=512,n=8192 )),计算复杂度大幅降低,使得 处理长序列成为可能。
(2) 通过层级叠加,间接传播长距离信息
虽然单个 Sliding Window 只能看到局部范围的 token,但 Transformer 具有多层结构 ,可以通过 层级叠加 逐步扩展信息传播范围。
示例:
- 设窗口大小 ( w = 512 ),模型有 12 层 Transformer。
- 第 1 层:每个 token 只看到 相邻 512 个 token。
- 第 2 层:由于前一层已经融合了 512 个 token 信息,相当于 间接看到 1024 个 token。
- 第 3 层:可看到 1536 个 token。
- ......
- 第 12 层:最终可以捕捉 6144+ token 的信息。
这意味着,即使单层 Sliding Window 只能看到局部信息,但多层叠加后, 整个 Transformer 仍然能捕捉远程依赖。
✅ 这类似 CNN 中的感受野(Receptive Field)扩展:
- 低层捕捉局部信息,高层逐步扩大感受野。
- 顶层的
CLS
token 可以聚合全局信息。
2. 如何进一步增强远程依赖能力?
(1) 结合全局注意力(Global Attention)
Sliding Window 主要用于局部注意力,但为了处理关键任务位置(如 CLS
,任务相关实体) ,通常会额外增加 Global Attention :
Hybrid Attention = Sliding Window + Global Attention \text{Hybrid Attention} = \text{Sliding Window} + \text{Global Attention} Hybrid Attention=Sliding Window+Global Attention
- Global Attention 让
CLS
token 直接看到所有位置,用于捕捉全局信息。 - 关键 token(如问题 token、摘要 token)可被全局注意力连接,使远距离 token 之间的信息传递更高效。
(2) 结合 Dilated Attention(扩张窗口注意力)
为了提高远程依赖能力,可以使用 Dilated Sliding Window Attention(扩张窗口注意力):
- 例如,窗口间隔
gap = 2
,每个 token 除了看到最近的 512 个 token,还能看到更远的 token。 - 这种方法 类似 CNN 的 Dilated Convolution,可以扩大感受野,而不会增加太多计算量。
3. Sliding Window Attention 如何影响长文本任务?
-
适用于 8K+ 长文本摘要(Summarization)
- 长文摘要模型(如 Longformer )使用 Sliding Window + Global Attention,使
CLS
位置能整合全局信息。 - 例如:arXiv 论文摘要任务 ,输入 16K tokens,模型仍然可以高效运行。
- 长文摘要模型(如 Longformer )使用 Sliding Window + Global Attention,使
-
适用于长文 QA(Long Document QA)
- 传统 QA 需要截断上下文(如 BERT 只能用 512 tokens)。
- Longformer 可以处理 8K+ tokens,保证所有信息被覆盖,提升答案查找准确率。
-
适用于长文分类(Long Document Classification)
CLS
位置的 Global Attention 可以整合 8K+ tokens 的全局信息,提高分类准确度。
4. 结论
✅ Sliding Window Attention 可以扩展到 8K+ 序列,原因如下:
- 降低计算复杂度,从 ( O ( n 2 ) O(n^2) O(n2) ) 变为 ( O ( n ⋅ w ) O(n \cdot w) O(n⋅w) ),可扩展到长文本。
- 通过 Transformer 层级堆叠,多层次传递信息,间接覆盖全局依赖。
- 结合 Global Attention,让关键 token 直接连接全局,提高远程依赖建模能力。
- 结合 Dilated Attention(扩张窗口)可以进一步提升长距离信息传播。
🔹 最终,Sliding Window Attention + Global Attention + Dilated Attention 让 Transformer 既能高效处理长文本,又能捕捉全局依赖! 🚀
Hybrid Attention(滑动窗口注意力 + 全局注意力)解析
在 Longformer 等长序列 Transformer 结构中,Hybrid Attention 结合了 Sliding Window Attention(局部注意力)和 Global Attention(全局注意力),以同时实现:
- 高效计算 :滑动窗口注意力降低计算复杂度到 ( O ( n ⋅ w ) O(n \cdot w) O(n⋅w) ),适用于长文本。
- 全局依赖捕捉 :全局注意力允许关键 token(如
CLS
、问题 token)能访问所有 tokens,确保长距离信息流通。
1. 什么是 Global Attention?
在标准 全自注意力(Self-Attention) 机制中:
- 每个 token 计算所有 token 的注意力权重 ,计算复杂度 ( O ( n 2 ) O(n^2) O(n2) )。
- 但 Sliding Window Attention 仅计算局部窗口内的 token,无法直接建模远程依赖。
Global Attention 的作用:
- 指定部分 token 作为"全局节点" ,这些 token 可以访问所有 tokens ,同时 所有 tokens 也可以访问这些全局节点。
- 一般用于关键任务相关 tokens,例如:
[CLS]
(分类任务)问题 tokens
(问答任务)摘要 tokens
(摘要任务)
Hybrid Attention 结合两者的方式:
Hybrid Attention = Sliding Window Attention + Global Attention \text{Hybrid Attention} = \text{Sliding Window Attention} + \text{Global Attention} Hybrid Attention=Sliding Window Attention+Global Attention
- 大部分 tokens 采用 Sliding Window 计算注意力 ,计算复杂度为 ( O ( n ⋅ w ) O(n \cdot w) O(n⋅w) )。
- 关键 tokens 采用 Global Attention,可以访问整个序列,补充长距离信息。
2. Hybrid Attention 计算方法
假设:
window_size = 512
global_mask
指定哪些 token 需要全局注意力(如CLS
)。
计算步骤:
- 计算 Sliding Window Attention
- 每个 token 仅计算 窗口范围内的注意力。
- 计算 Global Attention
- 只有被标记为全局注意力的 token 计算 全局 self-attention。
- 这些 token 可以访问所有 tokens,所有 tokens 也可以访问它们。
- 合并两种注意力机制
- 局部 token 使用 Sliding Window Attention。
- 全局 token 额外加上 Global Attention 权重,确保远程依赖信息传递。
3. PyTorch 实现(可运行)
下面是完整的 Hybrid Attention(滑动窗口注意力 + 全局注意力) 的 PyTorch 实现:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class HybridAttention(nn.Module):
def __init__(self, embed_dim, num_heads, window_size):
"""
Hybrid Attention: Sliding Window Attention + Global Attention
Args:
embed_dim: 词嵌入维度 d
num_heads: 注意力头数量 h
window_size: 滑动窗口大小 w
"""
super(HybridAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.window_size = window_size
self.head_dim = embed_dim // num_heads # 每个头的维度
assert self.head_dim * num_heads == embed_dim, "embed_dim 必须是 num_heads 的整数倍"
# 线性投影层
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.o_proj = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(self, x, global_mask):
"""
Args:
x: 输入张量 [batch, seq_len, embed_dim]
global_mask: 是否是全局注意力的 mask [batch, seq_len],1 表示全局注意力,0 表示普通窗口注意力
Returns:
输出张量 [batch, seq_len, embed_dim]
"""
batch_size, seq_len, _ = x.shape
# 计算 Q, K, V
Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, seq, d]
K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 初始化注意力分数
attn_scores = torch.full((batch_size, self.num_heads, seq_len, seq_len), float("-inf"), device=x.device)
# 计算 Sliding Window Attention
for i in range(seq_len):
start = max(0, i - self.window_size)
end = min(seq_len, i + self.window_size + 1)
attn_scores[:, :, i, start:end] = torch.matmul(Q[:, :, i, :], K[:, :, start:end, :].transpose(-2, -1))
# 计算 Global Attention
global_mask_expanded = global_mask.unsqueeze(1).unsqueeze(1) # [batch, 1, 1, seq_len]
global_indices = global_mask_expanded.expand(-1, self.num_heads, seq_len, -1) # 扩展到注意力头维度
attn_scores.masked_fill_(global_indices == 0, float("-inf")) # 让非全局 token 只计算滑动窗口注意力
global_attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # 全局 attention 计算
attn_scores = torch.where(global_indices == 1, global_attn_scores, attn_scores) # 合并全局和局部注意力
# 归一化 softmax
attn_scores /= self.head_dim ** 0.5
attn_weights = F.softmax(attn_scores, dim=-1)
# 计算注意力加权的 Value
output = torch.matmul(attn_weights, V)
# 重新排列形状
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
output = self.o_proj(output) # 线性变换回原始维度
return output
4. 代码解读
-
计算 Sliding Window Attention
- 只在窗口范围内计算注意力分数,保证计算复杂度 ( O(n \cdot w) )。
-
计算 Global Attention
- 通过
global_mask
选择需要全局注意力的 token(如CLS
)。 - 计算这些 token 与所有 tokens 之间的注意力分数。
- 通过
-
融合全局 & 局部注意力
- 使用
torch.where()
选择是否应用全局注意力:- 全局 token:使用全局 self-attention 计算的权重。
- 局部 token:仅计算滑动窗口内的注意力。
- 使用
5. 运行示例
python
# 测试 Hybrid Attention
batch_size, seq_len, embed_dim, num_heads, window_size = 2, 16, 64, 8, 4
x = torch.randn(batch_size, seq_len, embed_dim)
global_mask = torch.zeros(batch_size, seq_len, dtype=torch.long) # 默认无全局注意力
global_mask[:, 0] = 1 # 让 CLS 位置作为全局注意力
hybrid_attn = HybridAttention(embed_dim, num_heads, window_size)
output = hybrid_attn(x, global_mask)
print(output.shape) # 预期: (batch_size, seq_len, embed_dim)
6. 结论
- Hybrid Attention 结合了 Sliding Window 和 Global Attention,使 Transformer 既高效又能捕捉远程依赖。
- PyTorch 实现支持运行,适用于长文本任务(如长文摘要、QA、分类等)。
- 适用于 8K+ 长文本,提高推理效率,同时保持全局信息流通!🚀
Hybrid Attention 计算 Global Attention 详细解析
代码段
python
# 计算 Global Attention
global_mask_expanded = global_mask.unsqueeze(1).unsqueeze(1) # [batch, 1, 1, seq_len]
global_indices = global_mask_expanded.expand(-1, self.num_heads, seq_len, -1) # 扩展到注意力头维度
attn_scores.masked_fill_(global_indices == 0, float("-inf")) # 让非全局 token 只计算滑动窗口注意力
global_attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # 全局 attention 计算
attn_scores = torch.where(global_indices == 1, global_attn_scores, attn_scores) # 合并全局和局部注意力
1. 代码解析
(1) 处理 global_mask
python
global_mask_expanded = global_mask.unsqueeze(1).unsqueeze(1) # [batch, 1, 1, seq_len]
作用
global_mask
是一个[batch, seq_len]
形状的张量,标识哪些 token 是全局注意力。unsqueeze(1).unsqueeze(1)
作用是扩展维度,使其形状变成[batch, 1, 1, seq_len]
,以便后续expand()
操作匹配attn_scores
形状。
示例
假设 global_mask
:
python
global_mask = torch.tensor([
[1, 0, 0, 0, 0], # Batch 0:CLS(位置 0)是全局 token
[0, 0, 1, 0, 0] # Batch 1:位置 2 是全局 token
])
经过 unsqueeze()
变成:
python
global_mask_expanded = torch.tensor([
[[[1, 0, 0, 0, 0]]], # Batch 0
[[[0, 0, 1, 0, 0]]], # Batch 1
]) # 形状 [batch=2, 1, 1, seq_len=5]
(2) 扩展 global_mask
维度以匹配 attn_scores
python
global_indices = global_mask_expanded.expand(-1, self.num_heads, seq_len, -1) # [batch, num_heads, seq_len, seq_len]
作用
expand(-1, self.num_heads, seq_len, -1)
扩展维度 ,使global_indices
形状与attn_scores
匹配。
示例(假设 num_heads = 2
)
python
global_indices = torch.tensor([
# Batch 0
[[[1, 0, 0, 0, 0], # 头 1
[1, 0, 0, 0, 0]]], # 头 2
# Batch 1
[[[0, 0, 1, 0, 0], # 头 1
[0, 0, 1, 0, 0]]] # 头 2
]) # 形状 [batch=2, num_heads=2, seq_len=5, seq_len=5]
- 现在,每个 batch 的
global_indices
里:1
表示全局注意力 token。0
表示普通 token。
(3) 让普通 token 只计算滑动窗口内的注意力
python
attn_scores.masked_fill_(global_indices == 0, float("-inf"))
作用
- 让 非全局 token 的注意力变成
-inf
,确保它们只能计算滑动窗口范围的注意力。
示例
假设 attn_scores
初始值:
python
attn_scores = torch.tensor([
[[10, 20, 30, 40, 50],
[15, 25, 35, 45, 55]],
[[5, 10, 15, 20, 25],
[10, 15, 20, 25, 30]]
], dtype=torch.float32)
执行 masked_fill_()
后:
python
attn_scores = torch.tensor([
[[10, -inf, -inf, -inf, -inf],
[15, -inf, -inf, -inf, -inf]],
[[-inf, -inf, 15, -inf, -inf],
[-inf, -inf, 20, -inf, -inf]]
])
现在:
- 普通 token 的注意力分数变成
-inf
,它们只能计算滑动窗口范围内的 token。 - 全局 token(如 CLS)不受影响,它们仍然可以访问所有 token。
(4) 计算全局 Attention
python
global_attn_scores = torch.matmul(Q, K.transpose(-2, -1)) # 计算全局注意力
作用
- 全局 token 应该能访问所有 token ,所以这里 计算完整的注意力分数矩阵。
Q @ K^T
计算所有Q
和K
之间的点积注意力。
(5) 合并全局和滑动窗口注意力
python
attn_scores = torch.where(global_indices == 1, global_attn_scores, attn_scores)
作用
global_indices == 1
位置使用 完整的全局注意力global_attn_scores
。- 其他位置仍然使用 滑动窗口注意力
attn_scores
。
示例
假设 global_attn_scores
:
python
global_attn_scores = torch.tensor([
[[100, 110, 120, 130, 140],
[105, 115, 125, 135, 145]],
[[50, 60, 70, 80, 90],
[55, 65, 75, 85, 95]]
])
执行 torch.where()
后:
python
attn_scores = torch.tensor([
[[100, -inf, -inf, -inf, -inf],
[105, -inf, -inf, -inf, -inf]],
[[-inf, -inf, 70, -inf, -inf],
[-inf, -inf, 75, -inf, -inf]]
])
现在:
- 全局 token (
global_indices == 1
) 采用全局注意力global_attn_scores
。 - 普通 token (
global_indices == 0
) 继续使用滑动窗口注意力(-inf
表示屏蔽)。
6. 结论
代码 | 作用 |
---|---|
global_mask_expanded = global_mask.unsqueeze(1).unsqueeze(1) |
扩展 global_mask 形状,方便与 attn_scores 匹配 |
global_indices = global_mask_expanded.expand(-1, self.num_heads, seq_len, -1) |
复制 global_mask 到所有注意力头 |
attn_scores.masked_fill_(global_indices == 0, float("-inf")) |
让普通 token 只能访问滑动窗口内的 token |
global_attn_scores = torch.matmul(Q, K.transpose(-2, -1)) |
计算完整的全局注意力分数 |
attn_scores = torch.where(global_indices == 1, global_attn_scores, attn_scores) |
让全局 token 采用全局注意力,而普通 token 继续使用滑动窗口 |
🚀 最终,我们的 Hybrid Attention 既能高效计算长文本,又能让 CLS
等关键 token 访问全局信息! 🎯
后记
2025年2月23日14点36分于上海,在GPT 4o大模型辅助下完成。