算法手撕面经系列(1)--手撕多头注意力机制

多头注意力机制

一个简单的多头注意力模块可以分解为以下几个步骤:

  1. 先不分多头,对输入张量分别做变换,得到 Q , K , V Q,K,V Q,K,V
  2. 对得到的 Q , K , V Q,K,V Q,K,V按头的个数进行split;
  3. 用 Q , K Q,K Q,K计算向量点积
  4. 考虑是否要添因果mask
  5. 利softmax计算注意力得分矩阵atten
  6. 对注意力得分矩阵施加Dropout
  7. 将atten矩阵和 V V V矩阵相乘
  8. 再过一道最终的输出变换

代码

给出一个 d k = d v = d m o d e l d_k=d_v=d_{model} dk=dv=dmodel的多头注意力实现如下:

python 复制代码
class MHA(nn.Module):
    def __init__(self,C_in,dmodel,num_head=8,p_drop=0.2):
        super(MHA, self).__init__()

        self.QW=nn.Linear(C_in,dmodel)
        self.KW=nn.Linear(C_in,dmodel)
        self.VW=nn.Linear(C_in,dmodel)

        self.dp=nn.Dropout(p_drop)

        self.W_concat=nn.Linear(dmodel,dmodel)

        self.n_head=num_head
        self.p_drop=p_drop
        self.depth=dmodel//num_head

    def forward(self,X,casual=True):
        B,L,C=X.shape
        Q=self.QW(X)
        K=self.KW(X)
        V=self.VW(X)

        Q=Q.reshape((B,L,self.n_head,-1)).permute(0,2,1,3)
        K=K.reshape((B,L,self.n_head,-1)).permute(0,2,1,3)
        V=V.reshape((B,L,self.n_head,-1)).permute(0,2,1,3)

        atten=Q.matmul(K.transpose(2,3))

        if casual:
            mask=torch.triu(torch.ones(L,L))
            atten=torch.where(mask==1,atten,torch.ones_like(atten)*(-2**32+1))
        atten=torch.softmax(atten,dim=-1)

        atten=self.dp(atten)

        out=torch.matmul(atten,V)/self.depth**(1/2)

        out=out.permute(0,2,1,3).reshape(B,L,-1)
        out=self.W_concat(out)

        return out


if __name__=="__main__":
    input=torch.rand(10,5,3)
    model=MHA(3,64,4)
    res=model(input)
相关推荐
哥布林学者4 分钟前
吴恩达深度学习课程一:神经网络和深度学习 第三周:浅层神经网络(三)
深度学习·ai
MIXLLRED12 分钟前
YOLO学习——训练进阶和预测评价指标
深度学习·学习·yolo
草莓熊Lotso18 分钟前
《C++ Web 自动化测试实战:常用函数全解析与场景化应用指南》
前端·c++·python·dubbo
叼菠萝19 分钟前
AI 应用开发三剑客系列:LangChain 如何撑起 LLM 应用开发基石?
python·langchain
程序员小远27 分钟前
软件测试之压力测试详解
自动化测试·软件测试·python·测试工具·职场和发展·测试用例·压力测试
CheungChunChiu29 分钟前
AI 模型部署体系全景:从 PyTorch 到 RKNN 的嵌入式类比解析
人工智能·pytorch·python·模型
Scc_hy36 分钟前
强化学习_Paper_2000_Eligibility Traces for Off-Policy Policy Evaluation
人工智能·深度学习·算法·强化学习·rl
来酱何人38 分钟前
低资源NLP数据处理:少样本/零样本场景下数据增强与迁移学习结合方案
人工智能·深度学习·分类·nlp·bert
王彦臻40 分钟前
YOLOv3 技术总结
深度学习·yolo·目标跟踪
小小测试开发1 小时前
Python SQLAlchemy:告别原生 SQL,用 ORM 优雅操作数据库
数据库·python·sql·sqlalchemy