大模型推理——MLA实现方案

1.整体流程

先上一张图来整体理解下MLA的计算过程

2.实现代码

python 复制代码
import math
import torch
import torch.nn as nn


# rms归一化
class RMSNorm(nn.Module):
    """

    """
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        hidden_states = hidden_states.float()
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.float()


def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed


# 旋转位置编码
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=1024):
        super(RotaryEmbedding, self).__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_seq_len).float().unsqueeze(1)
        freqs = t @ inv_freq.unsqueeze(0)
        freqs = torch.cat((freqs, freqs), dim=-1)

        self.register_buffer("cos_cached", freqs.cos())
        self.register_buffer("sin_cached", freqs.sin())

    def forward(self, q, k):
        cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)
        sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
        return apply_rotate_pos_emb(q, k, cos, sin)


class MLA(nn.Module):
    def __init__(self,
                 dim,
                 n_heads,
                 q_lora_rank,
                 kv_lora_rank,
                 qk_nope_head_dim,
                 qk_rope_head_dim,
                 v_head_dim,
                 max_seq_len,
                 max_batch_size,
                 mode):
        super().__init__()
        self.dim = dim  # 隐藏层维度
        self.n_heads = n_heads  # 总头数
        self.q_lora_rank = q_lora_rank  # q低秩压缩到的维度
        self.kv_lora_rank = kv_lora_rank  # k/v低秩压缩到的维度
        self.qk_nope_head_dim = qk_nope_head_dim    # q/k不带旋转位置编码的维度
        self.qk_rope_head_dim = qk_rope_head_dim    # q/k带旋转位置编码的维度
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim  # q/k的总维度,不带旋转位置编码的维度加上带旋转位置编码的维度
        self.v_head_dim = v_head_dim  # value的维度,等于不带旋转位置编码的k维度
        self.mode = mode
        self.max_seq_len = max_seq_len
        self.max_batch_size = max_batch_size

        self.wq_a = nn.Linear(self.dim, self.q_lora_rank)  # q的降维矩阵
        self.q_norm = RMSNorm(self.q_lora_rank)
        self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)  # q的升维矩阵
        # 4096*128+128*4864 = 524,288 + 622592 = 1146880    4096*4864 = 19,922,944

        self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)  # k/v的降维矩阵
        # nn.Linear(self.dim, self.kv_lora_rank)
        # nn.Linear(self.dim, self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))  # k/v的升维矩阵

        self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)

        self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)  # 旋转位置编码
        # 没有矩阵融合
        if self.mode == 'naive':
            self.register_buffer('k_cache',
                                 torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.qk_head_dim),
                                 persistent=False)
            self.register_buffer('v_cache',
                                 torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.v_head_dim),
                                 persistent=False)
        # 有矩阵融合
        else:
            self.register_buffer('kv_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank),
                                 persistent=False)
            self.register_buffer('pe_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim),
                                 persistent=False)

    def forward(self, x, mask=None):

        bs, seq_len, _ = x.shape

        q = self.wq_a(x)  # [bs, seq_len, q_lora_rank]
        q = self.q_norm(q)  # [bs, seq_len, q_lora_rank]
        q = self.wq_b(q)  # [bs, seq_len, n_heads * qk_head_dim]
        q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim)  # [bs, seq_len, n_heads, qk_head_dim]
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
                                   dim=-1)  # q_nope shape:[bs, seq_len, n_heads, qk_nope_head_dim] q_pe shape:[bs, seq_len, n_heads, qk_rope_head_dim]

        kv = self.wkv_a(x)  # [bs, seq_len, kv_lora_rank + qk_rope_head_dim]
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim],
                               dim=-1)  # kv shape:[bs, seq_len, kv_lora_rank] k_pe shape:[bs, seq_len, qk_rope_head_dim]

        k_pe = k_pe.unsqueeze(2)  # k_pe shape:[bs, seq_len, 1, qk_rope_head_dim]   一层共享一个key
        q_pe, k_pe = self.rotary_emb(q_pe, k_pe)
        if self.mode == 'naive':

            q = torch.cat([q_nope, q_pe], dim=-1)  # * [bs, seq_len, n_heads, qk_head_dim]

            kv = self.kv_norm(kv)  # [bs, seq_len, kv_lora_rank)]
            kv = self.wkv_b(kv)  # [bs, seq_len, n_heads * (qk_nope_head_dim + v_head_dim)]
            kv = kv.view(bs, seq_len, self.n_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)
            # k shape:[bs, seq_len, n_heads, qk_head_dim]
            self.k_cache[:bs, :seq_len, :, :] = k
            self.v_cache[:bs, :seq_len, :, :] = v
            # scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bs, :seq_len]) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
            scores = torch.matmul(q.transpose(1, 2),
                                  self.k_cache[:bs, :seq_len, :, :].transpose(1, 2).transpose(2, 3) / math.sqrt(
                                      self.qk_nope_head_dim + self.qk_rope_head_dim))
            scores = scores.transpose(1, 2)

        else:
            k_pe = k_pe.squeeze(2)
            wkv_b = self.wkv_b.weight  # [n_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
            wkv_b = wkv_b.view(self.n_heads, -1,
                               self.kv_lora_rank)  # [n_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope,
                                  wkv_b[:, :self.qk_nope_head_dim])  # q_nope shape:[bs, seq_len, n_heads, kv_lora_rank]
            # q*k(T) = x*wq*(c*wkv_b[:, :self.qk_nope_head_dim])(T) = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T)    c为压缩后的k/v
            # wq*wkv_b[:, :self.qk_nope_head_dim](T)作为q的投影矩阵  c可以替代原先的k,这样就可以直接使用压缩后的k/v计算注意力了,kv_cache时也只需存储压缩后的k/v
            kv = self.kv_norm(kv)
            self.kv_cache[:bs, :seq_len, :] = kv  # kv shape:[bs, seq_len, kv_lora_rank]
            self.pe_cache[:bs, :seq_len, :] = k_pe  # k_pe shape:[bs, seq_len, qk_rope_head_dim]
            scores_nope = torch.einsum("bshc,btc->bsht", q_nope,
                                       self.kv_cache[:bs, :seq_len, :])  # bshc btc -> bshc bct -> bsht
            scores_pe = torch.einsum("bshr,btr->bsht", q_pe,
                                     self.pe_cache[:bs, :seq_len, :])  # bshr btr -> bshr bt1r -> bshr bthr -> bsht
            scores = (scores_nope + scores_pe) / math.sqrt(
                self.qk_nope_head_dim + self.qk_rope_head_dim)  # [bs, seq_len, n_heads, seq_len]

        if mask is not None:
            # mask shape:[bs, seq_len, seq_len]
            scores += mask.unsqueeze(2)

        scores = scores.softmax(dim=-1)

        if self.mode == 'naive':
            x = torch.einsum("bsht,bthd->bshd", scores,
                             self.v_cache[:bs, :seq_len])  # bsht,bthd -> bhst, bhtd -> bhsd -> bshd
        else:

            # scores * v = scores * c * wkv_b[:, -self.v_head_dim:]
            x = torch.einsum("bsht,btc->bshc", scores,
                             self.kv_cache[:bs, :seq_len])  # x shape:[bs, seq_len, n_heads, kv_lora_rank]
            x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])  # bshc, hdc -> bshc,dch -> bsdh -> bshd

        x = x.contiguous().view(bs, seq_len, -1)
        x = self.wo(x) 

        return x


if __name__ == '__main__':
    torch.manual_seed(0)
    torch.set_printoptions(precision=3, sci_mode=False)

    x = torch.randn(1, 4, 16)

    dim = 16
    n_heads = 2
    q_lora_rank = 10
    kv_lora_rank = 6
    qk_nope_head_dim = 8
    qk_rope_head_dim = 4
    v_head_dim = 8
    max_seq_len = 10
    max_batch_size = 4
    mode = 'none'

    mla = MLA(dim=dim,
              n_heads=n_heads,
              q_lora_rank=q_lora_rank,
              kv_lora_rank=kv_lora_rank,
              qk_nope_head_dim=qk_nope_head_dim,
              qk_rope_head_dim=qk_rope_head_dim,
              v_head_dim=v_head_dim,
              max_seq_len=max_seq_len,
              max_batch_size=max_batch_size,
              mode=mode)

    print(mla(x))
    print(mla.kv_cache)

参考资料:

https://zhuanlan.zhihu.com/p/16730036197

https://github.com/wyf3/llm_related/tree/main/deepseek_learn

相关推荐
Tumiz2 分钟前
尝试一下,交互式的三维计算python库,py3d
python·3d·数据可视化·旋转·三维
一 铭11 分钟前
Onyx(DAnswer)总体实现架构
人工智能·语言模型·大模型·llm
荷碧TongZJ27 分钟前
Jupyter Notebook 6/7 设置代码补全
ide·python·jupyter
LeeZhao@42 分钟前
【AIGC魔童】DeepSeek v3提示词Prompt书写技巧
人工智能·语言模型·自然语言处理·面试·prompt·aigc
nangonghen1 小时前
user、assistant、system三大角色在大语言模型中的作用(通俗解释)
人工智能·语言模型·大模型
√尖尖角↑1 小时前
力扣——【104. 二叉树的最大深度】
python·算法·leetcode
小涵3 小时前
Python和JavaScript在字符串比较上的差异
开发语言·javascript·python
运维小文4 小时前
python文本处理-基础篇
开发语言·python·正则表达式
GodGump8 小时前
DeepSeek开源多模态大模型Janus-Pro部署
python
test猿10 小时前
深度学习 - 神经网络的原理
人工智能·深度学习·神经网络