MoEs and Transformers 笔记

ref:https://huggingface.co/blog/zh/moe#用router-z-loss稳定模型训练

MoEs and Transformers

Transformer 类模型明确表明,增加参数数量可以提高性能,因此谷歌使用 GShard 尝试将 Transformer 模型的参数量扩展到超过 6000 亿并不令人惊讶。

GShard 将在编码器和解码器中的每个前馈网络 (FFN) 层中的替换为使用 Top-2 门控的混合专家模型 (MoE) 层。下图展示了编码器部分的结构。这种架构对于大规模计算非常有效: 当扩展到多个设备时,MoE 层在不同设备间共享,而其他所有层则在每个设备上复制。我们将在 "让 MoE 起飞" 部分对这一点进行更详细的讨论。

为了保持负载平衡和训练效率,GShard 的作者除了引入了上一节中讨论的类似辅助损失外,还引入了一些关键变化:

随机路由: 在 Top-2 设置中,我们始终选择排名最高的专家,但第二个专家是根据其权重比例随机选择的。

专家容量: 我们可以设定一个阈值,定义一个专家能处理多少令牌。如果两个专家的容量都达到上限,令牌就会溢出,并通过残差连接传递到下一层,或在某些情况下被完全丢弃。专家容量是 MoE 中最重要的概念之一。为什么需要专家容量呢?因为所有张量的形状在编译时是静态确定的,我们无法提前知道多少令牌会分配给每个专家,因此需要一个固定的容量因子。

GShard 的工作对适用于 MoE 的并行计算模式也做出了重要贡献,但这些内容的讨论超出了这篇博客的范围。

注意: 在推理过程中,只有部分专家被激活。同时,有些计算过程是共享的,例如自注意力 (self-attention) 机制,它适用于所有令牌。这就解释了为什么我们可以使用相当于 12B 稠密模型的计算资源来运行一个包含 8 个专家的 47B 模型。如果我们采用 Top-2 门控,模型会使用高达 14B 的参数。但是,由于自注意力操作 (专家间共享) 的存在,实际上模型运行时使用的参数数量是 12B。

Switch Transformers

尽管混合专家模型 (MoE) 显示出了很大的潜力,但它们在训练和微调过程中存在稳定性问题。Switch Transformers 是一项非常激动人心的工作,它深入研究了这些话题。作者甚至在 Hugging Face 上发布了一个 1.6 万亿参数的 MoE,拥有 2048 个专家,你可以使用 transformers 库来运行它。Switch Transformers 实现了与 T5-XXL 相比 4 倍的预训练速度提升。

就像在 GShard 中一样,作者用混合专家模型 (MoE) 层替换了前馈网络 (FFN) 层。Switch Transformers 提出了一个 Switch Transformer 层,它接收两个输入 (两个不同的令牌) 并拥有四个专家。

与最初使用至少两个专家的想法相反,Switch Transformers 采用了简化的单专家策略。这种方法的效果包括:

减少门控网络 (路由) 计算负担

每个专家的批量大小至少可以减半

降低通信成本

保持模型质量

Switch Transformers 采用了编码器 - 解码器的架构,实现了与 T5 类似的混合专家模型 (MoE) 版本。GLaM 这篇工作探索了如何使用仅为原来 1/3 的计算资源 (因为 MoE 模型在训练时需要的计算量较少,从而能够显著降低碳足迹) 来训练与 GPT-3 质量相匹配的模型来提高这些模型的规模。作者专注于仅解码器 (decoder-only) 的模型以及少样本和单样本评估,而不是微调。他们使用了 Top-2 路由和更大的容量因子。此外,他们探讨了将容量因子作为一个动态度量,根据训练和评估期间所使用的计算量进行调整。

用 Router z-loss 稳定模型训练

之前讨论的平衡损失可能会导致稳定性问题。我们可以使用许多方法来稳定稀疏模型的训练,但这可能会牺牲模型质量。例如,引入 dropout 可以提高稳定性,但会导致模型质量下降。另一方面,增加更多的乘法分量可以提高质量,但会降低模型稳定性。

ST-MoE 引入的 Router z-loss 在保持了模型性能的同时显著提升了训练的稳定性。这种损失机制通过惩罚门控网络输入的较大 logits 来起作用,目的是促使数值的绝对大小保持较小,这样可以有效减少计算中的舍入误差。这一点对于那些依赖指数函数进行计算的门控网络尤其重要。

专家的数量对预训练有何影响?

增加更多专家可以提升处理样本的效率和加速模型的运算速度,但这些优势随着专家数量的增加而递减 (尤其是当专家数量达到 256 或 512 之后更为明显)。同时,这也意味着在推理过程中,需要更多的显存来加载整个模型。值得注意的是,Switch Transformers 的研究表明,其在大规模模型中的特性在小规模模型下也同样适用,即便是每层仅包含 2、4 或 8 个专家。

对于开源的混合专家模型 (MoE),你可以关注下面这些:

Switch Transformers (Google): 基于 T5 的 MoE 集合,专家数量从 8 名到 2048 名。最大的模型有 1.6 万亿个参数。

NLLB MoE (Meta): NLLB 翻译模型的一个 MoE 变体。

OpenMoE: 社区对基于 Llama 的模型的 MoE 尝试。

Mixtral 8x7B (Mistral): 一个性能超越了 Llama 2 70B 的高质量混合专家模型,并且具有更快的推理速度。此外,还发布了一个经过指令微调的模型。有关更多信息,可以在 Mistral 的 公告博客文章 中了解。

REF:https://github.com/kyegomez/SwitchTransformers/blob/main/switch_transformers/model.py

bash 复制代码
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from zeta.nn import FeedForward, MultiQueryAttention


class SwitchGate(nn.Module):
    """
    SwitchGate module for MoE (Mixture of Experts) model.

    Args:
        dim (int): Input dimension.
        num_experts (int): Number of experts.
        capacity_factor (float, optional): Capacity factor for sparsity. Defaults to 1.0.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """

    def __init__(
        self,
        dim,
        num_experts: int,
        capacity_factor: float = 1.0,
        epsilon: float = 1e-6,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.num_experts = num_experts
        self.capacity_factor = capacity_factor
        self.epsilon = epsilon
        self.w_gate = nn.Linear(dim, num_experts)

    def forward(self, x: Tensor, use_aux_loss=False):
        """
        Forward pass of the SwitchGate module.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Gate scores.
        """
        # Compute gate scores
        gate_scores = F.softmax(self.w_gate(x), dim=-1)

        # Determine the top-1 expert for each token
        capacity = int(self.capacity_factor * x.size(0))

        top_k_scores, top_k_indices = gate_scores.topk(1, dim=-1)

        # Mask to enforce sparsity
        mask = torch.zeros_like(gate_scores).scatter_(
            1, top_k_indices, 1
        )

        # Combine gating scores with the mask
        masked_gate_scores = gate_scores * mask

        # Denominators
        denominators = (
            masked_gate_scores.sum(0, keepdim=True) + self.epsilon
        )

        # Norm gate scores to sum to the capacity
        gate_scores = (masked_gate_scores / denominators) * capacity

        if use_aux_loss:
            load = gate_scores.sum(0)  # Sum over all examples
            importance = gate_scores.sum(1)  # Sum over all experts

            # Aux loss is mean suqared difference between load and importance
            loss = ((load - importance) ** 2).mean()

            return gate_scores, loss

        return gate_scores, None


class SwitchMoE(nn.Module):
    """
    A module that implements the Switched Mixture of Experts (MoE) architecture.

    Args:
        dim (int): The input dimension.
        hidden_dim (int): The hidden dimension of the feedforward network.
        output_dim (int): The output dimension.
        num_experts (int): The number of experts in the MoE.
        capacity_factor (float, optional): The capacity factor that controls the capacity of the MoE. Defaults to 1.0.
        mult (int, optional): The multiplier for the hidden dimension of the feedforward network. Defaults to 4.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Attributes:
        dim (int): The input dimension.
        hidden_dim (int): The hidden dimension of the feedforward network.
        output_dim (int): The output dimension.
        num_experts (int): The number of experts in the MoE.
        capacity_factor (float): The capacity factor that controls the capacity of the MoE.
        mult (int): The multiplier for the hidden dimension of the feedforward network.
        experts (nn.ModuleList): The list of feedforward networks representing the experts.
        gate (SwitchGate): The switch gate module.

    """

    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        output_dim: int,
        num_experts: int,
        capacity_factor: float = 1.0,
        mult: int = 4,
        use_aux_loss: bool = False,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_experts = num_experts
        self.capacity_factor = capacity_factor
        self.mult = mult
        self.use_aux_loss = use_aux_loss

        self.experts = nn.ModuleList(
            [
                FeedForward(dim, dim, mult, *args, **kwargs)
                for _ in range(num_experts)
            ]
        )

        self.gate = SwitchGate(
            dim,
            num_experts,
            capacity_factor,
        )

    def forward(self, x: Tensor):
        """
        Forward pass of the SwitchMoE module.

        Args:
            x (Tensor): The input tensor.

        Returns:
            Tensor: The output tensor of the MoE.

        """
        # (batch_size, seq_len, num_experts)
        gate_scores, loss = self.gate(
            x, use_aux_loss=self.use_aux_loss
        )

        # Dispatch to experts
        expert_outputs = [expert(x) for expert in self.experts]

        # Check if any gate scores are nan and handle
        if torch.isnan(gate_scores).any():
            print("NaN in gate scores")
            gate_scores[torch.isnan(gate_scores)] = 0

        # Stack and weight outputs
        stacked_expert_outputs = torch.stack(
            expert_outputs, dim=-1
        )  # (batch_size, seq_len, output_dim, num_experts)
        if torch.isnan(stacked_expert_outputs).any():
            stacked_expert_outputs[
                torch.isnan(stacked_expert_outputs)
            ] = 0

        # Combine expert outputs and gating scores
        moe_output = torch.sum(
            gate_scores.unsqueeze(-2) * stacked_expert_outputs, dim=-1
        )

        return moe_output, loss


class SwitchTransformerBlock(nn.Module):
    """
    SwitchTransformerBlock is a module that represents a single block of the Switch Transformer model.

    Args:
        dim (int): The input dimension of the block.
        heads (int): The number of attention heads.
        dim_head (int): The dimension of each attention head.
        mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.
        dropout (float, optional): The dropout rate. Defaults to 0.1.
        depth (int, optional): The number of layers in the block. Defaults to 12.
        num_experts (int, optional): The number of experts in the SwitchMoE layer. Defaults to 6.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Attributes:
        dim (int): The input dimension of the block.
        heads (int): The number of attention heads.
        dim_head (int): The dimension of each attention head.
        mult (int): The multiplier for the hidden dimension in the feed-forward network.
        dropout (float): The dropout rate.
        attn_layers (nn.ModuleList): List of MultiQueryAttention layers.
        ffn_layers (nn.ModuleList): List of SwitchMoE layers.

    Examples:
        >>> block = SwitchTransformerBlock(dim=512, heads=8, dim_head=64)
        >>> x = torch.randn(1, 10, 512)
        >>> out = block(x)
        >>> out.shape

    """

    def __init__(
        self,
        dim: int,
        heads: int,
        dim_head: int,
        mult: int = 4,
        dropout: float = 0.1,
        num_experts: int = 3,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.dim_head = dim_head
        self.mult = mult
        self.dropout = dropout

        self.attn = MultiQueryAttention(
            dim, heads, qk_ln=True * args, **kwargs
        )

        self.ffn = SwitchMoE(
            dim, dim * mult, dim, num_experts, *args, **kwargs
        )
        
        self.add_norm = nn.LayerNorm(dim)


    def forward(self, x: Tensor):
        """
        Forward pass of the SwitchTransformerBlock.

        Args:
            x (Tensor): The input tensor.

        Returns:
            Tensor: The output tensor.

        """
        resi = x
        x, _, _ = self.attn(x)
        x = x + resi
        x = self.add_norm(x)
        add_normed = x
        
        ##### MoE #####
        x, _ = self.ffn(x)
        x = x + add_normed
        x = self.add_norm(x)
        return x


class SwitchTransformer(nn.Module):
    """
    SwitchTransformer is a PyTorch module that implements a transformer model with switchable experts.

    Args:
        num_tokens (int): The number of tokens in the input vocabulary.
        dim (int): The dimensionality of the token embeddings and hidden states.
        heads (int): The number of attention heads.
        dim_head (int, optional): The dimensionality of each attention head. Defaults to 64.
        mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.
        dropout (float, optional): The dropout rate. Defaults to 0.1.
        num_experts (int, optional): The number of experts in the switchable experts mechanism. Defaults to 3.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.
    """
    def __init__(
        self,
        num_tokens: int,
        dim: int,
        heads: int,
        dim_head: int = 64,
        mult: int = 4,
        dropout: float = 0.1,
        num_experts: int = 3,
        depth: int = 4,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.num_tokens = num_tokens
        self.dim = dim
        self.heads = heads
        self.dim_head = dim_head
        self.mult = mult
        self.dropout = dropout
        self.num_experts = num_experts
        self.depth = depth

        self.embedding = nn.Embedding(num_tokens, dim)
        self.layers = nn.ModuleList([])
        
        for _ in range(depth):
            self.layers.append(
                SwitchTransformerBlock(
                    dim,
                    heads,
                    dim_head,
                    mult,
                    dropout,
                    num_experts,
                    *args,
                    **kwargs,
                )
            )

        self.to_out = nn.Sequential(
            nn.Softmax(dim=-1),
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens),
        )

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the SwitchTransformer.

        Args:
            x (Tensor): The input tensor of shape (batch_size, sequence_length).

        Returns:
            Tensor: The output tensor of shape (batch_size, sequence_length, num_tokens).
        """
        # Embed tokens through embedding layer
        x = self.embedding(x)
        
        # Pass through the transformer block with MoE, it's in modulelist
        for layer in self.layers:
            x = layer(x)

        # Project to output tokens
        x = self.to_out(x)
        return x
相关推荐
9命怪猫2 分钟前
AI大模型-提示工程学习笔记5-零提示
人工智能·笔记·学习·ai·提示工程
达子6661 小时前
笔记-使用ffmpeg产生rtsp视频流,然后用进行VLC播放
windows·笔记·ffmpeg
m0_726965983 小时前
Harmony开发【笔记1】报错解决(字段名写错了。。)
笔记·bug·debug·harmonyos
1101 11013 小时前
STM32-笔记39-SPI-W25Q128
笔记·stm32·嵌入式硬件
Lovely Ruby4 小时前
大纲笔记幕布的替换
笔记
Spcarrydoinb5 小时前
python学习笔记—14—函数
笔记·python·学习
垂杨有暮鸦⊙_⊙8 小时前
有限元分析学习——Anasys Workbanch第一阶段笔记(8)水杯案例的对称与轴对称处理
笔记·学习·有限元分析
RickyWasYoung9 小时前
【笔记】Simulink与Workbench交互2:for循环内解决Constant模块无法赋值问题
笔记
LeoZY_9 小时前
CH348结合开源ModBus协议组成串口温度采集服务器
运维·笔记·嵌入式硬件·开源
JaneZJW10 小时前
江科大STM32入门——IIC通信笔记总结
c语言·笔记·stm32·单片机·嵌入式硬件·嵌入式·iic