multi-head attention 多头注意力实现细节

论文中关于多头注意力的描述

1706.03762

代码实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
 
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
 
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
 
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Q, K, V: (batch_size, num_heads, seq_len, d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        return torch.matmul(attn, V)
 
    def split_heads(self, x, batch_size):
        # x: (batch_size, seq_len, d_model)
        x = x.view(batch_size, -1, self.num_heads, self.d_k)  # (batch_size, seq_len, num_heads, d_k)
        return x.transpose(1, 2)  # (batch_size, num_heads, seq_len, d_k)
 
    def combine_heads(self, x, batch_size):
        # x: (batch_size, num_heads, seq_len, d_k)
        x = x.transpose(1, 2).contiguous()  # (batch_size, seq_len, num_heads, d_k)
        return x.view(batch_size, -1, self.d_model)  # (batch_size, seq_len, d_model)
 
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
 
        Q = self.W_q(Q)  # (batch_size, seq_len, d_model)
        K = self.W_k(K)
        V = self.W_v(V)
 
        Q = self.split_heads(Q, batch_size)  # (batch_size, num_heads, seq_len, d_k)
        K = self.split_heads(K, batch_size)
        V = self.split_heads(V, batch_size)
 
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.combine_heads(attn_output, batch_size)
 
        return self.W_o(output)  # Final linear projection

会发现其实代码和论文不是完全一样的,论文看起来是每个头有单独的W去乘,但是代码里是所有头共用W再拆分。其实两者是等价的。要注意一下,在multi-head attention中,输入是不被拆分的,它的shape一直是[L,D_model],拆分的是W,把[D_model, D_model]的矩阵拆分成K个[D_k, D_model]的矩阵。

根据矩阵的乘法定义

复制代码
Y = X W = X [W₁  W₂] = [X W₁   X W₂]

乘之前拆分还是乘之后拆分,是一样的。代码用大矩阵来乘,可以加快计算。

相关推荐
User_芊芊君子22 分钟前
【分布式训练】CANN SHMEM跨设备内存通信库:构建高效多机多卡训练的关键组件
分布式·深度学习·神经网络·wpf
聆风吟º32 分钟前
CANN算子开发:ops-nn神经网络算子库的技术解析与实战应用
人工智能·深度学习·神经网络·cann
觉醒大王33 分钟前
强女思维:着急,是贪欲外显的相。
java·论文阅读·笔记·深度学习·学习·自然语言处理·学习方法
笔画人生39 分钟前
# 探索 CANN 生态:深入解析 `ops-transformer` 项目
人工智能·深度学习·transformer
灰灰勇闯IT44 分钟前
领域制胜——CANN 领域加速库(ascend-transformer-boost)的场景化优化
人工智能·深度学习·transformer
小白狮ww1 小时前
要给 OCR 装个脑子吗?DeepSeek-OCR 2 让文档不再只是扫描
人工智能·深度学习·机器学习·ocr·cpu·gpu·deepseek
island13141 小时前
CANN GE(图引擎)深度解析:计算图优化管线、内存静态规划与异构任务的 Stream 调度机制
开发语言·人工智能·深度学习·神经网络
艾莉丝努力练剑1 小时前
深度学习视觉任务:如何基于ops-cv定制图像预处理流程
人工智能·深度学习
禁默1 小时前
大模型推理的“氮气加速系统”:全景解读 Ascend Transformer Boost (ATB)
人工智能·深度学习·transformer·cann
User_芊芊君子1 小时前
CANN大模型加速核心ops-transformer全面解析:Transformer架构算子的高性能实现与优化
人工智能·深度学习·transformer