MHA实现

实现时容易出错的地方:

1、pad_mask

复制代码
错误实现:
pad_mask = pad_mask.unsqueeze(1).expand(-1,self.num_heads,-1) # bs,seq -> bs,num_heads,seq
pad_mask=pad_mask.unsqueeze(2).expand(-1,-1,causal_mask.size(2),-1)# bs,num_heads,seq,seq

正确实现:
pad_mask_1 = pad_mask.unsqueeze(-1).float() #bs,num_heads,seq,1
pad_mask_2 = pad_mask.unsqueeze(2).float() # bs,num_heads,1,seq
pad_mask = torch.matmul(pad_mask_1,pad_mask_2)#bs,num_heads,seq,seq

2、计算att_weights

复制代码
错误实现:存在全行都是掩码的情况(字符pad,与其他字符att,都是要掩码掉的),这种情况,直接计算softmax会有问题
att_scores = att_scores.masked_fill(combine_mask.bool()==False,-2**32)
att_weights = F.softmax(att_scores,dim=-1)

正确实现:
att_scores = att_scores.masked_fill(combine_mask.bool()==False,float('-inf'))
att_weights = F.softmax(att_scores,dim=-1)

# 检查是否存在全掩码行,当float('-inf')改成-2**32,也要注意全行为掩码的情况,也要强制把全掩码的行置为0
all_inf_mask = (att_scores == float('-inf')).all(dim=-1, keepdim=True)
att_weights = att_weights.masked_fill(all_inf_mask, 0.0)  # 强制赋0避免NaN

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self,embed_dim,num_heads):
        super().__init__()
        self.embed_dim=embed_dim
        self.num_heads=num_heads
        self.q_proj=nn.Linear(embed_dim,embed_dim)
        self.k_proj=nn.Linear(embed_dim,embed_dim)
        self.v_proj=nn.Linear(embed_dim,embed_dim)
        self.head_dim = embed_dim//num_heads
        self.scale = 1 / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if self.head_dim*num_heads !=embed_dim:
            raise ValueError('embed_dim cannot be evenly divided by num_heads')
    def split_head(self,x):
        bs,seq,_ = x.size()
        return x.view(bs,seq,self.num_heads,self.head_dim).transpose(1,2)
    def combine_head(self,x):
        bs,_,seq,_ = x.size()
        return x.transpose(1,2).contiguous().view(bs,seq,self.head_dim*self.num_heads)
    def causal_mask(self,x):
        bs,seq,_=x.size()
        mask=torch.tril(torch.ones(bs,self.num_heads,seq,seq))
        return mask
    def forward(self,input_x,pad_mask):
        Q=self.q_proj(input_x) # bt,seq,dim
        K=self.k_proj(input_x)
        V=self.v_proj(input_x)

        q=self.split_head(Q) #bt,num_heads,seq,dim
        k=self.split_head(K)
        v=self.split_head(V)
        pad_mask = pad_mask.unsqueeze(1).expand(-1,self.num_heads,-1).float() # bs,seq -> bs,num_heads,seq

        att_scores = torch.matmul(q,k.transpose(2,3))*self.scale

        causal_mask = self.causal_mask(input_x)
        if pad_mask is not None:
            # pad_mask=pad_mask.unsqueeze(2).expand(-1,-1,causal_mask.size(2),-1)
            pad_mask_1 = pad_mask.unsqueeze(-1) #bs,num_heads,seq,1
            pad_mask_2 = pad_mask.unsqueeze(2) # bs,num_heads,1,seq
            pad_mask = torch.matmul(pad_mask_1,pad_mask_2)
            combine_mask = causal_mask.bool() & pad_mask.bool()
        att_scores = att_scores.masked_fill(combine_mask.bool()==False,float('-inf'))
        att_weights = F.softmax(att_scores,dim=-1)

        # 检查是否存在全掩码行,当float('-inf')改成-2**32,也要注意全行为掩码的情况,也要强制把全掩码的行置为0
        all_inf_mask = (att_scores == float('-inf')).all(dim=-1, keepdim=True)
        att_weights = att_weights.masked_fill(all_inf_mask, 0.0)  # 强制赋0避免NaN

        output = torch.matmul(att_weights,v)
        output = self.combine_head(output)
        return output

if __name__ == '__main__':
    # 参数设置
    embed_dim = 8
    num_heads = 2
    batch_size = 2
    seq_len = 3

    # 创建模型
    model = MultiHeadAttention(embed_dim, num_heads)

    # 随机输入张量
    input_x = torch.randn(batch_size, seq_len, embed_dim)

    # 填充掩码(假设第二个样本的最后一个位置是填充)
    pad_mask = torch.tensor([
        [True, False, False],  # 第一个样本填充2
        [True, True, False]  # 第二个样本第三个位置是填充
    ],dtype=torch.bool)

    # 扩展维度并取反
    expanded_mask = pad_mask.unsqueeze(-1)
    mask = ~expanded_mask

    # 应用 masked_fill
    input_x = input_x.masked_fill(mask, 0)

    # 前向传播
    print('input_x:',input_x)
    print('pad_mask',pad_mask)
    output = model(input_x, pad_mask)
    print(output)

att_weights输出:

tensor([[[[1.0000, 0.0000, 0.0000],

0.0000, 0.0000, 0.0000\], \[0.0000, 0.0000, 0.0000\]\], \[\[1.0000, 0.0000, 0.0000\], \[0.0000, 0.0000, 0.0000\], \[0.0000, 0.0000, 0.0000\]\]\], \[\[\[1.0000, 0.0000, 0.0000\], \[0.7197, 0.2803, 0.0000\], \[0.0000, 0.0000, 0.0000\]\], \[\[1.0000, 0.0000, 0.0000\], \[0.6969, 0.3031, 0.0000\], \[0.0000, 0.0000, 0.0000\]\]\]\], grad_fn=\)

output:

tensor([[[ 0.1226, 0.1810, -0.7282, -0.5597, -0.2594, -0.9893, 0.0600,

-0.1083],

0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000\], \[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000\]\], \[\[-0.8089, 0.5351, -0.2352, -0.7241, 0.9083, -1.1411, 0.3962, 0.4717\], \[-0.5592, -0.0366, -0.0783, -0.4553, 0.5360, -0.5977, 0.1961, 0.0568\], \[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000\]\]\], grad_fn=\)

相关推荐
坐吃山猪几秒前
Python多环境管理指南
开发语言·python
伊织code4 分钟前
PyTorch API 4 - 分布式通信、分布式张量
pytorch·python·ai·api·-·4·分布式通信、分布式张量
小声读源码8 分钟前
【部署】win10的wsl环境下调试dify的api后端服务
vscode·python·docker·uv·dify·remote-ssh·pyenv
Code_流苏32 分钟前
《Python星球日记》 第54天:卷积神经网络进阶
python·cnn·数据增强·图像分类·alexnet·lenet-5·vgg
suyukangchen34 分钟前
深入理解 Linux 阻塞IO与Socket数据结构
linux·数据结构·python
泡芙萝莉酱36 分钟前
各省份发电量数据(2005-2022年)-社科数据
大数据·人工智能·深度学习·数据挖掘·数据分析·毕业论文·数据统计
sword devil9001 小时前
Python实用工具:pdf转doc
python·pdf
zhuiQiuMX1 小时前
笔试阶段性心得总结
java·python
未名编程2 小时前
【Flask开发踩坑实录】pip 安装报错:“No matching distribution found” 的根本原因及解决方案!
python·flask·pip
q567315233 小时前
Node.js数据抓取技术实战示例
爬虫·python·scrapy·node.js