多头自注意力机制(Multi-Head Attention, MHA)
通过并行计算多个注意力头,使模型能够同时关注输入序列中不同位置的特征。其核心思想是将输入映射到多个子空间,分别计算注意力权重并聚合结果,从而增强模型对复杂模式的捕捉能力。
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, heads):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size # 嵌入维度大小
self.heads = heads # 头的数量
self.head_dim = embed_size // heads # 每个头的维度
# 确保嵌入维度可以被头的数量整除
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
# 定义线性变换层,用于将输入转换为查询、键和值
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
# 最终输出的全连接层
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
N = query.shape[0] # 批量大小
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# 将嵌入向量分割成多个头
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
# 对每个头进行线性变换
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
# 计算点积注意力分数
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# queries shape: (N, query_len, heads, head_dim)
# keys shape: (N, key_len, heads, head_dim)
# energy shape: (N, heads, query_len, key_len)
# 应用掩码(如果存在)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# 对注意力分数应用softmax函数
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
# 计算加权求和后的输出
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
# attention shape: (N, heads, query_len, key_len)
# values shape: (N, value_len, heads, heads_dim)
# out after matrix multiply: (N, query_len, heads, head_dim), then flatten last two dimensions
# 通过全连接层得到最终输出
out = self.fc_out(out)
return out
# 示例用法:
if __name__ == "__main__":
embed_size = 512 # 嵌入维度大小
heads = 8 # 头的数量
batch_size = 64 # 批量大小
seq_length = 10 # 序列长度
mha = MultiHeadAttention(embed_size, heads)
values = keys = query = torch.randn(batch_size, seq_length, embed_size)
print(mha.forward(values, keys, query, mask=None).shape)
多查询注意力机制(Multi-Query Attention,MQA)
Multi-Query Attention (MQA) 是对多头注意力(MHA)的高效改进版本,其核心思想是共享键(Key)和值(Value)的投影参数,仅对查询(Query)使用独立的头参数。这种方法显著减少了模型参数量和计算复杂度,同时保留了多头注意力的部分并行性优势。
python
import torch
import torch.nn as nn
from thop import profile
class MultiQueryAttention(nn.Module):
def __init__(self, hidden_size, num_heads, dropout=0.0):
"""
Multi-Query Attention 的实现。
Args:
hidden_size (int): 输入特征的维度,也即 hidden_state 的最后一维。
num_heads (int): 注意力头的数量。
dropout (float): dropout 的概率,默认为 0.0。
"""
super(MultiQueryAttention, self).__init__()
assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads # 每个头的维度
# 定义线性变换层,用于生成 Q, K, V
self.query = nn.Linear(hidden_size, hidden_size) # 每个头独立的 Query
self.key = nn.Linear(hidden_size, self.head_dim) # 所有头共享的 Key
self.value = nn.Linear(hidden_size, self.head_dim) # 所有头共享的 Value
self.dropout = nn.Dropout(dropout)
self.out_projection = nn.Linear(hidden_size, hidden_size)
def forward(self, query, key, value, mask=None):
N = query.shape[0] # 批量大小
query_len = query.shape[1]
key_len = key.shape[1]
value_len = value.shape[1]
# 将嵌入向量分割成多个头
query = self.query(query).view(N, query_len, self.num_heads, self.head_dim)
key = self.key(key).view(N, key_len, self.num_heads, self.head_dim)
value = self.value(value).view(N, value_len, self.num_heads, self.head_dim)
# 调整形状以便进行点积操作
query = query.transpose(1, 2) # shape: (N, heads, query_len, head_dim)
key = key.transpose(1, 2) # shape: (N, heads, key_len, head_dim)
value = value.transpose(1, 2) # shape: (N, heads, value_len, head_dim)
# 计算点积注意力分数
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))
# scores shape: (N, heads, query_len, key_len)
# 应用掩码(如果存在)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# 对注意力分数应用softmax函数
attention_weights = torch.softmax(scores, dim=-1)
# attention_weights shape: (N, heads, query_len, key_len)
# 计算加权求和后的输出
context = torch.matmul(attention_weights, value)
# context shape: (N, heads, query_len, head_dim)
# 调整形状以便通过全连接层
context = context.transpose(1, 2).contiguous().view(N, query_len, self.hidden_size)
# context shape: (N, query_len, hidden_size)
# 通过全连接层得到最终输出
output = self.out_projection(context)
# output shape: (N, query_len, hidden_size)
return output
# 示例用法:
if __name__ == "__main__":
hidden_size = 512 # 嵌入维度大小
num_heads = 8 # 头的数量
batch_size = 64 # 批量大小
seq_length = 10 # 序列长度
mqa = MultiQueryAttention(hidden_size, num_heads)
query = key = value = torch.randn(batch_size, seq_length, hidden_size)
print(mqa.forward(query, key, value, mask=None).shape)
分组查询注意力机制(Grouped Query Attention,GQA)
Grouped Query Attention (GQA) 是对多头注意力(MHA)和多查询注意力(MQA)的折中优化方案。其核心思想是将查询头(Query Heads)划分为多个组(Group),每组内的查询头共享一组键(Key)和值(Value),从而在保留多头并行性的同时减少参数量和计算复杂度。GQA 在参数效率与模型性能之间取得了平衡,适用于大规模模型的高效部署。
python
import torch
import torch.nn as nn
class GroupedQueryAttention(nn.Module):
def __init__(self, hidden_size, num_heads, group_size=2, dropout=0.0):
"""
Grouped Query Attention 实现。
Args:
hidden_size (int): 输入特征的维度。
num_heads (int): 查询头的数量。
group_size (int): 每个组中包含的查询头数量。
dropout (float): dropout 的概率。
"""
super(GroupedQueryAttention, self).__init__()
assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
assert num_heads % group_size == 0, "num_heads 必须能被 group_size 整除"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.group_size = group_size
self.group_num = num_heads // group_size
self.head_dim = hidden_size // num_heads # 每个头的维度
# 定义线性变换层,用于生成 Q, K, V
self.query = nn.Linear(hidden_size, hidden_size) # 每个头独立的 Query
self.key = nn.Linear(hidden_size, self.group_num * self.head_dim) # 分组共享的 Key
self.value = nn.Linear(hidden_size, self.group_num * self.head_dim) # 分组共享的 Value
self.dropout = nn.Dropout(dropout)
self.out_projection = nn.Linear(hidden_size, hidden_size)
def forward(self, query, key, value, mask=None):
N = query.shape[0] # 批量大小
query_len = query.shape[1]
key_len = key.shape[1]
value_len = value.shape[1]
# 将嵌入向量分割成多个头
query = self.query(query).view(N, query_len, self.num_heads, self.head_dim)
key = self.key(key).view(N, key_len, self.group_num, self.head_dim)
value = self.value(value).view(N, value_len, self.group_num, self.head_dim)
# 调整形状以便进行点积操作
query = query.transpose(1, 2) # shape: (N, heads, query_len, head_dim)
key = key.transpose(1, 2) # shape: (N, group_num, key_len, head_dim)
value = value.transpose(1, 2) # shape: (N, group_num, value_len, head_dim)
# 计算点积注意力分数
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))
# scores shape: (N, heads, query_len, key_len)
# 应用掩码(如果存在)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# 对注意力分数应用softmax函数
attention_weights = torch.softmax(scores, dim=-1)
# attention_weights shape: (N, heads, query_len, key_len)
# 计算加权求和后的输出
context = torch.matmul(attention_weights, value)
# context shape: (N, heads, query_len, head_dim)
# 合并组内的头
context = context.view(N, query_len, self.num_heads, self.head_dim)
context = context.permute(0, 2, 1, 3).contiguous().view(N, query_len, self.hidden_size)
# context shape: (N, query_len, hidden_size)
# 通过全连接层得到最终输出
output = self.out_projection(context)
# output shape: (N, query_len, hidden_size)
return output
# 示例用法:
if __name__ == "__main__":
hidden_size = 512 # 嵌入维度大小
num_heads = 8 # 头的数量
group_size = 2 # 每个组中包含的查询头数量
batch_size = 64 # 批量大小
seq_length = 10 # 序列长度
gqa = GroupedQueryAttention(hidden_size, num_heads, group_size)
query = key = value = torch.randn(batch_size, seq_length, hidden_size)
print(gqa.forward(query, key, value, mask=None).shape)
多头潜在注意力(Multi-Head Latent Attention, MLA)
Multi-Head Latent Attention (MLA) 是一种结合低秩参数化与旋转位置编码(RoPE)的高效注意力机制。其核心思想是通过低秩投影压缩查询(Q)、键(K)、值(V)的维度,并在注意力计算中解耦内容与位置信息,从而减少计算复杂度,同时保留长距离依赖建模能力。MLA 特别适用于大规模模型的部署,平衡了效率与性能。
python
import torch
import torch.nn as nn
import math
class RotaryEmbedding(nn.Module):
def __init__(self, hidden_size, num_heads, base=10000, max_len=512):
"""
RoPE位置编码模块
Args:
hidden_size (int): 模型维度
num_heads (int): 注意力头数量
base (int): 频率基值
max_len (int): 最大序列长度
"""
super().__init__()
self.head_dim = hidden_size // num_heads
self.hidden_size = hidden_size
self.num_heads = num_heads
self.base = base
self.max_len = max_len
self.cos_pos_cache, self.sin_pos_cache = self._compute_pos_emb()
def _compute_pos_emb(self):
# 计算频率因子
theta_i = 1. / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
# 创建位置索引
positions = torch.arange(self.max_len)
# 计算位置编码
pos_emb = positions.unsqueeze(1) * theta_i.unsqueeze(0)
# 计算cos和sin的位置编码
cos_pos = pos_emb.sin().repeat_interleave(2, dim=-1)
sin_pos = pos_emb.cos().repeat_interleave(2, dim=-1)
return cos_pos, sin_pos
def forward(self, x, seq_len=None):
if seq_len is None:
seq_len = x.size(1)
# 获取缓存中的位置编码
cos_pos = self.cos_pos_cache[:seq_len].unsqueeze(1)
sin_pos = self.sin_pos_cache[:seq_len].unsqueeze(1)
# 分割x为奇数和偶数部分
x1, x2 = x.chunk(2, dim=-1)
# 应用旋转位置编码
rotated_x1 = x1 * cos_pos - x2 * sin_pos
rotated_x2 = x1 * sin_pos + x2 * cos_pos
# 合并结果
rotated_x = torch.cat([rotated_x1, rotated_x2], dim=-1)
return rotated_x
# 示例用法:
if __name__ == "__main__":
hidden_size = 512 # 模型维度
num_heads = 8 # 注意力头数量
batch_size = 64 # 批量大小
seq_length = 10 # 序列长度
rope = RotaryEmbedding(hidden_size, num_heads)
x = torch.randn(batch_size, seq_length, hidden_size)
print(rope(x).shape)