Transformer——多头注意力机制(Pytorch)

  1. 原理图

  2. 代码

python 复制代码
import torch
import torch.nn as nn


class Multi_Head_Self_Attention(nn.Module):
    def __init__(self, embed_size, heads):
        super(Multi_Head_Self_Attention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False)

    def forward(self,queries, keys, values, mask):
        N = queries.shape[0]  # batch_size
        query_len = queries.shape[1]  # sequence_length
        key_len = keys.shape[1]  # sequence_length 
        value_len = values.shape[1]  # sequence_length

        queries = self.queries(queries)
        keys = self.keys(keys)
        values = self.values(values)

        # Split the embedding into self.heads pieces
        # batch_size, sequence_length, embed_size(512) --> 
        # batch_size, sequence_length, heads(8), head_dim(64)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        values = values.reshape(N, value_len, self.heads, self.head_dim)

        # batch_size, sequence_length, heads(8), head_dim(64) --> 
        # batch_size, heads(8), sequence_length, head_dim(64)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Scaled dot-product attention
        score = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** (1/2))

        if mask is not None:
            score = score.masked_fill(mask == 0, float("-inf"))
        # batch_size, heads(8), sequence_length, sequence_length
        attention = torch.softmax(score, dim=-1)

        out = torch.matmul(attention, values)
        # batch_size, heads(8), sequence_length, head_dim(64) -->
        # batch_size, sequence_length, heads(8), head_dim(64) -->
        # batch_size, sequence_length, embed_size(512)
        # 为了方便送入后面的网络
        out = out.transpose(1, 2).contiguous().reshape(N, query_len, self.embed_size)
        out = self.fc_out(out)

        return out
    

batch_size = 64
sequence_length = 10
embed_size = 512
heads = 8
mask = None

Q = torch.randn(batch_size, sequence_length, embed_size)  
K = torch.randn(batch_size, sequence_length, embed_size)  
V = torch.randn(batch_size, sequence_length, embed_size)  

model = Multi_Head_Self_Attention(embed_size, heads)
output = model(Q, K, V, mask)
print(output.shape)
相关推荐
千天夜17 分钟前
激活函数解析:神经网络背后的“驱动力”
人工智能·深度学习·神经网络
m0_5236742125 分钟前
技术前沿:从强化学习到Prompt Engineering,业务流程管理的创新之路
人工智能·深度学习·目标检测·机器学习·语言模型·自然语言处理·数据挖掘
小言从不摸鱼2 小时前
【AI大模型】ELMo模型介绍:深度理解语言模型的嵌入艺术
人工智能·深度学习·语言模型·自然语言处理·transformer
铖铖的花嫁12 小时前
基于RNNs(LSTM, GRU)的红点位置检测(pytorch)
pytorch·gru·lstm
python15612 小时前
基于驾驶员面部特征的疲劳检测系统
python·深度学习·目标检测
YRr YRr12 小时前
ubuntu20.04 解决Pytorch默认安装CPU版本的问题
人工智能·pytorch·python
LittroInno13 小时前
Tofu AI视频处理模块视频输入配置方法
人工智能·深度学习·计算机视觉·tofu
代码猪猪傻瓜coding13 小时前
pytorch torch.tile用法
人工智能·pytorch·python
铭瑾熙13 小时前
深度学习之 LSTM
人工智能·深度学习·lstm
sduerfh15 小时前
pytorch3d导入maya相机位姿踩坑
pytorch·3d·maya