大模型核心技术解析

一、大模型核心组件

1.1 RMSNorm

RMSNorm是LayerNorm的一种特殊情况。如论文绿色标识部分,LayerNorm是weights进行重中心化和缩放,减少模型对噪声的敏感度,提升模型的稳定性和鲁棒性。RMSNorm和LayerNorm不同的是RMSNorm只需要做缩放(re-scaling)不需要重中心化(re-centering),其中是训练参数(初始化为1),维度为hidden_size

1.1.1 代码实现

python 复制代码
import torch.nn as nn
import torch

class RMSNorm(nn.Module):
    def __init__(self, dim:int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weights = nn.Parameter(torch.ones(dim)) # 初始化g_i
        
    def _norm(self, x):
        # torch.rsqrt计算的是根号分之一
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        return self.weights * self._norm(x.float()).type_as(x)

1.1.2 拓展-BatchNorm与LayerNorm的区别

  • BatchNorm:逐通道进行归一化,如对shape:[4, 3, 2, 2]的特征图进行归一化,会计算出3个均值和方差,每次会对4 * 2 * 2个元素进行归一化;适用CV领域
  • LayerNorm:对最后一个维度进行归一化,即对每个token的隐藏层做归一化,如shape:[4, 10, 768]的张量,会计算出4 * 10个均值和方差,每次会对768个元素进行归一化;适用NLP领域

1.2 RoPE

旋转位置编码是一种相对位置编码,不同于绝对位置编码在self-attention之前将预计算好的位置编码加到embedding向量上,旋转位置编码是嵌入到self-attention计算内,具体的是先对query和key进行旋转,然后再做

如图所示,向量内两两进行组合应用旋转,旋转的角度为,其中m为token的位置索引,

对query或者对key进行旋转可表示为:

14式 结合15式展开转化可写成:

其中 都可预先计算16式 为query和key应用旋转后再做矩阵运算的形式,可以看到旋转位置编码有n-m表现相对位置的关系

1.2.1 代码实现

1.2.1.1 预计算sin与cos
python 复制代码
def precompute_freqs_cis(dim: int, max_seq_len: int = int(32*1024), theta: float = 1e6):
    """
    预计算sin(m * \theta)与cos(m * \theta)
    """
	# 计算公式中的theta:1/[1000000^(2(i-1)/d)]
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # 定义m
    m = torch.arange(end, device=freqs.device)
    # 得到的是 m*theta, shape: (max_seq_len, dim // 2)
    freqs = torch.outer(m, freqs).float()
    # cos(m*theta), 堆叠shape:(max_seq_len, dim)
    freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
    # sin(m*theta), 堆叠shape:(max_seq_len, dim)
    freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
    return freqs_cos, freqs_sin
1.2.1.2 应用RoPE

对query和key应用旋转位置编码

python 复制代码
def rotate_half(x):
    """
	将最后维度切割成两部分,构造出x_1和x_{d/2 + 1}旋转
	有别于论文中x1和x2旋转,x3和x4旋转,以此类推...
	如:最后一个维度为10,则x_1和x_6进行旋转
    x_1 * cosm\theta_1    		-	x_(d/2 + 1) * sinm\theta_1
    ...								...
    x_(d/2 + 1)* cosm\theta_1	+	x_1 * sinm\theta_1
	"""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)
	
	
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """
	应用旋转位置编码
	query和key shape: (batch_size, seq_len, num_of_head, head_dim)
	cos和sin   shape: (seq_len, head_dim)
	"""
	# unsqueeze后shape: (seq_len, 1, head_dim)
	# 在计算的时候再进行broadcast进行维度的补齐
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

1.3 Self-Attention

自注意机制是Transfomer的核心,能并行计算token与token之间的相关度,注意力计算公式如下:

CausalLM(自回归模型,仅有decoder)在计算token相关度时,当前token只能看到"自己"和"之前"的token,所以需要将当前token之后的注意力分数进行mask,在计算softmax时,对应的注意力分数就是无限接近于0,图示如下:

注意力机制也分为几种,有:

  • MHA(Multi-Head Attention):query head与key head、value head是一一对应的
  • GQA(Grouped-Query Attention):对query head进行分组,组内query head对应一个key head、value head
  • MQA(Multi-Query Attention):所有的query head对应一组key head、value head

其中,MHA的效果最佳,但是最占显存(kv-cache),MQA的效果是最差的,但最节省显存,GQA效果介于MHA和MQA之间,在计算效果与显存压力之间找到了平衡。假设2个query head对应1个key head/value head,则kv-cache占用的显存可节省一半

1.3.1 代码实现

python 复制代码
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
	"""
	如果是GQA模式,则需要扩充key head和value head
	让query head和key head、value head数量是一样的
	"""
    bs, slen, num_key_value_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :] # python的基础语法
        .expand(bs, slen, num_key_value_heads, n_rep, head_dim)
        .reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
    )


class Attention(nn.Module):
    def __init__(self, args):
        super().__init__()
        # num_attention_heads:是多头注意力中多少个heads,即query head的数量
        # num_key_value_heads:是key/value head
        self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
        # 验证一下参数是否合理,需要整除
        assert args.num_attention_heads % self.num_key_value_heads == 0
        # 换个名字
        self.n_local_heads = args.num_attention_heads
        self.n_local_kv_heads = self.num_key_value_heads
        # 计算grouped query的分组
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.hidden_size // args.num_attention_heads
        # 设置多头注意力的一些W矩阵
        self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout

    def forward(self,
                x: torch.Tensor,
                position_embeddings, # 接收预先计算的 cos 和 sin
                past_key_value = None, # 之前时刻的 K 和 V
                use_cache= False,
                attention_mask = None):
        bsz, seq_len, _ = x.shape
        xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
        # reshape
        xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)

        pre_cos, pre_sin = position_embeddings
        # 在 Q 和 K 身上应用 ROPE
        xq, xk = apply_rotary_pos_emb(xq, xk, pre_cos[:seq_len], pre_sin[:seq_len])

        # 关于 kv_cache
        if past_key_value is not None:
            xk = torch.cat([past_key_value[0], xk], dim=1)
            xv = torch.cat([past_key_value[1], xv], dim=1)

        if use_cache:
            past_kv = (xk, xv)
        else:
            past_kv = None

        xq, xk, xv = (
            xq.transpose(1, 2),
            repeat_kv(xk, self.n_rep).transpose(1, 2),
            repeat_kv(xv, self.n_rep).transpose(1, 2),
        )

        # 使用 self-attention 公式
        scaled_scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
        # scores + mask
        look_ahead_mask = torch.triu(
            torch.full((seq_len, seq_len), float('-inf'), device=scaled_scores.device), diagonal=1
        )
        masked_scores = (scaled_scores + look_ahead_mask).unsqueeze(0).unsqueeze(0)

        if attention_mask is not None:
            extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            # attention_mask 中的 0 值会变成非常小的负数 -1e9
            # 将 1 保持为 0 ,这样做在后续的 softmax 操作中,这些非常小的负数值会接近零
            # 从而在 softmax 之后几乎为零,实现忽略这些位置的效果
            extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
            masked_scores = masked_scores + extended_attention_mask

        scores = F.softmax(masked_scores.float(), dim=-1).type_as(xq)
        scores = self.attn_dropout(scores)
        output = scores @ xv

        output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
        output = self.resid_dropout(self.o_proj(output))
        return output, past_kv

1.3.2 拓展

  1. 在计算注意力时,为什么要对注意力分数进行缩放,即除以 ?

缩放时为了保证训练的稳定性,如果不进行缩放,会产生如下问题:

  • 注意力分数:query和key的点积实际上是每个元素的相乘再相加,如果隐藏层的维度d很大,则计算点积有些元素的值会极大,再进行softmax后有些注意力分数会无限接近于1,而其他注意力分数会无限接近于0
  • 梯度消失:softmax在输入值极大时,梯度会极小,在反向传播时,模型几乎无法更新参数,学习停滞(图示红框为softmax偏导

1.4 FFN

FFN(Feed Forword Network)是对隐藏层进行维度扩充后再进行维度还原。

1.4.1 代码实现

python 复制代码
from transformers.activations import ACT2FN

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        if config.intermediate_size is None:
            intermediate_size = int(config.hidden_size * 8 / 3)
            config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.dropout = nn.Dropout(config.dropout)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        return self.dropout(self.down_proj(self.up_proj(x) * self.act_fn(self.gate_proj(x))))

二、大模型MoE结构

2.1 概述

MoE(Mixture of Experts)混合专家模型主要应用在FFN,即每个expert都是一个FFN。序列中的每个token会和每个expert计算出一个分数,取分数最高的topk个expert进行前馈计算,计算之后需要乘以对应的分数,然后把topk个专家得出的结果再相加,如图所示为k=1的情况:

传统MoE中,每个专家本身还是一个较大的神经网络,很容易学习到通用、重复的知识,导致多个专家干同样的事,造成参数浪费。在DeepseekMoE的论文(2401.06066)中,将MoE的专家数进行了细化(图b),处理特定的,细分领域的知识,专家更专;同时剥离出共享专家(图c),处理共用的通用知识

2.2 代码实现

在代码实现过程中,通过抽取出MoEGate 来处理token与每个专家的计算(本质是做线性运算),得到每个token选取的topk个专家的索引权重(softmax后的值),同时计算出aux loss,aux loss分为两个级别:

  • sequence级别:将序列看成一个整体,如果某条序列的token都选取了同一个expert,那么惩罚这条序列,鼓励其使用其他的多个expert
  • token级别:统计每个token选择了哪些expert,如果大部分token都选择了同一个expert,则进行惩罚,鼓励选择其他expert

通过定义MoEFeedForward 封装整个MoE的计算,包含通过MoEGate来为每个token选出topk个expert,然后进行前馈计算,得到结果再乘以对应的权重。如果有共享expert 的话,则还需要和共享expert进行前馈计算,最终将topk个专家计算结果和共享专家**(无需乘权重)**计算结果进行相加

python 复制代码
class MoEGate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.top_k = config.num_experts_per_tok
        self.n_routed_experts = config.n_routed_experts # 表示总的可选专家数量

        self.scoring_func = config.scoring_func # 选择使用哪种评分方式(一般就是'softmax')
        # 为了让MoE表现的更均衡,我们可以设置关于MoE的权重,回头加到total loss身上
        self.alpha = config.aux_loss_alpha  # 控制辅助损失项的权重
        self.seq_aux = config.seq_aux  # 计算关于MOE是否balance的损失时有两种方式(1,token level;2,sequence level)
        
        self.norm_topk_prob = config.norm_topk_prob # 是否对 topK 的概率进行归一化
        self.gating_dim = config.hidden_size # 输入向量的维度
        # 定义一个可学习的门控矩阵,形状为 [n_routed_experts, hidden_size]
        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
        # 调用 初始化函数对上面这个 weight 进行初始化
        self.reset_parameters()

    def register_parameter(self):
        import torch.nn.init as init
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, hidden_states):
        # 这块是核心逻辑,输入是一个batch的隐藏状态,输出是每个token的专家分配结果和辅助损失
        bsz, seq_len, h = hidden_states.shape
        # 把输入展平成二维数组,方便处理每个token; 二维数组对应的形状就是 [bsz*seq_len, hidden_size]
        hidden_states = hidden_states.view(-1, h)
        # 计算每个token对每个专家expert的原始分数logits,形状是 [total_tokens, n_routed_experts]
        logits = F.linear(hidden_states, self.weight, None)
        if self.scoring_func == 'softmax':
            scores = logits.softmax(dim=-1)
        else:
            raise NotImplementedError(f'unsupportted scoring fucntion for MOE gating: {self.scoring_func}')
        
        # 对每个token,在expert维度上选出 topK 个得分最高的专家
        topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

        # 是否启用了norm_topk_prob,对topK的权重做归一化,使其总和为1,防止除以零加一个小数值
        if self.top_k > 1 and self.norm_topk_prob:
            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
            topk_weight = topk_weight / denominator

        # 如果处于训练模式并且启用了辅助损失,则开始构建辅助损失项
        if self.training and self.alpha > 0.0:
            scores_for_aux = scores # 所有expert的得分,也就还没取topK
            aux_topk = self.top_k
            topk_idx_for_aux_loss = topk_idx.view(bsz, -1) # 展平之后的topK专家索引

            if self.seq_aux:
                # 按照sequence级别计算辅助损失
                # 每条sequence看作一个整体,如果某条sequence所有token都只用了expert 0,那么则惩罚这条sequence,鼓励其使用其它多个expert
                scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
                # 构建一个专家被选择的频率矩阵 ce
                ce = torch.zeros((bsz, self.n_routed_experts), device=hidden_states.device)
                # 使用 scatter_add_ 来统计每个batch每个expert被选中了多少次
                ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len* aux_topk, 
                                                                     device=hidden_states.device)).div_(
                                                                         seq_len * aux_topk / self.n_routed_experts
                                                                     )
                # 然后做一个平均,并且与平均得分相乘,作为辅助损失
                # 目的是防止某些expert被频繁选中,造成负载不均
                aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
            else:
                # 按照token级别计算辅助损失
                # 分布统计每个token选择了哪个expert,如果大部分token都选择expert 0,则惩罚它,鼓励选择其它expert
                mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
                ce = mask_ce.float().mean(0)
                # 计算每个expert 的平均得分
                Pi = scores_for_aux.mean(0)
                # 计算每个expert被选中的频率
                fi = ce * self.n_routed_experts
                # 辅助损失是两者相乘的结果
                aux_loss = (Pi * fi).sum() * self.alpha


        # topk_idx: 每个token被分配到 topK 个 expert 的索引
        # topk_weight: 每个 expert 对应的权重
        # aux_loss: 辅助损失项,用于平衡专家之间的负载
        return topk_idx, topk_weight, aux_loss


# 定义MOEFeedForward类
class MOEFeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.experts = nn.ModuleList([
            FeedForward(config)
            for _ in range(config.n_routed_experts)
        ])
        self.gate = MoEGate(config)
        if config.n_shared_experts > 0:
            self.shared_experts = nn.ModuleList([
                FeedForward(config)
                for _ in range(config.n_shared_experts)
            ])

    def forward(self, x):
        identity = x  # 做 skip connection
        orig_shape = x.shape
        bsz, seq_len, _ = x.shape
        # 使用门控机制专家的选择
        topk_idx, topk_weight, aux_loss = self.gate(x)
        x = x.view(-1, x.shape[-1])
        flat_topk_idx = topk_idx.view(-1)

        if self.training:
            # 对每个token,复制 num_experts_per_tok 多份,
            # 这样做的目的是为了将每个token同时传入其top-K个被选中的专家里面进行计算
            x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
            # 创建一个与x形状相同但是类型为 float16 的空张量,用于存储每个token经过对应专家处理后的结果
            y = torch.empty_like(x, dtype=torch.float16)
            for i, expert in enumerate(self.experts):
                # flat_topk_idx 是一个索引张量,表示每个token被分配给了哪个专家
                y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype)
            # 将输出按照token和专家维度重新组织
            # 使用 topk_weight 权重对每个专家的输出进行加权求和
            y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
            # 把最终输出恢复成原始输入的形状
            y = y.view(*orig_shape)
        else:
            # 在推理阶段使用更高效的函数 moe_infer 处理 MOE 部分
            # 通常是为了减少内存冗余或计算冗余,例如合并多个token,一起处理
            y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
        
        # 如果启用了共享专家,它们会作用在所有的token上
        if self.config.n_shared_experts > 0:
            for expert in self.shared_experts:
                y = y + expert(identity)

        # 通常这个损失会加到 total_loss = task_loss + config.aux_loss_coeff * model.aux_loss
        self.aux_loss = aux_loss

        return y

    @torch.no_grad()
    def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
        expert_cache = torch.zeros_like(x)
        idxs = flat_expert_indices.argsort()
        tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
        # tokens_per_expert = [6, 15, 20, 26] 这四个数值分别代表4个专家处理的token数量
        tokens_idxs = idxs // self.config.num_experts_per_tok
        # token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 代表着 token_idxs[:6]
        # 属于0号专家的;每个token有可能被多个专家处理,取决于 config.num_experts_per_tok

        for i, end_idx in enumerate(tokens_per_expert):
            # 计算当前专家处理token的起始索引
            start_idx = 0 if i==0 else tokens_per_expert[i-1]
            # 如果没有token被分配给这个专家,跳过该专家
            if start_idx == end_idx:
                continue
            expert = self.experts[i]
            exp_token_idx = tokens_idxs[start_idx:end_idx]
            # 从原始的输入x中获取这些token的嵌入
            expert_tokens = x[exp_token_idx]
            # 输入到当前专家网络中进行前向传播;
            expert_out = expert(expert_tokens).to(expert_cache.dtype)
            # 对专家输出进行加权
            expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
            # 使用 scatter_add_ 将专家输出加到最终的输出张量上面去,加权之后的求和
            expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)

        return expert_cache

三、模型训练

3.1 预训练

预训练(Pre-train,简称PT)是直接让模型在语料库上进行自监督的训练,目的是为了让模型学到人类语言的语法、语义和世界知识。

大模型是生成模型,是依据前面的t个token,生成t+1时刻的token,所以在预训练样本的构造中,特征和label之间需要偏置一个token,例如有语料【我爱你。】,则特征的部分是【我爱你】,label的部分是【爱你。】

3.2 SFT

SFT(Supervised Fine-Tuning)有别于预训练阶段的自监督训练,SFT训练是需要带label的样本,是有监督训练。预训练得到的模型拥有丰富的语言知识和模式识别能力,SFT的目的是让模型能够准确的理解人类的指令和需求,按照明确的指示完成任务

现阶段SFT的方式有全参微调(FT)和高效微调(PEFT),PEFT算法又包含如Prompt-tuning、P-tuning、Adapter、LoRA、QLoRA等算法,本文着重介绍LoRA(QLoRA原理差不多)

LoRA的核心思想是冻结 原始模型的权重,并为模型的线性层(如注意力计算的投影矩阵)注入可训练 的**低秩分解矩阵,**这种方法以极少的参数量,实现媲美全参微调的性能,而不增加推理的耗时。假设一个线性层可训练参数矩阵的shape为(100, 100),则参数量为10000,其可以分解为两个低秩矩阵的乘积,假设秩r=10,则可训练参数量仅为100 * 10 * 2 = 2000,可大大节省训练显存

具体到训练样本构造方面,通常有两部分,一部分是prompt,另一部分是AI回答部分。在计算损失时,prompt部分的损失通常通过loss_mask进行mask掉,只计算AI回答部分的损失。

3.3 DPO

DPO(Direct Preference Optimization)是一种人类偏好对齐算法,与PPO(Proximal Policy Optimization)两阶段训练(需要先基于人类打标偏好数据训练出奖励模型;再使用奖励模型给策略模型的输出进行打分迭代训练)不同的是,DPO是端对端的,直接最小化损失函数

相较于PPO优势在于:

  • 计算效率高:不需要进行策略梯度估计和价值模型的学习
  • 训练过程简单:可以直接在单词前向传播计算损失
  • 计算开销小:无需额外的网络结构和训练步骤

3.3.1 损失函数解读

  • 是策略模型,是参考模型
  • 是chosen的输出,是rejected的输出
  • 是超参数,论文中的值为0.1;是sigmoid函数

基于上述符号的含义与对数函数的性质对公式进行变换:

=

= ········ ①

代码部分也是基于式①进行实现的,其中具体的 实现为:

  • 模型输出logits,对logits的最后一个维度进行log_softmax
  • 依据给出的label(input_ids),选取label对应位置的概率值probs
  • probs乘以loss_mask,在计算损失的时候只关注loss_mask为1(prompt位置为0,assistant位置为1)位置的损失值

3.3.2 损失代码实现

python 复制代码
import torch.nn.functional as F

def dpo_loss(policy_chosen_logps,    # 策略模型对chosen输出的对数概率
             policy_rejected_logps,  # 策略模型对rejected输出的对数概率
             reference_chosen_logps, # 参考模型对chosen输出的对数概率
             reference_rejected_logps, # 参考模型对rejected输出的对数概率
             beta=0.1):

    # 1. 计算策略模型与参考模型的对数概率差(隐式奖励)
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps

    # 2. 利用beta缩放差值,并计算损失
    # 对应公式:logits = beta * (pi_logratios - ref_logratios)
    logits = pi_logratios - ref_logratios
    # 使用二元交叉熵,目标是希望chosen优于rejected(标签为1)
    losses = -F.logsigmoid(beta * logits)

    return losses.mean()

3.4 GRPO

GRPO(Group Relative Policy Optimization)是一种简化的 PPO 变体,专门为 LLM 设计。GRPO 的核心思想是:不需要 Value Model,使用组内相对奖励代替绝对奖励;简化训练流程,只需要 Policy Model 和 Reference Model;提高训练稳定性,减少奖励崩塌的风险

3.4.1 训练过程解析

  1. **采样阶段:**对于每个问题,使用当前策略生成多个答案。这些答案构成一个组,用于计算相对奖励
  2. 奖励计算: 对每个生成的答案计算奖励。奖励可以是准确性奖励、答案长度奖励、步骤奖励或它们的组合
  3. 相对奖励: 计算组内平均奖励 ,然后计算相对奖励,这样做的好处是减少奖励方差,使训练更稳定
  4. **策略更新:**使用相对奖励更新策略,同时添加KL散度惩罚,防止策略偏离参考模型太远
  5. **重复:**重复上述步骤,直到完成所有训练轮次
python 复制代码
# 假设我们有一个问题
question = "What is 48 + 24?"

# 生成4个答案
answers = [
    "48 + 24 = 72. Final Answer: 72",      # 正确
    "48 + 24 = 72. Final Answer: 72",      # 正确
    "48 + 24 = 70. Final Answer: 70",      # 错误
    "Let me think... 72. Final Answer: 72" # 正确但冗长
]

# 计算奖励(假设使用准确率 + 长度惩罚)
rewards = [1.0, 1.0, 0.0, 0.8]  # 第4个答案因为冗长被惩罚

# 计算组内平均奖励
avg_reward = (1.0 + 1.0 + 0.0 + 0.8) / 4 = 0.7

# 计算相对奖励
relative_rewards = [
    1.0 - 0.7 = 0.3,   # 正确且简洁,相对奖励为正
    1.0 - 0.7 = 0.3,   # 正确且简洁,相对奖励为正
    0.0 - 0.7 = -0.7,  # 错误,相对奖励为负
    0.8 - 0.7 = 0.1    # 正确但冗长,相对奖励较小
]

# 策略更新:增加前两个答案的概率,减少第三个答案的概率

3.4.2 损失函数

代码实现:

python 复制代码
import torch
import torch.nn.functional as F

def grpo_loss(
    log_prob,  # 当前策略的对数概率 (new_log_probs),形状: (group_size, ...)
    old_log_prob,  # 旧策略的对数概率 (old_log_probs),形状: (group_size, ...)
    rewards,  # 奖励值,在可验证奖励场景下通常是二值的 (0或1),形状: (group_size,)
    ref_log_prob,  # 参考策略的对数概率,形状: (group_size, ...)
    beta=0.1  # KL正则化系数
):
    # 1. 计算重要性比率 (importance ratio)
    #    注意:许多实现直接操作概率比
    ratio = torch.exp(log_prob - old_log_prob)
    
    # 2. 计算未裁剪的替代目标 (surrogate objective)
    #    优势 (Advantage) 通常通过组内奖励归一化得到
    #    adv = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
    #    为简化,这里直接使用归一化的奖励作为优势,计算相对奖励
    normalized_rewards = (rewards - rewards.mean())
    surrogate = ratio * normalized_rewards.unsqueeze(-1)
    
    # 3. 计算策略损失 (policy loss)
    policy_loss = -torch.mean(surrogate)
    
    # 4. 计算KL散度正则项 (KL penalty)
    #    KL(当前策略 || 参考策略) 的估计
    #    kl = torch.exp(log_prob) * (log_prob - ref_log_prob)
    kl_div = torch.mean(torch.exp(log_prob) * (log_prob - ref_log_prob))
    
    # 5. 总损失
    total_loss = policy_loss + beta * kl_div
    return total_loss, policy_loss, kl_div

四、模型蒸馏

4.1 原理

模型蒸馏是一种将大模型(Teacher Model)的知识迁移到小模型(Student Model)的技术。通过这种方式,可以让小模型学习到大模型的"暗知识(Dark Knowledge)",从而以较小的成本达到接近大模型的性能

这里的"暗知识"就是Teacher Model输出的"软标签(soft labels)",包含了类别之间的相似性关系,而不仅仅是哪个类别是正确的。例如,教师模型可能认为某个样本属于类别 狗 的概率为 78% ,类别 猫 的概率为 20% ,类别 车 的概率为 2%,可以看出狗和猫是具有一定的相似性的,而狗和车相似性很低。这种分布信息可以帮助学⽣模型更好地理解不同类别的关系

软标签是通过模型输出得到logits ,再除以Temperature ,再经过softmax 函数,即:

  • Temperature = 1,则输出更加"尖锐",极端情况下接近于one-hot编码,表示确定性很强
  • Temperature > 1,则输出更加"平缓",小概率类别也有非零值,有助于传递"暗知识"

如图所示:

  • Teacher Model:基于输入x得到 Temperature=t 的软标签
  • Student Model:输出包含两部分,一部分是基于输入x得到 Temperature=t 的软标签,另一部分是基于输入x得到 Temperature=1 的硬标签
  • Teacher Model与Student Model得到的软标签计算KL散度损失,++鼓励Stuent Model学习Teacher Model的分布++
  • Student Model计算出来的硬标签与真实标签计算交叉熵损失,确保++模型仍然能够正确预测真实标签++

基于上,总损失如下,其中是一个超参数,用于控制KL损失和交叉熵损失的权重:

注:KL损失乘以 是因为在反向传播时, logits 的梯度会受到Temperature的影响。如果不补偿,⾼温会导致梯度变⼩,训练变慢。 Hinton 在论⽂附录中证明:为了保持梯度幅度与 T=1 时相当,需要将 KL 损失乘以该值

4.2 优势

  • 学⽣模型不仅需要预测正确的类别,还需要模仿教师模型的输出分布,这有助于提⾼模型的泛化能⼒
  • 学⽣模型的结构通常⽐教师模型更简单,因此推理速度更快,计算成本更低

4.3 问题与解决方案

|-----------|--------------------------------|
| 问题 | 解决方案 |
| 教师太大,推理慢 | 使⽤缓存:提前⽣成所有 teacher logits 并保存 |
| 学⽣太⼩,⽆法拟合 | 增加投影层、使⽤更强的数据增强 |
| vocab不⼀致 | 截断或扩展 vocab ,保持⼀致 |
| 训练不稳定 | 使⽤梯度裁剪、低学习率、 EMA 平滑 |

参考

hello agent - hello-agents/docs/chapter11/第十一章 Agentic-RL.md

相关推荐
wjykp1 分钟前
part6 PyTorch
人工智能·pytorch·python
liliangcsdn6 分钟前
DDIM扩散模型改进采样策略的推理探索
人工智能·深度学习·自然语言处理
读创商闻11 分钟前
中广融投让传统文化 “活” 起来
大数据·网络·人工智能
长桥夜波11 分钟前
【第二十五周】机器学习笔记
人工智能·笔记·机器学习
智驱力人工智能12 分钟前
在安全与尊严之间 特殊人员离岗检测系统的技术实现与伦理实践 高风险人员脱岗预警 人员离岗实时合规检测 监狱囚犯脱岗行为AI分析方案
人工智能·深度学习·opencv·算法·目标检测·cnn·边缘计算
Pith_12 分钟前
模式识别与机器学习复习笔记(上)
人工智能·笔记·机器学习
大任视点14 分钟前
云南首家现代农事综合服务中心在普洱思茅落户
大数据·人工智能
融云14 分钟前
融云 2025 回顾:「韧性」生长,「邪修」破局
人工智能·融云im·im选型
qdprobot15 分钟前
开源的在线串口调试助手支持macOS苹果电脑Windows系统Linux 浏览器webSerial
linux·运维·服务器·人工智能·mixly·小智ai·webserial