深度学习 | 注意力机制、自注意力机制

卷积神经网络的思想主要是通过卷积层对图像进行特征提取,从而达到降低计算复杂度的目的,利用的是空间特征信息;循环神级网络主要是对序列数据进行建模,利用的是时间维度的信息。

而第三种 注意力机制 网络,关注的是数据中重要性的维度,研究怎么充分关注更加重要的信息,从而提高计算的准确度和效率。



一、注意力机制

1、注意力机制的发展史

2、生物学中的注意力

从众多信息中选择出对当前任务目标更关键的信息。

3、深度学习中的注意力机制

深度学习中的注意力机制是指人工神经网络中的一种机制,它能够帮助模型更好地关注重要的信息,而忽略不重要的信息。

这种机制的作用类似于生物体中的注意力机制,可以帮助我们过滤出重要的信息,并将这些信息传递给相应的神经元进行处理。

深度学习中的注意力机制与生物学中的注意力机制的关系是相似的,但并不完全相同。在深度学习中,注意力机制主要是通过自然语言处理、视觉和听觉处理等应用来实现的。它通过学习模型来自动判断哪些信息更重要,并将信息传递给相应的神经元进行处理。生物体中的注意力机制是由两个神经系统------前额叶皮层和脑干辅助系统------协同工作完成的,而深度学习中的注意力机制则是通过人工神经网络来实现的。

在自然语言处理中,注意力机制可以帮助模型更好地理解句子的意思,并提取出重要的信息。例如,在翻译任务中,注意力机制可以帮助模型更好地理解源语言的句子,并将其翻译成目标语言。在问答任务中,注意力机制可以帮助模型更好地理解问题的意思,并找到相应的答案。在视觉和听觉处理中,注意力机制也有广泛的应用。

例如,在视觉任务中,注意力机制可以帮助模型更好地理解图像中的重要信息,并进行分类或识别。在听觉任务中,注意力机制可以帮助模型更好地理解音频中的重要信息,并进行分类或识别。

让神经网络能够更加关注图像中的重要特征,而不是整张图像。

4、NLP任务中的注意力

4.1、编解码器架构

大多数注意力机制都附着在Encoder-Decoder框架下;

注意力机制是一种思想,本身并不依赖于任何框架。

中间表示 c 存在信息瓶颈。

4.2、NLP中的注意力机制原理

上下文向量 c 应可访问输入序列所有部分,而不仅是最后一个;

每一时刻产生不同的语言编码向量,表示不同的关注区域。

4.3、注意力机制的类型

隐式注意力:非常深的神经网络已经学会了一种形式的隐式注意。

显式注意力:根据先前输入的记忆"权衡"其对输入的敏感度。

软注意力:函数在其域内平滑变化,因此是可微的。

硬注意力:用随机抽样模型代替了确定性方法,不可微的。可以看做是一种决定是否关注某个区域的开关机制,意味着这个功能在域内有许多突变,因此是不可微的,不能使用标准的梯度下降法,只能使用强化学习来进行训练。

软注意力关注的区域是渐进变化的,硬注意力通常使用图像裁剪来关注区域,只有0和1的表示,关注的区域保留的像素都是1。

4.4、编解码器架构中的注意力机制

e ij 就是一个数,他衡量了每个encoder的隐状态 h j 和解码器前一时刻的状态输出 y i-1 的相互关联关系。

因为有很多个隐变量,所以也有很多个权重。为了让他们形成一个分步。可以进一步使用 softmax 函数再进行一次变换,变成了 α ij。

进而 解码器隐变量 可以写成 注意力权重和编码器隐变量 加权平均的一种形式。

4.5、注意力的可视化



二、注意力机制的计算

1、编解码器中的注意力

注意力其实是一个小型神经网络的输出,是衡量两种隐状态间"对齐"程度的分数。

两个输入是解码器先前的状态 y i-1,以及各个时刻编码器的隐藏状态 h j 。

一个输出就是α ij 。

2、如何计算注意力

注意力就是衡量编码器隐状态与前一时刻解码器输出对齐的分数。

除了使用小型神经网络计算之外,还有很多方式:

最常用的是第二种,它将注意力参数转化成了一小型的全连接网络,显然我们可以扩展他以便使用更多的层。实际上意味着注意力是一组可以训练的权重,用我们标准的反向传播算法进行调整。

通过让解码器具有注意力机制,我们减轻了编码器必须将输入语句的所有信息编码成为固定长度向量的负担,通过这种新方法,信息可以分布在整个序列当中,解码器可以相应的选择检索这些信息。

注意 实际上是解码器来去注意编码器当中的隐状态序列的信息,当然是有代价的,我们牺牲了计算的复杂度,多加了一个神经网络,需要对其进行训练,因此计算复杂度多了一个 O( t ^ 2 ),t 是输入输出长度句子的和。

红色圈即 s ,绿色圈即 h,蓝色圈即 α ij 。

3、全局注意力和局部注意力 / 注意力机制网络实现

在神经网络中,注意力机制可以通过使用注意力机制层来实现。

GlobalAttention:在整个输入序列上计算注意力分数。

LocalAttention:只考虑输入单元/标记的一个子集。也可以被看做是一种硬注意力。

下图直观表达了这种连接关系,并且和卷积网络和全连接网络进行了对比,但是需要注意他们的输入输出是不同的。左边输入的都是前一层的神经元,右边输入的是编码器隐藏层不同时刻的隐状态。

颜色表明这些权重在不断的变化,而在卷积层和全连接层虽然权重也在变化,但是是通过梯度下降缓慢变化的,或者说没有右边变化的剧烈。

4、自注意力机制

Self-Attention:序列自身的注意力

序列元素之间的分数,一种自我关注,照镜子。

对称的 ~

最终的意义是在转换为另一个序列之前,先创造一个更加有意义的序列表示。

5、注意力机制的优点

解决了编码器到解码器之间信息传递的瓶颈问题;

建立编码器状态和解码器间直接联系,消除了梯度消失问题;

提供了更好的可解释性。

6、注意力与transformer

transformer:编码器到解码器两个序列之间的一个转换器

注意力机制某种程度上就是transformer。

注意力机制更强调结果,transformer更强调结构。

7、注意力机制的应用

通用的NLP模型,文本生成、聊天机器人、文本分类等任务

图像分类模型中也可以使用注意力机制,Vision Transformer

8、注意力机制的可视化

注意力机制可以使用可视化工具来展示其在处理信息时对不同信息的关注程度。这种可视化方法通常使用柱状图或热力图来表示。

例如,在处理文本信息时,注意力机制可以通过将每个单词的权重值可视化为柱状图的高度来展示。这样就可以看出,注意力机制对哪些单词的关注程度更高。同样的,注意力机制也可以使用热力图的方式来可视化。例如,在处理图像信息时,注意力机制可以将图像中每个像素的权重值可视化为不同的颜色,从而表示注意力机制对哪些像素的关注程度更高。



三、键值对注意力和多头注意力

1、经典注意力机制的计算过程和局限

经典注意力机制中均使用 编码器隐藏状态 h 来计算注意力分数,如下图所示:

编码器先生成一个个的隐状态(绿圈圈),然后之间和当前状态(红圈圈)一个个的计算权重(蓝圈圈),然后再经过softmax层,之后再分别与对应的编码器状态乘积,然后绿色圈圈相加得到上下文信息,再把这个信息送交给decoder。

这个过程中,直接使用编码器隐藏状态 h 是有局限性的,注意力分数会仅基于隐藏状态在序列中的相对位置而不是他们的内容,这样就限制了模型关注相关信息的能力,导致模型性能不佳。

2、键值对注意力

深度学习注意力机制中,有三个非常重要的概念:查询query,键key和值value,简称Q/K/V。

它们均来自于输入的数据。在 Transformer等注意力模型中,这些值通常是来自于输入的序列的向量表示。具体来说,Q(query)表示要查询的信息。它可以理解为是一个指针,指向我们想要查询的信息。K(key)表示一个参考信息,通常是一个序列中的向量表示。V(value)表示与 K 对应的输出信息。这三个概念的提出其实是借鉴了数据库理论说法。在计算机科学中,查询指的是在数据结构中查找特定数据的过程。键是数据结构中的一个唯一的标识符,可以用来查找值。值是存储在数据结构中的数据。

在注意力机制中,QKV(query、key、value)是一种表示注意力机制的模型。它描述了在处理信息时,人脑的注意力是如何转移的。查询(query)表示人脑对信息的关注点,键(key)表示当前注意力所在的信息单元,值(value)表示当前注意力所在信息单元的内容。

具体来说,当人脑对某个信息感兴趣时,会将注意力聚焦到这个信息上。此时查询就是对这个信息的关注点,键就是这个信息的标识符,值就是这个信息的内容。在 Attention 机制中,Q 和 K 之间的匹配程度决定了输出的权重。意味着 Q 中的信息会去匹配 K 中的信息,并基于匹配的结果来决定 V 中的信息在最终输出中的重要程度。常见的 Attention 机制包括有 Scaled Dot-Product Attention 和 Multi-Head Attention。这些机制都基于 Q、K、V 之间的匹配关系来进行输出计算。

这样一来,使得模型能够学习输入和输出序列之间更加复杂更有意义的对齐。

左边红色向量是 query,类似经典注意力机制的输出的角色,既可以来自输入本身(自注意力机制),也可以来自其他序列(解码器之前时刻的输出)。黄色的向量 key,和紫色的向量 value,其实都是绿色输入向量的线性变换。只有 key 和 query 通过计算相似性求权重分数而 value 的值不受影响,这样一来就实现了相似性和内容的分离。

这么说可能还有点抽象。再帮你深入浅出地理解一下。

注意力机制说白了就是要通过训练得到一个权重,自注意力机制就是要通过权重矩阵来自动地找到词与词之间的关系。因此需要给每个输入定义张量,然后通过张量间的乘法来得到输入之间的关系。在神经网络中,注意力机制的计算公式通常是这样的:

其中, Q 表示查询矩阵,K 表示键矩阵V, 表示值矩阵,d k 表示键的维度。

我们知道数学上两个向量

a和b同向,

a和b垂直,

a和b反向,

所以两个向量的点乘可以表示两个向量的相似度,越相似方向越趋于一致,a点乘b数值越大。在上面自注意力公式中,Q和K的点乘表示Q和K元素之间相似程度,分母进行了归一化,V是输入线性变换后的特征,乘上V就能得到加权后的特征。换句话说,Q和K的引入是为了得到一个所有数值为0-1的权重矩阵,V是为了保留输入特征。

左边是输入,首先将 Q、K、V 通过线性变换映射到隐藏的空间,然后计算 K 和 Q 的点积,并且使用 softmax 函数将其归一化为注意力分数。本质上 QKV 都是输入向量的线性变换,变换矩阵是通过训练得到的。

显然当 key 和 value 相同的情况下,就退化成了普通的经典注意力机制。换句话说整个模型的结构/框架并没有明显变化。只是通过 K-V的分离带来了更多的便利和灵活性。

3、多头注意力机制的基本原理

Multi-Head Attention:多个查询向量

利用多个查询向量,并行的从输入信息K和V选取多组信息,在查询的过程中每个查询向量 q i 都会关注输入信息的不同部分,也就是说从不同的角度分析当前输入信息。

不同查询向量就对应不同的 Q/K/V,最终将所有查询向量的结果进行拼接作为最终的结果。

多头注意力机制增加了模型的多样性和表示能力,同时提高了模型的训练效率(并行处理)。



四、自注意力机制

1、为什么要用这种机制

我们前面讲解的多数注意力机制都是 解码器 前一时刻的状态 和 编码器 隐变量之间的关系,或者说是 Q 与 K 之间的,而 Q 和 K 通常来自于不同的序列。

假设下面的句子是我们要翻译的输入, 其中的 it 指的是什么呢?是animal还是street呢?对人来说是个不复杂的问题,但是对于算法来说不是那么简单。

The animal didn't cross the street because it was too tired.

自注意力机制就允许把 it 和 animal 联系到一起,具体来说就是通过查看输入序列中其他位置以便寻找更好的编码 it 的线索。

不同颜色深度就表示权重的大小,实际上相当于把输入序列进行了一个预处理的过程。

自注意力机制有很多优势:可以让模型聚焦在输入的重要部分而忽略其他不相关信息;它能够处理变长的序列数据(文本、图像等),而且可以在不改变模型结构的情况下加强模型的表示能力,还能减少模型的计算复杂度,因为它只对关键信息进行处理。

2、自注意力机制的计算

(1)获取 Q K V 值

从每个编码器的输入向量,创建 Q K V 向量。具体来说就是把 词嵌入 Embedding 向量乘以训练得到的三个矩阵 W。

前面讲过,Query可以理解成解码器中前一时刻的状态,Key可以理解成编码器的隐状态,二者之间可以求向量的相似度,也就是注意力的分数。

(2)自注意力分数

假设我们正在计算 Thinking 的自注意力,需要根据这个词对输入句子的每一个词进行评分,相当于是看 Thinking 和其他每一个词的关联度。

通过 query 和 各个单词的 key 进行点积运算。 其实和之前的键值对注意力机制可以类比,只不过现在的query和输入其实是一个序列。

(3)softmax 归一化

除以key向量维度的平方根,目的是为了使训练中梯度更加稳定。

这个softmax分数就叫做注意力分布,他表示Thinking对每个位置的关注程度。

(4)注意力加权求和

将每个value向量 乘以 softmax分数。

然后通过加权求和得到自注意力的输出 z1 ,生成的向量再进一步发送到前馈神经网络,实际实现中计算都是以矩阵形式来完成的,效率更高。

3、自注意力的矩阵计算

使用矩阵运算可以一次性计算出所有位置的Attention输出向量。

4、自注意力的理解

计算输入序列每个元素与其余元素之间的相关性得分。



五、注意力池化及代码实现

注意力机制在transformer中效果惊人,除了可以处理自然语言任务,在计算机视觉领域也得到了很好的应用,比如Vit模型,然而Vit模型和传统的CNN大不一样,比如Vit的输入是image page,而不是图像的像素,因此不具备视觉的属性。

所以就有人提出了新的模型结构,就是用Attention层来取代CNN中的池化层。

1、注意力可视化

huggingface.co

为我们提供了一个可视化工具。

import torch
import matplotlib.pyplot as plt
from torch import nn
from matplotlib import ticker
import warnings
warnings.filterwarnings("ignore")

# 绘制注意力热图
def show_attention(axis, attention):
    fig = plt.figure(figsize=(10,10))
    ax=fig.add_subplot(111)
    cax=ax.matshow(attention, cmap='bone')
    if axis is not None:
        ax.set_xticklabels(axis[0])
        ax.set_yticklabels(axis[1])
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    plt.show()

# 生成一个样例
sentence = ' I love deep learning more than machine learning'
tokens = sentence.split(' ')

attention_weights = torch.eye(8).reshape((8, 8)) + torch.randn((8, 8)) * 0.1  # 生成注意力权重矩阵
attention_weights
复制代码
tensor([[ 9.8757e-01, -3.3269e-02, -6.4129e-03,  7.4444e-02, -9.9374e-03,
          8.6501e-02, -2.7283e-02,  7.9554e-02],
        [-1.4526e-01,  9.9452e-01, -1.0946e-01, -3.0252e-03, -8.7525e-02,
          6.3469e-02, -1.7428e-01,  9.7691e-02],
        [ 6.5311e-02, -1.3188e-01,  9.8884e-01,  1.0128e-01,  4.4508e-02,
         -5.3302e-02,  4.3820e-02, -4.7360e-02],
        [-1.2468e-01, -1.1606e-02,  2.9003e-02,  8.5149e-01,  1.4835e-01,
         -5.7268e-02, -5.0346e-02,  4.6933e-03],
        [-5.2518e-02,  3.6001e-02, -1.3409e-01,  8.4001e-02,  1.0626e+00,
          9.7431e-02, -4.4424e-02, -7.1158e-04],
        [ 2.8226e-02,  2.6271e-01,  4.2067e-02, -4.2097e-02,  1.5363e-01,
          1.1313e+00,  5.2735e-02, -1.6341e-01],
        [ 1.0850e-01, -4.2463e-02, -2.5154e-02,  1.2933e-01, -5.9269e-02,
          8.8615e-02,  1.0243e+00,  1.9031e-01],
        [ 1.4697e-01,  6.3187e-02,  6.1696e-02, -8.7981e-02, -1.0067e-01,
         -8.5357e-02, -3.2430e-02,  9.8060e-01]])
show_attention([tokens, tokens], attention_weights)  # 展示自注意力热图

2、注意力池化

传统的CNN网络中,卷积层后面会接一层池化层,用来调整卷积层的输出,池化层的作用有很多,最主要的作用就是特征降维、压缩、去除冗余信息。

而注意力机制则是对信息权重进行计算,找出重要信息,忽略次要信息。

非参数注意力池化:不需要学习参数矩阵。

通用注意力池化公式中 x 代表查询 q ,x i 代表键 key ,y i 相当于值 value。

参数注意力池化

2.1、数据集生成

# 定义一个映射函数
def func(x):
    return x + torch.sin(x)  # 映射函数 y = x + sin(x)

n = 100  # 样本个数100
x, _ = torch.sort(torch.rand(n) * 10)   # 生成0-10的随机样本并排序
y = func(x) + torch.normal(0.0, 1, (n,))  # 生成训练样本对应的y值, 增加均值为0,标准差为1的扰动
x, y
复制代码
(tensor([0.0932, 0.1369, 0.1464, 0.2588, 0.3617, 0.4888, 0.5750, 0.6377, 0.7767,
         0.9431, 1.2405, 1.3345, 1.3410, 1.3594, 1.4411, 1.5876, 1.7200, 1.8047,
         1.8210, 1.9156, 1.9657, 2.1769, 2.3423, 2.5221, 2.5958, 2.6168, 2.7542,
         3.2111, 3.2732, 3.4693, 3.6167, 3.6792, 3.6793, 3.6878, 3.8277, 3.8367,
         3.8376, 3.8876, 4.0447, 4.1571, 4.2137, 4.2324, 4.3410, 4.4588, 4.4594,
         4.9091, 4.9196, 5.0685, 5.2224, 5.3402, 5.3748, 5.3791, 5.4064, 5.4869,
         5.5518, 5.5796, 5.6379, 5.6441, 5.6606, 5.6656, 5.7890, 6.0018, 6.0525,
         6.1598, 6.2909, 6.3964, 6.4418, 6.6232, 6.6364, 6.7585, 7.3156, 7.4028,
         7.4515, 7.4991, 7.7147, 7.8343, 7.8439, 7.8470, 7.8872, 7.9275, 7.9587,
         8.0238, 8.2713, 8.3269, 8.3630, 8.6349, 8.6708, 8.7167, 8.7238, 8.8672,
         8.9560, 9.0643, 9.3687, 9.3806, 9.4898, 9.4994, 9.5740, 9.6834, 9.7924,
         9.8807]),
 tensor([ 1.8580,  0.4672, -0.6435,  1.1041, -1.2039,  0.8677,  2.2432,  1.7170,
         -0.3424,  1.6842,  0.8811,  2.5315,  0.5781,  2.8359,  2.9742,  2.5276,
          1.7305,  1.4146,  3.4234,  2.2744,  3.0479, -0.1695,  2.6256,  1.7427,
          3.6126,  2.0963,  2.3521,  2.7559,  2.4551,  5.6545,  3.3721,  4.3246,
          3.0190,  3.6240,  4.8242,  4.9076,  4.7384,  4.3644,  4.1528,  3.0237,
          1.3883,  3.9233,  2.7181,  3.4516,  3.0049,  2.6798,  3.2244,  4.3508,
          4.3351,  4.3058,  3.9870,  4.7560,  2.2353,  4.8209,  3.5682,  4.3192,
          5.2604,  5.5443,  4.8445,  6.4016,  4.3775,  6.1800,  4.2178,  7.9355,
          6.1463,  7.0191,  6.2096,  5.5149,  5.8452,  6.4167,  7.8151,  7.2597,
          6.2593,  8.3543,  9.0927,  8.3996,  7.7290,  9.0931,  9.0253,  9.5433,
          9.9514,  9.3971,  8.5425,  9.0645,  8.5238, 10.8705,  9.0520,  9.2222,
          8.5337,  9.9958,  7.9788,  9.3992, 10.2405,  9.6156,  7.5025,  9.4641,
          9.2835,  8.3229,  8.6281,  8.5180]))
# 绘制曲线上的点
x_curve = torch.arange(0, 10, 0.1)  
y_curve = func(x_curve)
plt.plot(x_curve, y_curve)
plt.plot(x, y, 'o')
plt.show()

2.2、非参数注意力池化

# 平均池化
y_hat = torch.repeat_interleave(y.mean(), n) # 将y_train中的元素进行复制,输入张量为y.mean, 重复次数为n
plt.plot(x_curve, y_curve)
plt.plot(x, y, 'o')
plt.plot(x_curve, y_hat)
plt.show()
# nadaraya-watson 核回归
x_nw = x_curve.repeat_interleave(n).reshape((-1, n))
x_nw.shape, x_nw
复制代码
(torch.Size([100, 100]),
 tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.1000, 0.1000, 0.1000,  ..., 0.1000, 0.1000, 0.1000],
         [0.2000, 0.2000, 0.2000,  ..., 0.2000, 0.2000, 0.2000],
         ...,
         [9.7000, 9.7000, 9.7000,  ..., 9.7000, 9.7000, 9.7000],
         [9.8000, 9.8000, 9.8000,  ..., 9.8000, 9.8000, 9.8000],
         [9.9000, 9.9000, 9.9000,  ..., 9.9000, 9.9000, 9.9000]]))
# 带入公式得到注意力权重矩阵
attention_weights = nn.functional.softmax(-(x_nw - x)**2 / 2, dim=1)
attention_weights.shape, attention_weights
复制代码
(torch.Size([100, 100]),
 tensor([[8.0585e-02, 8.0181e-02, 8.0073e-02,  ..., 3.5190e-22, 1.2183e-22,
          5.1098e-23],
         [7.5357e-02, 7.5307e-02, 7.5277e-02,  ..., 8.5860e-22, 3.0050e-22,
          1.2716e-22],
         [7.0189e-02, 7.0451e-02, 7.0490e-02,  ..., 2.0866e-21, 7.3827e-22,
          3.1517e-22],
         ...,
         [5.8540e-22, 8.9043e-22, 9.7487e-22,  ..., 6.4303e-02, 6.4038e-02,
          6.3270e-02],
         [2.3868e-22, 3.6464e-22, 3.9960e-22,  ..., 6.8407e-02, 6.8871e-02,
          6.8649e-02],
         [9.6921e-23, 1.4872e-22, 1.6313e-22,  ..., 7.2478e-02, 7.3769e-02,
          7.4183e-02]]))
# y_hat为注意力权重和y值的乘积,是加权平均值
y_hat = torch.matmul(attention_weights, y)
plt.plot(x_curve, y_curve)
plt.plot(x, y, 'o')
plt.plot(x_curve, y_hat)
plt.show()
show_attention(None, attention_weights) # 展示注意力热图

参考

11-1 注意力机制

11-2 注意力机制的种类

Chapter-11/11.5 注意力池化.ipynb · 梗直哥/Deep-Learning-Code - Gitee.com

相关推荐
KuaFuAI几秒前
微软推出的AI无代码编程微应用平台GitHub Spark和国产AI原生无代码工具CodeFlying比到底咋样?
人工智能·github·aigc·ai编程·codeflying·github spark·自然语言开发软件
Make_magic10 分钟前
Git学习教程(更新中)
大数据·人工智能·git·elasticsearch·计算机视觉
shelly聊AI14 分钟前
语音识别原理:AI 是如何听懂人类声音的
人工智能·语音识别
源于花海17 分钟前
论文学习(四) | 基于数据驱动的锂离子电池健康状态估计和剩余使用寿命预测
论文阅读·人工智能·学习·论文笔记
雷龙发展:Leah17 分钟前
离线语音识别自定义功能怎么用?
人工智能·音频·语音识别·信号处理·模块测试
4v1d21 分钟前
边缘计算的学习
人工智能·学习·边缘计算
风之馨技术录25 分钟前
智谱AI清影升级:引领AI视频进入音效新时代
人工智能·音视频
sniper_fandc35 分钟前
深度学习基础—Seq2Seq模型
人工智能·深度学习
goomind38 分钟前
深度学习模型评价指标介绍
人工智能·python·深度学习·计算机视觉
youcans_39 分钟前
【微软报告:多模态基础模型】(2)视觉理解
人工智能·计算机视觉·大语言模型·多模态·视觉理解