大模型核心技术解析

一、大模型核心组件

1.1 RMSNorm

RMSNorm是LayerNorm的一种特殊情况。如论文绿色标识部分,LayerNorm是weights进行重中心化和缩放,减少模型对噪声的敏感度,提升模型的稳定性和鲁棒性。RMSNorm和LayerNorm不同的是RMSNorm只需要做缩放(re-scaling)不需要重中心化(re-centering),其中是训练参数(初始化为1),维度为hidden_size

1.1.1 代码实现

python 复制代码
import torch.nn as nn
import torch

class RMSNorm(nn.Module):
    def __init__(self, dim:int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weights = nn.Parameter(torch.ones(dim)) # 初始化g_i
        
    def _norm(self, x):
        # torch.rsqrt计算的是根号分之一
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        return self.weights * self._norm(x.float()).type_as(x)

1.1.2 拓展-BatchNorm与LayerNorm的区别

  • BatchNorm:逐通道进行归一化,如对shape:[4, 3, 2, 2]的特征图进行归一化,会计算出3个均值和方差,每次会对4 * 2 * 2个元素进行归一化;适用CV领域
  • LayerNorm:对最后一个维度进行归一化,即对每个token的隐藏层做归一化,如shape:[4, 10, 768]的张量,会计算出4 * 10个均值和方差,每次会对768个元素进行归一化;适用NLP领域

1.2 RoPE

旋转位置编码是一种相对位置编码,不同于绝对位置编码在self-attention之前将预计算好的位置编码加到embedding向量上,旋转位置编码是嵌入到self-attention计算内,具体的是先对query和key进行旋转,然后再做

如图所示,向量内两两进行组合应用旋转,旋转的角度为,其中m为token的位置索引,

对query或者对key进行旋转可表示为:

14式 结合15式展开转化可写成:

其中 都可预先计算16式 为query和key应用旋转后再做矩阵运算的形式,可以看到旋转位置编码有n-m表现相对位置的关系

1.2.1 代码实现

1.2.1.1 预计算sin与cos
python 复制代码
def precompute_freqs_cis(dim: int, max_seq_len: int = int(32*1024), theta: float = 1e6):
    """
    预计算sin(m * \theta)与cos(m * \theta)
    """
	# 计算公式中的theta:1/[1000000^(2(i-1)/d)]
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # 定义m
    m = torch.arange(end, device=freqs.device)
    # 得到的是 m*theta, shape: (max_seq_len, dim // 2)
    freqs = torch.outer(m, freqs).float()
    # cos(m*theta), 堆叠shape:(max_seq_len, dim)
    freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
    # sin(m*theta), 堆叠shape:(max_seq_len, dim)
    freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
    return freqs_cos, freqs_sin
1.2.1.2 应用RoPE

对query和key应用旋转位置编码

python 复制代码
def rotate_half(x):
    """
	将最后维度切割成两部分,构造出x_1和x_{d/2 + 1}旋转
	有别于论文中x1和x2旋转,x3和x4旋转,以此类推...
	如:最后一个维度为10,则x_1和x_6进行旋转
    x_1 * cosm\theta_1    		-	x_(d/2 + 1) * sinm\theta_1
    ...								...
    x_(d/2 + 1)* cosm\theta_1	+	x_1 * sinm\theta_1
	"""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)
	
	
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """
	应用旋转位置编码
	query和key shape: (batch_size, seq_len, num_of_head, head_dim)
	cos和sin   shape: (seq_len, head_dim)
	"""
	# unsqueeze后shape: (seq_len, 1, head_dim)
	# 在计算的时候再进行broadcast进行维度的补齐
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

1.3 Self-Attention

自注意机制是Transfomer的核心,能并行计算token与token之间的相关度,注意力计算公式如下:

CausalLM(自回归模型,仅有decoder)在计算token相关度时,当前token只能看到"自己"和"之前"的token,所以需要将当前token之后的注意力分数进行mask,在计算softmax时,对应的注意力分数就是无限接近于0,图示如下:

注意力机制也分为几种,有:

  • MHA(Multi-Head Attention):query head与key head、value head是一一对应的
  • GQA(Grouped-Query Attention):对query head进行分组,组内query head对应一个key head、value head
  • MQA(Multi-Query Attention):所有的query head对应一组key head、value head

其中,MHA的效果最佳,但是最占显存(kv-cache),MQA的效果是最差的,但最节省显存,GQA效果介于MHA和MQA之间,在计算效果与显存压力之间找到了平衡。假设2个query head对应1个key head/value head,则kv-cache占用的显存可节省一半

1.3.1 代码实现

python 复制代码
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
	"""
	如果是GQA模式,则需要扩充key head和value head
	让query head和key head、value head数量是一样的
	"""
    bs, slen, num_key_value_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :] # python的基础语法
        .expand(bs, slen, num_key_value_heads, n_rep, head_dim)
        .reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
    )


class Attention(nn.Module):
    def __init__(self, args):
        super().__init__()
        # num_attention_heads:是多头注意力中多少个heads,即query head的数量
        # num_key_value_heads:是key/value head
        self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
        # 验证一下参数是否合理,需要整除
        assert args.num_attention_heads % self.num_key_value_heads == 0
        # 换个名字
        self.n_local_heads = args.num_attention_heads
        self.n_local_kv_heads = self.num_key_value_heads
        # 计算grouped query的分组
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.hidden_size // args.num_attention_heads
        # 设置多头注意力的一些W矩阵
        self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout

    def forward(self,
                x: torch.Tensor,
                position_embeddings, # 接收预先计算的 cos 和 sin
                past_key_value = None, # 之前时刻的 K 和 V
                use_cache= False,
                attention_mask = None):
        bsz, seq_len, _ = x.shape
        xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
        # reshape
        xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)

        pre_cos, pre_sin = position_embeddings
        # 在 Q 和 K 身上应用 ROPE
        xq, xk = apply_rotary_pos_emb(xq, xk, pre_cos[:seq_len], pre_sin[:seq_len])

        # 关于 kv_cache
        if past_key_value is not None:
            xk = torch.cat([past_key_value[0], xk], dim=1)
            xv = torch.cat([past_key_value[1], xv], dim=1)

        if use_cache:
            past_kv = (xk, xv)
        else:
            past_kv = None

        xq, xk, xv = (
            xq.transpose(1, 2),
            repeat_kv(xk, self.n_rep).transpose(1, 2),
            repeat_kv(xv, self.n_rep).transpose(1, 2),
        )

        # 使用 self-attention 公式
        scaled_scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
        # scores + mask
        look_ahead_mask = torch.triu(
            torch.full((seq_len, seq_len), float('-inf'), device=scaled_scores.device), diagonal=1
        )
        masked_scores = (scaled_scores + look_ahead_mask).unsqueeze(0).unsqueeze(0)

        if attention_mask is not None:
            extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            # attention_mask 中的 0 值会变成非常小的负数 -1e9
            # 将 1 保持为 0 ,这样做在后续的 softmax 操作中,这些非常小的负数值会接近零
            # 从而在 softmax 之后几乎为零,实现忽略这些位置的效果
            extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
            masked_scores = masked_scores + extended_attention_mask

        scores = F.softmax(masked_scores.float(), dim=-1).type_as(xq)
        scores = self.attn_dropout(scores)
        output = scores @ xv

        output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
        output = self.resid_dropout(self.o_proj(output))
        return output, past_kv

1.3.2 拓展

  1. 在计算注意力时,为什么要对注意力分数进行缩放,即除以 ?

缩放时为了保证训练的稳定性,如果不进行缩放,会产生如下问题:

  • 注意力分数:query和key的点积实际上是每个元素的相乘再相加,如果隐藏层的维度d很大,则计算点积有些元素的值会极大,再进行softmax后有些注意力分数会无限接近于1,而其他注意力分数会无限接近于0
  • 梯度消失:softmax在输入值极大时,梯度会极小,在反向传播时,模型几乎无法更新参数,学习停滞(图示红框为softmax偏导

1.4 FFN

FFN(Feed Forword Network)是对隐藏层进行维度扩充后再进行维度还原。

1.4.1 代码实现

python 复制代码
from transformers.activations import ACT2FN

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        if config.intermediate_size is None:
            intermediate_size = int(config.hidden_size * 8 / 3)
            config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.dropout = nn.Dropout(config.dropout)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        return self.dropout(self.down_proj(self.up_proj(x) * self.act_fn(self.gate_proj(x))))

二、大模型MoE结构

2.1 概述

MoE(Mixture of Experts)混合专家模型主要应用在FFN,即每个expert都是一个FFN。序列中的每个token会和每个expert计算出一个分数,取分数最高的topk个expert进行前馈计算,计算之后需要乘以对应的分数,然后把topk个专家得出的结果再相加,如图所示为k=1的情况:

传统MoE中,每个专家本身还是一个较大的神经网络,很容易学习到通用、重复的知识,导致多个专家干同样的事,造成参数浪费。在DeepseekMoE的论文(2401.06066)中,将MoE的专家数进行了细化(图b),处理特定的,细分领域的知识,专家更专;同时剥离出共享专家(图c),处理共用的通用知识

2.2 代码实现

在代码实现过程中,通过抽取出MoEGate 来处理token与每个专家的计算(本质是做线性运算),得到每个token选取的topk个专家的索引权重(softmax后的值),同时计算出aux loss,aux loss分为两个级别:

  • sequence级别:将序列看成一个整体,如果某条序列的token都选取了同一个expert,那么惩罚这条序列,鼓励其使用其他的多个expert
  • token级别:统计每个token选择了哪些expert,如果大部分token都选择了同一个expert,则进行惩罚,鼓励选择其他expert

通过定义MoEFeedForward 封装整个MoE的计算,包含通过MoEGate来为每个token选出topk个expert,然后进行前馈计算,得到结果再乘以对应的权重。如果有共享expert 的话,则还需要和共享expert进行前馈计算,最终将topk个专家计算结果和共享专家**(无需乘权重)**计算结果进行相加

python 复制代码
class MoEGate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.top_k = config.num_experts_per_tok
        self.n_routed_experts = config.n_routed_experts # 表示总的可选专家数量

        self.scoring_func = config.scoring_func # 选择使用哪种评分方式(一般就是'softmax')
        # 为了让MoE表现的更均衡,我们可以设置关于MoE的权重,回头加到total loss身上
        self.alpha = config.aux_loss_alpha  # 控制辅助损失项的权重
        self.seq_aux = config.seq_aux  # 计算关于MOE是否balance的损失时有两种方式(1,token level;2,sequence level)
        
        self.norm_topk_prob = config.norm_topk_prob # 是否对 topK 的概率进行归一化
        self.gating_dim = config.hidden_size # 输入向量的维度
        # 定义一个可学习的门控矩阵,形状为 [n_routed_experts, hidden_size]
        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
        # 调用 初始化函数对上面这个 weight 进行初始化
        self.reset_parameters()

    def register_parameter(self):
        import torch.nn.init as init
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, hidden_states):
        # 这块是核心逻辑,输入是一个batch的隐藏状态,输出是每个token的专家分配结果和辅助损失
        bsz, seq_len, h = hidden_states.shape
        # 把输入展平成二维数组,方便处理每个token; 二维数组对应的形状就是 [bsz*seq_len, hidden_size]
        hidden_states = hidden_states.view(-1, h)
        # 计算每个token对每个专家expert的原始分数logits,形状是 [total_tokens, n_routed_experts]
        logits = F.linear(hidden_states, self.weight, None)
        if self.scoring_func == 'softmax':
            scores = logits.softmax(dim=-1)
        else:
            raise NotImplementedError(f'unsupportted scoring fucntion for MOE gating: {self.scoring_func}')
        
        # 对每个token,在expert维度上选出 topK 个得分最高的专家
        topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

        # 是否启用了norm_topk_prob,对topK的权重做归一化,使其总和为1,防止除以零加一个小数值
        if self.top_k > 1 and self.norm_topk_prob:
            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
            topk_weight = topk_weight / denominator

        # 如果处于训练模式并且启用了辅助损失,则开始构建辅助损失项
        if self.training and self.alpha > 0.0:
            scores_for_aux = scores # 所有expert的得分,也就还没取topK
            aux_topk = self.top_k
            topk_idx_for_aux_loss = topk_idx.view(bsz, -1) # 展平之后的topK专家索引

            if self.seq_aux:
                # 按照sequence级别计算辅助损失
                # 每条sequence看作一个整体,如果某条sequence所有token都只用了expert 0,那么则惩罚这条sequence,鼓励其使用其它多个expert
                scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
                # 构建一个专家被选择的频率矩阵 ce
                ce = torch.zeros((bsz, self.n_routed_experts), device=hidden_states.device)
                # 使用 scatter_add_ 来统计每个batch每个expert被选中了多少次
                ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len* aux_topk, 
                                                                     device=hidden_states.device)).div_(
                                                                         seq_len * aux_topk / self.n_routed_experts
                                                                     )
                # 然后做一个平均,并且与平均得分相乘,作为辅助损失
                # 目的是防止某些expert被频繁选中,造成负载不均
                aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
            else:
                # 按照token级别计算辅助损失
                # 分布统计每个token选择了哪个expert,如果大部分token都选择expert 0,则惩罚它,鼓励选择其它expert
                mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
                ce = mask_ce.float().mean(0)
                # 计算每个expert 的平均得分
                Pi = scores_for_aux.mean(0)
                # 计算每个expert被选中的频率
                fi = ce * self.n_routed_experts
                # 辅助损失是两者相乘的结果
                aux_loss = (Pi * fi).sum() * self.alpha


        # topk_idx: 每个token被分配到 topK 个 expert 的索引
        # topk_weight: 每个 expert 对应的权重
        # aux_loss: 辅助损失项,用于平衡专家之间的负载
        return topk_idx, topk_weight, aux_loss


# 定义MOEFeedForward类
class MOEFeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.experts = nn.ModuleList([
            FeedForward(config)
            for _ in range(config.n_routed_experts)
        ])
        self.gate = MoEGate(config)
        if config.n_shared_experts > 0:
            self.shared_experts = nn.ModuleList([
                FeedForward(config)
                for _ in range(config.n_shared_experts)
            ])

    def forward(self, x):
        identity = x  # 做 skip connection
        orig_shape = x.shape
        bsz, seq_len, _ = x.shape
        # 使用门控机制专家的选择
        topk_idx, topk_weight, aux_loss = self.gate(x)
        x = x.view(-1, x.shape[-1])
        flat_topk_idx = topk_idx.view(-1)

        if self.training:
            # 对每个token,复制 num_experts_per_tok 多份,
            # 这样做的目的是为了将每个token同时传入其top-K个被选中的专家里面进行计算
            x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
            # 创建一个与x形状相同但是类型为 float16 的空张量,用于存储每个token经过对应专家处理后的结果
            y = torch.empty_like(x, dtype=torch.float16)
            for i, expert in enumerate(self.experts):
                # flat_topk_idx 是一个索引张量,表示每个token被分配给了哪个专家
                y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype)
            # 将输出按照token和专家维度重新组织
            # 使用 topk_weight 权重对每个专家的输出进行加权求和
            y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
            # 把最终输出恢复成原始输入的形状
            y = y.view(*orig_shape)
        else:
            # 在推理阶段使用更高效的函数 moe_infer 处理 MOE 部分
            # 通常是为了减少内存冗余或计算冗余,例如合并多个token,一起处理
            y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
        
        # 如果启用了共享专家,它们会作用在所有的token上
        if self.config.n_shared_experts > 0:
            for expert in self.shared_experts:
                y = y + expert(identity)

        # 通常这个损失会加到 total_loss = task_loss + config.aux_loss_coeff * model.aux_loss
        self.aux_loss = aux_loss

        return y

    @torch.no_grad()
    def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
        expert_cache = torch.zeros_like(x)
        idxs = flat_expert_indices.argsort()
        tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
        # tokens_per_expert = [6, 15, 20, 26] 这四个数值分别代表4个专家处理的token数量
        tokens_idxs = idxs // self.config.num_experts_per_tok
        # token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 代表着 token_idxs[:6]
        # 属于0号专家的;每个token有可能被多个专家处理,取决于 config.num_experts_per_tok

        for i, end_idx in enumerate(tokens_per_expert):
            # 计算当前专家处理token的起始索引
            start_idx = 0 if i==0 else tokens_per_expert[i-1]
            # 如果没有token被分配给这个专家,跳过该专家
            if start_idx == end_idx:
                continue
            expert = self.experts[i]
            exp_token_idx = tokens_idxs[start_idx:end_idx]
            # 从原始的输入x中获取这些token的嵌入
            expert_tokens = x[exp_token_idx]
            # 输入到当前专家网络中进行前向传播;
            expert_out = expert(expert_tokens).to(expert_cache.dtype)
            # 对专家输出进行加权
            expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
            # 使用 scatter_add_ 将专家输出加到最终的输出张量上面去,加权之后的求和
            expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)

        return expert_cache

三、模型训练

3.1 预训练

预训练(Pre-train,简称PT)是直接让模型在语料库上进行自监督的训练,目的是为了让模型学到人类语言的语法、语义和世界知识。

大模型是生成模型,是依据前面的t个token,生成t+1时刻的token,所以在预训练样本的构造中,特征和label之间需要偏置一个token,例如有语料【我爱你。】,则特征的部分是【我爱你】,label的部分是【爱你。】

3.2 SFT

SFT(Supervised Fine-Tuning)有别于预训练阶段的自监督训练,SFT训练是需要带label的样本,是有监督训练。预训练得到的模型拥有丰富的语言知识和模式识别能力,SFT的目的是让模型能够准确的理解人类的指令和需求,按照明确的指示完成任务

现阶段SFT的方式有全参微调(FT)和高效微调(PEFT),PEFT算法又包含如Prompt-tuning、P-tuning、Adapter、LoRA、QLoRA等算法,本文着重介绍LoRA(QLoRA原理差不多)

LoRA的核心思想是冻结 原始模型的权重,并为模型的线性层(如注意力计算的投影矩阵)注入可训练 的**低秩分解矩阵,**这种方法以极少的参数量,实现媲美全参微调的性能,而不增加推理的耗时。假设一个线性层可训练参数矩阵的shape为(100, 100),则参数量为10000,其可以分解为两个低秩矩阵的乘积,假设秩r=10,则可训练参数量仅为100 * 10 * 2 = 2000,可大大节省训练显存

具体到训练样本构造方面,通常有两部分,一部分是prompt,另一部分是AI回答部分。在计算损失时,prompt部分的损失通常通过loss_mask进行mask掉,只计算AI回答部分的损失。

3.3 DPO

DPO(Direct Preference Optimization)是一种人类偏好对齐算法,与PPO(Proximal Policy Optimization)两阶段训练(需要先基于人类打标偏好数据训练出奖励模型;再使用奖励模型给策略模型的输出进行打分迭代训练)不同的是,DPO是端对端的,直接最小化损失函数

相较于PPO优势在于:

  • 计算效率高:不需要进行策略梯度估计和价值模型的学习
  • 训练过程简单:可以直接在单词前向传播计算损失
  • 计算开销小:无需额外的网络结构和训练步骤

3.3.1 损失函数解读

  • 是策略模型,是参考模型
  • 是chosen的输出,是rejected的输出
  • 是超参数,论文中的值为0.1;是sigmoid函数

基于上述符号的含义与对数函数的性质对公式进行变换:

=

= ········ ①

代码部分也是基于式①进行实现的,其中具体的 实现为:

  • 模型输出logits,对logits的最后一个维度进行log_softmax
  • 依据给出的label(input_ids),选取label对应位置的概率值probs
  • probs乘以loss_mask,在计算损失的时候只关注loss_mask为1(prompt位置为0,assistant位置为1)位置的损失值

3.3.2 损失代码实现

python 复制代码
import torch.nn.functional as F

def dpo_loss(policy_chosen_logps,    # 策略模型对chosen输出的对数概率
             policy_rejected_logps,  # 策略模型对rejected输出的对数概率
             reference_chosen_logps, # 参考模型对chosen输出的对数概率
             reference_rejected_logps, # 参考模型对rejected输出的对数概率
             beta=0.1):

    # 1. 计算策略模型与参考模型的对数概率差(隐式奖励)
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps

    # 2. 利用beta缩放差值,并计算损失
    # 对应公式:logits = beta * (pi_logratios - ref_logratios)
    logits = pi_logratios - ref_logratios
    # 使用二元交叉熵,目标是希望chosen优于rejected(标签为1)
    losses = -F.logsigmoid(beta * logits)

    return losses.mean()

3.4 GRPO

GRPO(Group Relative Policy Optimization)是一种简化的 PPO 变体,专门为 LLM 设计。GRPO 的核心思想是:不需要 Value Model,使用组内相对奖励代替绝对奖励;简化训练流程,只需要 Policy Model 和 Reference Model;提高训练稳定性,减少奖励崩塌的风险

3.4.1 训练过程解析

  1. **采样阶段:**对于每个问题,使用当前策略生成多个答案。这些答案构成一个组,用于计算相对奖励
  2. 奖励计算: 对每个生成的答案计算奖励。奖励可以是准确性奖励、答案长度奖励、步骤奖励或它们的组合
  3. 相对奖励: 计算组内平均奖励 ,然后计算相对奖励,这样做的好处是减少奖励方差,使训练更稳定
  4. **策略更新:**使用相对奖励更新策略,同时添加KL散度惩罚,防止策略偏离参考模型太远
  5. **重复:**重复上述步骤,直到完成所有训练轮次
python 复制代码
# 假设我们有一个问题
question = "What is 48 + 24?"

# 生成4个答案
answers = [
    "48 + 24 = 72. Final Answer: 72",      # 正确
    "48 + 24 = 72. Final Answer: 72",      # 正确
    "48 + 24 = 70. Final Answer: 70",      # 错误
    "Let me think... 72. Final Answer: 72" # 正确但冗长
]

# 计算奖励(假设使用准确率 + 长度惩罚)
rewards = [1.0, 1.0, 0.0, 0.8]  # 第4个答案因为冗长被惩罚

# 计算组内平均奖励
avg_reward = (1.0 + 1.0 + 0.0 + 0.8) / 4 = 0.7

# 计算相对奖励
relative_rewards = [
    1.0 - 0.7 = 0.3,   # 正确且简洁,相对奖励为正
    1.0 - 0.7 = 0.3,   # 正确且简洁,相对奖励为正
    0.0 - 0.7 = -0.7,  # 错误,相对奖励为负
    0.8 - 0.7 = 0.1    # 正确但冗长,相对奖励较小
]

# 策略更新:增加前两个答案的概率,减少第三个答案的概率

3.4.2 损失函数

代码实现:

python 复制代码
import torch
import torch.nn.functional as F

def grpo_loss(
    log_prob,  # 当前策略的对数概率 (new_log_probs),形状: (group_size, ...)
    old_log_prob,  # 旧策略的对数概率 (old_log_probs),形状: (group_size, ...)
    rewards,  # 奖励值,在可验证奖励场景下通常是二值的 (0或1),形状: (group_size,)
    ref_log_prob,  # 参考策略的对数概率,形状: (group_size, ...)
    beta=0.1  # KL正则化系数
):
    # 1. 计算重要性比率 (importance ratio)
    #    注意:许多实现直接操作概率比
    ratio = torch.exp(log_prob - old_log_prob)
    
    # 2. 计算未裁剪的替代目标 (surrogate objective)
    #    优势 (Advantage) 通常通过组内奖励归一化得到
    #    adv = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
    #    为简化,这里直接使用归一化的奖励作为优势,计算相对奖励
    normalized_rewards = (rewards - rewards.mean())
    surrogate = ratio * normalized_rewards.unsqueeze(-1)
    
    # 3. 计算策略损失 (policy loss)
    policy_loss = -torch.mean(surrogate)
    
    # 4. 计算KL散度正则项 (KL penalty)
    #    KL(当前策略 || 参考策略) 的估计
    #    kl = torch.exp(log_prob) * (log_prob - ref_log_prob)
    kl_div = torch.mean(torch.exp(log_prob) * (log_prob - ref_log_prob))
    
    # 5. 总损失
    total_loss = policy_loss + beta * kl_div
    return total_loss, policy_loss, kl_div

四、模型蒸馏

4.1 原理

模型蒸馏是一种将大模型(Teacher Model)的知识迁移到小模型(Student Model)的技术。通过这种方式,可以让小模型学习到大模型的"暗知识(Dark Knowledge)",从而以较小的成本达到接近大模型的性能

这里的"暗知识"就是Teacher Model输出的"软标签(soft labels)",包含了类别之间的相似性关系,而不仅仅是哪个类别是正确的。例如,教师模型可能认为某个样本属于类别 狗 的概率为 78% ,类别 猫 的概率为 20% ,类别 车 的概率为 2%,可以看出狗和猫是具有一定的相似性的,而狗和车相似性很低。这种分布信息可以帮助学⽣模型更好地理解不同类别的关系

软标签是通过模型输出得到logits ,再除以Temperature ,再经过softmax 函数,即:

  • Temperature = 1,则输出更加"尖锐",极端情况下接近于one-hot编码,表示确定性很强
  • Temperature > 1,则输出更加"平缓",小概率类别也有非零值,有助于传递"暗知识"

如图所示:

  • Teacher Model:基于输入x得到 Temperature=t 的软标签
  • Student Model:输出包含两部分,一部分是基于输入x得到 Temperature=t 的软标签,另一部分是基于输入x得到 Temperature=1 的硬标签
  • Teacher Model与Student Model得到的软标签计算KL散度损失,++鼓励Stuent Model学习Teacher Model的分布++
  • Student Model计算出来的硬标签与真实标签计算交叉熵损失,确保++模型仍然能够正确预测真实标签++

基于上,总损失如下,其中是一个超参数,用于控制KL损失和交叉熵损失的权重:

注:KL损失乘以 是因为在反向传播时, logits 的梯度会受到Temperature的影响。如果不补偿,⾼温会导致梯度变⼩,训练变慢。 Hinton 在论⽂附录中证明:为了保持梯度幅度与 T=1 时相当,需要将 KL 损失乘以该值

4.2 优势

  • 学⽣模型不仅需要预测正确的类别,还需要模仿教师模型的输出分布,这有助于提⾼模型的泛化能⼒
  • 学⽣模型的结构通常⽐教师模型更简单,因此推理速度更快,计算成本更低

4.3 问题与解决方案

|-----------|--------------------------------|
| 问题 | 解决方案 |
| 教师太大,推理慢 | 使⽤缓存:提前⽣成所有 teacher logits 并保存 |
| 学⽣太⼩,⽆法拟合 | 增加投影层、使⽤更强的数据增强 |
| vocab不⼀致 | 截断或扩展 vocab ,保持⼀致 |
| 训练不稳定 | 使⽤梯度裁剪、低学习率、 EMA 平滑 |

参考

hello agent - hello-agents/docs/chapter11/第十一章 Agentic-RL.md

相关推荐
NAGNIP3 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab4 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab4 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP8 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年8 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼8 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS8 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
warm3snow9 小时前
Claude Code 黑客马拉松:5 个获奖项目,没有一个是"纯码农"做的
ai·大模型·llm·agent·skill·mcp
天翼云开发者社区10 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈10 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能