如何具体理解Self Attention中的Q、K、V以及计算过程

如何具体理解Self Attention中的Q、K、V以及计算过程

一、计算过程理解

1、我们直接用torch实现一个 S e l f A t t e n t i o n Self Attention SelfAttention:

首先定义三个线性变换矩阵, q u e r y , k e y , v a l u e query, key, value query,key,value:

python 复制代码
class BertSelfAttention(nn.Module):
    self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768

注意,这里的 q u e r y , k e y , v a l u e query, key, value query,key,value只是一种操作(线性变换)的名称,实际的 Q / K / V Q/K/V Q/K/V是这三个线性操作的输出,三个变换的输入都是 768 768 768维,输出都是 768 768 768维,也就是三个线性变换矩阵的维度都为 ( 768 , 768 ) (768, 768) (768,768)。

2、假设三种操作的输入都是同一个矩阵,这里暂且定为长度为 6 6 6的句子,每个 t o k e n token token的特征维度是 768 768 768,那么输入就是 ( 6 , 768 ) (6, 768) (6,768),每一行就是一个字的词向量,像这样:


图1 输入词向量矩阵

乘以上面代码中的三种线性变换操作就得到 了 Q / K / V 了Q/K/V 了Q/K/V 三个矩阵,他们的维度为 ( 6 , 768 ) ∗ ( 768 , 768 ) = ( 6 , 768 ) (6, 768)*(768,768) = (6,768) (6,768)∗(768,768)=(6,768),维度其实没变,即此刻的 Q / K / V Q/K/V Q/K/V分别为:


图2 输入词向量矩阵与线性变换矩阵相乘输出Q、K、V矩阵

代码为:

python 复制代码
class BertSelfAttention(nn.Module):
    def __init__(self, config):
        self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    
    def forward(self,hidden_states): # hidden_states 维度是(6, 768)
        Q = self.query(hidden_states)
        K = self.key(hidden_states)
        V = self.value(hidden_states)

3、计算Self Attention

Attention ( Q , K , V ) = Softmax ( Q K T d k ) V \text{Attention}(Q,K,V)=\text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=Softmax(dk QKT)V

(1) 首先是 Q Q Q和 K K K矩阵相乘, ( 6 , 768 ) × ( 6 , 768 ) T = ( 6 , 6 ) (6, 768)×(6, 768)^T=(6, 6) (6,768)×(6,768)T=(6,6),如图3:


图3 Q和K矩阵相乘的结果

具体的计算过程,首先用 Q Q Q的第一行,即"我"字的768特征和 K K K中"我"字的768为特征点乘求和,得到输出矩阵 ( 0 , 0 ) (0,0) (0,0)位置的数值,这个数值就代表了"我想吃酸菜鱼"中"我"字对"我"字的注意力权重。最终输出矩阵的第一行就是"我"字对"我想吃酸菜鱼"里面每个字的注意力权重。整个输出矩阵就是"我想吃酸菜鱼"里面每个字对其它字(包括自己)的注意力权重。

(2) 然后是除以 d k \sqrt{d_k} dk ,这个 d i m dim dim就是 768 768 768。

1)至于为什么要除以这个数值?主要是为了缩小点积范围,确保 s o f t m a x softmax softmax梯度稳定性。

2)为什么要 S o f t m a x Softmax Softmax?主要是为了保证注意力权重的非负性,同时增加非线性。

(3) 然后就是刚才的 注意力权重 注意力权重 注意力权重和 V V V矩阵相乘,如图4:


图4 注意力权重和V矩阵相乘

注意力权重 × V A L U E 矩阵 = 最终结果 注意力权重 × VALUE矩阵 = 最终结果 注意力权重×VALUE矩阵=最终结果

首先是"我"这个字对"我想吃酸菜鱼"这句话里面每个字的 注意力权重 注意力权重 注意力权重,和 V V V中"我想吃酸菜鱼"里面每个字的第一维特征进行相乘再求和,这个过程其实就相当于用每个字的权重对每个字的特征进行加权求和。然后再用"我"这个字对"我想吃酸菜鱼"这句话里面每个字的 注意力权重 注意力权重 注意力权重和 V V V中"我想吃酸菜鱼"里面每个字的第二维特征进行相乘再求和,依次类推最终也就得到了 ( 6 , 768 ) (6,768) (6,768) 的结果矩阵,和输入保持一致。

python 复制代码
class BertSelfAttention(nn.Module):
    def __init__(self, config):
        self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    
    def forward(self,hidden_states): # hidden_states 维度是(L, 768)
        Q = self.query(hidden_states)
        K = self.key(hidden_states)
        V = self.value(hidden_states)
        
        attention_scores = torch.matmul(Q, K.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        out = torch.matmul(attention_probs, V)
        return out

4、为什么叫自注意力机制?

因为可以看到 Q / K / V Q/K/V Q/K/V都是通过同一句话的输入算出来的,按照上面的流程也就是一句话内每个字对其它字(包括自己)的权重分配。

如果不是自注意力的话, Q Q Q来自于句 A A A, K K K, V V V来自于句 B B B。

5、注意, K / V K/V K/V中,如果同时替换任意两个字的位置,对最终的结果是不会有影响的,也就是说注意力机制是没有位置信息的,不像CNN/RNN/LSTM,这也是为什么要引入位置embeding的原因。

二、整体代码

python 复制代码
class BertSelfAttention(nn.Module):
    def __init__(self, config):
        self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    
    def forward(self,hidden_states): # hidden_states 维度是(L, 768)
        Q = self.query(hidden_states)
        K = self.key(hidden_states)
        V = self.value(hidden_states)
        
        attention_scores = torch.matmul(Q, K.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        out = torch.matmul(attention_probs, V)
        return out
相关推荐
春末的南方城市31 分钟前
FLUX的ID保持项目也来了! 字节开源PuLID-FLUX-v0.9.0,开启一致性风格写真新纪元!
人工智能·计算机视觉·stable diffusion·aigc·图像生成
zmjia11133 分钟前
AI大语言模型进阶应用及模型优化、本地化部署、从0-1搭建、智能体构建技术
人工智能·语言模型·自然语言处理
jndingxin1 小时前
OpenCV视频I/O(14)创建和写入视频文件的类:VideoWriter介绍
人工智能·opencv·音视频
AI完全体1 小时前
【AI知识点】偏差-方差权衡(Bias-Variance Tradeoff)
人工智能·深度学习·神经网络·机器学习·过拟合·模型复杂度·偏差-方差
GZ_TOGOGO1 小时前
【2024最新】华为HCIE认证考试流程
大数据·人工智能·网络协议·网络安全·华为
sp_fyf_20241 小时前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-02
人工智能·神经网络·算法·计算机视觉·语言模型·自然语言处理·数据挖掘
新缸中之脑1 小时前
Ollama 运行视觉语言模型LLaVA
人工智能·语言模型·自然语言处理
卷心菜小温2 小时前
【BUG】P-tuningv2微调ChatGLM2-6B时所踩的坑
python·深度学习·语言模型·nlp·bug
胡耀超2 小时前
知识图谱入门——3:工具分类与对比(知识建模工具:Protégé、 知识抽取工具:DeepDive、知识存储工具:Neo4j)
人工智能·知识图谱
陈苏同学2 小时前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm