一、大模型核心组件
1.1 RMSNorm
RMSNorm是LayerNorm的一种特殊情况。如论文绿色标识部分,LayerNorm是weights进行重中心化和缩放,减少模型对噪声的敏感度,提升模型的稳定性和鲁棒性。RMSNorm和LayerNorm不同的是RMSNorm只需要做缩放(re-scaling)不需要重中心化(re-centering),其中是训练参数(初始化为1),维度为hidden_size

1.1.1 代码实现
python
import torch.nn as nn
import torch
class RMSNorm(nn.Module):
def __init__(self, dim:int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weights = nn.Parameter(torch.ones(dim)) # 初始化g_i
def _norm(self, x):
# torch.rsqrt计算的是根号分之一
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor):
return self.weights * self._norm(x.float()).type_as(x)
1.1.2 拓展-BatchNorm与LayerNorm的区别
- BatchNorm:逐通道进行归一化,如对shape:[4, 3, 2, 2]的特征图进行归一化,会计算出3个均值和方差,每次会对4 * 2 * 2个元素进行归一化;适用CV领域
- LayerNorm:对最后一个维度进行归一化,即对每个token的隐藏层做归一化,如shape:[4, 10, 768]的张量,会计算出4 * 10个均值和方差,每次会对768个元素进行归一化;适用NLP领域
1.2 RoPE
旋转位置编码是一种相对位置编码,不同于绝对位置编码在self-attention之前将预计算好的位置编码加到embedding向量上,旋转位置编码是嵌入到self-attention计算内,具体的是先对query和key进行旋转,然后再做
如图所示,向量内两两进行组合应用旋转,旋转的角度为,其中m为token的位置索引,

对query或者对key进行旋转可表示为:

对14式 结合15式展开转化可写成:

其中 和
都可预先计算 。16式 为query和key应用旋转后再做矩阵运算的形式,可以看到旋转位置编码有n-m表现相对位置的关系
1.2.1 代码实现
1.2.1.1 预计算sin与cos
python
def precompute_freqs_cis(dim: int, max_seq_len: int = int(32*1024), theta: float = 1e6):
"""
预计算sin(m * \theta)与cos(m * \theta)
"""
# 计算公式中的theta:1/[1000000^(2(i-1)/d)]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 定义m
m = torch.arange(end, device=freqs.device)
# 得到的是 m*theta, shape: (max_seq_len, dim // 2)
freqs = torch.outer(m, freqs).float()
# cos(m*theta), 堆叠shape:(max_seq_len, dim)
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
# sin(m*theta), 堆叠shape:(max_seq_len, dim)
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
return freqs_cos, freqs_sin
1.2.1.2 应用RoPE
对query和key应用旋转位置编码
python
def rotate_half(x):
"""
将最后维度切割成两部分,构造出x_1和x_{d/2 + 1}旋转
有别于论文中x1和x2旋转,x3和x4旋转,以此类推...
如:最后一个维度为10,则x_1和x_6进行旋转
x_1 * cosm\theta_1 - x_(d/2 + 1) * sinm\theta_1
... ...
x_(d/2 + 1)* cosm\theta_1 + x_1 * sinm\theta_1
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
应用旋转位置编码
query和key shape: (batch_size, seq_len, num_of_head, head_dim)
cos和sin shape: (seq_len, head_dim)
"""
# unsqueeze后shape: (seq_len, 1, head_dim)
# 在计算的时候再进行broadcast进行维度的补齐
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
1.3 Self-Attention
自注意机制是Transfomer的核心,能并行计算token与token之间的相关度,注意力计算公式如下:
CausalLM(自回归模型,仅有decoder)在计算token相关度时,当前token只能看到"自己"和"之前"的token,所以需要将当前token之后的注意力分数进行mask,在计算softmax时,对应的注意力分数就是无限接近于0,图示如下:


注意力机制也分为几种,有:
- MHA(Multi-Head Attention):query head与key head、value head是一一对应的
- GQA(Grouped-Query Attention):对query head进行分组,组内query head对应一个key head、value head
- MQA(Multi-Query Attention):所有的query head对应一组key head、value head

其中,MHA的效果最佳,但是最占显存(kv-cache),MQA的效果是最差的,但最节省显存,GQA效果介于MHA和MQA之间,在计算效果与显存压力之间找到了平衡。假设2个query head对应1个key head/value head,则kv-cache占用的显存可节省一半
1.3.1 代码实现
python
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
如果是GQA模式,则需要扩充key head和value head
让query head和key head、value head数量是一样的
"""
bs, slen, num_key_value_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :] # python的基础语法
.expand(bs, slen, num_key_value_heads, n_rep, head_dim)
.reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args):
super().__init__()
# num_attention_heads:是多头注意力中多少个heads,即query head的数量
# num_key_value_heads:是key/value head
self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
# 验证一下参数是否合理,需要整除
assert args.num_attention_heads % self.num_key_value_heads == 0
# 换个名字
self.n_local_heads = args.num_attention_heads
self.n_local_kv_heads = self.num_key_value_heads
# 计算grouped query的分组
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.hidden_size // args.num_attention_heads
# 设置多头注意力的一些W矩阵
self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
def forward(self,
x: torch.Tensor,
position_embeddings, # 接收预先计算的 cos 和 sin
past_key_value = None, # 之前时刻的 K 和 V
use_cache= False,
attention_mask = None):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# reshape
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
pre_cos, pre_sin = position_embeddings
# 在 Q 和 K 身上应用 ROPE
xq, xk = apply_rotary_pos_emb(xq, xk, pre_cos[:seq_len], pre_sin[:seq_len])
# 关于 kv_cache
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
if use_cache:
past_kv = (xk, xv)
else:
past_kv = None
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2),
)
# 使用 self-attention 公式
scaled_scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
# scores + mask
look_ahead_mask = torch.triu(
torch.full((seq_len, seq_len), float('-inf'), device=scaled_scores.device), diagonal=1
)
masked_scores = (scaled_scores + look_ahead_mask).unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# attention_mask 中的 0 值会变成非常小的负数 -1e9
# 将 1 保持为 0 ,这样做在后续的 softmax 操作中,这些非常小的负数值会接近零
# 从而在 softmax 之后几乎为零,实现忽略这些位置的效果
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
masked_scores = masked_scores + extended_attention_mask
scores = F.softmax(masked_scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.o_proj(output))
return output, past_kv
1.3.2 拓展
- 在计算注意力时,为什么要对注意力分数进行缩放,即除以
?
缩放时为了保证训练的稳定性,如果不进行缩放,会产生如下问题:
- 注意力分数:query和key的点积实际上是每个元素的相乘再相加,如果隐藏层的维度d很大,则计算点积有些元素的值会极大,再进行softmax后有些注意力分数会无限接近于1,而其他注意力分数会无限接近于0
- 梯度消失:softmax在输入值极大时,梯度会极小,在反向传播时,模型几乎无法更新参数,学习停滞(图示红框为softmax偏导)

1.4 FFN
FFN(Feed Forword Network)是对隐藏层进行维度扩充后再进行维度还原。
1.4.1 代码实现
python
from transformers.activations import ACT2FN
class FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
if config.intermediate_size is None:
intermediate_size = int(config.hidden_size * 8 / 3)
config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.dropout = nn.Dropout(config.dropout)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.dropout(self.down_proj(self.up_proj(x) * self.act_fn(self.gate_proj(x))))
二、大模型MoE结构
2.1 概述
MoE(Mixture of Experts)混合专家模型主要应用在FFN,即每个expert都是一个FFN。序列中的每个token会和每个expert计算出一个分数,取分数最高的topk个expert进行前馈计算,计算之后需要乘以对应的分数,然后把topk个专家得出的结果再相加,如图所示为k=1的情况:

传统MoE中,每个专家本身还是一个较大的神经网络,很容易学习到通用、重复的知识,导致多个专家干同样的事,造成参数浪费。在DeepseekMoE的论文(2401.06066)中,将MoE的专家数进行了细化(图b),处理特定的,细分领域的知识,专家更专;同时剥离出共享专家(图c),处理共用的通用知识

2.2 代码实现
在代码实现过程中,通过抽取出MoEGate 来处理token与每个专家的计算(本质是做线性运算),得到每个token选取的topk个专家的索引 和权重(softmax后的值),同时计算出aux loss,aux loss分为两个级别:
- sequence级别:将序列看成一个整体,如果某条序列的token都选取了同一个expert,那么惩罚这条序列,鼓励其使用其他的多个expert
- token级别:统计每个token选择了哪些expert,如果大部分token都选择了同一个expert,则进行惩罚,鼓励选择其他expert
通过定义MoEFeedForward 封装整个MoE的计算,包含通过MoEGate来为每个token选出topk个expert,然后进行前馈计算,得到结果再乘以对应的权重。如果有共享expert 的话,则还需要和共享expert进行前馈计算,最终将topk个专家计算结果和共享专家**(无需乘权重)**计算结果进行相加
python
class MoEGate(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts # 表示总的可选专家数量
self.scoring_func = config.scoring_func # 选择使用哪种评分方式(一般就是'softmax')
# 为了让MoE表现的更均衡,我们可以设置关于MoE的权重,回头加到total loss身上
self.alpha = config.aux_loss_alpha # 控制辅助损失项的权重
self.seq_aux = config.seq_aux # 计算关于MOE是否balance的损失时有两种方式(1,token level;2,sequence level)
self.norm_topk_prob = config.norm_topk_prob # 是否对 topK 的概率进行归一化
self.gating_dim = config.hidden_size # 输入向量的维度
# 定义一个可学习的门控矩阵,形状为 [n_routed_experts, hidden_size]
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
# 调用 初始化函数对上面这个 weight 进行初始化
self.reset_parameters()
def register_parameter(self):
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
# 这块是核心逻辑,输入是一个batch的隐藏状态,输出是每个token的专家分配结果和辅助损失
bsz, seq_len, h = hidden_states.shape
# 把输入展平成二维数组,方便处理每个token; 二维数组对应的形状就是 [bsz*seq_len, hidden_size]
hidden_states = hidden_states.view(-1, h)
# 计算每个token对每个专家expert的原始分数logits,形状是 [total_tokens, n_routed_experts]
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'unsupportted scoring fucntion for MOE gating: {self.scoring_func}')
# 对每个token,在expert维度上选出 topK 个得分最高的专家
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
# 是否启用了norm_topk_prob,对topK的权重做归一化,使其总和为1,防止除以零加一个小数值
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
# 如果处于训练模式并且启用了辅助损失,则开始构建辅助损失项
if self.training and self.alpha > 0.0:
scores_for_aux = scores # 所有expert的得分,也就还没取topK
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1) # 展平之后的topK专家索引
if self.seq_aux:
# 按照sequence级别计算辅助损失
# 每条sequence看作一个整体,如果某条sequence所有token都只用了expert 0,那么则惩罚这条sequence,鼓励其使用其它多个expert
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
# 构建一个专家被选择的频率矩阵 ce
ce = torch.zeros((bsz, self.n_routed_experts), device=hidden_states.device)
# 使用 scatter_add_ 来统计每个batch每个expert被选中了多少次
ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len* aux_topk,
device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts
)
# 然后做一个平均,并且与平均得分相乘,作为辅助损失
# 目的是防止某些expert被频繁选中,造成负载不均
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
# 按照token级别计算辅助损失
# 分布统计每个token选择了哪个expert,如果大部分token都选择expert 0,则惩罚它,鼓励选择其它expert
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
# 计算每个expert 的平均得分
Pi = scores_for_aux.mean(0)
# 计算每个expert被选中的频率
fi = ce * self.n_routed_experts
# 辅助损失是两者相乘的结果
aux_loss = (Pi * fi).sum() * self.alpha
# topk_idx: 每个token被分配到 topK 个 expert 的索引
# topk_weight: 每个 expert 对应的权重
# aux_loss: 辅助损失项,用于平衡专家之间的负载
return topk_idx, topk_weight, aux_loss
# 定义MOEFeedForward类
class MOEFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts > 0:
self.shared_experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_shared_experts)
])
def forward(self, x):
identity = x # 做 skip connection
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制专家的选择
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
# 对每个token,复制 num_experts_per_tok 多份,
# 这样做的目的是为了将每个token同时传入其top-K个被选中的专家里面进行计算
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
# 创建一个与x形状相同但是类型为 float16 的空张量,用于存储每个token经过对应专家处理后的结果
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
# flat_topk_idx 是一个索引张量,表示每个token被分配给了哪个专家
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype)
# 将输出按照token和专家维度重新组织
# 使用 topk_weight 权重对每个专家的输出进行加权求和
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
# 把最终输出恢复成原始输入的形状
y = y.view(*orig_shape)
else:
# 在推理阶段使用更高效的函数 moe_infer 处理 MOE 部分
# 通常是为了减少内存冗余或计算冗余,例如合并多个token,一起处理
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
# 如果启用了共享专家,它们会作用在所有的token上
if self.config.n_shared_experts > 0:
for expert in self.shared_experts:
y = y + expert(identity)
# 通常这个损失会加到 total_loss = task_loss + config.aux_loss_coeff * model.aux_loss
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
# tokens_per_expert = [6, 15, 20, 26] 这四个数值分别代表4个专家处理的token数量
tokens_idxs = idxs // self.config.num_experts_per_tok
# token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 代表着 token_idxs[:6]
# 属于0号专家的;每个token有可能被多个专家处理,取决于 config.num_experts_per_tok
for i, end_idx in enumerate(tokens_per_expert):
# 计算当前专家处理token的起始索引
start_idx = 0 if i==0 else tokens_per_expert[i-1]
# 如果没有token被分配给这个专家,跳过该专家
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = tokens_idxs[start_idx:end_idx]
# 从原始的输入x中获取这些token的嵌入
expert_tokens = x[exp_token_idx]
# 输入到当前专家网络中进行前向传播;
expert_out = expert(expert_tokens).to(expert_cache.dtype)
# 对专家输出进行加权
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# 使用 scatter_add_ 将专家输出加到最终的输出张量上面去,加权之后的求和
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
三、模型训练
3.1 预训练
预训练(Pre-train,简称PT)是直接让模型在语料库上进行自监督的训练,目的是为了让模型学到人类语言的语法、语义和世界知识。
大模型是生成模型,是依据前面的t个token,生成t+1时刻的token,所以在预训练样本的构造中,特征和label之间需要偏置一个token,例如有语料【我爱你。】,则特征的部分是【我爱你】,label的部分是【爱你。】

3.2 SFT
SFT(Supervised Fine-Tuning)有别于预训练阶段的自监督训练,SFT训练是需要带label的样本,是有监督训练。预训练得到的模型拥有丰富的语言知识和模式识别能力,SFT的目的是让模型能够准确的理解人类的指令和需求,按照明确的指示完成任务
现阶段SFT的方式有全参微调(FT)和高效微调(PEFT),PEFT算法又包含如Prompt-tuning、P-tuning、Adapter、LoRA、QLoRA等算法,本文着重介绍LoRA(QLoRA原理差不多)
LoRA的核心思想是冻结 原始模型的权重,并为模型的线性层(如注意力计算的投影矩阵)注入可训练 的**低秩分解矩阵,**这种方法以极少的参数量,实现媲美全参微调的性能,而不增加推理的耗时。假设一个线性层可训练参数矩阵的shape为(100, 100),则参数量为10000,其可以分解为两个低秩矩阵的乘积,假设秩r=10,则可训练参数量仅为100 * 10 * 2 = 2000,可大大节省训练显存

具体到训练样本构造方面,通常有两部分,一部分是prompt,另一部分是AI回答部分。在计算损失时,prompt部分的损失通常通过loss_mask进行mask掉,只计算AI回答部分的损失。
3.3 DPO
DPO(Direct Preference Optimization)是一种人类偏好对齐算法,与PPO(Proximal Policy Optimization)两阶段训练(需要先基于人类打标偏好数据训练出奖励模型;再使用奖励模型给策略模型的输出进行打分迭代训练)不同的是,DPO是端对端的,直接最小化损失函数

相较于PPO优势在于:
- 计算效率高:不需要进行策略梯度估计和价值模型的学习
- 训练过程简单:可以直接在单词前向传播计算损失
- 计算开销小:无需额外的网络结构和训练步骤
3.3.1 损失函数解读

是策略模型,
是参考模型
是chosen的输出,
是rejected的输出
是超参数,论文中的值为0.1;
是sigmoid函数
基于上述符号的含义与对数函数的性质对公式进行变换:
=
= ········ ①
代码部分也是基于式①进行实现的,其中具体的 实现为:
- 模型输出logits,对logits的最后一个维度进行log_softmax
- 依据给出的label(input_ids),选取label对应位置的概率值probs
- probs乘以loss_mask,在计算损失的时候只关注loss_mask为1(prompt位置为0,assistant位置为1)位置的损失值
3.3.2 损失代码实现
python
import torch.nn.functional as F
def dpo_loss(policy_chosen_logps, # 策略模型对chosen输出的对数概率
policy_rejected_logps, # 策略模型对rejected输出的对数概率
reference_chosen_logps, # 参考模型对chosen输出的对数概率
reference_rejected_logps, # 参考模型对rejected输出的对数概率
beta=0.1):
# 1. 计算策略模型与参考模型的对数概率差(隐式奖励)
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
# 2. 利用beta缩放差值,并计算损失
# 对应公式:logits = beta * (pi_logratios - ref_logratios)
logits = pi_logratios - ref_logratios
# 使用二元交叉熵,目标是希望chosen优于rejected(标签为1)
losses = -F.logsigmoid(beta * logits)
return losses.mean()
3.4 GRPO
GRPO(Group Relative Policy Optimization)是一种简化的 PPO 变体,专门为 LLM 设计。GRPO 的核心思想是:不需要 Value Model,使用组内相对奖励代替绝对奖励;简化训练流程,只需要 Policy Model 和 Reference Model;提高训练稳定性,减少奖励崩塌的风险
3.4.1 训练过程解析

- **采样阶段:**对于每个问题,使用当前策略生成多个答案。这些答案构成一个组,用于计算相对奖励
- 奖励计算: 对每个生成的答案计算奖励
。奖励可以是准确性奖励、答案长度奖励、步骤奖励或它们的组合
- 相对奖励: 计算组内平均奖励
,然后计算相对奖励
,这样做的好处是减少奖励方差,使训练更稳定
- **策略更新:**使用相对奖励更新策略,同时添加KL散度惩罚,防止策略偏离参考模型太远
- **重复:**重复上述步骤,直到完成所有训练轮次
python
# 假设我们有一个问题
question = "What is 48 + 24?"
# 生成4个答案
answers = [
"48 + 24 = 72. Final Answer: 72", # 正确
"48 + 24 = 72. Final Answer: 72", # 正确
"48 + 24 = 70. Final Answer: 70", # 错误
"Let me think... 72. Final Answer: 72" # 正确但冗长
]
# 计算奖励(假设使用准确率 + 长度惩罚)
rewards = [1.0, 1.0, 0.0, 0.8] # 第4个答案因为冗长被惩罚
# 计算组内平均奖励
avg_reward = (1.0 + 1.0 + 0.0 + 0.8) / 4 = 0.7
# 计算相对奖励
relative_rewards = [
1.0 - 0.7 = 0.3, # 正确且简洁,相对奖励为正
1.0 - 0.7 = 0.3, # 正确且简洁,相对奖励为正
0.0 - 0.7 = -0.7, # 错误,相对奖励为负
0.8 - 0.7 = 0.1 # 正确但冗长,相对奖励较小
]
# 策略更新:增加前两个答案的概率,减少第三个答案的概率
3.4.2 损失函数

代码实现:
python
import torch
import torch.nn.functional as F
def grpo_loss(
log_prob, # 当前策略的对数概率 (new_log_probs),形状: (group_size, ...)
old_log_prob, # 旧策略的对数概率 (old_log_probs),形状: (group_size, ...)
rewards, # 奖励值,在可验证奖励场景下通常是二值的 (0或1),形状: (group_size,)
ref_log_prob, # 参考策略的对数概率,形状: (group_size, ...)
beta=0.1 # KL正则化系数
):
# 1. 计算重要性比率 (importance ratio)
# 注意:许多实现直接操作概率比
ratio = torch.exp(log_prob - old_log_prob)
# 2. 计算未裁剪的替代目标 (surrogate objective)
# 优势 (Advantage) 通常通过组内奖励归一化得到
# adv = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
# 为简化,这里直接使用归一化的奖励作为优势,计算相对奖励
normalized_rewards = (rewards - rewards.mean())
surrogate = ratio * normalized_rewards.unsqueeze(-1)
# 3. 计算策略损失 (policy loss)
policy_loss = -torch.mean(surrogate)
# 4. 计算KL散度正则项 (KL penalty)
# KL(当前策略 || 参考策略) 的估计
# kl = torch.exp(log_prob) * (log_prob - ref_log_prob)
kl_div = torch.mean(torch.exp(log_prob) * (log_prob - ref_log_prob))
# 5. 总损失
total_loss = policy_loss + beta * kl_div
return total_loss, policy_loss, kl_div
四、模型蒸馏
4.1 原理
模型蒸馏是一种将大模型(Teacher Model)的知识迁移到小模型(Student Model)的技术。通过这种方式,可以让小模型学习到大模型的"暗知识(Dark Knowledge)",从而以较小的成本达到接近大模型的性能
这里的"暗知识"就是Teacher Model输出的"软标签(soft labels)",包含了类别之间的相似性关系,而不仅仅是哪个类别是正确的。例如,教师模型可能认为某个样本属于类别 狗 的概率为 78% ,类别 猫 的概率为 20% ,类别 车 的概率为 2%,可以看出狗和猫是具有一定的相似性的,而狗和车相似性很低。这种分布信息可以帮助学⽣模型更好地理解不同类别的关系
软标签是通过模型输出得到logits ,再除以Temperature ,再经过softmax 函数,即:
- Temperature = 1,则输出更加"尖锐",极端情况下接近于one-hot编码,表示确定性很强
- Temperature > 1,则输出更加"平缓",小概率类别也有非零值,有助于传递"暗知识"

如图所示:
- Teacher Model:基于输入x得到 Temperature=t 的软标签
- Student Model:输出包含两部分,一部分是基于输入x得到 Temperature=t 的软标签,另一部分是基于输入x得到 Temperature=1 的硬标签
- Teacher Model与Student Model得到的软标签计算KL散度损失,++鼓励Stuent Model学习Teacher Model的分布++
- Student Model计算出来的硬标签与真实标签计算交叉熵损失,确保++模型仍然能够正确预测真实标签++
基于上,总损失如下,其中是一个超参数,用于控制KL损失和交叉熵损失的权重:
注:KL损失乘以 是因为在反向传播时, logits 的梯度会受到Temperature的影响。如果不补偿,⾼温会导致梯度变⼩,训练变慢。 Hinton 在论⽂附录中证明:为了保持梯度幅度与 T=1 时相当,需要将 KL 损失乘以该值
4.2 优势
- 学⽣模型不仅需要预测正确的类别,还需要模仿教师模型的输出分布,这有助于提⾼模型的泛化能⼒
- 学⽣模型的结构通常⽐教师模型更简单,因此推理速度更快,计算成本更低
4.3 问题与解决方案
|-----------|--------------------------------|
| 问题 | 解决方案 |
| 教师太大,推理慢 | 使⽤缓存:提前⽣成所有 teacher logits 并保存 |
| 学⽣太⼩,⽆法拟合 | 增加投影层、使⽤更强的数据增强 |
| vocab不⼀致 | 截断或扩展 vocab ,保持⼀致 |
| 训练不稳定 | 使⽤梯度裁剪、低学习率、 EMA 平滑 |
参考
hello agent - hello-agents/docs/chapter11/第十一章 Agentic-RL.md