对PosWiseFFN的改进: MoE、PKM、UltraMem

先从PosWiseFFN说起

python 复制代码
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.GeLU(),
            nn.Linear(d_ff, d_model, bias=False))

    def forward(self, inputs):                                  # inputs: [batch_size, seq_len, d_model]
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model)(output + residual)  # [batch_size, seq_len, d_model]

如果Attention的维度是d_model,通常PosWiseFFN模型结构就是2个矩阵中间加个Gelu,d_ff是d_model的4倍:第1个矩阵的weight是[d_model, 4*d_model],第2个矩阵的的weight是[4*d_model, d_model]。

PosWiseFFN这个结构也可以理解成一种qkv查询的思路,如果第1个矩阵理解成key,第二矩阵理解成value ,那么输入就是[batch_size, seq_len, d_model]的input作为query先去和key做矩阵乘法,得到一个[batch_size, seq_len, 4*d_model]的dots,这个dots过了GeLU后再去和[4*d_model, d_model]的第二个矩阵相乘,这一步变向取了前d_model重要的结果。问题来了,能不能把 4*d_model的d_ff给变得更大呢 ?Figure 1来自Large Memory Layers with Product Keys的Figure1,图里的|K|在PosWiseFFN里就是 4*d_model。

下面的PKM简单来说就是把这种qkv查询的思路借用PQ的思想给改进了

PKM(Product Key Memory,这个Product其实就是Product Quantization的Product)

在Large Memory Layers with Product Keys的Figure1里,q的shape是[...,d_model],k的shape是[d_model, |K|],下面看Figure2里怎么解决|K|过大的问题?图里把d_model维的q劈成q1和q2,q1和q2的维度分别是d_model/2;同样的,把[d_model, |K|]的keys劈成[d_model/2, |K|]的sub-key set 1(下图里不带'的 c 1 c_1 c1, c 2 c_2 c2, c 3 c_3 c3)和[d_model/2, |K|]的sub-key set 2(下图里带'的 c 1 ′ c^{'}_1 c1′, c 2 ′ c^{'}_2 c2′, c 3 ′ c^{'}_3 c3′)。这样两半都出topk,最后从 k 2 k^2 k2里再选出k个,这就是Product Quantization的思想

代码赏析

代码来自https://github.com/lucidrains/product-key-memory/tree/master,里面einops用的不错,下面给一些注释:

python 复制代码
class PKM(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        num_keys = 128,
        topk = 32,
        dim_head = 128,
        input_dropout = 0.,
        query_dropout = 0.,
        value_dropout = 0.,
        attn_dropout = 0.,
        use_layernorm = True,
        pre_layernorm = False,
        differentiable_topk = False,
        concat_values_and_combine = False,
        norm_output = False,
        non_competitive_gates = False # Csordas et al. claims non-competitive gates work even better
    ):
        super().__init__()
        self.topk = topk
        self.heads = heads
        self.num_keys = num_keys
        dim_query = dim_head * heads * 2
        self.to_queries = nn.Linear(dim, dim_query, bias = False)

        # pre-layernorm pattern
        self.pre_layernorm = nn.LayerNorm(dim) if pre_layernorm else nn.Identity()

        # batchnorm would break causality
        self.use_layernorm = use_layernorm

        if use_layernorm:
            self.norm = nn.LayerNorm(dim_head)
        else:
            self.norm = MaskedBatchNorm1D(nn.BatchNorm1d(dim_head))

        # keys
        self.keys = nn.Parameter(torch.zeros(heads, num_keys, 2, dim_head))
        init_(self.keys)

        # values
        self.concat_values_and_combine = concat_values_and_combine
        if concat_values_and_combine:
            values = nn.Embedding(num_keys ** 2, dim_head)

            self.values = nn.Sequential(
                values,
                Reduce('b (h k) d -> b h d', 'sum', h = heads),
                Rearrange('b n d -> b (n d)'),
                nn.Linear(dim_head * heads, dim, bias = False)
            )
        else:
            values = nn.EmbeddingBag(num_keys ** 2, dim, mode = 'sum')
            self.values = values
        init_(values.weight)

        # dropouts
        self.input_dropout = nn.Dropout(input_dropout)
        self.query_dropout = nn.Dropout(query_dropout)
        self.value_dropout = nn.Dropout(value_dropout)
        self.attn_dropout = nn.Dropout(attn_dropout)

        # non competitive gates
        self.gate_activation = nn.Softmax(dim = -1) if not non_competitive_gates else nn.ReLU()
        # use a differentiable topk, based on coordinate descent
        self.differentiable_topk = differentiable_topk
        # https://arxiv.org/abs/2302.06461
        # claims to boost performance of softmax key / value networks by simply layernorming the output
        self.output_norm = nn.LayerNorm(dim) if norm_output else nn.Identity()

    def forward(
        self,
        x,
        input_mask = None,
        gumbel_noise_scale = 0.,
        **kwargs
    ):
        b, t, h = *x.shape[:2], self.heads

        x = self.pre_layernorm(x)
        x = self.input_dropout(x)

        queries = self.to_queries(x)

        #写一下queries的shape: b=batch_size, t=target_seq_len, p=partition, h=num_heads, d=head_dim
        queries = rearrange(queries, 'b t (p h d) -> (b p h) t d', p = 2, h = h)

        # norm and dropout queries
        norm_kwargs = dict(mask = input_mask) if not self.use_layernorm else dict()
        queries = self.norm(queries, **norm_kwargs)
        queries = self.query_dropout(queries)

        queries = rearrange(queries, '(b p h) t d -> p b t h d', p = 2, h = h)

        # similarity to keys
        # keys.shape:heads, num_keys, 2, dim_head。这里的n是keys的batch_size
        # 这里的keys本质上是一个二维数组
        dots = einsum('p b t h d, h n p d -> b t h p n', queries, self.keys)

        # gumbel noise
        if gumbel_noise_scale > 0.:
            dots = dots + gumbel_noise(dots) * gumbel_noise_scale

        # topk scores
        if self.differentiable_topk:
            scores, indices, *_ = coor_descent_topk(dots, k = self.topk, fused = True)
        else:
            scores, indices = dots.topk(k = self.topk, dim = -1)
        # scores are factorized
        (scores_x, scores_y), (indices_x, indices_y) = map(lambda t: t.chunk(2, dim = 3), (scores, indices))

        all_topk = self.topk ** 2

        all_scores = rearrange((
            rearrange(scores_x, '... k -> ... k 1') +
            rearrange(scores_y, '... k -> ... 1 k')
        ), 'b t h ... -> b t h (...)')

        all_indices = rearrange((
            rearrange(indices_x, '... k -> ... k 1') * self.num_keys +
            rearrange(indices_y, '... k -> ... 1 k')
        ), 'b t h ... -> b t h (...)')

        final_topk, final_indices = all_scores.topk(self.topk, dim=-1)
        value_indices = all_indices.gather(-1, final_indices)

        # attention

        attn = self.gate_activation(final_topk)
        attn = self.attn_dropout(attn)

        value_indices, attn = map(lambda t: rearrange(t, 'b t h k -> (b t) (h k)'), (value_indices, attn))

        # aggregate

        if self.concat_values_and_combine:
            out = self.values(value_indices)
        else:
            out = self.values(value_indices, per_sample_weights = attn)

        out = self.value_dropout(out)

        # maybe layernorm the output

        out = self.output_norm(out)

        return rearrange(out, '(b t) d -> b t d', b = b)

UltraMem

来自ULTRA-SPARSE MEMORY NETWORK,字节发这个时候吹"有效解决了MoE推理时高额的访存问题,推理速度较MoE架构提升2-6倍,推理成本最高可降低83%",猛地一看以为把DeepSeekMoE又给提升了2-6倍,可实际上是下面这个MoE的paper。UltraMem的思路实际上是对PKM思路的一种改进,但字节并没有公布源代码,也不知道他们家的智障豆包用上了没,先摘录一些核心想法,等代码出了再仔细拜读。

为了解决drawback1和drawback3,把PQ改成了下面的TDQKR,一种基于SVD分解的方法:

MoE

这个MoE不同于MoE架构LLM中的MoE,而是对PosWiseFFN的改进,来自于Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity,以下是论文中的截图,看一眼就知道大致的思路:

附录:

  1. https://mp.weixin.qq.com/s/BPGbzAQ5AKPj7yqrOCCuGQ?token=2117558689\&lang=zh_CN
  2. https://team.doubao.com/zh/publication/ultra-sparse-memory-network?view_from=research
  3. https://www.cls.cn/detail/1940788
相关推荐
不学习怎么给老板打工?10 分钟前
yolov8使用
深度学习
J先生x13 分钟前
【开源项目】基于sherpa-onnx的实时语音识别系统 - LiveASR
人工智能·语音识别
火星资讯30 分钟前
“兴火·燎原”总冠军诞生,云宏信息《金融高算力轻量云平台》登顶
人工智能·科技
AI产品备案34 分钟前
算法备案类型解析:如何判断你的算法属于哪种类型?
深度学习·安全
whaosoft-1431 小时前
51c自动驾驶~合集37
人工智能
小技工丨1 小时前
详解大语言模型生态系统概念:lama,llama.cpp,HuggingFace 模型 ,GGUF,MLX,lm-studio,ollama这都是什么?
人工智能·语言模型·llama
陈奕昆1 小时前
大模型微调之LLaMA-Factory 系列教程大纲
人工智能·llama·大模型微调·llama-factory
上海云盾商务经理杨杨1 小时前
AI如何重塑DDoS防护行业?六大变革与未来展望
人工智能·安全·web安全·ddos
lanboAI1 小时前
基于卷积神经网络的蔬菜水果识别系统,resnet50,mobilenet模型【pytorch框架+python源码】
pytorch·python·cnn
苯酸氨酰糖化物1 小时前
计算机毕业设计--基于深度学习(U-Net与多尺度ViT)的车牌模糊图像修复算法设计与实现(含Github代码+Web端在线体验界面)
深度学习·算法·课程设计