多头注意力机制
一个简单的多头注意力模块可以分解为以下几个步骤:
- 先不分多头,对输入张量分别做变换,得到 Q , K , V Q,K,V Q,K,V
- 对得到的 Q , K , V Q,K,V Q,K,V按头的个数进行split;
- 用 Q , K Q,K Q,K计算向量点积
- 考虑是否要添因果mask
- 利softmax计算注意力得分矩阵atten
- 对注意力得分矩阵施加Dropout
- 将atten矩阵和 V V V矩阵相乘
- 再过一道最终的输出变换
代码
给出一个 d k = d v = d m o d e l d_k=d_v=d_{model} dk=dv=dmodel的多头注意力实现如下:
python
class MHA(nn.Module):
def __init__(self,C_in,dmodel,num_head=8,p_drop=0.2):
super(MHA, self).__init__()
self.QW=nn.Linear(C_in,dmodel)
self.KW=nn.Linear(C_in,dmodel)
self.VW=nn.Linear(C_in,dmodel)
self.dp=nn.Dropout(p_drop)
self.W_concat=nn.Linear(dmodel,dmodel)
self.n_head=num_head
self.p_drop=p_drop
self.depth=dmodel//num_head
def forward(self,X,casual=True):
B,L,C=X.shape
Q=self.QW(X)
K=self.KW(X)
V=self.VW(X)
Q=Q.reshape((B,L,self.n_head,-1)).permute(0,2,1,3)
K=K.reshape((B,L,self.n_head,-1)).permute(0,2,1,3)
V=V.reshape((B,L,self.n_head,-1)).permute(0,2,1,3)
atten=Q.matmul(K.transpose(2,3))
if casual:
mask=torch.triu(torch.ones(L,L))
atten=torch.where(mask==1,atten,torch.ones_like(atten)*(-2**32+1))
atten=torch.softmax(atten,dim=-1)
atten=self.dp(atten)
out=torch.matmul(atten,V)/self.depth**(1/2)
out=out.permute(0,2,1,3).reshape(B,L,-1)
out=self.W_concat(out)
return out
if __name__=="__main__":
input=torch.rand(10,5,3)
model=MHA(3,64,4)
res=model(input)