本文基于CANN开源社区的多个仓库进行技术解读
CANN组织地址:https://atomgit.com/cann
hccl仓库地址:https://atomgit.com/cann/hccl
runtime仓库地址:https://atomgit.com/cann/runtime
前言
随着模型规模的不断增大,单卡已无法容纳完整模型。模型并行技术通过将模型分割到多个设备上,使得训练超大规模模型成为可能。CANN提供了完整的模型并行支持。
本文将深入解读CANN模型并行的原理、实现方式、通信优化以及最佳实践,帮助你高效训练和部署大模型。
模型并行基础
1. 并行策略对比
python
# 不同并行策略对比
import torch
import torch_npu
import torch.distributed as dist
class ParallelismComparison:
"""
并行策略对比
"""
def data_parallelism(self):
"""
数据并行
每个设备持有完整模型,处理不同数据
"""
example = """
# 数据并行示例
设备0: Model(完整) + Data[0:32]
设备1: Model(完整) + Data[32:64]
设备2: Model(完整) + Data[64:96]
设备3: Model(完整) + Data[96:128]
特点:
- 模型复制:每个设备一份
- 数据分片:不同设备处理不同数据
- 梯度同步:AllReduce梯度
适用场景:
- 模型较小,单卡可容纳
- 数据量大
- 需要高吞吐量
"""
return example
def model_parallelism(self):
"""
模型并行
模型分割到多个设备,处理相同数据
"""
example = """
# 模型并行示例
设备0: Layer[0-10] + Data(完整)
设备1: Layer[11-20] + Data(完整)
设备2: Layer[21-30] + Data(完整)
设备3: Layer[31-40] + Data(完整)
特点:
- 模型分片:不同设备持有不同层
- 数据复制:每个设备处理相同数据
- 激活传递:层间传递激活值
适用场景:
- 模型超大,单卡无法容纳
- 层数多
- 内存受限
"""
return example
def pipeline_parallelism(self):
"""
流水线并行
模型分段,数据流水线处理
"""
example = """
# 流水线并行示例
时间步1: 设备0处理Batch0
时间步2: 设备0处理Batch1, 设备1处理Batch0
时间步3: 设备0处理Batch2, 设备1处理Batch1, 设备2处理Batch0
时间步4: 设备0处理Batch3, 设备1处理Batch2, 设备2处理Batch1, 设备3处理Batch0
特点:
- 模型分段:类似模型并行
- 流水线:多个micro-batch并行
- 气泡时间:需要优化
适用场景:
- 超大模型
- 需要高吞吐量
- 可容忍一定延迟
"""
return example
# 打印对比
comparison = ParallelismComparison()
print(comparison.data_parallelism())
print(comparison.model_parallelism())
print(comparison.pipeline_parallelism())
2. 张量并行
python
# 张量并行实现
class TensorParallelism:
"""
张量并行
将张量切分到多个设备
"""
def __init__(self, world_size):
self.world_size = world_size
self.rank = dist.get_rank()
def column_parallel_linear(self, input, weight, bias=None):
"""
列并行线性层
权重按列切分
"""
# 权重形状: (out_features, in_features)
# 切分: 每个设备持有 (out_features/world_size, in_features)
# 本地计算
output_parallel = torch.matmul(input, weight.t())
if bias is not None:
output_parallel = output_parallel + bias
# AllGather收集所有设备的输出
output = self.all_gather(output_parallel)
return output
def row_parallel_linear(self, input, weight, bias=None):
"""
行并行线性层
权重按行切分
"""
# 权重形状: (out_features, in_features)
# 切分: 每个设备持有 (out_features, in_features/world_size)
# 输入也需要切分
input_parallel = self.split_tensor(input, dim=-1)
# 本地计算
output_parallel = torch.matmul(input_parallel, weight.t())
# AllReduce求和
output = self.all_reduce(output_parallel)
if bias is not None and self.rank == 0:
output = output + bias
return output
def split_tensor(self, tensor, dim):
"""切分张量"""
chunk_size = tensor.size(dim) // self.world_size
start = self.rank * chunk_size
end = start + chunk_size
return tensor.narrow(dim, start, chunk_size)
def all_gather(self, tensor):
"""AllGather操作"""
tensor_list = [torch.zeros_like(tensor) for _ in range(self.world_size)]
dist.all_gather(tensor_list, tensor)
return torch.cat(tensor_list, dim=-1)
def all_reduce(self, tensor):
"""AllReduce操作"""
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
return tensor
# 使用示例
# 初始化分布式环境
# dist.init_process_group(backend='hccl')
# tp = TensorParallelism(world_size=4)
# 列并行
# weight_col = torch.randn(256, 1024).npu() # 每个设备持有1/4
# output = tp.column_parallel_linear(input, weight_col)
# 行并行
# weight_row = torch.randn(1024, 256).npu() # 每个设备持有1/4
# output = tp.row_parallel_linear(input, weight_row)
print("张量并行将大张量切分到多个设备")
流水线并行
1. GPipe实现
python
# GPipe流水线并行
class GPipePipeline:
"""
GPipe流水线并行
将模型分成多个stage,micro-batch流水线执行
"""
def __init__(self, model_stages, num_microbatches):
self.stages = model_stages
self.num_microbatches = num_microbatches
self.num_stages = len(model_stages)
def split_batch(self, batch, num_splits):
"""将batch分成micro-batches"""
batch_size = batch.size(0)
micro_batch_size = batch_size // num_splits
micro_batches = []
for i in range(num_splits):
start = i * micro_batch_size
end = start + micro_batch_size
micro_batches.append(batch[start:end])
return micro_batches
def forward(self, batch):
"""前向传播"""
# 分割batch
micro_batches = self.split_batch(batch, self.num_microbatches)
# 流水线执行
outputs = []
activations = [[] for _ in range(self.num_stages)]
# 前向传播
for micro_batch in micro_batches:
x = micro_batch
for stage_idx, stage in enumerate(self.stages):
x = stage(x)
activations[stage_idx].append(x)
outputs.append(x)
return outputs, activations
def backward(self, outputs, activations, targets):
"""反向传播"""
# 计算损失
losses = []
for output, target in zip(outputs, targets):
loss = torch.nn.functional.cross_entropy(output, target)
losses.append(loss)
# 反向传播
for loss in reversed(losses):
loss.backward()
# 使用示例
# 定义模型stages
stage1 = torch.nn.Sequential(
torch.nn.Linear(1024, 2048),
torch.nn.ReLU()
)
stage2 = torch.nn.Sequential(
torch.nn.Linear(2048, 2048),
torch.nn.ReLU()
)
stage3 = torch.nn.Sequential(
torch.nn.Linear(2048, 1024),
torch.nn.ReLU()
)
stage4 = torch.nn.Linear(1024, 10)
# 创建流水线
pipeline = GPipePipeline(
model_stages=[stage1, stage2, stage3, stage4],
num_microbatches=4
)
# 训练
# batch = torch.randn(32, 1024)
# outputs, activations = pipeline.forward(batch)
# pipeline.backward(outputs, activations, targets)
print("GPipe通过micro-batch减少流水线气泡")
2. 1F1B调度
python
# 1F1B调度策略
class OneFOneBScheduler:
"""
1F1B (One Forward One Backward) 调度
交替执行前向和反向,减少内存占用
"""
def __init__(self, num_stages, num_microbatches):
self.num_stages = num_stages
self.num_microbatches = num_microbatches
def schedule(self):
"""生成调度计划"""
schedule = []
# Warmup阶段:只做前向
warmup_steps = self.num_stages - 1
for step in range(warmup_steps):
schedule.append(('F', step)) # Forward micro-batch step
# 1F1B阶段:交替前向和反向
for step in range(warmup_steps, self.num_microbatches):
schedule.append(('F', step)) # Forward
schedule.append(('B', step - warmup_steps)) # Backward
# Cooldown阶段:只做反向
for step in range(self.num_microbatches - warmup_steps, self.num_microbatches):
schedule.append(('B', step))
return schedule
def execute(self, model_stage, micro_batches):
"""执行调度"""
schedule = self.schedule()
activations = {}
for action, step in schedule:
if action == 'F':
# 前向传播
print(f"Forward micro-batch {step}")
output = model_stage(micro_batches[step])
activations[step] = output
elif action == 'B':
# 反向传播
print(f"Backward micro-batch {step}")
activation = activations[step]
activation.backward()
del activations[step] # 释放内存
# 使用示例
scheduler = OneFOneBScheduler(num_stages=4, num_microbatches=8)
schedule = scheduler.schedule()
print("1F1B调度计划:")
for action, step in schedule:
print(f" {action} micro-batch {step}")
# 1F1B的优势:
# - 内存效率:及时释放激活值
# - 流水线效率:减少气泡时间
# - 适合大模型:内存占用更少
通信优化
1. 通信与计算重叠
python
# 通信与计算重叠
class OverlapCommunication:
"""
通信与计算重叠
在计算的同时进行通信
"""
def __init__(self, model):
self.model = model
self.comm_stream = torch_npu.Stream()
self.compute_stream = torch_npu.Stream()
def forward_backward_overlap(self, inputs, targets):
"""前向反向与通信重叠"""
# 前向传播
with torch_npu.stream(self.compute_stream):
outputs = self.model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
# 反向传播
with torch_npu.stream(self.compute_stream):
loss.backward()
# 梯度通信(与下一批次的前向重叠)
with torch_npu.stream(self.comm_stream):
# 等待反向完成
self.comm_stream.wait_stream(self.compute_stream)
# AllReduce梯度
for param in self.model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, async_op=True)
# 使用示例
# model = MyModel().npu()
# overlap = OverlapCommunication(model)
# for batch in dataloader:
# overlap.forward_backward_overlap(inputs, targets)
print("通信与计算重叠隐藏通信延迟")
2. 梯度压缩
python
# 梯度压缩
class GradientCompression:
"""
梯度压缩
减少通信量
"""
def __init__(self, compression_ratio=0.01):
self.compression_ratio = compression_ratio
def compress(self, gradient):
"""压缩梯度"""
# Top-K压缩:只保留最大的k个梯度
k = int(gradient.numel() * self.compression_ratio)
# 展平
flat_grad = gradient.flatten()
# 找到top-k
_, indices = torch.topk(torch.abs(flat_grad), k)
# 创建稀疏梯度
values = flat_grad[indices]
return indices, values
def decompress(self, indices, values, shape):
"""解压缩梯度"""
# 创建零张量
gradient = torch.zeros(shape).flatten()
# 填充值
gradient[indices] = values
# 恢复形状
return gradient.view(shape)
def communicate_compressed(self, gradient):
"""通信压缩梯度"""
# 压缩
indices, values = self.compress(gradient)
# 通信(只传输indices和values)
# dist.all_reduce(indices)
# dist.all_reduce(values)
# 解压缩
decompressed = self.decompress(indices, values, gradient.shape)
return decompressed
# 使用示例
compressor = GradientCompression(compression_ratio=0.01)
# gradient = torch.randn(1000, 1000)
# compressed_grad = compressor.communicate_compressed(gradient)
print("梯度压缩减少通信量,加速训练")
混合并行
1. 3D并行
python
# 3D并行(数据+张量+流水线)
class ThreeDParallelism:
"""
3D并行
结合数据并行、张量并行和流水线并行
"""
def __init__(self, dp_size, tp_size, pp_size):
self.dp_size = dp_size # 数据并行度
self.tp_size = tp_size # 张量并行度
self.pp_size = pp_size # 流水线并行度
# 总设备数
self.world_size = dp_size * tp_size * pp_size
def get_parallel_groups(self):
"""获取并行组"""
groups = {
'data_parallel': [],
'tensor_parallel': [],
'pipeline_parallel': []
}
# 数据并行组:相同模型分片的设备
for pp in range(self.pp_size):
for tp in range(self.tp_size):
group = []
for dp in range(self.dp_size):
rank = dp * (self.tp_size * self.pp_size) + pp * self.tp_size + tp
group.append(rank)
groups['data_parallel'].append(group)
# 张量并行组:同一流水线stage的设备
for dp in range(self.dp_size):
for pp in range(self.pp_size):
group = []
for tp in range(self.tp_size):
rank = dp * (self.tp_size * self.pp_size) + pp * self.tp_size + tp
group.append(rank)
groups['tensor_parallel'].append(group)
# 流水线并行组:同一张量分片的设备
for dp in range(self.dp_size):
for tp in range(self.tp_size):
group = []
for pp in range(self.pp_size):
rank = dp * (self.tp_size * self.pp_size) + pp * self.tp_size + tp
group.append(rank)
groups['pipeline_parallel'].append(group)
return groups
# 使用示例
# 配置:8个数据并行,4个张量并行,2个流水线并行
parallel_3d = ThreeDParallelism(dp_size=8, tp_size=4, pp_size=2)
groups = parallel_3d.get_parallel_groups()
print(f"总设备数: {parallel_3d.world_size}")
print(f"数据并行组数: {len(groups['data_parallel'])}")
print(f"张量并行组数: {len(groups['tensor_parallel'])}")
print(f"流水线并行组数: {len(groups['pipeline_parallel'])}")
# 3D并行的优势:
# - 灵活性:可根据模型和硬件调整
# - 扩展性:支持超大规模训练
# - 效率:充分利用所有设备
总结
CANN模型并行技术要点:
- 并行策略:数据并行、模型并行、流水线并行、张量并行
- 流水线优化:GPipe、1F1B调度
- 通信优化:计算通信重叠、梯度压缩
- 混合并行:3D并行、灵活组合
- 最佳实践:根据模型和硬件选择策略
通过合理使用模型并行技术,可以训练和部署超大规模模型。
相关链接
hccl仓库地址:https://atomgit.com/cann/hccl
runtime仓库地址:https://atomgit.com/cann/runtime
CANN组织地址:https://atomgit.com/cann