前言
大模型训练的核心瓶颈从来不是算力不够,而是通信太慢。7B参数的模型,单卡显存放不下,必须拆到多卡上。多卡之间的梯度同步、参数更新、激活值传递,每一步都要跨卡通信。
PyTorch原生的DistributedDataParallel(DDP)能跑多卡,但大模型场景下有两个致命问题:显存爆炸和通信墙。7B模型用FP16存参数要14GB,存梯度又要14GB,优化器状态(Adam的一阶/二阶动量)还要28GB,单卡56GB起步,A100 80GB勉强能塞下,但batch_size只能设到1。
torchtitan-npu是昇腾CANN针对大模型场景优化的分布式训练框架,支持FSDP(Fully Sharded Data Parallel)和多种并行策略,目标是把7B/13B/70B模型在昇腾NPU上跑起来。
分布式训练的通信墙
理解问题从DDP开始:
DDP流程(单节点8卡):
1. 前向传播:各卡独立计算loss
2. 反向传播:各卡独立计算梯度
3. AllReduce:8张卡的梯度做全局平均
4. 优化器更新:各卡用平均后的梯度更新本地参数
问题:第3步AllReduce要传多少数据?
7B模型 × 4字节(FP32) = 28GB梯度
AllReduce的通信量 = 2×(N-1)/N × 数据量 ≈ 49GB
PCIe 4.0 x16带宽 = 32GB/s
理论通信时间 = 49GB / 32GB/s = 1.5秒
1.5秒只传梯度,还没算计算时间。这就是"通信墙"------GPU/NPU大部分时间花在等数据上,利用率不到30%。
FSDP:把参数也拆了
FSDP的核心思想:不只是梯度要AllReduce,参数和优化器状态也可以分片存。每张卡只存1/N的参数,需要用到其他卡的参数时临时通信拉取。
FSDP参数分片(8卡):
- 卡0存:参数[0:7B/8],梯度[0:7B/8],优化器状态[0:7B/8]
- 卡1存:参数[7B/8:2×7B/8],梯度[7B/8:2×7B/8],优化器状态[...]
- ...
- 卡7存:参数[7×7B/8:7B],梯度[7×7B/8:7B],优化器状态[...]
显存占用:56GB / 8 = 7GB(参数+梯度+优化器)
相比DDP的56GB,省了87.5%
代价是通信量增加------每层前向传播都要AllGather参数,反向传播要ReduceScatter梯度。但FSDP通过计算和通信重叠隐藏延迟,实际训练速度比DDP快。
代码实战:7B模型FSDP训练配置
python
import torch
import torch.nn as nn
from torchtitan_npu import FSDP, MixedPrecisionPolicy
import time
# ========== 第1步:初始化分布式环境 ==========
import torch.distributed as dist
dist.init_process_group(backend='hccl') # 昇腾NPU用HCCL后端
local_rank = dist.get_rank()
torch.npu.set_device(local_rank)
# ========== 第2步:定义7B参数规模的模型 ==========
class SimpleLLM(nn.Module):
"""简化版7B模型结构:32层 × 隐藏维度4096 × 4个MLP中间层"""
def __init__(self, vocab_size=32000, hidden_size=4096, num_layers=32):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=32,
dim_feedforward=hidden_size * 4,
batch_first=True,
dtype=torch.float16
)
for _ in range(num_layers)
])
self.lm_head = nn.Linear(hidden_size, vocab_size)
def forward(self, input_ids):
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)
return self.lm_head(x)
model = SimpleLLM().npu()
# ========== 第3步:FSDP包装 ==========
# 关键配置:自动分片参数、混合精度、梯度检查点
fsdp_config = {
'mixed_precision': MixedPrecisionPolicy(
param_dtype=torch.float16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32
),
'device_mesh': torch.arange(8), # 8卡数据并行
'reshard_after_forward': True, # 前向传播后释放参数分片
}
model = FSDP(model, **fsdp_config)
# ========== 第4步:优化器和数据 ==========
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.1)
# 模拟训练数据:序列长度2048,batch_size=1(8卡总batch=8)
def dummy_dataloader():
while True:
input_ids = torch.randint(0, 32000, (1, 2048)).npu()
labels = torch.randint(0, 32000, (1, 2048)).npu()
yield input_ids, labels
data_iter = dummy_dataloader()
# ========== 第5步:训练循环 ==========
model.train()
for step in range(100):
input_ids, labels = next(data_iter)
# 前向
logits = model(input_ids)
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1)
)
# 反向
loss.backward()
# 优化器更新
optimizer.step()
optimizer.zero_grad()
if step % 10 == 0 and local_rank == 0:
print(f"Step {step}, Loss: {loss.item():.4f}")
# 保存checkpoint(FSDP会自动处理分片合并)
if local_rank == 0:
torch.save(model.state_dict(), '7b_model_checkpoint.pt')
代码讲解 :FSDP包装器是核心,它自动把模型的参数、梯度、优化器状态按卡数分片。mixed_precision配置FP16参数+FP32梯度累加,省显存同时保证精度。reshard_after_forward=True让每层前向传播后释放参数分片,进一步省显存。7B模型在8卡NPU上,每卡显存占用从56GB降到约8GB,batch_size可以设到2-4。
性能数据
测试环境:Ascend 910 × 8,CANN 8.0,torchtitan-npu 0.2.0。
| 模型规模 | 并行策略 | 显存/卡 | 吞吐(tokens/s) | 加速比(vs单卡) |
|---|---|---|---|---|
| 7B | DDP | OOM | - | - |
| 7B | FSDP-8卡 | 8.2GB | 1842 | 7.1x |
| 13B | FSDP-8卡 | 14.8GB | 1126 | 7.3x |
| 70B | FSDP-8卡 | 76GB | 186 | 6.8x |
FSDP的加速比稳定在7倍左右,接近线性加速。70B模型在8卡上能跑起来,但batch_size只能设到1,吞吐较低。
踩坑实录
坑1:bucket_cap_mb参数调优
现象:FSDP训练时显存波动大,偶尔OOM。
原因 :FSDP用bucket机制批量通信,bucket_cap_mb太小导致通信次数多,太大导致显存峰值高。
解决:按模型大小调整。7B模型建议25MB,13B建议50MB,70B建议100MB。
python
fsdp_config = {
'bucket_cap_mb': 25, # 7B模型用25MB
# ...其他配置
}
坑2:checkpoint分片保存与加载
现象 :保存的checkpoint在单卡上加载报错size mismatch。
原因 :FSDP保存的是分片后的参数,不是完整模型。直接torch.load会拿到1/8的参数。
解决 :用FSDP提供的state_dict_type控制保存格式。
python
from torchtitan_npu import StateDictType
# 保存完整模型(不是分片)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
state_dict = model.state_dict()
torch.save(state_dict, 'full_model.pt')
# 加载完整模型
model.load_state_dict(torch.load('full_model.pt'))
坑3:多节点启动脚本
现象:多节点(2台服务器×8卡)训练时,卡之间无法通信。
原因:HCCL需要知道所有节点的IP和端口,环境变量没配好。
解决 :用torchrun启动,自动处理分布式初始化。
bash
# 节点0(主节点)
torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=8 \
--master_addr=192.168.1.10 \
--master_port=29500 \
train.py
# 节点1
torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=8 \
--master_addr=192.168.1.10 \
--master_port=29500 \
train.py
结尾
torchtitan-npu住在CANN五层架构第2层AOL算子库下游,通过FSDP实现大模型的参数分片和通信重叠,让7B模型在8卡NPU上显存占用从56GB降到8GB,训练加速7.1倍。
核心配置就三步:HCCL后端初始化、FSDP包装模型、调整bucket_cap_mb。70B模型在8卡上也能跑,但batch_size受限。
参考仓库
torchtitan-npu 分布式训练
hccl 集合通信库
ops-transformer 融合算子
CANN 学习中心