1. 介绍
在 DeepSeek-R1 等推理模型的训练范式中,冷启动 SFT 是第一步。
- 普通 SFT:教模型学会"说话"和"听指令"(比如:请帮我写个请假条)。
- 冷启动 SFT(本脚本) :教模型学会**"思考的格式"**。小模型(如 MiniMind)最初并不知道
<think>标签是什么意思。通过这一步,我们先给它"喂"几千条高质量的推理链数据,让它形成一种肌肉记忆:看到问题 -> 开启<think>-> 进行逻辑推演 -> 开启<answer>-> 给出结论。
为什么说它不仅仅是"普通 SFT"?
虽然代码里都是 loss.backward(),但有两处细节让它具有了"蒸馏"的性质:
A. 损失加权(Loss Weighting)
代码中的 loss_mask[sp_ids] = 10。
这是在普通 SFT 中很少见的。在普通微调里,所有 Token 权重通常是 1。这里人为把 <think>、<answer> 等标签的权重拉高 10 倍,本质上是在做强制约束。
- 目的 :不是为了让模型学知识,而是为了让模型绝对不能搞错推理的框架。这更像是在"蒸馏"大模型的行为模式(Behavioral Cloning)。
B. 数据的"纯度"与"深度"
普通 SFT 的数据是 Question -> Answer。
这个脚本跑的数据是 Question -> CoT (思维链) -> Answer。
模型在微调过程中,不仅仅是在学习正确答案,更是在模仿大模型(如 GPT-4o 或 R1)的推理逻辑分布。这种将大模型的思考过程迁移到小模型身上的行为,就是标准的"蒸馏"。
SFT 与后续 RL 的关系
在推理模型的开发中,SFT 只是序幕,真正的重头戏是之后的 RL(强化学习)。
| 阶段 | 任务 | 目的 |
|---|---|---|
| SFT (本脚本) | 冷启动 / 蒸馏 | 让模型学会"讲逻辑",保证输出格式不乱,能写出 <think>。 |
| RL (如 GRPO) | 强化学习 / 进化 | 不再喂数据,而是让模型自己去试。答对了奖励,答错了惩罚。 |
简单来说: 没有这步 SFT,模型在 RL 阶段会像没头苍蝇一样乱撞,根本不知道要输出 <think>;有了这步 SFT,模型就有了基础,RL 才能引导它在逻辑上更进一步。
总结: 它通过"暴力"加权特殊标签的方式,强制小模型套上大模型的"思考外壳"。
2. 代码
python
import os
import sys
# 设置包名为 trainer,并将上一级目录加入系统搜索路径,以便能正确导入项目内部的 model 和 dataset 模块
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import math
import warnings
import torch
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import SFTDataset
# 忽略不必要的警告信息(如版本过时等提示)
warnings.filterwarnings('ignore')
def Logger(content):
"""
自定义日志函数:仅在非分布式模式或分布式模式下的主进程(Rank 0)打印日志,
避免多卡训练时屏幕输出多份重复内容。
"""
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
"""
学习率调度函数:采用余弦退火算法(Cosine Annealing)。
- 最小学习率为初始学习率的 1/10。
- 随着训练步数增加,学习率按余弦曲线平滑下降。
"""
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def train_epoch(epoch, wandb):
"""
执行一个训练周期的函数。
"""
# 1. 提取推理相关的特殊标签 ID,用于后续在 Loss 中加权
start_of_think_ids = tokenizer('<think>').input_ids
end_of_think_ids = tokenizer('</think>').input_ids
start_of_answer_ids = tokenizer('<answer>').input_ids
end_of_answer_ids = tokenizer('</answer>').input_ids
# 定义交叉熵损失函数,设置 reduction='none' 以便对每个 Token 独立加权
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
# 遍历数据加载器中的批次数据
for step, (X, Y, loss_mask) in enumerate(train_loader):
# 将数据迁移到指定设备(GPU 或 CPU)
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
# 2. 计算并更新当前步骤的学习率
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 3. 前向传播(混合精度上下文)
with ctx:
res = model(X) # 模型输出结果对象,包含 logits 和可能的 aux_loss
# 计算原始损失值(形状:[Batch, Seq_Len])
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
# 4. 【推理蒸馏核心逻辑】特殊标签识别
# 找到 Y(标签)中所有属于 <think>, </think>, <answer>, </answer> 组成部分的 Token 位置
sp_ids = torch.isin(Y.view(-1),
torch.tensor(start_of_think_ids + end_of_think_ids
+ start_of_answer_ids + end_of_answer_ids
).to(args.device))
# 5. 调整损失掩码权重
loss_mask = loss_mask.view(-1)
loss_mask_sum = loss_mask.sum() # 计算当前批次中有效 Token 的总数
# 将特殊标签位置的权重设为 10(普通 Token 默认为 1),强化模型对推理格式的记忆
loss_mask[sp_ids] = 10
loss_mask = loss_mask.view(Y.size())
# 对损失值应用掩码并取平均值
loss = (loss * loss_mask).sum() / loss_mask_sum
# 若是 MoE 模型,需加上辅助损失(用于平衡专家调用负载)
loss += res.aux_loss
# 梯度累加:将损失除以累加步数
loss = loss / args.accumulation_steps
# 6. 反向传播与梯度更新
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
# 取消缩放以进行梯度裁剪
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
# 更新权重并更新缩放因子
scaler.step(optimizer)
scaler.update()
# 梯度清零,释放内存
optimizer.zero_grad(set_to_none=True)
# 7. 日志记录
if step % args.log_interval == 0:
spend_time = time.time() - start_time
# 计算并显示预估剩余时间
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item() * args.accumulation_steps,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
# 如果启用 wandb,则上传训练数据
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
# 8. 定期保存检查点
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/reason_{lm_config.hidden_size}{moe_path}.pth'
# 获取状态字典(处理 DDP 包装的情况)
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
# 以 float16 半精度保存模型以节省磁盘空间
state_dict = {k: v.half() for k, v in state_dict.items()}
torch.save(state_dict, ckp)
model.train() # 回到训练模式
def init_model(lm_config):
"""
初始化模型和分词器。
"""
tokenizer = AutoTokenizer.from_pretrained('../model')
model = MiniMindForCausalLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
# 设定预训练权重路径(通常是基于已完成全量 SFT 或 RLHF 的模型)
ckp = f'{args.save_dir}/rlhf_{lm_config.hidden_size}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
# 加载权重,strict=False 允许加载部分匹配的参数
model.load_state_dict(state_dict, strict=False)
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
return model, tokenizer
def init_distributed_mode():
"""
初始化分布式数据并行环境(DDP)。
"""
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl") # 使用 NVIDIA NCCL 后端
ddp_rank = int(os.environ["RANK"]) # 总排名
ddp_local_rank = int(os.environ["LOCAL_RANK"]) # 当前机器上的 GPU 排名
ddp_world_size = int(os.environ["WORLD_SIZE"]) # 总 GPU 数量
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
# --- 程序主入口 ---
if __name__ == "__main__":
# 解析命令行参数
parser = argparse.ArgumentParser(description="MiniMind Distill Reasoning")
parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=1e-6)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16") # 推荐使用 bfloat16 以提高数值稳定性
parser.add_argument("--use_wandb", action="store_true") # 是否开启 Weights & Biases 监控
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
parser.add_argument("--num_workers", type=int, default=1) # DataLoader 的多进程数
parser.add_argument("--ddp", action="store_true") # 标记是否使用 DDP 启动
parser.add_argument("--accumulation_steps", type=int, default=1) # 梯度累加,用于变相增大 Batch Size
parser.add_argument("--grad_clip", type=float, default=1.0) # 梯度剪裁阈值,防止梯度爆炸
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=1)
parser.add_argument("--save_interval", type=int, default=50)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=1024, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="../dataset/r1_mix_1024.jsonl")
args = parser.parse_args()
# 初始化模型配置
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
use_moe=args.use_moe)
# 建立输出目录
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
device_type = "cuda" if "cuda" in args.device else "cpu"
args.wandb_run_name = f"MiniMind-Distill-Reasoning-{time.time()}"
# 设置自动混合精度(AMP)的上下文
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
# 检测是否处于 DDP 环境
ddp = int(os.environ.get("RANK", -1)) != -1
ddp_local_rank, DEVICE = 0, "cuda:0"
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
rank = dist.get_rank()
torch.manual_seed(base_seed + rank) # 保证每张卡上的随机性不同但可控
torch.cuda.manual_seed(base_seed + rank)
# 初始化 wandb 日志系统
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
# 初始化模型与分词器
model, tokenizer = init_model(lm_config)
# 准备数据集和数据加载器
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True, # 锁页内存,加快 GPU 拷贝
drop_last=False,
shuffle=False if ddp else True,
num_workers=args.num_workers,
sampler=train_sampler
)
# 初始化梯度缩放器(混合精度训练必备)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
# 优化器设置
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# 将模型包装为 DDP 模式
if ddp:
# 忽略 RoPE 位置编码相关的参数同步(优化加速)
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
# 循环训练所有 Epoch
for epoch in range(args.epochs):
# 如果是分布式模式,需要设置 sampler 的 epoch 以保证数据的随机洗牌
if ddp: train_loader.sampler.set_epoch(epoch)
train_epoch(epoch, wandb)