
文章目录
- 前言
-
- 为什么需要训练配方?
- 架构拆解:三层训练场景
-
- 预训练:从头训练的工程挑战
- [微调:SFT 与 LoRA 的效率取舍](#微调:SFT 与 LoRA 的效率取舍)
- RLHF:训练与推理的混合编排
- 链路实战:一个训练任务的启动流程
- 在架构中的位置
- 与其他仓库的协作关系
前言
把一个 70B 参数的大模型从单卡搬上多机多卡分布式训练,需要调多少东西?通信拓扑、并行策略、梯度同步、显存分配......随便漏一项就是 OOM 或者通信死锁。昇腾CANN 的 cann-recipes-train 仓库就是来解决这个问题的------它把大模型在昇腾 NPU 上的分布式训练部署方案打包成「配方」,让开发者不用从零拼装,直接按方抓药跑训练。
为什么需要训练配方?
分布式训练不是「把代码丢到多张卡上就完事」。同样的 LLaMA-2 70B,TP=8 和 TP=4+PP=2 的显存曲线完全不同;MoE 模型的专家并行策略选错了,通信开销能吃掉所有算力增益。训练配方的存在逻辑很清晰:把「踩坑→调优→验证」这条路上的经验沉淀成可复用的方案,而不是让每个团队都从头交一遍学费。
传统开发模式下,一个训练任务从立项到稳定运行,通常要经历以下阶段:首先确定模型架构和参数量级,然后根据显存估算并行策略,接着反复调整 TP/PP/EP 等并行维度配比,还要处理集合通信的死锁问题,最后才能跑通一个基础版本。这个过程对于 70B 甚至更大规模的模型来说,可能需要数周甚至数月的迭代。而且不同团队之间这套经验基本无法复用------A 团队踩过的坑,B 团队很可能要再踩一遍。
cann-recipes-train 正是为了解决这个问题而诞生的。它不提供全新的训练框架,而是把 PyTorch 生态中已经被验证过的分布式训练方案,以「配方」的形式组织起来。每个配方本质上是一组经过实测的配置清单:模型适配、并行策略、通信调优、显存分配、启动脚本......全都有。最重要的是,这些配置不是理论推导,而是基于真实集群测试得出的最优解或者接近最优的方案。
cann-recipes-train 覆盖了预训练、微调、RLHF 三大训练场景,每种场景下提供模型适配、并行配置、启动脚本等一揽子方案。本质上,它不是训练框架本身,而是训练框架的「最佳实践仓库」。
架构拆解:三层训练场景
预训练:从头训练的工程挑战
预训练阶段的核心难题是规模------参数量大、数据量大、训练周期长,任何效率损失都会被放大。cann-recipes-train 的预训练配方主要解决两件事:并行策略选型和通信开销控制。
并行策略的选择直接影响训练效率和显存使用。常见的并行维度包括 Tensor Parallelism(TP)、Pipeline Parallelism(PP)、Context Parallelism(CP)和 Expert Parallelism(EP)。对于 Dense 模型如 LLaMA 系列,TP+PP 组合是最常见的方案:TP 减少单层计算耗时,PP 把不同层分到不同设备上降低显存峰值。对于 MoE 模型,还需要引入 EP 来处理专家并行的通信开销。
yaml
# 典型预训练配置示意
model_name: "llama2-70b"
parallel_config:
tensor_model_parallel_size: 8
pipeline_model_parallel_size: 2
context_parallel_size: 1
/expert_parallel_size: 1
training_config:
batch_size_per_device: 4
gradient_accumulation_steps: 8
learning_rate: 1e-4
warmup_steps: 2000
num.training_steps: 100000
配方中会给出不同模型规模下的推荐并行度。这里的推荐并行度不是拍脑袋的数字,而是经过显存压力测试和吞吐 benchmark 之后的结果。以 LLaMA-2 70B 为例,如果使用 TP=8+PP=2 的配置,单机的 8 张 NPU 需要保证每张卡的显存不少于 80GB 才能安全运行;如果改为 TP=4+PP=4,虽然单卡显存需求降到 40GB 左右,但引入了额外的流水线通信,开销会增加 15%~20%。这些 tradeoff 都已经在配方中标注清楚。
下面是一个典型的大规模预训练环境变量配置,用于设置分布式训练的各项参数:
bash
# 环境变量配置模板
export ASCEND_NPUS_PER_HOST=8
export HCCL_COMM_WORLD_ROOT=0
export RANK_TABLE_FILE=/path/to/rank_table.json
export PYTORCH_CUDA_ALLOC_conf=max_split_size_mb:512
export ASCEND_GLOBAL_LOG_LEVEL=3
rank_table.json 是 HCCL 集合通信初始化所必需的文件,描述了所有参与训练的 NPU 节点的 IP 地址和通信端口。配方目录中通常会提供一个脚本来自动生成这个文件,开发者只需要修改其中的节点 IP 列表即可。
微调:SFT 与 LoRA 的效率取舍
微调配方覆盖全参微调和参数高效微调(PEFT)两大类。全参微调(Full Parameter Fine-tuning,SFT)指对模型所有参数进行更新,而 PEFT 则只更新附加的部分,原始模型参数保持冻结。LoRA 是 PEFT 中最流行的技术方案之一,它通过在 Transformer 的注意力层注入低秩矩阵来实现参数高效微调。
两者的核心差异在于显存需求和精度表现。全参微调的显存的求近似等于预训练阶段,因为梯度、动量、优化器状态都需要存储;而 LoRA 把可训练参数压缩到原模型的 0.1% 以下,单卡就能跑 7B 参数模型的微调。
python
# LoRA 配置示意
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
r=16, # LoRA rank
lora_alpha=32, # scaling factor
lora_dropout=0.05, # dropout 概率
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # 注入的目标模块
bias="none",
task_type=TaskType.CAUSAL_LM
)
# 将 LoRA 适配器挂载到基础模型上
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
# 输出类似: "trainable params: 4,194,304 || all params: 7,048,192,512 || trainable%: 0.0595"
但全参微调在精度敏感场景(如指令跟随、对齐调优)中仍有不可替代性。LoRA 虽然效率高,但在某些任务上可能无法达到全参微调的精度水平,特别是当需要模型全面学习新的知识体系时。配方会同时给出两条路径的启动方式和配置模板,让用户根据自身的精度要求和硬件条件进行选择。
微调阶段还需要注意的一个关键点是学习率的调整。由于预训练阶段使用的学习率通常较小(全参数情况下约为 1e-4 到 5e-5),直接沿用到微调阶段会导致 loss 不收敛或者震荡。经验做法是将学习率提升 10 倍左右,同时减小 batch size 以防止过拟合。
yaml
# 全参微调配置示意
model_name: "llama2-7b"
finetune_config:
method: "full_parameter" # 全参微调
learning_rate: 5e-5 # 较预训练阶段提升
batch_size_per_device: 2
gradient_accumulation_steps: 16
num_training_epochs: 3
save_interval_steps: 500
lora_config: null # 不使用 LoRA
RLHF:训练与推理的混合编排
RLHF(Reinforcement Learning from Human Feedback)是目前大语言模型对齐训练的主流技术方案,包括三个阶段:预训练模型的有监督微调(SFT)、奖励模型(Reward Model)的训练、以及基于强化学习(通常是 PPO 算法)的对齐训练。整个流程涉及多个模型的协同工作,是最复杂的训练场景。
在 RLHF 的最后一个阶段(PPO 训练),需要同时运行 Actor 模型(待对齐的语言模型)、Critic 模型(价值函数评估模型)和 Reward 模型(奖励信号来源)。这三个模型的显存需求是叠加的,而且 Actor 在生成阶段是推理模式,在参数更新阶段是训练模式,两者之间的切换会带来额外的开销。
yaml
# RLHF 阶段 PPO 配置示意
rlhf_config:
stage: "ppo" # PPO 训练阶段
actor_model: "llama2-7b"
critic_model: "llama2-7b"
reward_model: "llama2-7b-rm"
kl_coeff: 0.02 # KL 散度惩罚系数
ppo_epochs: 4 # 每个 batch 的 PPO 迭代轮数
generate_max_length: 512 # 生成序列的最大长度
device_placement: "separated" # 模型分离部署
关于 device_placement 参数,这里有一个容易踩的坑:如果设置为「shared」(共享部署),即让 Actor、Critic 和 Reward Model 共用同一组 NPU,那么显存峰值是三者之和,很可能在生成阶段触发 OOM。如果设置为「separated」(分离部署),则需要额外增加节点间的通信来传递 reward 信号,但显存压力会大大缓解。配方文档中会明确标注这两种拓扑的适用场景。
在实际部署时,还需要考虑 Actor 的生成(inference)和训练(training)之间的调度。通常的做法是把生成任务和训练任务放在不同的流(stream)上并行执行,或者干脆分成两个独立的进程,通过共享文件系统来传递生成的样本数据。
python
# RLHF 中 Actor 模型的生成与训练分离实现思路
class RLHFTrainer:
def __init__(self, actor_model, critic_model, reward_model):
self.actor = actor_model
self.critic = critic_model
self.reward_fn = reward_model
def generate_samples(self, prompts, num_return_sequences=4):
"""使用推理模式的 Actor 生成多个回复样本"""
with self.actor.inference_mode():
responses = self.actor.generate(
prompts,
num_return_sequences=num_return_sequences,
max_length=self.max_generation_length
)
return responses
def ppo_update(self, prompt_response_pairs):
"""PPO 更新步骤"""
with self.actor.train_mode():
# 计算 reward、 advantage、 policy loss 等
advantages = self.compute_advantages(prompt_response_pairs)
# 执行 PPO 裁剪更新
self.optimizer.step()
链路实战:一个训练任务的启动流程
从拿到 cann-recipes-train 仓库到跑通第一个训练任务,大致是这样一条链路:
bash
# 1. 克隆仓库
git clone https://atomgit.com/cann/cann-recipes-train.git
cd cann-recipes-train
# 2. 选择模型配方目录
cd recipes/llama2-70b/pretrain
# 3. 检查环境
npu-smi info # 确认 NPU 在位
# 4. 修改配置文件中的数据路径和数据格式
vim configs/llama2_70b_pretrain.yaml
# 5. 启动训练
bash scripts/run_pretrain.sh
配置文件是整个配方的核心入口。它不是一个简单的 key-value 文件,而是包含了并行策略、训练超参、数据加载、检查点保存等所有环节的完整声明。改一个 tensor_model_parallel_size 就能切换并行拓扑,不需要动训练脚本。
训练脚本内部通常会调用 torchtitan-npu 的分布式启动器来执行训练任务。torchtitan-npu 是基于 PyTorch 的分布式训练框架,专门针对昇腾 NPU 进行了深度优化。
python
# 启动脚本内部通常会调用 torchtitan-npu 的分布式启动器
torchrun --nproc_per_node=8 \
--master_port=29500 \
pretrain.py \
--config configs/llama2_70b_pretrain.yaml
torchrun 是 PyTorch 原生的分布式启动工具,负责解析环境变量、建立进程组、初始化分布式通信。然后 pretrain.py 作为训练入口脚本,会进一步调用 torchtitan-npu 框架的初始化接口,完成模型构建、分布式 optimizer 设置、以及混合并行的切分逻辑。
启动后,torchtitan-npu 框架接管分布式调度,底层 hccl 负责集合通信,ops-transformer 提���训���算子------配方本身只是最上层的使用入口。整体的数据流可以概括为:配置文件 → 配方启动脚本 → torchtitan-npu 调度层 → ops-transformer 算子层 → HCCL 通信层 → 昇腾 NPU 硬件。
在架构中的位置
cann-recipes-train 位于 CANN 五层架构的第 5 层(应用层),是距离终端用户最近的仓库。它的上游依赖关系清晰:
cann-recipes-train(应用层,第 5 层)
↑
torchtitan-npu(训练框架层,第 4 层)
--- 提供分布式启动、DDP/PP 封装、图优化pass
↑
ops-transformer(算子层,第 3 层)
--- 提供 FlashAttention、MoE dispatch、 RMSNorm 等训练专用算子
↑
hccl(通信层,第 2 层)
--- 提供 AllReduce、AllGather、ReduceScatter 等集合通信原语
↑
昇腾 NPU(硬件抽象层,第 1 层)
torchtitan-npu 是 Cann Recipes Train 的直接上游。它在 PyTorch 分布式的基础上,增加了对昇腾 NPU 硬件特性的适配,包括混合精度训练、流水并行调度、通信 overlap、显存优化等技术。配方中的启动脚本只是传入配置参数,真正的模型切分和分布式逻辑由 torchtitan-npu 完成。
ops-transformer 提供了训练阶段专用的 CUDA/HCCL 算子,比如融合版的 Attention 算子(对标 FlashAttention)、MoE 的路由dispatch算子、以及各种 fused kernel。这些算子对训练吞吐量有直接影响------相同模型和硬件条件下,好的算子实现能带来 20%~30% 的吞吐提升。
hccl(Hierarchical Collective Communication Library)工作在更底层,提供了昇腾 NPU 之间高速互联的集合通信原语。所有上层的数据聚合操作(梯度同步、模型参数同步、embedding gather 等)最终都会翻译成 HCCL 的调用。
下游就是终端用户------拿配方改改配置,跑训练。没有更上层的消费者了。
与其他仓库的协作关系
cann-recipes-train 最紧密的关联仓库是 cann-recipes-infer。两者是训练-推理的镜像关系:
- cann-recipes-train:训练阶段的模型适配与并行方案,解决的是「怎么把模型训练出来」的问题
- cann-recipes-infer:推理阶段的模型适配与优化方案,解决的是「怎么把训练好的模型部署出去」的问题
同一个 LLaMA-2 70B 模型,训练时关注 TP/PP 并行策略和梯度同步效率,推理时关注 EP 并行和 KV Cache 管理。两个仓库的目录结构高度对称,方便从训练无缝切换到推理部署。
cann-recipes-train/recipes/llama2-70b/
├── pretrain/ # 预训练配方
│ ├── configs/
│ ├── scripts/
│ └── README.md
├── finetune/ # 微调配方
│ ├── sft/
│ └── lora/
└── rlhf/ # RLHF 配方
├── sft/
└── ppo/
cann-recipes-infer/recipes/llama2-70b/
└── decode/ # 推理配方
├── configs/
├── scripts/
└── README.md
与 torchtitan-npu 的关系是上下游依赖。torchtitan-npu 是独立发布的训练框架,而 cann-recipes-train 是 torchtitan-npu 的「配方合集」------每个配方本质上是一套经过验证的 torchtitan-npu 配置方案。用户使用配方时不需要关心 torchtitan-npu 内部的实现细节,只需要修改配置并执行启动脚本即可。需要注意的是,cann-recipes-train 和 torchtitan-npu 的版本需要严格对应,版本不匹配可能导致配置解析失败或者运行结果不符合预期。
与 ops-transformer 的关系更加间接------配方中指定的 Attention 实现方式、是否启用 fused kernel、MoE 的 router 类型等,最终都会影响 ops-transformer 算子的选择。普通用户感知不到这层关系,但如果遇到性能问题需要深入调优时,可能会涉及到算子级别的调整。
与 hccl 的关系同样隐蔽。配方不会直接暴露 HCCL 的 API 调用,但并行策略的每一次调整(TP 还是 PP 或者两者混合)最终都会转化为 HCCL 内部集合通信模式和通信域的构建。如果出现通信死锁或者通信效率不如预期的情况,可能需要检查 HCCL 的配置参数,比如通信算法选择(Ring 还是 Tree)、是否启用 inplace 优化等。