MHA实现

实现时容易出错的地方:

1、pad_mask

复制代码
错误实现:
pad_mask = pad_mask.unsqueeze(1).expand(-1,self.num_heads,-1) # bs,seq -> bs,num_heads,seq
pad_mask=pad_mask.unsqueeze(2).expand(-1,-1,causal_mask.size(2),-1)# bs,num_heads,seq,seq

正确实现:
pad_mask_1 = pad_mask.unsqueeze(-1).float() #bs,num_heads,seq,1
pad_mask_2 = pad_mask.unsqueeze(2).float() # bs,num_heads,1,seq
pad_mask = torch.matmul(pad_mask_1,pad_mask_2)#bs,num_heads,seq,seq

2、计算att_weights

复制代码
错误实现:存在全行都是掩码的情况(字符pad,与其他字符att,都是要掩码掉的),这种情况,直接计算softmax会有问题
att_scores = att_scores.masked_fill(combine_mask.bool()==False,-2**32)
att_weights = F.softmax(att_scores,dim=-1)

正确实现:
att_scores = att_scores.masked_fill(combine_mask.bool()==False,float('-inf'))
att_weights = F.softmax(att_scores,dim=-1)

# 检查是否存在全掩码行,当float('-inf')改成-2**32,也要注意全行为掩码的情况,也要强制把全掩码的行置为0
all_inf_mask = (att_scores == float('-inf')).all(dim=-1, keepdim=True)
att_weights = att_weights.masked_fill(all_inf_mask, 0.0)  # 强制赋0避免NaN

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self,embed_dim,num_heads):
        super().__init__()
        self.embed_dim=embed_dim
        self.num_heads=num_heads
        self.q_proj=nn.Linear(embed_dim,embed_dim)
        self.k_proj=nn.Linear(embed_dim,embed_dim)
        self.v_proj=nn.Linear(embed_dim,embed_dim)
        self.head_dim = embed_dim//num_heads
        self.scale = 1 / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if self.head_dim*num_heads !=embed_dim:
            raise ValueError('embed_dim cannot be evenly divided by num_heads')
    def split_head(self,x):
        bs,seq,_ = x.size()
        return x.view(bs,seq,self.num_heads,self.head_dim).transpose(1,2)
    def combine_head(self,x):
        bs,_,seq,_ = x.size()
        return x.transpose(1,2).contiguous().view(bs,seq,self.head_dim*self.num_heads)
    def causal_mask(self,x):
        bs,seq,_=x.size()
        mask=torch.tril(torch.ones(bs,self.num_heads,seq,seq))
        return mask
    def forward(self,input_x,pad_mask):
        Q=self.q_proj(input_x) # bt,seq,dim
        K=self.k_proj(input_x)
        V=self.v_proj(input_x)

        q=self.split_head(Q) #bt,num_heads,seq,dim
        k=self.split_head(K)
        v=self.split_head(V)
        pad_mask = pad_mask.unsqueeze(1).expand(-1,self.num_heads,-1).float() # bs,seq -> bs,num_heads,seq

        att_scores = torch.matmul(q,k.transpose(2,3))*self.scale

        causal_mask = self.causal_mask(input_x)
        if pad_mask is not None:
            # pad_mask=pad_mask.unsqueeze(2).expand(-1,-1,causal_mask.size(2),-1)
            pad_mask_1 = pad_mask.unsqueeze(-1) #bs,num_heads,seq,1
            pad_mask_2 = pad_mask.unsqueeze(2) # bs,num_heads,1,seq
            pad_mask = torch.matmul(pad_mask_1,pad_mask_2)
            combine_mask = causal_mask.bool() & pad_mask.bool()
        att_scores = att_scores.masked_fill(combine_mask.bool()==False,float('-inf'))
        att_weights = F.softmax(att_scores,dim=-1)

        # 检查是否存在全掩码行,当float('-inf')改成-2**32,也要注意全行为掩码的情况,也要强制把全掩码的行置为0
        all_inf_mask = (att_scores == float('-inf')).all(dim=-1, keepdim=True)
        att_weights = att_weights.masked_fill(all_inf_mask, 0.0)  # 强制赋0避免NaN

        output = torch.matmul(att_weights,v)
        output = self.combine_head(output)
        return output

if __name__ == '__main__':
    # 参数设置
    embed_dim = 8
    num_heads = 2
    batch_size = 2
    seq_len = 3

    # 创建模型
    model = MultiHeadAttention(embed_dim, num_heads)

    # 随机输入张量
    input_x = torch.randn(batch_size, seq_len, embed_dim)

    # 填充掩码(假设第二个样本的最后一个位置是填充)
    pad_mask = torch.tensor([
        [True, False, False],  # 第一个样本填充2
        [True, True, False]  # 第二个样本第三个位置是填充
    ],dtype=torch.bool)

    # 扩展维度并取反
    expanded_mask = pad_mask.unsqueeze(-1)
    mask = ~expanded_mask

    # 应用 masked_fill
    input_x = input_x.masked_fill(mask, 0)

    # 前向传播
    print('input_x:',input_x)
    print('pad_mask',pad_mask)
    output = model(input_x, pad_mask)
    print(output)

att_weights输出:

tensor(\[\[\[1.0000, 0.0000, 0.0000,

0.0000, 0.0000, 0.0000,

0.0000, 0.0000, 0.0000],

\[1.0000, 0.0000, 0.0000,

0.0000, 0.0000, 0.0000,

0.0000, 0.0000, 0.0000]],

\[\[1.0000, 0.0000, 0.0000,

0.7197, 0.2803, 0.0000,

0.0000, 0.0000, 0.0000],

\[1.0000, 0.0000, 0.0000,

0.6969, 0.3031, 0.0000,

0.0000, 0.0000, 0.0000]]], grad_fn=<MaskedFillBackward0>)

output:

tensor([[[ 0.1226, 0.1810, -0.7282, -0.5597, -0.2594, -0.9893, 0.0600,

-0.1083],

[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,

0.0000],

[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,

0.0000]],

[[-0.8089, 0.5351, -0.2352, -0.7241, 0.9083, -1.1411, 0.3962,

0.4717],

[-0.5592, -0.0366, -0.0783, -0.4553, 0.5360, -0.5977, 0.1961,

0.0568],

[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,

0.0000]]], grad_fn=<ViewBackward0>)

相关推荐
老毛肚3 小时前
jeecg-boot-base-core 02 day
javascript·python
yaoxin5211234 小时前
434. Java 日期时间 API - Period 基于日期的时间段
java·开发语言·python
DreamLife☼4 小时前
OpenBCI-脑机接口在康复医疗中的应用
深度学习·cnn·脑电·康复·fes·openbci·外骨骼
岁月宁静5 小时前
RAG 文档摄入全链路,从原理到生产落地
vue.js·人工智能·python
硅谷秋水5 小时前
面向长上下文自动驾驶的规划对齐Token压缩
人工智能·深度学习·机器学习·计算机视觉·自动驾驶
JaydenAI5 小时前
[对比学习LangChain和MAF-07]如何引入人机交互的审批流程
python·ai·langchain·c#·agent·hitl·maf
郭泽斌之心5 小时前
MQL5 EA 怎么和外部程序通信?文件三件套协议:参数热更新不重启、状态心跳、远程触发
人工智能·经验分享·深度学习·ea·fay数字人·easydeal
AI人工智能+5 小时前
智能文档抽取系统以专业的文档解析底座和大模型智能语义理解能力为核心,洞察文档的语义内涵与逻辑结构
深度学习·自然语言处理·ocr·文档抽取
nap-joker5 小时前
用于转录组信息精确肿瘤学和药物机制分析的多模态可解释深度学习
人工智能·深度学习·药物敏感性·多层级生物网络·细胞异质性·可解释性多模态
神奇元创5 小时前
商用级光路加速卡:大模型推理的极速落地方案
python·神经网络·fpga开发·dsp开发