ref:https://huggingface.co/blog/zh/moe#用router-z-loss稳定模型训练
MoEs and Transformers
Transformer 类模型明确表明,增加参数数量可以提高性能,因此谷歌使用 GShard 尝试将 Transformer 模型的参数量扩展到超过 6000 亿并不令人惊讶。
GShard 将在编码器和解码器中的每个前馈网络 (FFN) 层中的替换为使用 Top-2 门控的混合专家模型 (MoE) 层。下图展示了编码器部分的结构。这种架构对于大规模计算非常有效: 当扩展到多个设备时,MoE 层在不同设备间共享,而其他所有层则在每个设备上复制。我们将在 "让 MoE 起飞" 部分对这一点进行更详细的讨论。
为了保持负载平衡和训练效率,GShard 的作者除了引入了上一节中讨论的类似辅助损失外,还引入了一些关键变化:
随机路由: 在 Top-2 设置中,我们始终选择排名最高的专家,但第二个专家是根据其权重比例随机选择的。
专家容量: 我们可以设定一个阈值,定义一个专家能处理多少令牌。如果两个专家的容量都达到上限,令牌就会溢出,并通过残差连接传递到下一层,或在某些情况下被完全丢弃。专家容量是 MoE 中最重要的概念之一。为什么需要专家容量呢?因为所有张量的形状在编译时是静态确定的,我们无法提前知道多少令牌会分配给每个专家,因此需要一个固定的容量因子。
GShard 的工作对适用于 MoE 的并行计算模式也做出了重要贡献,但这些内容的讨论超出了这篇博客的范围。
注意: 在推理过程中,只有部分专家被激活。同时,有些计算过程是共享的,例如自注意力 (self-attention) 机制,它适用于所有令牌。这就解释了为什么我们可以使用相当于 12B 稠密模型的计算资源来运行一个包含 8 个专家的 47B 模型。如果我们采用 Top-2 门控,模型会使用高达 14B 的参数。但是,由于自注意力操作 (专家间共享) 的存在,实际上模型运行时使用的参数数量是 12B。
Switch Transformers
尽管混合专家模型 (MoE) 显示出了很大的潜力,但它们在训练和微调过程中存在稳定性问题。Switch Transformers 是一项非常激动人心的工作,它深入研究了这些话题。作者甚至在 Hugging Face 上发布了一个 1.6 万亿参数的 MoE,拥有 2048 个专家,你可以使用 transformers 库来运行它。Switch Transformers 实现了与 T5-XXL 相比 4 倍的预训练速度提升。
就像在 GShard 中一样,作者用混合专家模型 (MoE) 层替换了前馈网络 (FFN) 层。Switch Transformers 提出了一个 Switch Transformer 层,它接收两个输入 (两个不同的令牌) 并拥有四个专家。
与最初使用至少两个专家的想法相反,Switch Transformers 采用了简化的单专家策略。这种方法的效果包括:
减少门控网络 (路由) 计算负担
每个专家的批量大小至少可以减半
降低通信成本
保持模型质量
Switch Transformers 采用了编码器 - 解码器的架构,实现了与 T5 类似的混合专家模型 (MoE) 版本。GLaM 这篇工作探索了如何使用仅为原来 1/3 的计算资源 (因为 MoE 模型在训练时需要的计算量较少,从而能够显著降低碳足迹) 来训练与 GPT-3 质量相匹配的模型来提高这些模型的规模。作者专注于仅解码器 (decoder-only) 的模型以及少样本和单样本评估,而不是微调。他们使用了 Top-2 路由和更大的容量因子。此外,他们探讨了将容量因子作为一个动态度量,根据训练和评估期间所使用的计算量进行调整。
用 Router z-loss 稳定模型训练
之前讨论的平衡损失可能会导致稳定性问题。我们可以使用许多方法来稳定稀疏模型的训练,但这可能会牺牲模型质量。例如,引入 dropout 可以提高稳定性,但会导致模型质量下降。另一方面,增加更多的乘法分量可以提高质量,但会降低模型稳定性。
ST-MoE 引入的 Router z-loss 在保持了模型性能的同时显著提升了训练的稳定性。这种损失机制通过惩罚门控网络输入的较大 logits 来起作用,目的是促使数值的绝对大小保持较小,这样可以有效减少计算中的舍入误差。这一点对于那些依赖指数函数进行计算的门控网络尤其重要。
专家的数量对预训练有何影响?
增加更多专家可以提升处理样本的效率和加速模型的运算速度,但这些优势随着专家数量的增加而递减 (尤其是当专家数量达到 256 或 512 之后更为明显)。同时,这也意味着在推理过程中,需要更多的显存来加载整个模型。值得注意的是,Switch Transformers 的研究表明,其在大规模模型中的特性在小规模模型下也同样适用,即便是每层仅包含 2、4 或 8 个专家。
对于开源的混合专家模型 (MoE),你可以关注下面这些:
Switch Transformers (Google): 基于 T5 的 MoE 集合,专家数量从 8 名到 2048 名。最大的模型有 1.6 万亿个参数。
NLLB MoE (Meta): NLLB 翻译模型的一个 MoE 变体。
OpenMoE: 社区对基于 Llama 的模型的 MoE 尝试。
Mixtral 8x7B (Mistral): 一个性能超越了 Llama 2 70B 的高质量混合专家模型,并且具有更快的推理速度。此外,还发布了一个经过指令微调的模型。有关更多信息,可以在 Mistral 的 公告博客文章 中了解。
REF:https://github.com/kyegomez/SwitchTransformers/blob/main/switch_transformers/model.py
bash
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from zeta.nn import FeedForward, MultiQueryAttention
class SwitchGate(nn.Module):
"""
SwitchGate module for MoE (Mixture of Experts) model.
Args:
dim (int): Input dimension.
num_experts (int): Number of experts.
capacity_factor (float, optional): Capacity factor for sparsity. Defaults to 1.0.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
def __init__(
self,
dim,
num_experts: int,
capacity_factor: float = 1.0,
epsilon: float = 1e-6,
*args,
**kwargs,
):
super().__init__()
self.dim = dim
self.num_experts = num_experts
self.capacity_factor = capacity_factor
self.epsilon = epsilon
self.w_gate = nn.Linear(dim, num_experts)
def forward(self, x: Tensor, use_aux_loss=False):
"""
Forward pass of the SwitchGate module.
Args:
x (Tensor): Input tensor.
Returns:
Tensor: Gate scores.
"""
# Compute gate scores
gate_scores = F.softmax(self.w_gate(x), dim=-1)
# Determine the top-1 expert for each token
capacity = int(self.capacity_factor * x.size(0))
top_k_scores, top_k_indices = gate_scores.topk(1, dim=-1)
# Mask to enforce sparsity
mask = torch.zeros_like(gate_scores).scatter_(
1, top_k_indices, 1
)
# Combine gating scores with the mask
masked_gate_scores = gate_scores * mask
# Denominators
denominators = (
masked_gate_scores.sum(0, keepdim=True) + self.epsilon
)
# Norm gate scores to sum to the capacity
gate_scores = (masked_gate_scores / denominators) * capacity
if use_aux_loss:
load = gate_scores.sum(0) # Sum over all examples
importance = gate_scores.sum(1) # Sum over all experts
# Aux loss is mean suqared difference between load and importance
loss = ((load - importance) ** 2).mean()
return gate_scores, loss
return gate_scores, None
class SwitchMoE(nn.Module):
"""
A module that implements the Switched Mixture of Experts (MoE) architecture.
Args:
dim (int): The input dimension.
hidden_dim (int): The hidden dimension of the feedforward network.
output_dim (int): The output dimension.
num_experts (int): The number of experts in the MoE.
capacity_factor (float, optional): The capacity factor that controls the capacity of the MoE. Defaults to 1.0.
mult (int, optional): The multiplier for the hidden dimension of the feedforward network. Defaults to 4.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Attributes:
dim (int): The input dimension.
hidden_dim (int): The hidden dimension of the feedforward network.
output_dim (int): The output dimension.
num_experts (int): The number of experts in the MoE.
capacity_factor (float): The capacity factor that controls the capacity of the MoE.
mult (int): The multiplier for the hidden dimension of the feedforward network.
experts (nn.ModuleList): The list of feedforward networks representing the experts.
gate (SwitchGate): The switch gate module.
"""
def __init__(
self,
dim: int,
hidden_dim: int,
output_dim: int,
num_experts: int,
capacity_factor: float = 1.0,
mult: int = 4,
use_aux_loss: bool = False,
*args,
**kwargs,
):
super().__init__()
self.dim = dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.num_experts = num_experts
self.capacity_factor = capacity_factor
self.mult = mult
self.use_aux_loss = use_aux_loss
self.experts = nn.ModuleList(
[
FeedForward(dim, dim, mult, *args, **kwargs)
for _ in range(num_experts)
]
)
self.gate = SwitchGate(
dim,
num_experts,
capacity_factor,
)
def forward(self, x: Tensor):
"""
Forward pass of the SwitchMoE module.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor of the MoE.
"""
# (batch_size, seq_len, num_experts)
gate_scores, loss = self.gate(
x, use_aux_loss=self.use_aux_loss
)
# Dispatch to experts
expert_outputs = [expert(x) for expert in self.experts]
# Check if any gate scores are nan and handle
if torch.isnan(gate_scores).any():
print("NaN in gate scores")
gate_scores[torch.isnan(gate_scores)] = 0
# Stack and weight outputs
stacked_expert_outputs = torch.stack(
expert_outputs, dim=-1
) # (batch_size, seq_len, output_dim, num_experts)
if torch.isnan(stacked_expert_outputs).any():
stacked_expert_outputs[
torch.isnan(stacked_expert_outputs)
] = 0
# Combine expert outputs and gating scores
moe_output = torch.sum(
gate_scores.unsqueeze(-2) * stacked_expert_outputs, dim=-1
)
return moe_output, loss
class SwitchTransformerBlock(nn.Module):
"""
SwitchTransformerBlock is a module that represents a single block of the Switch Transformer model.
Args:
dim (int): The input dimension of the block.
heads (int): The number of attention heads.
dim_head (int): The dimension of each attention head.
mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.
dropout (float, optional): The dropout rate. Defaults to 0.1.
depth (int, optional): The number of layers in the block. Defaults to 12.
num_experts (int, optional): The number of experts in the SwitchMoE layer. Defaults to 6.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Attributes:
dim (int): The input dimension of the block.
heads (int): The number of attention heads.
dim_head (int): The dimension of each attention head.
mult (int): The multiplier for the hidden dimension in the feed-forward network.
dropout (float): The dropout rate.
attn_layers (nn.ModuleList): List of MultiQueryAttention layers.
ffn_layers (nn.ModuleList): List of SwitchMoE layers.
Examples:
>>> block = SwitchTransformerBlock(dim=512, heads=8, dim_head=64)
>>> x = torch.randn(1, 10, 512)
>>> out = block(x)
>>> out.shape
"""
def __init__(
self,
dim: int,
heads: int,
dim_head: int,
mult: int = 4,
dropout: float = 0.1,
num_experts: int = 3,
*args,
**kwargs,
):
super().__init__()
self.dim = dim
self.heads = heads
self.dim_head = dim_head
self.mult = mult
self.dropout = dropout
self.attn = MultiQueryAttention(
dim, heads, qk_ln=True * args, **kwargs
)
self.ffn = SwitchMoE(
dim, dim * mult, dim, num_experts, *args, **kwargs
)
self.add_norm = nn.LayerNorm(dim)
def forward(self, x: Tensor):
"""
Forward pass of the SwitchTransformerBlock.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
resi = x
x, _, _ = self.attn(x)
x = x + resi
x = self.add_norm(x)
add_normed = x
##### MoE #####
x, _ = self.ffn(x)
x = x + add_normed
x = self.add_norm(x)
return x
class SwitchTransformer(nn.Module):
"""
SwitchTransformer is a PyTorch module that implements a transformer model with switchable experts.
Args:
num_tokens (int): The number of tokens in the input vocabulary.
dim (int): The dimensionality of the token embeddings and hidden states.
heads (int): The number of attention heads.
dim_head (int, optional): The dimensionality of each attention head. Defaults to 64.
mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.
dropout (float, optional): The dropout rate. Defaults to 0.1.
num_experts (int, optional): The number of experts in the switchable experts mechanism. Defaults to 3.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
num_tokens: int,
dim: int,
heads: int,
dim_head: int = 64,
mult: int = 4,
dropout: float = 0.1,
num_experts: int = 3,
depth: int = 4,
*args,
**kwargs,
):
super().__init__()
self.num_tokens = num_tokens
self.dim = dim
self.heads = heads
self.dim_head = dim_head
self.mult = mult
self.dropout = dropout
self.num_experts = num_experts
self.depth = depth
self.embedding = nn.Embedding(num_tokens, dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
SwitchTransformerBlock(
dim,
heads,
dim_head,
mult,
dropout,
num_experts,
*args,
**kwargs,
)
)
self.to_out = nn.Sequential(
nn.Softmax(dim=-1),
nn.LayerNorm(dim),
nn.Linear(dim, num_tokens),
)
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the SwitchTransformer.
Args:
x (Tensor): The input tensor of shape (batch_size, sequence_length).
Returns:
Tensor: The output tensor of shape (batch_size, sequence_length, num_tokens).
"""
# Embed tokens through embedding layer
x = self.embedding(x)
# Pass through the transformer block with MoE, it's in modulelist
for layer in self.layers:
x = layer(x)
# Project to output tokens
x = self.to_out(x)
return x