理论知识链接:理解Attention:从起源到MHA,MQA和GQA | Linsight
现有模型升级方法:https://blog.nghuyong.top/2023/09/10/NLP/llm-attention/
pytorch代码实现:
class BaseAttention(torch.nn.Module):
def __init__(self):
super(BaseAttention, self).__init__()
self.softmax = torch.nn.Softmax(dim=-1)
def attention(self, q, k, v, mask=None, dropout=None):
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
if mask is not None:
attn = attn + mask
attn = self.softmax(attn)
if dropout is not None:
attn = dropout(attn)
output = torch.matmul(attn, v)
return output
class Attention(BaseAttention):
def __init__(self, hidden_size, dropout=None):
super(Attention, self).__init__()
self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
self.k_proj = torch.nn.Linear(hidden_size, hidden_size)
self.v_proj = torch.nn.Linear(hidden_size, hidden_size)
self.softmax = torch.nn.Softmax(dim=-1)
if dropout is not None:
self.dropout = torch.nn.Dropout(p=dropout)
else:
self.dropout = None
def forward(self, x, mask=None):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
output = self.attention(q, k, v, mask, self.dropout)
return output
class MHAttention(BaseAttention):
def __init__(self, hidden_size, num_heads=32, dropout=None):
super(MHAttention, self).__init__()
self.num_heads = num_heads
self.softmax = torch.nn.Softmax(dim=-1)
self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
self.k_proj = torch.nn.Linear(hidden_size, hidden_size)
self.v_proj = torch.nn.Linear(hidden_size, hidden_size)
if dropout is not None:
self.dropout = torch.nn.Dropout(p=dropout)
def forward(self, x, mask=None):
bs, seq_len, hidden_size = x.shape
q = self.q_proj(x).view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
k = self.k_proj(x).view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
v = self.v_proj(x).view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
output = self.attention(q, k, v, mask, self.dropout)
output = output.view(bs, seq_len, hidden_size)
return output
class MQAttention(BaseAttention):
def __init__(self, hidden_size, num_heads=32, dropout=None):
super(MQAttention, self).__init__()
self.num_heads = num_heads
self.softmax = torch.nn.Softmax(dim=-1)
assert hidden_size % num_heads == 0
self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
self.k_proj = torch.nn.Linear(hidden_size, hidden_size // num_heads)
self.v_proj = torch.nn.Linear(hidden_size, hidden_size // num_heads)
if dropout is not None:
self.dropout = torch.nn.Dropout(p=dropout)
def forward(self, x, mask=None):
bs, seq_len, hidden_size = x.shape
q = self.q_proj(x).view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
k = self.k_proj(x).view(bs, seq_len, -1, hidden_size // self.num_heads).transpose(1, 2)
v = self.v_proj(x).view(bs, seq_len, -1, hidden_size // self.num_heads).transpose(1, 2)
output = self.attention(q, k, v, mask, self.dropout)
output = output.view(bs, seq_len, hidden_size)
return output
class GQAttention(BaseAttention):
def __init__(self, hidden_size, num_heads=32, num_kv_heads=8, dropout=None):
super(GQAttention, self).__init__()
assert hidden_size % num_heads == 0 and num_heads % num_kv_heads == 0
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_group = num_heads // num_kv_heads
self.softmax = torch.nn.Softmax(dim=-1)
self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
self.k_proj = torch.nn.Linear(hidden_size, hidden_size // num_heads * num_kv_heads)
self.v_proj = torch.nn.Linear(hidden_size, hidden_size // num_heads * num_kv_heads)
if dropout is not None:
self.dropout = torch.nn.Dropout(p=dropout)
def repeat_kv(self, feature, num_group): #llama2源码
bs, num_kv_heads, seq_len, head_dims = feature.shape
if num_group == 1:
return feature
feature = feature[:, :, None, :, :].expand(bs, num_kv_heads, num_group, seq_len, head_dims)
return feature.reshape(bs, num_kv_heads * num_group, seq_len, head_dims)
def forward(self, x, mask=None):
bs, seq_len, hidden_size = x.shape
q = self.q_proj(x).view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
k = self.k_proj(x).view(bs, seq_len, -1, hidden_size // self.num_heads).transpose(1, 2)
v = self.v_proj(x).view(bs, seq_len, -1, hidden_size // self.num_heads).transpose(1, 2)
k, v = self.repeat_kv(k, self.num_group), self.repeat_kv(v, self.num_group)
output = self.attention(q, k, v, mask, self.dropout)
output = output.view(bs, seq_len, hidden_size)
return output
model = Attention(hidden_size=4096, dropout=0.1)
model = MHAttention(hidden_size=4096, num_heads=32, dropout=0.1)
model = MQAttention(hidden_size=4096, num_heads=32, dropout=0.1)
model = GQAttention(hidden_size=4096, num_heads=32, num_kv_heads=4, dropout=0.1)
input_data = torch.randn(1, 20, 4096)
output = model(input_data)
print()