multi-head attention 多头注意力实现细节

论文中关于多头注意力的描述

1706.03762

代码实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
 
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
 
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
 
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Q, K, V: (batch_size, num_heads, seq_len, d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        return torch.matmul(attn, V)
 
    def split_heads(self, x, batch_size):
        # x: (batch_size, seq_len, d_model)
        x = x.view(batch_size, -1, self.num_heads, self.d_k)  # (batch_size, seq_len, num_heads, d_k)
        return x.transpose(1, 2)  # (batch_size, num_heads, seq_len, d_k)
 
    def combine_heads(self, x, batch_size):
        # x: (batch_size, num_heads, seq_len, d_k)
        x = x.transpose(1, 2).contiguous()  # (batch_size, seq_len, num_heads, d_k)
        return x.view(batch_size, -1, self.d_model)  # (batch_size, seq_len, d_model)
 
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
 
        Q = self.W_q(Q)  # (batch_size, seq_len, d_model)
        K = self.W_k(K)
        V = self.W_v(V)
 
        Q = self.split_heads(Q, batch_size)  # (batch_size, num_heads, seq_len, d_k)
        K = self.split_heads(K, batch_size)
        V = self.split_heads(V, batch_size)
 
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.combine_heads(attn_output, batch_size)
 
        return self.W_o(output)  # Final linear projection

会发现其实代码和论文不是完全一样的,论文看起来是每个头有单独的W去乘,但是代码里是所有头共用W再拆分。其实两者是等价的。要注意一下,在multi-head attention中,输入是不被拆分的,它的shape一直是L,D_model,拆分的是W,把D_model, D_model的矩阵拆分成K个D_k, D_model的矩阵。

根据矩阵的乘法定义

复制代码
Y = X W = X [W₁  W₂] = [X W₁   X W₂]

乘之前拆分还是乘之后拆分,是一样的。代码用大矩阵来乘,可以加快计算。

相关推荐
Kobebryant-Manba19 小时前
RNN从0实现
pytorch·rnn·深度学习
qingyulee21 小时前
循环神经网络
人工智能·rnn·深度学习
叫我:松哥1 天前
基于机器学习的中文文本抑郁症风险检测系统,包括NLP与传统机器学习的抑郁症识别,准确率92%
人工智能·深度学习·机器学习·自然语言处理·flask·nlp·bootstrap
AI语宙漫游指南1 天前
从 CV 扩散到 NLP:详解 Google DiffusionGemma 架构、推理机制与优劣
深度学习·llm
逻辑君1 天前
认知神经科学研究报告【20260089】
人工智能·深度学习·机器学习
装不满的克莱因瓶1 天前
掌握语义分割经典模型 FCN——从像素分类到端到端分割的奠基之作
人工智能·python·深度学习·算法·机器学习·分类·数据挖掘
chen_zn951 天前
VLA 的 Co-training:通过多源数据提升机器人泛化能力
人工智能·深度学习·具身智能·vla
大模型最新论文速读1 天前
06-15 · LLM 最新论文速览
论文阅读·人工智能·深度学习·自然语言处理
_codemonster1 天前
手语识别损失函数
人工智能·深度学习·机器学习