为了缩短inference time, KV-Cache机制正是为了解决这一问题而提出的。它的工作原理是在生成过程中,将已经计算过的键和值向量存储在缓存中,这样在生成后续token时,可以直接从缓存中获取之前token的键和值,而不需要重新计算。具体来说,当生成一个新的token时,模型只需要计算这个新token的查询向量,并与缓存中的键向量计算注意力得分,然后使用这些得分和缓存中的值向量来计算新token的输出表示.
MLA reduces the KV cache size by compressing the keys and values into a smaller latent vector and decoupling the position information (RoPE). Here's how the cache size is calculated.
三 Decoupled Rotary Position Embedding (RoPE)
旋转位置编码(Rotary Position Embedding, RoPE)是一种用于编码序列中标记位置的技术。然而,RoPE是位置敏感的,这意味着它依赖于每个标记的具体位置。这在使用低秩压缩时会产生问题,因为位置信息会被混入压缩后的键(keys)和值(values)中,导致在推理过程中难以高效地重用它们。为了解决ROPE问题,使用了下面架构
参考:
KV-cache 的大小(包括了ROPE 部分)
四 PyTorch代码
常用超参数
复制代码
# -*- coding: utf-8 -*-
"""
Created on Sat Mar 15 18:24:47 2025
@author: cxf
"""
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 13 13:51:48 2025
@author: chengxf2
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class Config:
def __init__(self):
self.vocab_size = 32000
#词向量的维度
self.d_model = 1024
#number of attention heads
self.n_heads = 8
#dDmension of per head =64
self.d_head = self.d_model//self.n_heads
#ROPE dimension, typically 128
self.d_rope = self.d_head//2
#compression dimension KV_cache <<n_head*d_h
self.d_kv_cache = 4*self.d_head
self.seq_len = 10
self.batch_size = 1
#256
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
#Dimension must be even for Rotary Embedding
assert dim % 2 == 0, "Dimension must be even for rotary embeddings"
self.dim = dim//2
inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, seq_len):
t = torch.arange(seq_len)
freqs = torch.einsum("i,j->ij",t, self.inv_freq)
output = torch.cat((freqs, freqs), dim=-1)
return output
def rotate_half(x):
"""
Apply rotary embeddings to the first half of x.
"""
x1 ,x2 = x.chunk(2,dim=-1)
output = torch.cat((-x2,x1),dim=-1)
return output
def apply_rotary(x, cos, sin):
"""
Apply rotary embeddings to the first half of x.
"""
#x.shape batch_size, seq_len, head, d_h
# Split x into two parts: one for rotary embeddings and the other untouched x_rot, x_base = x.split(cos.shape[-1],dim=-1)
print("\n apply _rotary ",x.shape)
print("\n cos x ",cos.shape, x.shape)
x_rot, x_base = x.split(cos.shape[-1],dim=-1)
x_rot =(x_rot*cos)+(rotate_half(x_rot)*sin)
output = torch.cat([x_rot,x_base],dim=-1)
return output
config = Config()
class MemoryOptimizedMLA(nn.Module):
def __init__(self):
super().__init__()
self.d_head = config.d_head
self.d_split = config.d_model-config.d_rope
#down-projection
self.W_DQ = nn.Linear(config.d_model, config.d_kv_cache)
self.W_DKV = nn.Linear(config.d_model, config.d_kv_cache)
print("\n kv cache size ",config.d_kv_cache)
# RoPE
self.W_q_rope = nn.Linear(config.d_kv_cache, config.d_rope)
self.W_k_rope = nn.Linear(config.d_model, config.d_rope)
#step2: Up Projections
self.W_UQ = nn.Linear(config.d_kv_cache, self.d_split)
self.W_UK = nn.Linear(config.d_kv_cache, self.d_split)
self.W_UV = nn.Linear(config.d_kv_cache, config.d_model)
#rotary Embedding
self.rotary = RotaryEmbedding(config.d_rope//config.n_heads)
#step3 output
self.output = nn.Linear(config.d_model, config.d_model)
def forward(self, x):
batch_size, seq_len, d_model = x.shape
print("\n bat_size %d seq_len: %d d_model: %d "%(batch_size, seq_len, d_model))
#step1: down-projection Compression
print("\n step1 : down projection")
#query compression
q_c = self.W_DQ(x)
kv_cache = self.W_DKV(x)
#print("\n kv-cache",kv_cache.shape,"\t q_c",q_c.shape)
#Apply RoPE
print("\n step2 : apply ROPE ")
rotary_emb = self.rotary(seq_len)
cos = torch.cos(rotary_emb).view(1, seq_len, 1, -1)
sin = torch.sin(rotary_emb).view(1, seq_len, 1, -1)
q_rot = self.W_q_rope(q_c)
q_rot = q_rot.view(batch_size, seq_len, config.n_heads, -1)
q_rot = apply_rotary(q_rot, cos, sin)
k_rot_cache = self.W_k_rope(x)
k_rot_cache = k_rot_cache.view(batch_size, seq_len, config.n_heads,-1)
k_rot_cache = apply_rotary(k_rot_cache,cos, sin)
#up-projection
print("\n step3 : up projection ")
q_base = self.W_UQ(q_c).view(batch_size, seq_len, config.n_heads, -1)
k = self.W_UK(kv_cache).view(batch_size, seq_len, config.n_heads, -1)
v = self.W_UV(kv_cache).view(batch_size, seq_len, config.n_heads, -1)
# concate
q = torch.cat([q_base, q_rot], dim=-1)
k = torch.cat([k, k_rot_cache], dim=-1)
# Attention computation
scores = torch.einsum("bqhd,bkhd->bhqk", q, k) / math.sqrt(self.d_head)
attn = F.softmax(scores, dim=-1)
out = torch.einsum("bhqk,bkhd->bqhd", attn, v)
out = self.output(out.contiguous().view(batch_size, seq_len, -1))
output = out, (kv_cache, k_rot_cache)
print("\n output ",out.shape)
return output
net= MemoryOptimizedMLA()
x = torch.randn((config.batch_size, config.seq_len, config.d_model))
out = net(x)