# -*- coding: utf-8 -*-
"""
Created on Fri Feb 21 15:02:18 2025
@author: chengxf2
"""
import torch.nn as nn
import torch.nn.functional as F
import torch
def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
# 获取 key 的维度大小,用于缩放
d_k = query.size(-1)
# 计算点积注意力得分
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# 如果提供了 mask,将其应用到得分上
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 对得分进行 softmax 操作,得到注意力权重
p_attention = F.softmax(scores, dim=-1)
# 如果提供了 dropout,应用 dropout
if dropout is not None:
p_attention = dropout(p_attention)
# 使用注意力权重对 value 进行加权求和
return torch.matmul(p_attention, value)
class Attention(nn.Module):
def __init__(self, d_model=512,num_heads=8, num_kv_heads=2,dropout=0.5):
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model//num_heads
self.num_kv_heads = num_kv_heads
assert self.num_heads%self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads//self.num_kv_heads
#Linear
self.query = nn.Linear(d_model, self.head_dim * self.num_heads)
self.key = nn.Linear(d_model, self.head_dim * self.num_kv_heads)
self.value = nn.Linear(d_model, self.head_dim * self.num_kv_heads)
#输出
self.proj = nn.Linear(d_model, d_model)
self.attn_dropout = nn.Dropout(dropout)
def forward(self, inputs):
batch, seq_len, d_model = inputs.shape
q = self.query(inputs)
k = self.key(inputs)
v = self.value(inputs)
# shape = (B, seq_len, num_heads, head_dim)
q = q.view(batch, seq_len, -1, self.head_dim)
k = k.view(batch, seq_len, -1 , self.head_dim)
v = v.view(batch, seq_len, -1, self.head_dim)
print("default q.shape",q.shape)
print("default k.shape",k.shape)
print("default v.shape",v.shape)
# Grouped Query Attention
#[batch, seq_len, num_kv_heads, head_dim]->[batch, seq_len, num_heads, head_dim]
if self.num_kv_heads != self.num_heads:
k = torch.repeat_interleave(k, self.num_queries_per_kv, dim=2)
v = torch.repeat_interleave(v, self.num_queries_per_kv, dim=2)
# shape = (B, num_heads, seq_len, head_dim)
k = k.transpose(1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
print("q.shape",q.shape)
print("k.shape",k.shape)
print("v.shape",v.shape)
output = scaled_dot_product_attention(
q,
k,
v, # order impotent
None,
self.attn_dropout,
)
print("v.shape",v.shape)
print("output.shape",output.shape)
output = output.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
# final projection into the residual stream
output = self.proj(output)
return output
net = Attention()
batch_size =2
seq_len = 5
d_model =512
x = torch.randn(batch_size,seq_len, d_model)
net(x)
七 Multi-Head Latent Attention --- (MLA)
Multi-Head Latent Attention (MLA) achieves superior performance than MHA, as well as significantly reduces KV-cache boosting inference efficiency. Instead of reducing KV-heads as in MQA and GQA, MLA jointly compresses the Key and Value into a latent vector.
Low-Rank Key-Value Joint Compression
Instead of caching both the Key and Value matrices, MLA jointly compresses them in a low-rank vector which allows caching fewer items since the compression dimension is much less compared to the output projection matrix dimension in MHA.
Comparison of Deepseek's new Multi-latent head attention with MHA, MQA, and GQA.