【PyTorch][chapter-35][MLA]

前言:

MLA(Multi-head Latent Attention,多头潜在注意力)旨在提高推理效率和降低计算资源的消。MLA的核心思想在于通过信息转移来优化KV缓存的使用

MLA的技术特点主要包括:

  1. KV压缩与潜在变量:将键(Key)和值(Value)联合压缩为低维潜在向量,显著减少推理时的KV缓存,降低内存占用。计算时通过升维恢复原始信息,平衡压缩效率与计算精度。
  2. 低秩降维技术:对查询(Queries)进行低秩压缩(降维后再升维),减少训练中的激活内存(activation memory),但需注意此操作不影响KV缓存。
  3. 动态序列处理:针对可变长度输入序列优化,支持高效处理不同长度的句子(如长文本对话场景 ROPE)。

目录

  1. KV-cache
  2. MLA 模型简介
  3. MLA+ROPE
  4. MLA 数学原理
  5. PyTorh 代码

一 KV-cache

1.1 MHA (多头注意力)

1.2 KV-cache

在自回归生成过程中,每个新生成的token都会依赖于之前所有token的信息,这就需要在生成每个新token时重新计算整个序列的自注意力。然而,这种计算方式非常低效,因为大量重复的计算被浪费在了已经生成过的token上。

为了缩短inference time, KV-Cache机制正是为了解决这一问题而提出的。它的工作原理是在生成过程中,将已经计算过的键和值向量存储在缓存中,这样在生成后续token时,可以直接从缓存中获取之前token的键和值,而不需要重新计算。具体来说,当生成一个新的token时,模型只需要计算这个新token的查询向量,并与缓存中的键向量计算注意力得分,然后使用这些得分和缓存中的值向量来计算新token的输出表示.

KV-Cache 的大小取决于以下参数:

  • : 注意力头数,每层的注意力头数量。

  • : 每个注意力头的维度,每个注意力头的 Key 和 Value 的维度。

  • l: 输入的层数模

则每个token 对应的 KV-cache 为

不同注意力机制对应的kv-cache


二 MLA(Multi-Layer Adaptation)

多头潜在注意力 (MLA) 是一种新的注意力机制,它通过将键和值压缩为一个较小的共享表示(称为潜在向量)来实现这一点。这可以减小 KV 缓存的大小,同时保持甚至提高性能。

MLA 引入了两项关键创新:

  1. Low-Rank Key-Value Compression
  2. Decoupled Rotary Position Embedding (RoPE)

2.1 MLA 架构

2.2 计算流程

参考:

MLA reduces the KV cache size by compressing the keys and values into a smaller latent vector and decoupling the position information (RoPE). Here's how the cache size is calculated.


Decoupled Rotary Position Embedding (RoPE)

旋转位置编码(Rotary Position Embedding, RoPE)是一种用于编码序列中标记位置的技术。然而,RoPE是位置敏感的,这意味着它依赖于每个标记的具体位置。这在使用低秩压缩时会产生问题,因为位置信息会被混入压缩后的键(keys)和值(values)中,导致在推理过程中难以高效地重用它们。为了解决ROPE问题,使用了下面架构

参考:

KV-cache 的大小(包括了ROPE 部分)


PyTorch代码

常用超参数

复制代码
# -*- coding: utf-8 -*-
"""
Created on Sat Mar 15 18:24:47 2025

@author: cxf
"""

# -*- coding: utf-8 -*-
"""
Created on Thu Mar 13 13:51:48 2025

@author: chengxf2
"""

import torch
import torch.nn  as nn
import torch.nn.functional as F
import math

class Config:
    def __init__(self):
        self.vocab_size = 32000
        #词向量的维度
        self.d_model = 1024
        #number of attention heads 
        self.n_heads = 8
        #dDmension of per head =64
        self.d_head = self.d_model//self.n_heads
        #ROPE dimension, typically 128
        self.d_rope =  self.d_head//2
        #compression dimension KV_cache <<n_head*d_h
        self.d_kv_cache = 4*self.d_head 
        self.seq_len = 10
        self.batch_size = 1
        #256
        

class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        #Dimension must be even for Rotary Embedding
        assert dim % 2 == 0, "Dimension must be even for rotary embeddings"
        self.dim = dim//2
        inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim, 2).float() / self.dim))
        self.register_buffer("inv_freq", inv_freq)
    
    def forward(self, seq_len):
        t = torch.arange(seq_len)
        freqs = torch.einsum("i,j->ij",t, self.inv_freq)
        output = torch.cat((freqs, freqs), dim=-1)
        return output

def rotate_half(x):
    """
    Apply rotary embeddings to the first half of x.
    """
    x1 ,x2 = x.chunk(2,dim=-1)
    output = torch.cat((-x2,x1),dim=-1)
    return output

def apply_rotary(x, cos, sin):
    """
    Apply rotary embeddings to the first half of x.
    """
    #x.shape batch_size, seq_len, head, d_h
    # Split x into two parts: one for rotary embeddings and the other untouched    x_rot, x_base = x.split(cos.shape[-1],dim=-1)
    print("\n apply _rotary ",x.shape)
    print("\n cos x ",cos.shape, x.shape)
    x_rot, x_base = x.split(cos.shape[-1],dim=-1)
    x_rot =(x_rot*cos)+(rotate_half(x_rot)*sin)
    output = torch.cat([x_rot,x_base],dim=-1)
    return output



config = Config()
class MemoryOptimizedMLA(nn.Module):
    def __init__(self):
        super().__init__()
        self.d_head = config.d_head
        self.d_split = config.d_model-config.d_rope
        #down-projection
        self.W_DQ =  nn.Linear(config.d_model,  config.d_kv_cache)
        self.W_DKV = nn.Linear(config.d_model,  config.d_kv_cache)
        print("\n kv cache size ",config.d_kv_cache)
        # RoPE
        self.W_q_rope = nn.Linear(config.d_kv_cache,  config.d_rope)
        self.W_k_rope = nn.Linear(config.d_model,     config.d_rope)
        #step2:  Up Projections
        self.W_UQ = nn.Linear(config.d_kv_cache, self.d_split)
        self.W_UK = nn.Linear(config.d_kv_cache, self.d_split)
        self.W_UV = nn.Linear(config.d_kv_cache, config.d_model)  
        #rotary Embedding
        self.rotary = RotaryEmbedding(config.d_rope//config.n_heads)
        #step3 output
        self.output = nn.Linear(config.d_model, config.d_model)
        
    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        print("\n bat_size %d seq_len: %d d_model: %d "%(batch_size, seq_len, d_model))
        #step1: down-projection Compression
        print("\n step1 : down projection")
        #query compression
        q_c      =  self.W_DQ(x)
        kv_cache =  self.W_DKV(x)
        #print("\n kv-cache",kv_cache.shape,"\t q_c",q_c.shape)
        #Apply RoPE
        print("\n step2 : apply ROPE ")
        rotary_emb = self.rotary(seq_len)
        cos = torch.cos(rotary_emb).view(1, seq_len, 1, -1)  
        sin = torch.sin(rotary_emb).view(1, seq_len, 1, -1)
        q_rot =  self.W_q_rope(q_c)
        q_rot = q_rot.view(batch_size, seq_len, config.n_heads, -1)
        q_rot = apply_rotary(q_rot, cos, sin)
        k_rot_cache =   self.W_k_rope(x)
        k_rot_cache =   k_rot_cache.view(batch_size, seq_len, config.n_heads,-1)
        k_rot_cache =   apply_rotary(k_rot_cache,cos, sin)
        #up-projection
        print("\n step3 : up projection ")
        q_base = self.W_UQ(q_c).view(batch_size, seq_len, config.n_heads, -1)
        k = self.W_UK(kv_cache).view(batch_size, seq_len, config.n_heads, -1)
        v = self.W_UV(kv_cache).view(batch_size, seq_len, config.n_heads, -1)
        # concate
        q = torch.cat([q_base, q_rot], dim=-1)
        k = torch.cat([k, k_rot_cache], dim=-1)
        # Attention computation
        scores = torch.einsum("bqhd,bkhd->bhqk", q, k) / math.sqrt(self.d_head)
        attn = F.softmax(scores, dim=-1)
        out = torch.einsum("bhqk,bkhd->bqhd", attn, v)
        out = self.output(out.contiguous().view(batch_size, seq_len, -1))
        output =  out, (kv_cache, k_rot_cache)
        print("\n output ",out.shape)
        return output
    
net= MemoryOptimizedMLA()
x  = torch.randn((config.batch_size, config.seq_len, config.d_model))
out = net(x)

https://medium.com/@shaiknagurshareef/multi-head-latent-attention-mla-secret-behind-the-success-of-deepseek-large-language-models-66612071d756

DeepSeek's Multi-Head Latent Attention - Lior Sinai

https://www.youtube.com/watch?v=s9R5s4U1WH8

https://medium.com/@atulit23/implementing-multi-head-latent-attention-from-scratch-in-python-1e14d03fbc91

相关推荐
大模型铲屎官1 分钟前
从过拟合到强化学习:机器学习核心知识全解析
人工智能·python·机器学习·llm·scikit-learn·强化学习·过拟合
机器学习Zero1 分钟前
自然语言处理(2)—— NLP之百年风雨路
人工智能·自然语言处理
塔能物联运维17 分钟前
塔能科技:做节能界的“催化剂”,加速工厂能源改造变革
大数据·人工智能
全星00722 分钟前
全星研发管理APQP软件系统:助力汽车零部件企业高效研发,打造核心竞争力
大数据·人工智能·汽车
yu_xiaoxian31 分钟前
BEV学习笔记之-LSS 手撕代码
人工智能·自动驾驶
lihuayong38 分钟前
有了大模型为何还需要Agent智能体
人工智能·ai agent·agent 智能体·agent 原理
一个处女座的程序猿O(∩_∩)O43 分钟前
人工智能中神经网络是如何进行学习的
人工智能·神经网络·学习
网安导师小李1 小时前
Android Studio下载及安装和Gradle的配置(非常详细)从零基础入门到精通,看完这一篇就够了
android·运维·ide·人工智能·安全·web安全·android studio
韩曙亮1 小时前
【AI 大模型】RAG 检索增强生成 ⑤ ( 向量数据库 | 向量数据库 索引结构和搜索算法 | 常见 向量数据库 对比 | 安装并使用 向量数据库 chromadb 案例 )
数据库·人工智能·大模型·openai·向量数据库·ai大模型·chromadb
skywalk81631 小时前
使用 PaddleNLP 在 CPU(支持 AVX 指令)下跑通 llama2-7b或DeepSeek-r1:1.5b 模型(完成度80%)
人工智能·python·大模型·paddlenlp