Mosaic:面向超长序列的多GPU注意力分片方案

Transformer的"二次方注意力瓶颈"的问题是老生常谈了。这个瓶颈到底卡在哪实际工程里怎么绕过去?本文从一个具体问题出发,介绍Mosaic这套多轴注意力分片方案的设计思路。

注意力的内存困境

注意力机制的计算公式:

复制代码
 Attention(Q, K, V) = softmax(QKᵀ / √d) × V

问题出在 QKᵀ 这个矩阵上,它的形状是

复制代码
(序列长度 × 序列长度)

拿150,000个token的序列算一下:

复制代码
 Memory = 150,000² × 4 bytes = 90 billion bytes ≈ 84 GB

这只是注意力权重本身的开销,而且还是单层、单头。A100的显存上限是80GB,放不下就是放不下。

现有方案的局限

FlashAttention 它通过分块计算,不需要把完整的注意力矩阵实例化出来,内存复杂度从O(n²)降到O(n)。单卡场景下效果很好,但问题是整个序列还是得塞进同一张GPU。

Ring Attention 换了个思路:把序列切片分到多张GPU上,每张卡持有一部分Q,K和V在GPU之间像传令牌一样轮转,一维序列处理起来是很不错的。

但是多维怎么办?

比如处理表格数据的Transformer,输入张量形状是

复制代码
(batch, rows, features, embed)

。模型需要在不同维度上做注意力:features维度只有5个token,rows维度却有150,000个。前者单卡轻松搞定,后者则必须分片。

现有的库都没法干净地处理这种多轴场景。手写的话,每个轴要单独写分片逻辑,进程组管理、张量reshape全得自己来。代码会变得很脏。

Mosaic的设计

Mosaic本质上是个协调层,负责把不同的注意力轴路由到合适的计算后端:

复制代码
 import mosaic

# Small axis: run locally
feature_attn = mosaic.MultiAxisAttention(  
    embed_dim=96,   
    num_heads=4,  
    attention_axis=2,    # features dimension
    backend="local"      # no communication needed
)

# Large axis: shard across GPUs
row_attn = mosaic.MultiAxisAttention(  
    embed_dim=96,   
    num_heads=4,  
    attention_axis=1,    # rows dimension
    backend="ring"       # ring attention across GPUs
 )

底层Mosaic会自动处理轴的置换、QKV投影前的reshape、后端分发、以及计算完成后张量形状的还原。模型代码保持清晰,分布式的复杂性被封装掉了。

Ring Attention的工作机制

核心思想其实很直接:不需要同时持有全部的K和V。可以分批计算注意力分数,逐步累积,最后再做归一化。

比如说4张GPU的情况下流程是这样的:

复制代码
 Initial state:  
  GPU 0: Q₀, K₀, V₀  
  GPU 1: Q₁, K₁, V₁    
  GPU 2: Q₂, K₂, V₂  
  GPU 3: Q₃, K₃, V₃

Step 1: Each GPU computes attention with its local K, V  
  GPU 0: score₀₀ = Q₀ @ K₀ᵀ  
  ...

Step 2: Pass K, V to the next GPU in the ring  
  GPU 0 receives K₃, V₃ from GPU 3  
  GPU 0 sends K₀, V₀ to GPU 1  
    
Step 3: Compute attention with received K, V  
  GPU 0: score₀₃ = Q₀ @ K₃ᵀ  
  Accumulate with score₀₀

Repeat for all chunks...

 Final: Each GPU has complete attention output for its Q chunk

单卡内存占用变成O(n²/p),p是GPU数量。8张卡的话内存需求直接砍到1/8。150k序列从84GB降到约10GB每卡。

Mesh2D:更激进的分片

序列特别长的时候Ring Attention的线性分片可能还不够,这时候可以用Mesh2D把Q和K都切分了:

复制代码
 4 GPUs arranged in 2×2 mesh:

          K₀    K₁  
       ┌──────┬──────┐  
  Q₀   │GPU 0 │GPU 1 │  
       ├──────┼──────┤  
  Q₁   │GPU 2 │GPU 3 │  
       └──────┴──────┘  
         
 Each GPU computes one tile of QKᵀ

内存复杂度降到O(n²/p²)。64张卡组成8×8网格时,每卡内存需求下降64倍。

复制代码
 attn=mosaic.MultiAxisAttention(  
     embed_dim=128,   
     num_heads=8,  
     attention_axis=1,  
     backend="mesh2d",  
     mesh_shape=(8, 8)  
 )

感知集群拓扑的组合策略

在实际部署环境里,不同GPU之间的通信带宽差异很大。节点内GPU走NVLink能到900 GB/s,跨节点通过InfiniBand通常只有200 GB/s左右。

复制代码
ComposedAttention

就是针对这种拓扑特征设计的:

复制代码
 # 4 nodes × 8 GPUs = 32 total
 composed = mosaic.ComposedAttention(  
     mesh_shape=(4, 8),       # (nodes, gpus_per_node)
     head_parallel=True,      # Split heads across nodes (slow link)
     seq_parallel="ring"      # Ring within nodes (fast link)
 )

需要更精细控制的话,可以用

复制代码
HierarchicalAttention

复制代码
 hier = mosaic.HierarchicalAttention(  
     intra_node_size=8,  
     intra_node_strategy="local",   # Compute locally within node
     inter_node_strategy="ring"     # Ring between node leaders
 )

重通信走快链路轻通信才跨节点。

实现细节

整个库大约800行Python,核心代码如下:

复制代码
 class MultiAxisAttention(nn.Module):  
    def forward(self, x):  
        # 1. Move attention axis to seq position
        x, inv_perm = self._permute_to_seq(x)  
          
        # 2. Flatten batch dims, project QKV
        x = x.view(-1, seq_len, embed_dim)  
        qkv = self.qkv_proj(x).view(batch, seq, 3, heads, head_dim)  
        q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)  
          
        # 3. Dispatch to backend
        out = self._attn_fn(q, k, v)  # local, ring, or mesh2d
          
        # 4. Project output, restore shape
        out = self.out_proj(out.transpose(1, 2).reshape(...))  
         return out.permute(inv_perm)

后端封装了现有的成熟实现:

复制代码
local

后端调用

复制代码
F.scaled_dot_product_attention

(也就是FlashAttention),

复制代码
ring

后端用ring-flash-attn库的

复制代码
ring_flash_attn_func

复制代码
mesh2d

是自定义的all-gather加SDPA,所有的底层都跑的是FlashAttention内核。

所有后端统一用FlashAttention的融合GEMM+softmax实现。后端函数在初始化时就绑定好,前向传播不做分支判断。张量操作尽量用

复制代码
x.view()

而不是

复制代码
x.reshape()

,保持内存连续性。集合通信的目标张量预分配好,避免

复制代码
torch.cat

的开销。模块级别做导入不在每次前向传播时产生import开销。

快速上手

安装:

复制代码
 pip install git+https://github.com/stprnvsh/mosaic.git
 
 # With ring attention support
 pip install flash-attn ring-flash-attn

单节点启动:

复制代码
 torchrun --nproc_per_node=4 train.py

多节点的话:

复制代码
 # Node 0
 torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 \  
          --master_addr=192.168.1.100 --master_port=29500 train.py
 
 # Node 1
 torchrun --nnodes=2 --nproc_per_node=8 --node_rank=1 \  
          --master_addr=192.168.1.100 --master_port=29500 train.py

训练脚本示例:

复制代码
 import mosaic  
import torch.distributed as dist

dist.init_process_group("nccl")  
ctx = mosaic.init(sp_size=dist.get_world_size())

model = MyModel().to(ctx.device)

# Data is pre-sharded: each GPU has seq_total / world_size tokens
x_local = load_my_shard()  
 out = model(x_local)  # Communication handled by Mosaic

总结

最后,Mosaic不会自动并行化模型(这个用nnScaler),不管数据并行(PyTorch DDP/FSDP的事),也不处理模型分片(交给FSDP或Megatron)。

Mosaic专注于一件事:多轴注意力的分片路由,这套方案最初是给 nanoTabPFN 做的,一个表格数据Transformer。

这个模型要同时在rows(150k个)和features(5个)两个维度做注意力。标准Ring Attention对维度语义没有感知,它只认序列这个概念,分不清rows和features的区别。

所以Mosaic需求很明确:小轴本地算,大轴分布式算,轴的路由逻辑不能侵入模型代码,有兴趣的可以试试。

https://avoid.overfit.cn/post/791e0f30540e4d289a43d01d383e8ab2

作者:Pranav Sateesh

相关推荐
陈广亮24 分钟前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬33 分钟前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia1 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区1 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两4 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪4 小时前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
strayCat232554 小时前
Clawdbot 源码解读 7: 扩展机制
人工智能·开源
程序员打怪兽4 小时前
详解Visual Transformer (ViT)网络模型
深度学习
王鑫星4 小时前
SWE-bench 首次突破 80%:Claude Opus 4.5 发布,Anthropic 的野心不止于写代码
人工智能