万字长文拆解大模型训练:预训练→微调→RLHF,ChatGPT 是怎么炼成的

前言:今天就拿 ChatGPT 的训练流程当例子,给大家从头到尾捋一遍大模型的训练 pipeline。看完这篇,你就能理解为什么 GPT-4 比 GPT-3 强这么多,以及为什么微调一个模型要花那么多钱


先上一张全局图:大模型训练的三级火箭

打个比方:

阶段 比喻 具体说明
预训练 上大学,学基础知识 读万卷书,打下坚实的语言和知识基础
微调 职业培训,学专业技能 专攻问答、对话,让模型学会"怎么回答"
对齐 职前培训,学公司文化 学会理解人类偏好,回答更有帮助、更安全

一、预训练:让模型"读遍天下书"

1.1 预训练在做什么?

预训练是大模型训练的第一步,也是最耗时、最烧钱的一步。

核心任务是:给定一段文本,让模型预测下一个词是什么

这就是著名的 Next Token Prediction (下一个 token 预测),也叫 语言建模任务(Language Modeling)

python 复制代码
# 预训练任务示例
# 输入: "今天天气真"
# 目标: 预测下一个词是"好"

# 模型看到的训练数据格式:
# [CLS] 今 天 天 气 真 [MASK]
#              ↑
#          模型需要预测的词

1.2 训练数据从哪来?

预训练需要海量的文本数据。主流的数据来源:

数据源 占比(估计) 特点
网页爬取 (Common Crawl) ~60% 量大,但噪音多,需要清洗
网络书籍 (Books) ~15% 质量较高,叙事性强
学术论文 (ArXiv) ~5% 专业术语多,逻辑性强
维基百科 (Wikipedia) ~5% 结构化,质量高
代码 (GitHub) ~5-15% 学习编程逻辑
其他 ~10% 新闻、对话等

数据清洗是核心竞争力:为什么 GPT-4 比很多开源模型强?除了模型架构,数据质量和清洗流程也是关键因素。OpenAI 在数据预处理上投入了大量人力,包括去重、过滤低质量内容、质量评分等。

1.3 预训练的技术细节

训练目标 :预测下一个 token,本质上是一个分类任务

python 复制代码
import torch
import torch.nn as nn

class PretrainingLoss(nn.Module):
    """
    预训练语言模型损失函数
    给定输入序列,预测下一个 token
    """
    def __init__(self, model, vocab_size):
        super().__init__()
        self.model = model
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)  # 忽略 padding

    def forward(self, input_ids, labels=None):
        """
        input_ids: (batch_size, seq_len)
        labels: (batch_size, seq_len) - 就是 input_ids 右移一位
        """
        # 模型输出 logits: (batch_size, seq_len, vocab_size)
        logits = self.model(input_ids)

        # 计算交叉熵损失
        # 预测第 i 个 token 时,使用第 i-1 个 token 的表示
        # 所以 labels 是 input_ids 的 shifted 版本
        shift_logits = logits[:, :-1, :]  # 去掉最后一个位置
        shift_labels = labels[:, 1:]      # 去掉第一个位置

        loss = self.loss_fn(
            shift_logits.reshape(-1, shift_logits.size(-1)),
            shift_labels.reshape(-1)
        )

        return loss

# 训练循环
for batch in dataloader:
    input_ids = batch['input_ids']
    labels = batch['labels']  # 右移后的 token IDs

    loss = pretrain_loss(input_ids, labels)
    loss.backward()
    optimizer.step()
    scheduler.step()

模型规模:预训练模型的参数规模通常很大:

模型 参数量 训练 Token 数
GPT-3 175B 300B
LLaMA 2 7B~70B 2T
PaLM 540B 780B
GPT-4 未公开(估计 1~1.8T) 未公开

为什么模型要这么大? 这就是著名的"涌现能力"(Emergent Abilities)。当模型规模超过某个阈值后,会突然涌现出一些小模型不具备的能力,比如复杂推理、多步计算等。具体阈值因任务而异,大模型在 10B~100B 参数区间往往会有质的飞跃。

1.4 预训练的挑战

计算资源:训练一个 175B 的模型需要数千张 A100/H100 GPU,耗时数周甚至数月。

灾难性遗忘:大模型学了很多知识后,可能会忘记之前学过的一些东西。

训练不稳定:大模型训练过程中容易出现 loss spike、梯度爆炸等问题。

python 复制代码
# 预训练常见问题及解决方案
problems_and_solutions = {
    "梯度爆炸": "梯度裁剪 (gradient clipping), 混合精度训练",
    "loss spike": "学习率重启 (warmup + cosine decay)",
    "显存不足": "ZeRO 优化, 流水线并行, 张量并行",
    "训练太慢": "混合专家 (MoE), Flash Attention",
}

二、微调:让模型学会"好好说话"

2.1 预训练模型有什么问题?

经过预训练后,模型其实已经很强了------它学会了语言的规律,掌握了大量知识。但是:

它不知道怎么回答问题!

预训练模型本质上是在做"完形填空":给定一段话,预测下一个词。它不知道什么是"问题",什么是"回答",更不知道什么回答是"好的"。

python 复制代码
# 预训练模型的典型输出(幻觉问题)
用户: "请介绍一下北京"
预训练模型: "北京是中国的首都......"  # 可能继续胡编乱造

# 用户真正想要的
用户: "请介绍一下北京"
微调后模型: "好的!北京是中国的首都,位于华北平原北部..."
          # 格式规范,内容可靠,有礼貌

2.2 什么是 SFT(监督微调)?

Supervised Fine-Tuning (SFT) ,也叫指令微调(Instruction Tuning)

核心思想:用人工标注的高质量问答数据,教模型学会"怎么回答问题"。

python 复制代码
# SFT 训练数据格式
sft_data = [
    {
        "instruction": "请介绍一下北京",
        "input": "",
        "output": "北京是中国的首都,位于华北平原北部..."
    },
    {
        "instruction": "帮我写一首关于春天的诗",
        "input": "",
        "output": "春眠不觉晓,处处闻啼鸟。夜来风雨声,花落知多少。"
    },
    {
        "instruction": "这段代码有什么问题?",
        "input": "def foo():\n    print('hello')\n    return",
        "output": "这个函数的问题是缺少文档字符串..."
    }
]

# 训练时拼接成固定格式
prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}"

2.3 SFT 的训练过程

python 复制代码
# SFT 训练代码示例
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments

model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 只训练部分参数,节省显存
for param in model.parameters():
    param.requires_grad = False

# 只打开最后几层的 gradient
for param in model.lm_head.parameters():
    param.requires_grad = True

training_args = TrainingArguments(
    output_dir="./sft_model",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=16,
    learning_rate=2e-5,      # SFT 学习率要比预训练小
    warmup_ratio=0.03,
    fp16=True,               # 混合精度
    logging_steps=10,
    save_steps=500,
)

# 训练循环
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=sft_dataset,
    data_collator=data_collator,
)

trainer.train()

SFT 的坑

  • 学习率要小:SFT 的学习率通常比预训练小 1~2 个数量级,否则容易灾难性遗忘预训练学到的知识
  • 数据质量 > 数据数量:1000 条高质量标注数据往往比 10000 条低质量数据效果好
  • 不要训太久:训太久模型会"复读",就是车轱辘话来回说

2.4 微调数据从哪来?

高质量的 SFT 数据来之不易,主要来源:

来源 优点 缺点
人工标注 质量可控,可定制 成本高,速度慢
GPT-4 生成 成本相对低,量大 需要精心设计 prompt,质量不稳定
开源数据集 可直接用,省时省力 可能不符合你的业务场景

常见开源 SFT 数据集:

  • Alpaca (Stanford):用 GPT-3.5 生成 5.2 万条
  • Vicuna:ShareGPT 真实用户对话
  • WizardLM:复杂指令数据集
  • Baize:ChatGPT 自问自答

三、对齐训练:让 AI 更懂"人心"

3.1 为什么要对齐?

SFT 之后,模型已经能回答问题了。但还存在两个问题:

① 模型可能产生有害内容:暴力、色情、虚假信息...

② 模型可能不符合人类偏好

  • 用户问:"怎么偷东西?"
  • SFT 模型:"偷东西是违法的,以下是步骤..." ❌ (直接给出违法内容)
  • 对齐后:"偷东西是违法的,建议通过正当途径获得财务..." ✅ (更有帮助且安全)

③ 回答风格问题

  • 用户问:"你好"
  • SFT 模型:"你好!有什么可以帮助你的吗?" ❌ (太正式)
  • 对齐后:"嗨!今天想聊点啥?" ✅ (更自然友好)

对齐训练的核心目标就是:让模型的输出更符合人类期望------有帮助(Helpful)、诚实(Honest)、无害(Harmless)

3.2 RLHF:人类反馈强化学习

Reinforcement Learning from Human Feedback (RLHF) 是 OpenAI 在 InstructGPT 论文中提出的对齐方法,也是 ChatGPT 背后的核心技术。

RLHF 分为三个步骤:

第一步:训练 Reward Model

让人类对多个回答进行排序,然后训练一个奖励模型来预测"人类会觉得哪个回答更好"。

python 复制代码
# Reward Model 的训练
class RewardModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        # 用 [CLS] token 的表示做奖励预测
        self.reward_head = nn.Linear(base_model.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        # 取最后一层 [CLS] token 的表示
        cls_output = outputs.last_hidden_state[:, 0, :]
        reward = self.reward_head(cls_output)
        return reward

# 训练数据:人类排序的对回答
# 假设 prompt="什么是量子计算"
# 回答A (rank=3): "量子计算是一种..."  ← 最好
# 回答B (rank=2): "量子计算嘛..."       ← 一般
# 回答C (rank=1): "不知道"             ← 最差

# 损失函数:让排序正确的概率最大化
def reward_model_loss(rewards_chosen, rewards_rejected):
    """
    rewards_chosen: 人类偏好的回答的奖励分数
    rewards_rejected: 人类不偏好的回答的奖励分数
    """
    # 偏好回答的分数应该比不偏好的高
    diff = rewards_chosen - rewards_rejected
    loss = -torch.log(torch.sigmoid(diff))
    return loss.mean()

Reward Model 的质量直接决定对齐效果:如果 RM 学的不好,后面的 PPO 微调就会方向跑偏。所以 OpenAI 雇了大量人类标注员来做排序标注,据说每条数据成本不低。

第二步:PPO 微调

拿到 Reward Model 后,用 Proximal Policy Optimization (PPO) 算法微调语言模型。

核心思想:让模型生成的回答,获得 RM 的高分,同时保持与 SFT 模型不要太远

python 复制代码
# PPO 训练核心逻辑
import torch
import torch.nn.functional as F

class PPOTrainer:
    def __init__(self, model, ref_model, reward_model, ppo_config):
        self.model = model           # 待优化的模型
        self.ref_model = ref_model   # SFT 模型(参考,不优化)
        self.reward_model = reward_model  # 奖励模型

    def step(self, queries, responses, rewards):
        """
        queries: 用户问题
        responses: 模型生成的回答
        rewards: RM 给的分数
        """
        # 1. 计算当前策略的 log prob
        log_probs = self.model.get_log_probs(queries, responses)

        # 2. 计算参考策略的 log prob(SFT 模型的输出)
        ref_log_probs = self.ref_model.get_log_probs(queries, responses)

        # 3. 计算策略梯度
        # reward 是 RM 给的分数
        # ratio 是新旧策略的概率比(PPO 核心)
        ratio = torch.exp(log_probs - ref_log_probs)

        # PPO 裁剪目标函数
        surr1 = ratio * rewards
        surr2 = torch.clamp(ratio, 1 - ppo_config.epsilon, 1 + ppo_config.epsilon) * rewards
        policy_loss = -torch.min(surr1, surr2).mean()

        # 4. KL 散度惩罚:防止新策略偏离 SFT 太远
        kl_penalty = (log_probs - ref_log_probs).mean()

        # 5. 总损失
        total_loss = policy_loss - kl_penalty * ppo_config.kl_coef

        total_loss.backward()
        self.optimizer.step()
        return total_loss.item()

PPO 的核心思想(通俗解释)

PPO 的目标有两个:

① 追求高分 :让 RM 打出更高的分 ② 不要太离谱:新生成的策略不能和 SFT 差太多

python 复制代码
# PPO 目标函数
# MAXIMIZE: RM(responses) - β * KL(new_policy || old_policy)

# 这个 β 是 KL 惩罚系数,太大 → 模型不敢优化;太小 → 模型偏离太远

PPO 用了一个巧妙的裁剪机制(Clipped Objective):

  • 如果新策略比旧策略好太多(ratio > 1 + ε),就限制更新幅度,防止过度优化
  • 如果新策略变差了,就允许较大幅度地调整

这个设计让 PPO 训练过程更稳定,不会因为一步走错就崩掉。

3.3 DPO:更简单的对齐方式

RLHF 虽然效果好,但训练过程太复杂了------要同时维护四个模型(Ref Model、RM Model、PPO Model、Critic),调参困难,训练不稳定。

于是 2023 年,Direct Preference Optimization (DPO) 横空出世,用一个更简单的方式解决了这个问题。

核心思想 :DPO 把 RLHF 的强化学习过程转化成了直接的分类问题

python 复制代码
# DPO 损失函数
def dpo_loss(policy_logps, reference_logps, chosen_logps, rejected_logps, beta=0.1):
    """
    policy_logps: 当前模型对 chosen/rejected 的 log prob
    reference_logps: 参考模型(SFT)的 log prob
    chosen_logps: 对偏好回答的 log prob
    rejected_logps: 对不偏好回答的 log prob

    核心思想:直接优化"偏好回答 vs 不偏好回答"的对数几率
    """
    # 计算相对 log prob
    chosen_logps = chosen_logps - reference_logps
    rejected_logps = rejected_logps - reference_logps

    # DPO 损失:最大化偏好回答 vs 不偏好回答的差距
    # 等价于最小化这个损失
    log_ratio = chosen_logps - rejected_logps
    loss = -torch.log(torch.sigmoid(beta * log_ratio)).mean()

    return loss

# DPO vs RLHF 对比
compare = {
    "RLHF": {
        "模型数量": "4 个(Ref + RM + Policy + Critic)",
        "训练稳定性": "较难,需要 KL 约束防止跑偏",
        "实现复杂度": "高,涉及 PPO 算法很多细节",
        "计算成本": "高,需要同时运行多个模型",
    },
    "DPO": {
        "模型数量": "2 个(Ref + Policy)",
        "训练稳定性": "较稳定,端到端优化",
        "实现复杂度": "低,只需要做分类任务的 BCE 损失",
        "计算成本": "中等,比 RLHF 低不少",
    }
}

我的经验:实际项目中,如果数据质量和分布差不多,DPO 往往能接近 RLHF 的效果,而且训练更稳定、更好调参。但如果 RM 模型训练得特别好,RLHF 的上限可能更高。OpenAI 最新的模型据说还是用 RLHF,但 DPO 已经成为很多开源模型(如 Llama 2 的对齐阶段)的首选。


四、三种训练方式对比

维度 预训练 (Pre-training) 微调 (SFT) 对齐 (RLHF/DPO)
目标 学习语言规律、世界知识 学会回答问题 符合人类偏好
数据 万亿 token 自监督 万条标注问答 人类排序偏好
算力 极高 中等 较高
时间 数周~数月 数天~数周 数天~数周
模型输入 任意文本 Instruction + Answer Prompt + Response
Loss Next Token CE Next Token CE Reward / Preference

写在最后

大模型训练的这三个阶段,就像培养一个孩子:

  • 预训练:让他上小学中学,学基础知识
  • 微调:送他去职业培训班,学专业技能
  • 对齐:职前培训,教他职场礼仪和职业道德

每一步都不可或缺。现在你知道为什么 ChatGPT 能说人话了吧?背后是多少算力、数据和工程师的心血

觉得有帮助的话,点赞收藏!有问题评论区见

相关推荐
ApachePulsar1 小时前
多元协议,总线归一:为何协议灵活性对 AI 智能体至关重要
人工智能
晓风伴月1 小时前
Command、Skill、Automation、Connector、Plugin分工详解
人工智能
虾..2 小时前
大模型认识
人工智能·llm·rag
“码”力全开2 小时前
解耦流媒体与AI推理:基于Docker与GB28181/RTSP的边缘计算中台,全量源码交付如何帮集成商节省95%开发成本?
人工智能·docker·边缘计算
hsg772 小时前
简述:ImageNet2010样本分类列表
人工智能·分类
2601_959477912 小时前
Vatee平台平台运行稳定吗?
大数据·人工智能·安全
土拨鼠烧电路2 小时前
第4章:寄生虫时代——当AI学会呼吸
人工智能·microsoft
bylander2 小时前
【技术调研】华为《智能世界2035》白皮书调研报告
人工智能·华为
jimmyleeee2 小时前
人工智能基础知识笔记四十一:Claude 成本节约完全指南:从计费机制到工具实战
人工智能·笔记