MHA实现

实现时容易出错的地方:

1、pad_mask

复制代码
错误实现:
pad_mask = pad_mask.unsqueeze(1).expand(-1,self.num_heads,-1) # bs,seq -> bs,num_heads,seq
pad_mask=pad_mask.unsqueeze(2).expand(-1,-1,causal_mask.size(2),-1)# bs,num_heads,seq,seq

正确实现:
pad_mask_1 = pad_mask.unsqueeze(-1).float() #bs,num_heads,seq,1
pad_mask_2 = pad_mask.unsqueeze(2).float() # bs,num_heads,1,seq
pad_mask = torch.matmul(pad_mask_1,pad_mask_2)#bs,num_heads,seq,seq

2、计算att_weights

复制代码
错误实现:存在全行都是掩码的情况(字符pad,与其他字符att,都是要掩码掉的),这种情况,直接计算softmax会有问题
att_scores = att_scores.masked_fill(combine_mask.bool()==False,-2**32)
att_weights = F.softmax(att_scores,dim=-1)

正确实现:
att_scores = att_scores.masked_fill(combine_mask.bool()==False,float('-inf'))
att_weights = F.softmax(att_scores,dim=-1)

# 检查是否存在全掩码行,当float('-inf')改成-2**32,也要注意全行为掩码的情况,也要强制把全掩码的行置为0
all_inf_mask = (att_scores == float('-inf')).all(dim=-1, keepdim=True)
att_weights = att_weights.masked_fill(all_inf_mask, 0.0)  # 强制赋0避免NaN

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self,embed_dim,num_heads):
        super().__init__()
        self.embed_dim=embed_dim
        self.num_heads=num_heads
        self.q_proj=nn.Linear(embed_dim,embed_dim)
        self.k_proj=nn.Linear(embed_dim,embed_dim)
        self.v_proj=nn.Linear(embed_dim,embed_dim)
        self.head_dim = embed_dim//num_heads
        self.scale = 1 / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if self.head_dim*num_heads !=embed_dim:
            raise ValueError('embed_dim cannot be evenly divided by num_heads')
    def split_head(self,x):
        bs,seq,_ = x.size()
        return x.view(bs,seq,self.num_heads,self.head_dim).transpose(1,2)
    def combine_head(self,x):
        bs,_,seq,_ = x.size()
        return x.transpose(1,2).contiguous().view(bs,seq,self.head_dim*self.num_heads)
    def causal_mask(self,x):
        bs,seq,_=x.size()
        mask=torch.tril(torch.ones(bs,self.num_heads,seq,seq))
        return mask
    def forward(self,input_x,pad_mask):
        Q=self.q_proj(input_x) # bt,seq,dim
        K=self.k_proj(input_x)
        V=self.v_proj(input_x)

        q=self.split_head(Q) #bt,num_heads,seq,dim
        k=self.split_head(K)
        v=self.split_head(V)
        pad_mask = pad_mask.unsqueeze(1).expand(-1,self.num_heads,-1).float() # bs,seq -> bs,num_heads,seq

        att_scores = torch.matmul(q,k.transpose(2,3))*self.scale

        causal_mask = self.causal_mask(input_x)
        if pad_mask is not None:
            # pad_mask=pad_mask.unsqueeze(2).expand(-1,-1,causal_mask.size(2),-1)
            pad_mask_1 = pad_mask.unsqueeze(-1) #bs,num_heads,seq,1
            pad_mask_2 = pad_mask.unsqueeze(2) # bs,num_heads,1,seq
            pad_mask = torch.matmul(pad_mask_1,pad_mask_2)
            combine_mask = causal_mask.bool() & pad_mask.bool()
        att_scores = att_scores.masked_fill(combine_mask.bool()==False,float('-inf'))
        att_weights = F.softmax(att_scores,dim=-1)

        # 检查是否存在全掩码行,当float('-inf')改成-2**32,也要注意全行为掩码的情况,也要强制把全掩码的行置为0
        all_inf_mask = (att_scores == float('-inf')).all(dim=-1, keepdim=True)
        att_weights = att_weights.masked_fill(all_inf_mask, 0.0)  # 强制赋0避免NaN

        output = torch.matmul(att_weights,v)
        output = self.combine_head(output)
        return output

if __name__ == '__main__':
    # 参数设置
    embed_dim = 8
    num_heads = 2
    batch_size = 2
    seq_len = 3

    # 创建模型
    model = MultiHeadAttention(embed_dim, num_heads)

    # 随机输入张量
    input_x = torch.randn(batch_size, seq_len, embed_dim)

    # 填充掩码(假设第二个样本的最后一个位置是填充)
    pad_mask = torch.tensor([
        [True, False, False],  # 第一个样本填充2
        [True, True, False]  # 第二个样本第三个位置是填充
    ],dtype=torch.bool)

    # 扩展维度并取反
    expanded_mask = pad_mask.unsqueeze(-1)
    mask = ~expanded_mask

    # 应用 masked_fill
    input_x = input_x.masked_fill(mask, 0)

    # 前向传播
    print('input_x:',input_x)
    print('pad_mask',pad_mask)
    output = model(input_x, pad_mask)
    print(output)

att_weights输出:

tensor([[[[1.0000, 0.0000, 0.0000],

0.0000, 0.0000, 0.0000\], \[0.0000, 0.0000, 0.0000\]\], \[\[1.0000, 0.0000, 0.0000\], \[0.0000, 0.0000, 0.0000\], \[0.0000, 0.0000, 0.0000\]\]\], \[\[\[1.0000, 0.0000, 0.0000\], \[0.7197, 0.2803, 0.0000\], \[0.0000, 0.0000, 0.0000\]\], \[\[1.0000, 0.0000, 0.0000\], \[0.6969, 0.3031, 0.0000\], \[0.0000, 0.0000, 0.0000\]\]\]\], grad_fn=\)

output:

tensor([[[ 0.1226, 0.1810, -0.7282, -0.5597, -0.2594, -0.9893, 0.0600,

-0.1083],

0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000\], \[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000\]\], \[\[-0.8089, 0.5351, -0.2352, -0.7241, 0.9083, -1.1411, 0.3962, 0.4717\], \[-0.5592, -0.0366, -0.0783, -0.4553, 0.5360, -0.5977, 0.1961, 0.0568\], \[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000\]\]\], grad_fn=\)

相关推荐
AndrewHZ5 分钟前
【图像处理基石】GIS图像处理入门:4个核心算法与Python实现(附完整代码)
图像处理·python·算法·计算机视觉·gis·cv·地理信息系统
帮帮志37 分钟前
目录【系列文章目录】-(关于帮帮志,关于作者)
java·开发语言·python·链表·交互
二王一个今2 小时前
Python打包成exe(windows)或者app(mac)
开发语言·python·macos
一勺菠萝丶2 小时前
Mac 上用 Homebrew 安装 JDK 8(适配 zsh 终端)完整教程
java·python·macos
C嘎嘎嵌入式开发7 小时前
(2)100天python从入门到拿捏
开发语言·python
Stanford_11067 小时前
如何利用Python进行数据分析与可视化的具体操作指南
开发语言·c++·python·微信小程序·微信公众平台·twitter·微信开放平台
white-persist9 小时前
Python实例方法与Python类的构造方法全解析
开发语言·前端·python·原型模式
Java 码农9 小时前
Centos7 maven 安装
java·python·centos·maven
格林威9 小时前
常规线扫描镜头有哪些类型?能做什么?
人工智能·深度学习·数码相机·算法·计算机视觉·视觉检测·工业镜头
lyx33136967599 小时前
#深度学习基础:神经网络基础与PyTorch
pytorch·深度学习·神经网络·参数初始化