DeepSeek V3 源码:从入门到放弃!

从入门到放弃

花了几天时间,看懂了DeepSeek V3 源码的逻辑。源码的逻辑是不难的,但为什么模型结构需要这样设计,为什么参数需要这样设置呢?知其然,但不知其所以然。除了模型结构以外,模型的训练数据、训练脚本和训练经验,也是DeepSeek V3能够训练出来的关键,但这些是DeepSeek母公司的核心机密,我们无从得知。

因此,看懂了源码,算是入门 了DeepSeek V3,因为没有条件知道更多重要细节,因此不得不放弃重现整个模型的训练。

Paper 和源码

Paper URL: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf

Code URL: https://github.com/deepseek-ai/DeepSeek-V3

B站大牛详解视频链接

https://www.bilibili.com/video/BV1RtNLeqEeu/

模型逻辑

下面这张图,代表了DeepSeek的核心逻辑。左边是Transformer的逻辑结构,可以认为有N个左边这样的Block结构不断重复,组成Transformer模型。每个Block中,分成两个部分,Attention 和 Feed-Forward Network。对这两个部分使用不同的网络结构,我们就得到了不同的模型。

DeepSeek V3 的 Attention 用的是 Multi-Head Latent Attention(MLA) ,Feed-Forward Network 用的是DeepSeekMoE。

MLA

Multi-Head Latent Attention(MLA)即多头潜在注意力,是DeepSeek模型中引入的一种创新注意力机制,旨在优化传统多头注意力(Multi-Head Attention,MHA)的计算效率和内存占用。具体介绍如下:

核心创新点

  • 低秩键值压缩
    • KV的低秩压缩:不直接存储原始的Key和Value,而是先将隐藏状态投影到一个更小的压缩潜在向量。在推理时,只需缓存该压缩潜在向量,而不是完整的Key和Value,从而大大降低了KV缓存的存储需求。
    • Query的低秩压缩:对Query也进行低秩压缩,虽然不会减少KV缓存的大小,但可以减少训练时的激活存储需求,进而降低计算成本。
  • 解耦旋转位置嵌入(RoPE)
    • 额外引入"解耦查询":将查询拆分为两个部分,一部分不经过RoPE变换,代表非位置敏感的特征信息;另一部分专门用于嵌入RoPE位置编码信息。
    • 共享RoPE变换的Key:所有注意力头共用一个旋转变换后的Key,减少了计算开销,也减小了KV缓存大小,降低了GPU内存占用,提高了推理速度,特别适用于长序列任务和大规模Transformer。

推理过程中的优化

将上投影矩阵吸收到里面,简化查询计算,并优化注意力分数的计算,减少了计算步骤,提升了计算效率。避免了先计算Value向量,减少了矩阵运算的开销,使推理更快。

整体优势

  • 降低内存占用:通过对键值进行低秩联合压缩以及解耦RoPE等策略,显著减少了KV缓存的存储需求,降低了GPU内存占用。
  • 提高计算效率:减少了训练和推理过程中的计算量,加快了模型的推理速度,在保持甚至提高模型性能的同时,提升了模型的运行效率。
  • 增强模型适应性:特别适用于长序列任务和大规模Transformer模型,能够更好地处理长序列输入,提高模型在各种自然语言处理任务中的表现。

MLA 有物理意义吗?

Multi-Head Latent Attention(MLA)能够起作用主要源于其独特的技术设计,在数学和信息处理层面有清晰的逻辑,不过它是一种抽象的算法概念,并不直接对应具体的物理意义,以下是对其作用原理的分析:

起作用的原因

  • 低秩压缩的有效性
    • 信息浓缩与降噪:通过低秩键值压缩,MLA将高维的Key和Value信息投影到低维的潜在向量空间,这一过程类似于对原始信息进行浓缩,提取出最关键、最具代表性的特征,去除了一些可能的噪声和冗余信息,使得模型能够更聚焦于重要信息,从而提高信息处理的效率和准确性。
    • 减少计算量和存储需求:低秩压缩大大降低了数据的维度,减少了模型训练和推理过程中的计算量和存储需求,使得模型能够更高效地运行,尤其是在处理大规模数据和长序列数据时,这种优势更为明显。
  • 解耦旋转位置嵌入的优势
    • 位置信息与内容信息的分离:传统的位置编码方式将位置信息和内容信息混合在一起进行处理,而MLA的解耦旋转位置嵌入将查询拆分为位置敏感和非位置敏感两部分,使模型能够更清晰地分离和处理位置信息与内容信息,更好地捕捉文本中的长距离依赖关系。
    • 共享RoPE变换的Key:所有注意力头共用一个旋转变换后的Key,不仅减少了计算开销,还使得模型能够从更宏观的角度利用位置信息,增强了模型对序列数据整体结构的理解和把握能力。
  • 多头机制的协同作用
    • 捕捉多维度信息:MLA中的多头机制允许模型同时从多个不同的角度和维度去捕捉输入数据中的信息,每个头可以关注到输入序列的不同方面,通过多个头的并行计算和协同工作,模型能够更全面、更深入地理解输入数据,提高模型的表示能力和泛化能力。

难以直接赋予物理意义的原因及近似理解

  • 抽象的算法概念:MLA是一种基于数学和计算机科学的算法概念,主要用于处理和分析数据中的模式和关系,它不像物理概念那样具有直接可观测的物理实体或现象与之对应,更多地是在数据空间和计算逻辑中发挥作用。
  • 类比物理现象理解:可以进行一些类比来帮助理解。比如低秩压缩类似于物理中的能量聚集,将分散的能量(信息)聚集到关键的"点"上;解耦旋转位置嵌入有点像物理中对不同性质力的分解,将位置信息和内容信息这两种"力"分开处理;多头机制如同多个物理传感器从不同方向和角度对环境进行感知,然后综合这些感知信息来对整个系统进行理解和判断。

DeepSeekMoE

DeepSeekMoE是由深度求索(DeepSeek)研发的基于混合专家系统(Mixture of Experts,MoE)的技术架构,以下是具体介绍:

架构原理

  • 混合专家系统核心:采用MoE架构,核心在于通过动态路由机制,把输入数据分配给最相关的专家处理。比如在自然语言处理中,有的专家专门处理情感分析,有的处理主题建模。
  • 结合多头潜在注意力机制:与MLA相结合,MLA通过引入潜在向量,减少键值缓存(KV cache)需求,提升推理效率。
  • Transformer架构基础:以Transformer架构为基础,每个Transformer块由一个注意力模块和一个前馈网络(FFN)组成,在注意力机制和FFN方面采用创新架构。

技术优势

  • 降低算力需求:MoE的动态分配机制和MLA减少KV缓存需求等特点,使模型在训练和推理时对算力的要求降低。
  • 保持高性能:在参数量减少的情况下仍能保持高性能,例如DeepSeek-V2以236B总参数、21B激活,大致可以达到70B-110B Dense的模型能力。
  • 减少计算量:自研Sparse结构DeepSeekMoE进一步降低了计算量。
  • 长上下文理解能力强:支持超100万token的上下文窗口,显著优于行业平均水平,适用于长文档分析、代码开发等复杂场景的连贯交互。

DeepSeekMoE的物理意义是什么?

DeepSeekMoE作为一种人工智能技术架构,没有严格意义上的物理意义,但可以从一些角度进行类比和理解:

从系统资源分配角度

  • 资源按需分配类比:可以将DeepSeekMoE的专家网络和动态路由机制类比为一个智能电力分配系统。在这个系统中,不同的电器设备(任务)需要不同的电量(计算资源)来运行。专家网络就像不同功率的发电机,而动态路由机制则像是智能电表和分配器,它会根据每个电器设备的实际需求,将电力(计算资源)精准地分配给需要的设备,避免了资源的浪费,提高了整个系统的能源利用效率。
  • 负载均衡类比:类似于在一个大型物流中心,不同的仓库区域(专家)负责存储和处理不同类型的货物(数据)。当有货物运输任务时,调度系统(动态路由)会根据货物的特点和仓库的负载情况,合理地安排货物存储到哪个仓库,确保每个仓库都能在其承载能力范围内高效运作,不会出现某个仓库过度拥挤而其他仓库闲置的情况,实现了负载均衡,提高了物流中心的整体运营效率。

从信息处理角度

  • 多维度信息处理类比:可以把DeepSeekMoE处理信息的过程想象成一个由多个不同专业的侦探(专家)组成的侦探团队在调查一个复杂案件。每个侦探都有自己独特的专业技能和视角,比如有的擅长调查线索,有的擅长分析人物关系,有的擅长破解密码等。当面对案件(输入数据)时,队长(路由器)会根据案件的具体情况,分配合适的侦探去处理相应的部分,最后将各个侦探的调查结果综合起来,形成对整个案件的全面了解和判断,从而更高效地解决复杂问题。
  • 特征提取与融合类比:如同在一个化学实验中,不同的化学试剂(专家)可以与不同的物质发生反应,提取出特定的化学特征。DeepSeekMoE中的专家网络就像这些化学试剂,它们各自对输入数据进行处理,提取出不同的特征。然后通过融合机制,将这些特征像混合化学物质一样进行整合,得到更全面、更有价值的信息,用于后续的分析和决策。

从模型架构角度

  • 积木搭建类比:把DeepSeekMoE的架构比作搭建积木。每个专家网络就像不同形状和功能的积木块,有的积木块负责搭建基础结构,有的负责构建上层建筑,有的负责添加装饰等。路由器则像是搭建者的手,根据要搭建的目标模型的需求,选择合适的积木块进行组合,最终搭建出一个复杂而功能强大的模型结构,实现对各种自然语言处理任务的高效处理。
  • 人体神经系统类比:可以将DeepSeekMoE类比为人体的神经系统。专家网络类似于人体的不同神经细胞或神经中枢,它们各自负责处理特定类型的信息,如视觉神经细胞负责处理视觉信息,听觉神经细胞负责处理听觉信息等。路由器就像神经系统中的神经递质或信号传导机制,它负责将外界的刺激信号(输入数据)准确地传递给相应的神经细胞,并将各个神经细胞处理后的信号进行整合和传递,使人体能够做出协调的反应和决策,实现对外部世界的感知和交互。

代码逻辑

整体 - Transformer

下面这段代码是典型的 Transformer 实现,核心可以看 forward 函数逻辑:

  1. 进行 Embeding;
  2. 经过各个 Block;
  3. 归一化并输出。
    对应的代码:
dart 复制代码
# 通过嵌入层将输入标记转换为向量表示
h = self.embed(tokens)
# 依次通过每个Transformer块进行处理
for layer in self.layers:
    h = layer(h, start_pos, freqs_cis, mask)
# 对输出进行层归一化,并取最后一个时间步的输出
h = self.norm(h)[:, -1]
# 通过输出投影层得到对数概率
logits = self.head(h)

完整代码:

python 复制代码
# 定义Transformer类,继承自PyTorch的nn.Module类
class Transformer(Module):
    """
    Transformer模型,包含位置嵌入、多个层以及输出投影。

    属性:
        max_seq_len (int): Transformer允许的最大序列长度。
        embed (nn.Module): 用于输入标记的嵌入层,将输入的标记转换为向量表示。
        layers (torch.nn.ModuleList): 存储多个Transformer块的列表,每个块包含多头注意力和前馈网络。
        norm (nn.Module): 层归一化层,在所有Transformer块之后应用,用于稳定训练。
        head (nn.Module): 输出投影层,将模型的输出映射到词汇表大小,用于预测下一个标记。
        freqs_cis (torch.Tensor): 预计算的复指数值,用于旋转位置嵌入,帮助模型捕捉序列中的位置信息。
    """
    def __init__(self, args):
        """
        初始化Transformer模型。

        参数:
            args: 模型参数对象,包含Transformer的各种参数,如词汇表大小、维度、层数等。
        """
        # 获取全局变量world_size和rank,分别表示分布式训练中的进程总数和当前进程的编号
        global world_size, rank
        # 如果分布式训练已初始化,则获取进程总数,否则默认为1
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        # 如果分布式训练已初始化,则获取当前进程编号,否则默认为0
        rank = dist.get_rank() if dist.is_initialized() else 0
        # 根据参数设置线性层的数据类型
        Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
        # 调用父类的初始化方法
        super().__init__()
        # 保存最大序列长度
        self.max_seq_len = args.max_seq_len
        # 初始化嵌入层,将输入标记转换为向量表示
        self.embed = ParallelEmbedding(args.vocab_size, args.dim)
        # 初始化一个空的ModuleList,用于存储Transformer块
        self.layers = torch.nn.ModuleList()
        # 循环创建指定数量的Transformer块,并添加到layers列表中
        for layer_id in range(args.n_layers):
            self.layers.append(Block(layer_id, args))
        # 初始化层归一化层
        self.norm = RMSNorm(args.dim)
        # 初始化输出投影层,将模型的输出映射到词汇表大小
        self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
        # 预计算旋转位置嵌入所需的复指数值,并将其注册为缓冲区,不参与模型参数的更新
        self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)

    @torch.inference_mode()
    def forward(self, tokens, start_pos=0):
        """
        Transformer模型的前向传播过程。

        参数:
            tokens (torch.Tensor): 输入的标记ID张量,形状为 (batch_size, seq_len)。
            start_pos (int, 可选): 旋转位置嵌入的起始位置,默认为0。

        返回:
            torch.Tensor: 对数概率张量,形状为 (batch_size, vocab_size),表示每个标记的预测概率。
        """
        # 获取输入序列的长度
        seqlen = tokens.size(1)
        # 通过嵌入层将输入标记转换为向量表示
        h = self.embed(tokens)
        # 从预计算的复指数值中截取当前序列所需的部分
        freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
        # 初始化掩码为None
        mask = None
        # 如果序列长度大于1,则创建一个上三角掩码,用于屏蔽未来的标记
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
        # 依次通过每个Transformer块进行处理
        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        # 对输出进行层归一化,并取最后一个时间步的输出
        h = self.norm(h)[:, -1]
        # 通过输出投影层得到对数概率
        logits = self.head(h)
        # 如果使用分布式训练,则收集所有进程的对数概率
        if world_size > 1:
            # 创建一个列表,用于存储所有进程的对数概率
            all_logits = [torch.empty_like(logits) for _ in range(world_size)]
            # 收集所有进程的对数概率
            dist.all_gather(all_logits, logits)
            # 将所有进程的对数概率拼接在一起
            logits = torch.cat(all_logits, dim=-1)
        return logits

单个 - Block

核心代码非常简单MLA(attention) + MOE(Feed-Forward Network):

python 复制代码
# 首先对输入进行层归一化,然后通过注意力层进行计算,最后将结果与输入进行残差连接
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
# 接着对上述结果进行层归一化,再通过前馈网络层进行计算,最后将结果与之前的结果进行残差连接
x = x + self.ffn(self.ffn_norm(x))

全部代码:

python 复制代码
# 定义一个Transformer块类,继承自PyTorch的nn.Module类
class Block(Module):
    """
    Transformer块,结合了注意力层和前馈网络层。

    属性:
        attn (nn.Module): 注意力层(采用多头潜在注意力机制,即MLA),用于捕捉输入序列中不同位置之间的依赖关系。
        ffn (nn.Module): 前馈网络层(可以是多层感知机MLP或者混合专家模型MoE),对注意力层的输出进行非线性变换。
        attn_norm (nn.Module): 用于注意力层的层归一化层,对输入到注意力层的数据进行归一化处理,稳定训练过程。
        ffn_norm (nn.Module): 用于前馈网络层的层归一化层,对输入到前馈网络层的数据进行归一化处理。
    """
    def __init__(self, layer_id, args):
        """
        初始化Transformer块。

        参数:
            layer_id (int): 当前块在Transformer模型中的层索引,用于确定使用哪种前馈网络结构。
            args: 模型参数对象,包含了块的各种参数,如维度、层数等。
        """
        # 调用父类的初始化方法
        super().__init__()
        # 初始化注意力层,使用多头潜在注意力机制(MLA)
        self.attn = MLA(args)
        # 根据当前层的索引来决定使用MLP还是MoE作为前馈网络
        # 如果当前层索引小于密集层的数量,则使用MLP
        # 否则使用混合专家模型(MoE)
        self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
        # 初始化用于注意力层的层归一化层
        self.attn_norm = RMSNorm(args.dim)
        # 初始化用于前馈网络层的层归一化层
        self.ffn_norm = RMSNorm(args.dim)

    def forward(self, x, start_pos, freqs_cis, mask=None):
        """
        Transformer块的前向传播过程。

        参数:
            x (torch.Tensor): 输入张量,包含了序列的特征信息。
            start_pos (int): 序列中的起始位置,用于旋转位置嵌入。
            freqs_cis (torch.Tensor): 预计算的复指数值,用于旋转位置嵌入,帮助模型捕捉序列中的位置信息。
            mask (Optional[torch.Tensor]): 掩码张量,用于在注意力计算中排除某些位置,避免模型关注到不应该关注的信息。

        返回:
            torch.Tensor: 经过当前Transformer块计算后的输出张量。
        """
        # 首先对输入进行层归一化,然后通过注意力层进行计算,最后将结果与输入进行残差连接
        x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
        # 接着对上述结果进行层归一化,再通过前馈网络层进行计算,最后将结果与之前的结果进行残差连接
        x = x + self.ffn(self.ffn_norm(x))
        return x

Attention 模块

经典的QKV计算公式。解释可以自行搜索,或者参考:Transformer结构和注意力机制

和传统的QKV相比,可以认为是做了压缩,主要是为了减小 KV Cache。

代码,就是做了一堆下采样上采样和矩阵的组合变换。最终目的是减少计算量和显存使用量。

python 复制代码
# 定义多头注意力层类,继承自PyTorch的nn.Module类
class MLA(Module):
    """
    多头注意力层(MLA)。

    属性:
        dim (int): 输入特征的维度。
        n_heads (int): 注意力头的数量。
        n_local_heads (int): 分布式系统中本地注意力头的数量。
        q_lora_rank (int): 查询(query)的低秩投影的秩。
        kv_lora_rank (int): 键(key)和值(value)的低秩投影的秩。
        qk_nope_head_dim (int): 非位置相关的查询/键投影的维度。
        qk_rope_head_dim (int): 旋转位置编码的查询/键投影的维度。
        qk_head_dim (int): 查询/键投影的总维度。
        v_head_dim (int): 值投影的维度。
        softmax_scale (float): 注意力计算中softmax函数的缩放因子。
    """
    def __init__(self, args):
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入特征的维度
        self.dim = args.dim
        # 保存注意力头的数量
        self.n_heads = args.n_heads
        # 计算分布式系统中本地注意力头的数量
        self.n_local_heads = args.n_heads // world_size
        # 保存查询的低秩投影的秩
        self.q_lora_rank = args.q_lora_rank
        # 保存键和值的低秩投影的秩
        self.kv_lora_rank = args.kv_lora_rank
        # 保存非位置相关的查询/键投影的维度
        self.qk_nope_head_dim = args.qk_nope_head_dim
        # 保存旋转位置编码的查询/键投影的维度
        self.qk_rope_head_dim = args.qk_rope_head_dim
        # 计算查询/键投影的总维度
        self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
        # 保存值投影的维度
        self.v_head_dim = args.v_head_dim

        # 如果查询的低秩投影的秩为0,直接使用列并行线性层进行查询投影
        if self.q_lora_rank == 0:
            self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
        # 否则,使用低秩分解的方式进行查询投影
        else:
            self.wq_a = Linear(self.dim, self.q_lora_rank)
            self.q_norm = RMSNorm(self.q_lora_rank)
            self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
        # 对输入进行线性变换得到键和值的低秩表示
        self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        # 对键和值的低秩表示进行归一化
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        # 对归一化后的键和值进行线性变换
        self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
        # 对多头注意力的输出进行行并行线性变换
        self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
        # 计算softmax函数的缩放因子
        self.softmax_scale = self.qk_head_dim ** -0.5
        # 如果最大序列长度大于原始序列长度,对缩放因子进行调整
        if args.max_seq_len > args.original_seq_len:
            mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
            self.softmax_scale = self.softmax_scale * mscale * mscale

        # 如果注意力实现方式为朴素方式
        if attn_impl == "naive":
            # 注册键缓存
            self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
            # 注册值缓存
            self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
        # 否则
        else:
            # 注册键值缓存
            self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
            # 注册位置编码缓存
            self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)

    def forward(self, x, start_pos, freqs_cis, mask=None):
        """
        多头注意力层(MLA)的前向传播过程。

        参数:
            x (torch.Tensor): 输入张量,形状为 (batch_size, seq_len, dim)。
            start_pos (int): 序列中用于缓存的起始位置。
            freqs_cis (torch.Tensor): 预计算的复指数值,用于旋转位置编码。
            mask (Optional[torch.Tensor]): 掩码张量,用于在注意力计算中排除某些位置。

        返回:
            torch.Tensor: 输出张量,形状与输入相同。
        """
        # 获取输入张量的批次大小、序列长度
        bsz, seqlen, _ = x.size()
        # 计算序列的结束位置
        end_pos = start_pos + seqlen
        # 如果查询的低秩投影的秩为0,直接通过线性层得到查询
        if self.q_lora_rank == 0:
            q = self.wq(x)
        # 否则,通过低秩分解的方式得到查询
        else:
            q = self.wq_b(self.q_norm(self.wq_a(x)))
        # 调整查询的形状,将其划分为多个头
        q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
        # 将查询划分为非位置相关部分和位置相关部分
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # 对位置相关部分应用旋转位置编码
        q_pe = apply_rotary_emb(q_pe, freqs_cis)
        # 通过线性层得到键和值的低秩表示
        kv = self.wkv_a(x)
        # 将键和值的低秩表示划分为低秩部分和位置编码部分
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        # 对位置编码部分应用旋转位置编码
        k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)

        # 如果注意力实现方式为朴素方式
        if attn_impl == "naive":
            # 将非位置相关部分和位置相关部分拼接得到完整的查询
            q = torch.cat([q_nope, q_pe], dim=-1)
            # 对键和值的低秩表示进行归一化和线性变换
            kv = self.wkv_b(self.kv_norm(kv))
            # 调整键和值的形状,将其划分为多个头
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
            # 将键和值划分为非位置相关部分和值部分
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            # 将非位置相关部分和位置编码部分拼接得到完整的键
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
            # 将键存入缓存
            self.k_cache[:bsz, start_pos:end_pos] = k
            # 将值存入缓存
            self.v_cache[:bsz, start_pos:end_pos] = v
            # 计算注意力分数
            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
        # 否则
        else:
            # 获取键和值的线性变换层的权重
            wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) 
            # 调整权重的形状
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
            # 计算非位置相关部分的注意力分数
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
            # 将键和值的低秩表示归一化后存入缓存
            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
            # 将位置编码部分存入缓存
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
            # 计算注意力分数
            scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                      torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale

        # 如果存在掩码,将掩码加到注意力分数上
        if mask is not None:
            scores += mask.unsqueeze(1)
        # 对注意力分数应用softmax函数
        scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)

        # 如果注意力实现方式为朴素方式
        if attn_impl == "naive":
            # 通过注意力分数和值缓存计算输出
            x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
        # 否则
        else:
            # 通过注意力分数和键值缓存计算中间结果
            x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
            # 通过中间结果和权重计算输出
            x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
        # 对输出进行线性变换
        x = self.wo(x.flatten(2))
        return x

代码解释总结

这段代码定义了一个多头注意力层(MLA)类。在初始化时,根据传入的参数设置各种维度、低秩投影的秩等,并初始化相应的线性层和归一化层,同时根据注意力实现方式注册不同的缓存。在前向传播过程中,对输入进行处理得到查询、键和值,应用旋转位置编码,根据不同的注意力实现方式计算注意力分数,最后通过注意力分数和缓存得到输出并进行线性变换。

Feed-Forward Network

这个MoE分成两个部分,左边是一些可以分享的专家,就是每次都需要去计算的,右边的是根据分数来选择的。如何选择,是通过一个门控机制来选择。这个门控是如何设计的,代码里有实现,但论文和代码都没有对它物理意义的解释。门控机制,简单来说,就是设计了一个网络,来选出K个候选。

核心逻辑:

  1. 通过门控确定本轮要用到的本地专家:weights, indices = self.gate(x)
  2. 用选择的每个本地专家进行计算:y[idx] += expert(x[idx]) * weights[idx, top, None]
  3. 用共享专家进行计算:z = self.shared_experts(x)
  4. 将本地专家的输出和共享专家的输出相加,并恢复到原始形状: return (y + z).view(shape)

全部代码:

python 复制代码
# 定义混合专家(Mixture-of-Experts, MoE)模块类,继承自PyTorch的nn.Module类
class MoE(nn.Module):
    """
    混合专家(Mixture-of-Experts, MoE)模块。

    属性:
        dim (int): 输入特征的维度。
        n_routed_experts (int): 模型中专家的总数。
        n_local_experts (int): 在分布式系统中本地处理的专家数量。
        n_activated_experts (int): 每个输入激活的专家数量。
        gate (nn.Module): 门控机制,用于将输入路由到不同的专家。
        experts (nn.ModuleList): 专家模块列表,包含多个专家网络。
        shared_experts (nn.Module): 共享专家模块,应用于所有输入。
    """
    def __init__(self, args):
        """
        初始化MoE模块。

        参数:
            args: 模型参数对象,包含MoE模块的相关参数。
        """
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入特征的维度
        self.dim = args.dim
        # 确保专家总数能被分布式系统中的进程数整除
        assert args.n_routed_experts % world_size == 0, f"专家数量必须能被进程数整除 (进程数={world_size})"
        # 保存模型中专家的总数
        self.n_routed_experts = args.n_routed_experts
        # 计算本地处理的专家数量
        self.n_local_experts = args.n_routed_experts // world_size
        # 保存每个输入激活的专家数量
        self.n_activated_experts = args.n_activated_experts
        # 计算本地专家在所有专家中的起始索引
        self.experts_start_idx = rank * self.n_local_experts
        # 计算本地专家在所有专家中的结束索引
        self.experts_end_idx = self.experts_start_idx + self.n_local_experts
        # 初始化门控机制
        self.gate = Gate(args)
        # 初始化专家模块列表,本地负责的专家使用Expert模块,其他位置置为None
        self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
                                      for i in range(self.n_routed_experts)])
        # 初始化共享专家模块
        self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)

    def forward(self, x):
        """
        MoE模块的前向传播过程。

        参数:
            x (torch.Tensor): 输入张量。

        返回:
            torch.Tensor: 经过专家路由和计算后的输出张量。
        """
        # 保存输入张量的原始形状
        shape = x.size()
        # 将输入张量展平为二维张量,方便后续处理
        x = x.view(-1, self.dim)
        # 通过门控机制得到每个输入分配到各个专家的权重和对应的专家索引
        weights, indices = self.gate(x)
        # 初始化输出张量,形状与输入相同,初始值全为0
        y = torch.zeros_like(x)
        # 统计每个专家被分配到的输入数量
        counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
        # 遍历本地负责的专家
        for i in range(self.experts_start_idx, self.experts_end_idx):
            # 如果该专家没有被分配到输入,则跳过
            if counts[i] == 0:
                continue
            # 获取当前专家模块
            expert = self.experts[i]
            # 找出分配到当前专家的输入的索引
            idx, top = torch.where(indices == i)
            # 将这些输入通过当前专家模块进行计算,并乘以对应的权重,累加到输出张量中
            y[idx] += expert(x[idx]) * weights[idx, top, None]
        # 将输入通过共享专家模块进行计算
        z = self.shared_experts(x)
        # 如果使用分布式训练,对本地专家的输出进行全局归约操作
        if world_size > 1:
            dist.all_reduce(y)
        # 将本地专家的输出和共享专家的输出相加,并恢复到原始形状
        return (y + z).view(shape)

代码解释总结

这段代码定义了一个混合专家(MoE)模块。在初始化时,根据传入的参数设置专家的数量、门控机制、专家模块列表和共享专家模块。在前向传播过程中,首先通过门控机制将输入路由到不同的专家,然后对本地负责的专家进行计算并累加结果,同时将输入通过共享专家模块进行计算,最后将两部分结果相加并恢复原始形状。如果使用分布式训练,还会对本地专家的输出进行全局归约操作。

Gate

不考虑分组路由来看看它的核心逻辑,实际上就是线下变换,然后激活,选择K个极值(如果用了分组,就是选择K的方式发生了一些变化):

python 复制代码
# 通过线性变换计算每个输入对应各个专家的分数
scores = linear(x, self.weight)
# 根据评分函数类型对分数进行处理
if self.score_func == "softmax":
    scores = scores.softmax(dim=-1, dtype=torch.float32)
else:
    scores = scores.sigmoid()
# 选择分数最高的若干专家
indices = torch.topk(scores, self.topk, dim=-1)[1]
# 根据选择的专家索引,从原始分数中获取对应的权重
weights = scores.gather(1, indices)
# 如果评分函数是sigmoid,对权重进行归一化
if self.score_func == "sigmoid":
    weights /= weights.sum(dim=-1, keepdim=True)
# 对权重进行缩放
weights *= self.route_scale
return weights.type_as(x), indices
python 复制代码
# 定义门控机制类,用于在混合专家(MoE)模型中对输入进行路由
class Gate(nn.Module):
    """
    混合专家(MoE)模型中用于输入路由的门控机制。

    属性:
        dim (int): 输入特征的维度。
        topk (int): 每个输入激活的顶级专家数量。
        n_groups (int): 用于路由的分组数量。
        topk_groups (int): 输入将被路由到的分组数量。
        score_func (str): 评分函数,取值为 'softmax' 或 'sigmoid'。
        route_scale (float): 路由权重的缩放因子。
        weight (torch.nn.Parameter): 门控机制的可学习权重。
        bias (Optional[torch.nn.Parameter]): 门控机制的可选偏置项。
    """
    def __init__(self, args):
        """
        初始化门控机制模块。

        参数:
            args: 模型参数对象,包含门控机制的相关参数。
        """
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入特征的维度
        self.dim = args.dim
        # 保存每个输入激活的顶级专家数量
        self.topk = args.n_activated_experts
        # 保存用于路由的分组数量
        self.n_groups = args.n_expert_groups
        # 保存输入将被路由到的分组数量
        self.topk_groups = args.n_limited_groups
        # 保存评分函数类型
        self.score_func = args.score_func
        # 保存路由权重的缩放因子
        self.route_scale = args.route_scale
        # 初始化可学习权重
        self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
        # 根据输入特征维度决定是否初始化偏置项
        self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None

    def forward(self, x):
        """
        门控机制的前向传播过程。

        参数:
            x (torch.Tensor): 输入张量。

        返回:
            Tuple[torch.Tensor, torch.Tensor]: 路由权重和选择的专家索引。
        """
        # 通过线性变换计算每个输入对应各个专家的分数
        scores = linear(x, self.weight)
        # 根据评分函数类型对分数进行处理
        if self.score_func == "softmax":
            scores = scores.softmax(dim=-1, dtype=torch.float32)
        else:
            scores = scores.sigmoid()
        # 保存原始分数,后续计算权重时使用
        original_scores = scores
        # 如果存在偏置项,将其加到分数上
        if self.bias is not None:
            scores = scores + self.bias
        # 如果分组数量大于1,进行分组路由操作
        if self.n_groups > 1:
            # 调整分数的形状,以便按组处理
            scores = scores.view(x.size(0), self.n_groups, -1)
            # 根据是否有偏置项,计算每个组的分数表示
            if self.bias is None:
                group_scores = scores.amax(dim=-1)
            else:
                group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
            # 选择分数最高的若干组
            indices = group_scores.topk(self.topk_groups, dim=-1)[1]
            # 创建掩码,用于屏蔽未选择的组
            mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
            # 将屏蔽组的分数设为负无穷
            scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
        # 选择分数最高的若干专家
        indices = torch.topk(scores, self.topk, dim=-1)[1]
        # 根据选择的专家索引,从原始分数中获取对应的权重
        weights = original_scores.gather(1, indices)
        # 如果评分函数是sigmoid,对权重进行归一化
        if self.score_func == "sigmoid":
            weights /= weights.sum(dim=-1, keepdim=True)
        # 对权重进行缩放
        weights *= self.route_scale
        return weights.type_as(x), indices

代码解释总结

这段代码定义了一个门控机制(Gate)类,用于在混合专家(MoE)模型中对输入进行路由。在初始化时,根据传入的参数设置各种属性,如输入维度、激活专家数量、分组数量等,并初始化可学习的权重和偏置项。在前向传播过程中,首先计算每个输入对应各个专家的分数,然后根据评分函数类型进行处理,接着根据分组情况进行分组路由操作,选择激活的专家并计算对应的权重,最后返回路由权重和选择的专家索引。

相关推荐
瑞瑞大大3 分钟前
简单介绍下Manus功能
人工智能
小杨4046 分钟前
python入门系列六(文件操作)
人工智能·python·pycharm
deephub13 分钟前
Chain of Draft: 借鉴人类草稿思维让大型语言模型更快地思考
人工智能·语言模型·自然语言处理·思维链
碣石潇湘无限路40 分钟前
【AI】基于扩散方案的大语言模型研究报告
人工智能·语言模型·自然语言处理
EasyCVR1 小时前
EasyRTC嵌入式音视频通话SDK:基于ICE与STUN/TURN的实时音视频通信解决方案
人工智能·音视频·webrtc·实时音视频·h.265
非优秀程序员1 小时前
使用Python给自己网站生成llms.txt
人工智能·后端·架构
二川bro1 小时前
AI 人工智能深度解析:从基础到前沿,全面掌握未来科技
人工智能·科技
非优秀程序员1 小时前
人工智能时代,如何让你的网站更好被大模型收录,获得新的自然流量并成为互联网的信息来源
人工智能·机器学习·架构
Dipeak数巅科技1 小时前
数巅科技携手智慧足迹深耕行业大模型应用
大数据·人工智能·商业智能bi
AI34561 小时前
AI壁纸进阶宝典:让创作效率与质量飞速提升的法门
人工智能