Transformer——多头注意力机制(Pytorch)

  1. 原理图

  2. 代码

python 复制代码
import torch
import torch.nn as nn


class Multi_Head_Self_Attention(nn.Module):
    def __init__(self, embed_size, heads):
        super(Multi_Head_Self_Attention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False)

    def forward(self,queries, keys, values, mask):
        N = queries.shape[0]  # batch_size
        query_len = queries.shape[1]  # sequence_length
        key_len = keys.shape[1]  # sequence_length 
        value_len = values.shape[1]  # sequence_length

        queries = self.queries(queries)
        keys = self.keys(keys)
        values = self.values(values)

        # Split the embedding into self.heads pieces
        # batch_size, sequence_length, embed_size(512) --> 
        # batch_size, sequence_length, heads(8), head_dim(64)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        values = values.reshape(N, value_len, self.heads, self.head_dim)

        # batch_size, sequence_length, heads(8), head_dim(64) --> 
        # batch_size, heads(8), sequence_length, head_dim(64)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Scaled dot-product attention
        score = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** (1/2))

        if mask is not None:
            score = score.masked_fill(mask == 0, float("-inf"))
        # batch_size, heads(8), sequence_length, sequence_length
        attention = torch.softmax(score, dim=-1)

        out = torch.matmul(attention, values)
        # batch_size, heads(8), sequence_length, head_dim(64) -->
        # batch_size, sequence_length, heads(8), head_dim(64) -->
        # batch_size, sequence_length, embed_size(512)
        # 为了方便送入后面的网络
        out = out.transpose(1, 2).contiguous().reshape(N, query_len, self.embed_size)
        out = self.fc_out(out)

        return out
    

batch_size = 64
sequence_length = 10
embed_size = 512
heads = 8
mask = None

Q = torch.randn(batch_size, sequence_length, embed_size)  
K = torch.randn(batch_size, sequence_length, embed_size)  
V = torch.randn(batch_size, sequence_length, embed_size)  

model = Multi_Head_Self_Attention(embed_size, heads)
output = model(Q, K, V, mask)
print(output.shape)
相关推荐
狮子座明仔4 小时前
阶跃星辰重磅发布:32B参数模型如何实现“深度研究“自动化?
人工智能·深度学习·自动化
Yan-英杰5 小时前
2025 AI数据准备:EasyLink让多模态非结构化数据处理变简单
人工智能·深度学习·神经网络·机器学习·ai·大模型
棒棒的皮皮6 小时前
【深度学习】YOLO-Python基础认知与算法演进
python·深度学习·yolo·计算机视觉
人工智能培训7 小时前
10分钟了解向量数据库(1)
人工智能·深度学习·算法·机器学习·大模型·智能体搭建
不吃香菜的鱼7 小时前
PyTorch-CUDA-v2.9镜像自动混合精度训练配置指南
pytorch·cuda·自动混合精度
新职语7 小时前
打造个人AI实验室:低成本使用PyTorch-CUDA-v2.8云实例
pytorch·cuda·云实例
小程故事多_807 小时前
从零吃透PyTorch,最易懂的入门全指南
人工智能·pytorch·python
大叔and小萝莉7 小时前
PyTorch-v2.8新特性解析:性能提升背后的秘密
pytorch· torch.compile· 性能优化
lifetime‵(+﹏+)′7 小时前
5060显卡Windows配置Anaconda中的CUDA及Pytorch
人工智能·pytorch·windows
老鱼说AI7 小时前
万字长文警告!一次性搞定GAN(生成对抗网络):从浅入深原理级精析 + PyTorch代码逐行讲解实现
人工智能·深度学习·神经网络·生成对抗网络·计算机视觉·ai作画·超分辨率重建