通义灵码分析ms-swift框架中CHORD算法实现

提示词:请分析这个项目,重点分析chord.sh这个文件中的参数是如何配置,如何在算法中发挥作用,整个算法是怎么实现的

CHORD 算法概述

CHORD 是一种结合了监督微调(SFT)和强化学习(RL)的算法,通过动态权重控制两者的平衡。核心思想是在 GRPO 损失函数中引入 SFT 损失,实现模仿学习和自主探索的平衡。

chord.sh 脚本参数详解

基础环境配置

bash 复制代码
# 使用8块80G GPU进行训练
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
NPROC_PER_NODE=8

核心训练参数

  • --rlhf_type grpo: 指定使用 GRPO 算法作为基础 RL 方法
  • --model Qwen/Qwen2.5-7B-Instruct: 使用 Qwen2.5-7B-Instruct 作为基础模型
  • --dataset AI-MO/NuminaMath-TIR: 使用 NuminaMath-TIR 作为训练数据集

批量大小配置

bash 复制代码
# GRPO 批量大小配置
--per_device_train_batch_size 4                    # 每设备GRPO批次大小
--gradient_accumulation_steps 8                    # 梯度累积步数
# 总GRPO批量大小 = 4 × 8 × 8 = 256 (32 prompts × 8 generations)

# SFT 批量大小配置
--chord_sft_per_device_train_batch_size 1          # 每设备SFT批次大小
# 总SFT批量大小 = 1 × 8 × 8 = 64

CHORD 特有参数

SFT 数据配置
bash 复制代码
--chord_sft_dataset AI-MO/NuminaMath-TIR           # SFT使用的数据集
CHORD-µ 参数配置
bash 复制代码
--chord_mu_warmup_steps 0                          # μ值预热步数
--chord_mu_decay_steps 200                         # μ值衰减步数
--chord_mu_peak 0.9                                # μ峰值
--chord_mu_valley 0.05                             # μ谷值

这表示 μ 值会从 0.9 逐渐衰减到 0.05,在 200 步内完成衰减过程。

CHORD-ϕ 参数配置
bash 复制代码
--chord_enable_phi_function false                  # 是否启用token级别权重函数

当前设置未启用 token 级别的权重调整。

训练相关参数

bash 复制代码
--num_generations 8                                # 每个prompt生成的样本数量
--train_type full                                  # 全量微调模式
--reward_funcs accuracy                            # 使用准确率作为奖励函数
--beta 0.0                                         # GRPO中的KL散度系数
--steps_per_generation 4                           # 每代训练步数

模型推理配置

bash 复制代码
--use_vllm true                                    # 使用vLLM加速推理
--vllm_mode colocate                               # vLLM与训练并置运行
--vllm_gpu_memory_utilization 0.4                  # vLLM GPU内存利用率
--vllm_max_model_len 8192                          # 最大模型长度
--max_completion_length 4096                       # 最大生成长度

优化器和训练技巧

bash 复制代码
--learning_rate 1e-6                               # 学习率
--warmup_ratio 0.05                                # 学习率预热比例
--deepspeed zero3                                  # 使用DeepSpeed ZeRO-3优化
--offload_optimizer true                           # 优化器卸载到CPU
--offload_model true                               # 模型卸载到CPU

算法实现原理

损失函数设计

CHORD 的核心损失函数为:
LCHORD=(1−μ)⋅LGRPO+μ⋅LSFT\mathcal{L}{\text{CHORD}} = (1 - \mu) \cdot \mathcal{L}{\text{GRPO}} + \mu \cdot \mathcal{L}_{\text{SFT}}LCHORD=(1−μ)⋅LGRPO+μ⋅LSFT

其中:

  • LGRPO\mathcal{L}_{\text{GRPO}}LGRPO: 基于在线采样的强化学习损失
  • LSFT\mathcal{L}_{\text{SFT}}LSFT: 监督微调损失
  • μ\muμ: 控制两者平衡的全局权重系数

CHORD-µ 实现机制

通过动态调节 μ 值实现从模仿到探索的过渡:

  1. 初始阶段 μ 较高,更多依赖 SFT 损失进行模仿学习
  2. 随着训练进行,μ 逐渐降低,增加 RL 探索成分
  3. 最终以 RL 为主导进行精细调优

CHORD-ϕ 实现机制(可选)

当启用时,通过 token 级别权重函数 φ 动态调整每个 token 的梯度贡献:
ϕ(yt⋆,πθ)=pt⋅(1−pt)\phi(y_t^\star, \pi_\theta) = p_t \cdot (1 - p_t)ϕ(yt⋆,πθ)=pt⋅(1−pt)

其中 ptp_tpt 是模型预测专家 token 的概率,当模型不确定性高时(pt≈0.5p_t ≈ 0.5pt≈0.5)权重最大。

整体工作流程

  1. 数据准备: 同时加载 RLHF 和 SFT 数据集(注意:示例中使用相同数据集可能导致重叠)
  2. 模型初始化: 加载预训练模型并配置训练环境
  3. 训练循环 :
    • 在每个训练步骤中同时计算 GRPO 和 SFT 损失
    • 根据当前 μ 值组合两个损失
    • 使用 DeepSpeed 优化分布式训练
    • 动态调整 μ 值(如果启用了 CHORD-µ)
  4. 推理优化: 使用 vLLM 加速生成过程
  5. 监控和保存: 定期保存模型检查点并记录训练日志

这种设计使得模型能够在训练初期借助 SFT 数据快速收敛,然后通过 RL 探索进一步提升性能,最终达到更好的综合表现。

loss代码实现

rlhf_trainer/grpo_trainer.py中:

python 复制代码
        if mode == 'train' and self.chord_sft_iterator is not None:
            loss = compute_chord_loss(self, grpo_loss=loss)

rlhf_trainer/utils.py->compute_chord_loss

python 复制代码
def compute_chord_loss(trainer, grpo_loss: torch.Tensor) -> torch.Tensor:
    """
    Compute CHORD loss combining GRPO loss with SFT loss.

    Args:
        trainer: The trainer instance
        grpo_loss: The GRPO loss tensor

    Returns:
        Combined CHORD loss tensor
    """
    from swift.trainers import per_token_loss_func
    from swift.llm import to_device

    current_step = trainer.state.global_step
    mu = mu_schedule_function(current_step, trainer.args.chord_mu_warmup_steps, trainer.args.chord_mu_decay_steps,
                              trainer.args.chord_mu_peak, trainer.args.chord_mu_valley)
    chord_sft_loss = torch.tensor(0.0, device=grpo_loss.device, dtype=grpo_loss.dtype)
    if mu > 0:
        sft_inputs = next(trainer.chord_sft_iterator)
        sft_inputs = to_device(trainer.template.data_collator(sft_inputs), trainer.accelerator.device)

        labels = sft_inputs.pop('labels')
        loss_scale = sft_inputs.pop('loss_scale', None)
        outputs = trainer.model(**sft_inputs)
        chord_sft_loss = per_token_loss_func(outputs, labels)

        if trainer.args.chord_enable_phi_function:
            per_token_probs = torch.exp(-chord_sft_loss)
            phi = per_token_probs * (1 - per_token_probs)
            chord_sft_loss *= phi

        if loss_scale is not None:
            loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1).view(-1)
            chord_sft_loss *= loss_scale

        num_items_in_batch = (labels[:, 1:] != -100).sum()
        chord_sft_loss = chord_sft_loss.sum() / num_items_in_batch
    else:
        assert mu == 0
        chord_sft_loss = torch.tensor(0.0, device=grpo_loss.device, dtype=grpo_loss.dtype)
    loss = (1 - mu) * grpo_loss + mu * chord_sft_loss
    return loss

1、loss分析

提示词:请重点分析这个函数,每一行代码

我们来逐行分析 compute_chord_loss 函数的实现,这是 CHORD 算法的核心部分。

python 复制代码
    current_step = trainer.state.global_step
    mu = mu_schedule_function(current_step, trainer.args.chord_mu_warmup_steps, trainer.args.chord_mu_decay_steps,
                              trainer.args.chord_mu_peak, trainer.args.chord_mu_valley)

这两行代码负责计算当前训练步数对应的 μ 值:

  1. 获取当前全局训练步数 current_step
  2. 调用 mu_schedule_function 函数,根据配置参数计算当前的 μ 值
    • chord_mu_warmup_steps: μ 值上升到峰值的步数
    • chord_mu_decay_steps: μ 值从峰值衰减到谷值的步数
    • chord_mu_peak: μ 的峰值
    • chord_mu_valley: μ 的谷值
python 复制代码
    chord_sft_loss = torch.tensor(0.0, device=grpo_loss.device, dtype=grpo_loss.dtype)

初始化 SFT 损失为 0,确保它与 GRPO 损失在同一设备和数据类型上。

python 复制代码
    if mu > 0:

只有当 μ 值大于 0 时才计算 SFT 损失,这样可以在训练后期完全关闭 SFT 损失的影响。

python 复制代码
        sft_inputs = next(trainer.chord_sft_iterator)
        sft_inputs = to_device(trainer.template.data_collator(sft_inputs), trainer.accelerator.device)

获取下一个 SFT 批次的数据并将其移动到正确设备上:

  1. chord_sft_iterator 中获取下一个批次的原始输入数据
  2. 使用模板的数据整理器([data_collator](file://E:\PycharmProjects\ms-swift-main\ms-swift-main\swift\llm\template\base.py#L1436-L1469))处理数据
  3. 使用 [to_device](file://E:\PycharmProjects\ms-swift-main\ms-swift-main\swift\llm\utils.py#L63-L72) 将处理后的数据移动到训练设备上
python 复制代码
        labels = sft_inputs.pop('labels')
        loss_scale = sft_inputs.pop('loss_scale', None)
        outputs = trainer.model(**sft_inputs)

准备并执行前向传播:

  1. 从输入中弹出标签和损失缩放因子
  2. 将剩余的输入传递给模型进行前向传播,得到输出
python 复制代码
        chord_sft_loss = per_token_loss_func(outputs, labels)

使用 per_token_loss_func 计算每个 token 的 SFT 损失。

python 复制代码
        if trainer.args.chord_enable_phi_function:
            per_token_probs = torch.exp(-chord_sft_loss)
            phi = per_token_probs * (1 - per_token_probs)
            chord_sft_loss *= phi

这是 CHORD-ϕ 的核心实现:

  1. 如果启用了 chord_enable_phi_function 参数,则计算 token 级别的权重 φ
  2. 首先通过 torch.exp(-chord_sft_loss) 计算每个 token 的预测概率 ptp_tpt
  3. 然后按照公式 ϕ=pt⋅(1−pt)\phi = p_t \cdot (1 - p_t)ϕ=pt⋅(1−pt) 计算权重
  4. 将 φ 权重应用到 SFT 损失上
python 复制代码
        if loss_scale is not None:
            loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1).view(-1)
            chord_sft_loss *= loss_scale

如果存在损失缩放因子,则对其进行处理并应用到损失上:

  1. 对 [loss_scale](file://E:\PycharmProjects\ms-swift-main\ms-swift-main\swift\llm\template\base.py#L0-L0) 进行移位操作(向左移动一位)
  2. 将其展平为一维张量
  3. 将缩放因子应用到 SFT 损失上
python 复制代码
        num_items_in_batch = (labels[:, 1:] != -100).sum()
        chord_sft_loss = chord_sft_loss.sum() / num_items_in_batch

计算平均 SFT 损失:

  1. 计算批次中有效的 token 数量(排除标签为 -100 的填充 token)
  2. 对所有 token 的损失求和并除以有效 token 数量,得到平均损失
python 复制代码
    else:
        assert mu == 0
        chord_sft_loss = torch.tensor(0.0, device=grpo_loss.device, dtype=grpo_loss.dtype)

如果 μ 值等于 0,则直接将 SFT 损失设为 0。

python 复制代码
    loss = (1 - mu) * grpo_loss + mu * chord_sft_loss
    return loss

最后按照 CHORD 算法的公式计算组合损失并返回:
LCHORD=(1−μ)⋅LGRPO+μ⋅LSFT\mathcal{L}{\text{CHORD}} = (1 - \mu) \cdot \mathcal{L}{\text{GRPO}} + \mu \cdot \mathcal{L}_{\text{SFT}}LCHORD=(1−μ)⋅LGRPO+μ⋅LSFT

这个函数完整实现了 CHORD 算法的核心机制:

  1. 动态 μ 值控制 : 通过 mu_schedule_function 实现 CHORD-µ,动态调整 SFT 和 RL 损失的权重
  2. CHORD-ϕ 实现 : 当启用时,通过 φ 函数 pt⋅(1−pt)p_t \cdot (1 - p_t)pt⋅(1−pt) 对每个 token 的损失进行加权
  3. 损失组合: 按照论文公式将 GRPO 损失和 SFT 损失进行线性组合

在您提供的 [chord.sh](file://E:\PycharmProjects\ms-swift-main\ms-swift-main\examples\train\grpo\internal\chord.sh) 配置中,--chord_enable_phi_function false 表示当前并未启用 CHORD-ϕ 功能,因此 φ 权重计算部分不会执行。

2、 CHORD-µ 计算公式实现

提示词:CHORD-µ 计算公式。通义灵码:提供了文件上下文

swift/trainers/rlhf_trainer/utils.py 文件中的 mu_schedule_function 函数实现了 CHORD-µ 的调度逻辑:

python 复制代码
def mu_schedule_function(global_step: int, mu_warmup_steps: int, mu_decay_steps: int, mu_peak: float,
                         mu_valley: float) -> float:
    """
    Computes a cosine decay schedule with a warmup phase for the mu parameter.

    Args:
        global_step: Current global training step
        mu_warmup_steps: Number of warmup steps
        mu_decay_steps: Number of decay steps
        mu_peak: Peak value of mu during warmup
        mu_valley: Final value of mu after decay

    Returns:
        Current mu value based on the schedule
    """
    # Warmup
    if global_step < mu_warmup_steps:
        return (global_step / mu_warmup_steps) * mu_peak

    # Decay
    if global_step >= (mu_warmup_steps + mu_decay_steps):
        return mu_valley

    adjusted_step = global_step - mu_warmup_steps
    cosine_decay = 0.5 * (1 + math.cos(math.pi * adjusted_step / mu_decay_steps))
    decayed_mu = (mu_peak - mu_valley) * cosine_decay + mu_valley
    return decayed_mu

具体计算过程

1. 预热阶段 (Warmup Phase)

global_step < mu_warmup_steps 时:
μ=global_stepmu_warmup_steps×mu_peak\mu = \frac{\text{global\_step}}{\text{mu\_warmup\_steps}} \times \text{mu\_peak}μ=mu_warmup_stepsglobal_step×mu_peak

这是一个线性增长过程,从 0 增长到 mu_peak

2. 衰减阶段 (Decay Phase)

mu_warmup_steps ≤ global_step < mu_warmup_steps + mu_decay_steps 时:
adjusted_step=global_step−mu_warmup_steps\text{adjusted\_step} = \text{global\_step} - \text{mu\_warmup\_steps}adjusted_step=global_step−mu_warmup_steps

cosine_decay=0.5×(1+cos⁡(π×adjusted_stepmu_decay_steps))\text{cosine\_decay} = 0.5 \times (1 + \cos(\pi \times \frac{\text{adjusted\_step}}{\text{mu\_decay\_steps}}))cosine_decay=0.5×(1+cos(π×mu_decay_stepsadjusted_step))

μ=(mu_peak−mu_valley)×cosine_decay+mu_valley\mu = (\text{mu\_peak} - \text{mu\_valley}) \times \text{cosine\_decay} + \text{mu\_valley}μ=(mu_peak−mu_valley)×cosine_decay+mu_valley

这是一个余弦衰减过程,从 mu_peak 平滑衰减到 mu_valley

3. 稳定阶段 (Stable Phase)

global_step ≥ mu_warmup_steps + mu_decay_steps 时:
μ=mu_valley\mu = \text{mu\_valley}μ=mu_valley

μ 值保持稳定在 mu_valley

chord.sh 中的参数配置

根据 [chord.sh](file://E:\PycharmProjects\ms-swift-main\ms-swift-main\examples\train\grpo\internal\chord.sh) 文件中的配置:

bash 复制代码
--chord_mu_warmup_steps 0 \
--chord_mu_decay_steps 200 \
--chord_mu_peak 0.9 \
--chord_mu_valley 0.05 \

这意味着:

  • 没有预热阶段(mu_warmup_steps = 0
  • 衰减阶段持续 200 步(mu_decay_steps = 200
  • μ 值从 0.9 开始衰减(mu_peak = 0.9
  • 最终稳定在 0.05(mu_valley = 0.05

由于没有预热阶段,μ 值会在第 0 步就处于峰值 0.9,然后在接下来的 200 步中按照余弦衰减规律逐渐降低到 0.05。

在 CHORD 损失函数中的应用

计算出的 μ 值随后在 compute_chord_loss 函数中用于组合 GRPO 损失和 SFT 损失:

LCHORD=(1−μ)⋅LGRPO+μ⋅LSFT\mathcal{L}{\text{CHORD}} = (1 - \mu) \cdot \mathcal{L}{\text{GRPO}} + \mu \cdot \mathcal{L}_{\text{SFT}}LCHORD=(1−μ)⋅LGRPO+μ⋅LSFT

随着训练的进行,μ 值从高到低变化,意味着算法从主要依赖 SFT 损失(模仿专家行为)逐渐转向主要依赖 GRPO 损失(自主探索)。

相关推荐
1024小神23 分钟前
在 Swift 中,self. 的使用遵循明确的规则
开发语言·ios·swift
做怪小疯子24 分钟前
LeetCode 热题 100——二叉树——二叉树的层序遍历&将有序数组转换为二叉搜索树
算法·leetcode·职场和发展
Swift社区26 分钟前
Swift 类型系统升级:当协议遇上不可拷贝的类型
开发语言·ios·swift
chengpei14726 分钟前
I²C协议简介
c语言·开发语言
唐古乌梁海26 分钟前
【IT】常见计算机编程语言多继承问题
开发语言
雨中散步撒哈拉29 分钟前
18、做中学 | 初升高 | 考场一 | 面向过程-家庭收支记账软件
开发语言·后端·golang
CoderYanger36 分钟前
递归、搜索与回溯-记忆化搜索:38.最长递增子序列
java·算法·leetcode·1024程序员节
翔云 OCR API1 小时前
承兑汇票识别接口技术解析-开发者接口
开发语言·前端·数据库·人工智能·ocr
小白学大数据2 小时前
基于Splash的搜狗图片动态页面渲染爬取实战指南
开发语言·爬虫·python