面试-冷启动

1. 介绍

在 DeepSeek-R1 等推理模型的训练范式中,冷启动 SFT 是第一步。

  • 普通 SFT:教模型学会"说话"和"听指令"(比如:请帮我写个请假条)。
  • 冷启动 SFT(本脚本) :教模型学会**"思考的格式"**。小模型(如 MiniMind)最初并不知道 <think> 标签是什么意思。通过这一步,我们先给它"喂"几千条高质量的推理链数据,让它形成一种肌肉记忆:看到问题 -> 开启 <think> -> 进行逻辑推演 -> 开启 <answer> -> 给出结论。
为什么说它不仅仅是"普通 SFT"?

虽然代码里都是 loss.backward(),但有两处细节让它具有了"蒸馏"的性质:

A. 损失加权(Loss Weighting)

代码中的 loss_mask[sp_ids] = 10

这是在普通 SFT 中很少见的。在普通微调里,所有 Token 权重通常是 1。这里人为把 <think><answer> 等标签的权重拉高 10 倍,本质上是在做强制约束

  • 目的 :不是为了让模型学知识,而是为了让模型绝对不能搞错推理的框架。这更像是在"蒸馏"大模型的行为模式(Behavioral Cloning)。
B. 数据的"纯度"与"深度"

普通 SFT 的数据是 Question -> Answer

这个脚本跑的数据是 Question -> CoT (思维链) -> Answer

模型在微调过程中,不仅仅是在学习正确答案,更是在模仿大模型(如 GPT-4o 或 R1)的推理逻辑分布。这种将大模型的思考过程迁移到小模型身上的行为,就是标准的"蒸馏"。


SFT 与后续 RL 的关系

在推理模型的开发中,SFT 只是序幕,真正的重头戏是之后的 RL(强化学习)

阶段 任务 目的
SFT (本脚本) 冷启动 / 蒸馏 让模型学会"讲逻辑",保证输出格式不乱,能写出 <think>
RL (如 GRPO) 强化学习 / 进化 不再喂数据,而是让模型自己去试。答对了奖励,答错了惩罚。

简单来说: 没有这步 SFT,模型在 RL 阶段会像没头苍蝇一样乱撞,根本不知道要输出 <think>;有了这步 SFT,模型就有了基础,RL 才能引导它在逻辑上更进一步。


总结: 它通过"暴力"加权特殊标签的方式,强制小模型套上大模型的"思考外壳"。

2. 代码

python 复制代码
import os
import sys

# 设置包名为 trainer,并将上一级目录加入系统搜索路径,以便能正确导入项目内部的 model 和 dataset 模块
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import argparse
import time
import math
import warnings
import torch
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import SFTDataset

# 忽略不必要的警告信息(如版本过时等提示)
warnings.filterwarnings('ignore')

def Logger(content):
    """
    自定义日志函数:仅在非分布式模式或分布式模式下的主进程(Rank 0)打印日志,
    避免多卡训练时屏幕输出多份重复内容。
    """
    if not ddp or dist.get_rank() == 0:
        print(content)

def get_lr(current_step, total_steps, lr):
    """
    学习率调度函数:采用余弦退火算法(Cosine Annealing)。
    - 最小学习率为初始学习率的 1/10。
    - 随着训练步数增加,学习率按余弦曲线平滑下降。
    """
    return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))

def train_epoch(epoch, wandb):
    """
    执行一个训练周期的函数。
    """
    # 1. 提取推理相关的特殊标签 ID,用于后续在 Loss 中加权
    start_of_think_ids = tokenizer('<think>').input_ids
    end_of_think_ids = tokenizer('</think>').input_ids
    start_of_answer_ids = tokenizer('<answer>').input_ids
    end_of_answer_ids = tokenizer('</answer>').input_ids
    
    # 定义交叉熵损失函数,设置 reduction='none' 以便对每个 Token 独立加权
    loss_fct = nn.CrossEntropyLoss(reduction='none')
    start_time = time.time()
    
    # 遍历数据加载器中的批次数据
    for step, (X, Y, loss_mask) in enumerate(train_loader):
        # 将数据迁移到指定设备(GPU 或 CPU)
        X = X.to(args.device)
        Y = Y.to(args.device)
        loss_mask = loss_mask.to(args.device)
        
        # 2. 计算并更新当前步骤的学习率
        lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # 3. 前向传播(混合精度上下文)
        with ctx:
            res = model(X) # 模型输出结果对象,包含 logits 和可能的 aux_loss
            # 计算原始损失值(形状:[Batch, Seq_Len])
            loss = loss_fct(
                res.logits.view(-1, res.logits.size(-1)),
                Y.view(-1)
            ).view(Y.size())
            
            # 4. 【推理蒸馏核心逻辑】特殊标签识别
            # 找到 Y(标签)中所有属于 <think>, </think>, <answer>, </answer> 组成部分的 Token 位置
            sp_ids = torch.isin(Y.view(-1),
                                torch.tensor(start_of_think_ids + end_of_think_ids
                                             + start_of_answer_ids + end_of_answer_ids
                                             ).to(args.device))
            
            # 5. 调整损失掩码权重
            loss_mask = loss_mask.view(-1)
            loss_mask_sum = loss_mask.sum() # 计算当前批次中有效 Token 的总数
            
            # 将特殊标签位置的权重设为 10(普通 Token 默认为 1),强化模型对推理格式的记忆
            loss_mask[sp_ids] = 10
            loss_mask = loss_mask.view(Y.size())
            
            # 对损失值应用掩码并取平均值
            loss = (loss * loss_mask).sum() / loss_mask_sum
            # 若是 MoE 模型,需加上辅助损失(用于平衡专家调用负载)
            loss += res.aux_loss
            # 梯度累加:将损失除以累加步数
            loss = loss / args.accumulation_steps

        # 6. 反向传播与梯度更新
        scaler.scale(loss).backward()

        if (step + 1) % args.accumulation_steps == 0:
            # 取消缩放以进行梯度裁剪
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

            # 更新权重并更新缩放因子
            scaler.step(optimizer)
            scaler.update()

            # 梯度清零,释放内存
            optimizer.zero_grad(set_to_none=True)

        # 7. 日志记录
        if step % args.log_interval == 0:
            spend_time = time.time() - start_time
            # 计算并显示预估剩余时间
            Logger(
                'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
                    epoch + 1,
                    args.epochs,
                    step,
                    iter_per_epoch,
                    loss.item() * args.accumulation_steps,
                    optimizer.param_groups[-1]['lr'],
                    spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))

            # 如果启用 wandb,则上传训练数据
            if (wandb is not None) and (not ddp or dist.get_rank() == 0):
                wandb.log({"loss": loss * args.accumulation_steps,
                           "lr": optimizer.param_groups[-1]['lr'],
                           "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})

        # 8. 定期保存检查点
        if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
            model.eval()
            moe_path = '_moe' if lm_config.use_moe else ''
            ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'

            # 获取状态字典(处理 DDP 包装的情况)
            if isinstance(model, torch.nn.parallel.DistributedDataParallel):
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            # 以 float16 半精度保存模型以节省磁盘空间
            state_dict = {k: v.half() for k, v in state_dict.items()}
            torch.save(state_dict, ckp)
            model.train() # 回到训练模式

def init_model(lm_config):
    """
    初始化模型和分词器。
    """
    tokenizer = AutoTokenizer.from_pretrained('../model')
    model = MiniMindForCausalLM(lm_config)
    moe_path = '_moe' if lm_config.use_moe else ''
    # 设定预训练权重路径(通常是基于已完成全量 SFT 或 RLHF 的模型)
    ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth'
    state_dict = torch.load(ckp, map_location=args.device)
    # 加载权重,strict=False 允许加载部分匹配的参数
    model.load_state_dict(state_dict, strict=False)
    Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
    model = model.to(args.device)
    return model, tokenizer

def init_distributed_mode():
    """
    初始化分布式数据并行环境(DDP)。
    """
    if not ddp: return
    global ddp_local_rank, DEVICE

    dist.init_process_group(backend="nccl") # 使用 NVIDIA NCCL 后端
    ddp_rank = int(os.environ["RANK"]) # 总排名
    ddp_local_rank = int(os.environ["LOCAL_RANK"]) # 当前机器上的 GPU 排名
    ddp_world_size = int(os.environ["WORLD_SIZE"]) # 总 GPU 数量
    DEVICE = f"cuda:{ddp_local_rank}"
    torch.cuda.set_device(DEVICE)

# --- 程序主入口 ---
if __name__ == "__main__":
    # 解析命令行参数
    parser = argparse.ArgumentParser(description="MiniMind Distill Reasoning")
    parser.add_argument("--out_dir", type=str, default="../out")
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--learning_rate", type=float, default=1e-6)
    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--dtype", type=str, default="bfloat16") # 推荐使用 bfloat16 以提高数值稳定性
    parser.add_argument("--use_wandb", action="store_true") # 是否开启 Weights & Biases 监控
    parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
    parser.add_argument("--num_workers", type=int, default=1) # DataLoader 的多进程数
    parser.add_argument("--ddp", action="store_true") # 标记是否使用 DDP 启动
    parser.add_argument("--accumulation_steps", type=int, default=1) # 梯度累加,用于变相增大 Batch Size
    parser.add_argument("--grad_clip", type=float, default=1.0) # 梯度剪裁阈值,防止梯度爆炸
    parser.add_argument("--warmup_iters", type=int, default=0)
    parser.add_argument("--log_interval", type=int, default=1)
    parser.add_argument("--save_interval", type=int, default=50)
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--hidden_size', default=512, type=int)
    parser.add_argument('--num_hidden_layers', default=8, type=int)
    parser.add_argument('--max_seq_len', default=1024, type=int)
    parser.add_argument('--use_moe', default=False, type=bool)
    parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl")

    args = parser.parse_args()

    # 初始化模型配置
    lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
                             use_moe=args.use_moe)
    
    # 建立输出目录
    args.save_dir = os.path.join(args.out_dir)
    os.makedirs(args.save_dir, exist_ok=True)
    os.makedirs(args.out_dir, exist_ok=True)
    
    device_type = "cuda" if "cuda" in args.device else "cpu"
    args.wandb_run_name = f"MiniMind-Distill-Reasoning-{time.time()}"

    # 设置自动混合精度(AMP)的上下文
    ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
    
    # 检测是否处于 DDP 环境
    ddp = int(os.environ.get("RANK", -1)) != -1
    ddp_local_rank, DEVICE = 0, "cuda:0"
    base_seed = 1337
    torch.manual_seed(base_seed)
    torch.cuda.manual_seed(base_seed)

    if ddp:
        init_distributed_mode()
        args.device = torch.device(DEVICE)
        rank = dist.get_rank()
        torch.manual_seed(base_seed + rank) # 保证每张卡上的随机性不同但可控
        torch.cuda.manual_seed(base_seed + rank)

    # 初始化 wandb 日志系统
    if args.use_wandb and (not ddp or ddp_local_rank == 0):
        import wandb
        wandb.init(project=args.wandb_project, name=args.wandb_run_name)
    else:
        wandb = None

    # 初始化模型与分词器
    model, tokenizer = init_model(lm_config)

    # 准备数据集和数据加载器
    train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
    train_sampler = DistributedSampler(train_ds) if ddp else None
    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        pin_memory=True, # 锁页内存,加快 GPU 拷贝
        drop_last=False,
        shuffle=False if ddp else True,
        num_workers=args.num_workers,
        sampler=train_sampler
    )

    # 初始化梯度缩放器(混合精度训练必备)
    scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
    # 优化器设置
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

    # 将模型包装为 DDP 模式
    if ddp:
        # 忽略 RoPE 位置编码相关的参数同步(优化加速)
        model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
        model = DistributedDataParallel(model, device_ids=[ddp_local_rank])

    iter_per_epoch = len(train_loader)
    
    # 循环训练所有 Epoch
    for epoch in range(args.epochs):
        # 如果是分布式模式,需要设置 sampler 的 epoch 以保证数据的随机洗牌
        if ddp: train_loader.sampler.set_epoch(epoch)
        train_epoch(epoch, wandb)
相关推荐
xinxiangwangzhi_3 分钟前
立体匹配--深度学习方法综述(1)
人工智能·深度学习·计算机视觉
DatGuy20 分钟前
Week 37: 深度学习进阶:基于 OpenClaw 的多智能体协同架构
人工智能·深度学习·架构
ForDreamMusk35 分钟前
神经网络的基本原理
人工智能·深度学习
Zhansiqi38 分钟前
day33
人工智能·深度学习·机器学习
宝贝儿好6 小时前
【强化学习实战】第十一章:Gymnasium库的介绍和使用(1)、出租车游戏代码详解(Sarsa & Q learning)
人工智能·python·深度学习·算法·游戏·机器学习
阿_旭11 小时前
基于YOLO26深度学习的交警手势识别系统【python源码+Pyqt5界面+数据集+训练代码】
人工智能·python·深度学习·交警手势识别
love530love12 小时前
Windows 11 源码编译 vLLM 0.16 完全指南(CUDA 12.6 / PyTorch 2.7.1+cu126)
人工智能·pytorch·windows·python·深度学习·comfyui·vllm
有Li13 小时前
CIA-net:用于多模态MRI卵巢肿瘤分割的跨模态交互与聚合网络/文献速递-大模型与图像分割在医疗影像中应用
论文阅读·人工智能·深度学习·计算机视觉·文献
WeeJot嵌入式14 小时前
ICLR 2026低秩Transformer解决方案:多变量时间序列异常检测与定位的数学原理
人工智能·深度学习·transformer