transformer-注意力评分函数

目录

[10.2 节使用了高斯核来对查询和键之间的关系建模,10.6中的高斯核指数部分可以视为注意力评分函数,简称评分函数,然后把这个函数的输出结果输入softmax 函数中进行运算,通过上述步骤,将得到与键对应的值的概率分布,最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。](#10.2 节使用了高斯核来对查询和键之间的关系建模,10.6中的高斯核指数部分可以视为注意力评分函数,简称评分函数,然后把这个函数的输出结果输入softmax 函数中进行运算,通过上述步骤,将得到与键对应的值的概率分布,最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。)


评分函数 注意力权重 输出

键 softmax 值

查询

图10-4 计算注意力汇聚的输出为值的加权和

用数学语言描述,假设有一个查询q 属于 Rq和m个键值对(ki, ,,,v1),,,,(km,Vm) 其中ki属于Rk

Vi属于Rv,注意力汇聚函数f就被表示成值的加权和

f(q,(k1, V1)) = Sigma a(q,ki) Vi属于Rv

其中,查询q和键Ki的注意力权重是通过注意力评分函数a将两个向量映射成标量再经过softmax运算得到的。

正如图10-4所示,选择不同的注意力评分函数a会导致不同的注意力汇聚操作,本节将介绍两个流行的评分函数,稍后将用它来实现更复杂的注意力机制。

import math

import torch

from torch import nn

from d2l import torch as d2l

10.3.1 掩蔽softmax操作

上面提到的,softmax操作用于输出一个概率分布为注意力权重,在某些情况下,并非所有的值都应该被纳入注意力汇聚中。例如,为了在9.5节中高校处理小批量数据集,某些文本序列被填充了没有意义的特殊词元。为了仅将有意义的词元作为值来获取注意力汇聚,可以指定一个有效序列长度,以便在计算softmax时过滤掉超出指定范围的位置,下面的masked_softmax函数实现了这样的掩蔽softmax操作,其中任何超出有效长度的位置都被掩蔽并设置为0.

def masked_softmax(X, valid_lens):

通过在最后一个轴上掩蔽元素来执行softmax操作

X:3D张量,valid_lens:1D或2D张量

if valid_lens is None:

return nn.functional.softmax(X, dim = -1)

else:

shape = X.shape

if valid_lens.dim() == 1:

valid_lens = torch.repeat_interleave(valid_lens, shape[1])

else:

valid_lens = valid_lens.reshape(-1)

最后一个轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0

X=d2l.sequence_mask(X.reshape(-1, shape[-1], valid_lens, value=-1e6))

return nn.functional_softmax(X.reshape(shape), dim=-1)

为了掩饰此函数时如何工作的,考虑由两个2x4矩阵表示的样本,这两个样本有效长度分别为2和3,经过掩蔽softmax操作,超出有效长度的值都被掩蔽为0

masked_softmax(torch.rand(2,2,4), torch.tensor([2,3]))

同样,也可以使用二维张量,为矩阵样本中的每一行指定有效长度。

masked_softmax(torch.rand(2,2,4), torch.tensor([1,3],[2,4]))

10.3.2 加性注意力

当查询和键不同长度的向量时,可以使用加性注意力作为评分函数,给定查询q属于Rq和键k属于Rk,加性注意力的评分函数为

可学习的参数是Wq属于Rkxq,Wk属于Rhxk和Wt属于Rh。 查询和键连接起来后输入一个多层感知机MLP中,感知机包含一个隐藏层,其隐藏单元数一个超参数h,通过使用tanh作为激活函数,并且禁用偏置项。

下面来实现加性注意力

class AdditiveAttention(nn.Module)

加性注意力

def init(self, key_size, query_size, num_hiddens, dropout, **kwargs):

super(AdditiveAttention, self).init(**kwargs)

self.W_k = nn.Linear(key_size, num_hiddens, bias = False)

self.W_q = nn.Linear(query_size, num_hiddens, bias = False)

self.W_v = nn.Linear(num_hiddens, 1, bias = False)

def forward(self, queries, keys, values, valid_lens):

queries, keys = self.W_q(queries), self.W_k(keys)

#在维度扩展后

#queries的形状为(batch_size, 查询数,1,num_hidden)

#key 的形状为(batch_size, 1, 键-值对数,num_hiddens)

#使用广播方式求和

features = queries.unsqueeze(2) + keys.unqueeze(1)

features = torch.tanh(features)

#self.w_v仅有一个输出,因此从形状中移除最后的维度

scores的形状为(batch_size, 查询数,键-值对数)

scores = self.w_v(features).squeeze(-1)

self.attention_weights = masked_softmax(scores, valid_lens)

values 的形状为batch_size, 键-值对数,值的维度

return torch.bmm(self.dropout(self.attention_weights), values)

用一个小例子演示上面的additiveAttention类,其中查询,键和值的形状为量大小,步数或词元序列长度,实际输出为2,1,20,注意力汇聚输出的形状为 批量大小,查询的步数,值的维度。

queries, keys = torch.normal(0,1,(2,1,20)), torch.ones(2, 10, 2)

#values的小批量,两个值矩阵是相同的

values = torch.arange(40, dtype = torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)

valid_lens = torch.tensor([2, 6])

attention = AdditiveAttention(key_size = 2, query_size=20, num_hiddens=8, dropout = 0.1)

attention.eval()

attention(queries, keys, values, valid_lens)

尽管加性注意力包含了可学习的参数,由于本例中每个键都是相同的,因此注意力权重是均匀的,由指定的有效长度决定。

d2l.show_heatmaps(attention.attention_weights.reshape(1,1,2,10)):

xlabel='keys', ylabel='Queries'

10.3.3 缩放点积注意力

使用点积可以得到计算效率更高的评分函数,但是点积操作要求查询和键具有相同的长度d,假设查询和键的所有元素都是独立的随机变量,并且都满足零均值和单位方差,那么两个向量的点积的均值为0,方差为d,为确保无论向量长度如何,点积的方差在不考虑向量长度的情况下都是1,我们再将点积除以 根号d,则缩放点积注意力评分函数为

在实践中,我们通常从小批量的角度来i考虑提高效率,例如基于n个查询的m个键-值对计算注意力,其中查询和键的长度为d, 值的长度为v,查询Q属于Rnxd。

下面的缩放点积注意力的实现使用了暂退法进行模型正则化。

class DotProdductAttention(nn.Module):

缩放点积注意力

def init(self, dropout, **kwargs):

super(DotProductAttention, self).init(**kwargs)

self.dropout = nn.Dropout(dropout)

#queries 的形状为batch_szie,查询数,d

#keys的形状为batch_size, 键-值对数,d

values的形状为batch_size 键-值对数,值的维度

valid_lens的形状为batch_size, 或者batch_size, valid_lens = None

def forward(self, queries, keys, values, valid_lens = None):

d = queries.shape[-1]

设置transpose_b=True是为了交换keys的最后两个维度。

scores = torch.bmm(queries, keys.transpose(1,2))/math.sqrt(d)

self.attention_weights = masked_softmax(scores, valid_lens)

return torch.bmm(self.dropout(self.attention_weights), values)

为了演示上述的DotProductAttention类,我们使用的先前加性注意力例子中相同的键,值和有效长度。对于点积操作,我们令查询的特征维度与键的特征维度大小相同

queries = torch.normal(0, 1, (2, 1, 2))

attention = DotProductAttention(dropout=0.5)

attention.eval()

attention(queries, keys, values, valid_lens)

与加性注意力演示相同,由于键包含的是相同元素,而这些元素无法通过任何查询进行区分,因此获得了均匀的注意力权重。

d2l.show_heatmaps(attention.attention_weights.reshape(1, 1, 2, 10)),

xlabel='Keys', ylabel='Queries'

小结:注意力汇聚的输出计算可以作为值的加权平均,选择不同的注意力评分函数会带来不同的注意力汇聚操作

当查询和键是不同长度的向量时,可以使用加性注意力评分函数,当他们的长度相同的,使用缩放点积注意力评分函数的计算效率更高。

相关推荐
逐云者1233 小时前
自动驾驶强化学习的价值对齐:奖励函数设计的艺术与科学
人工智能·机器学习·自动驾驶·自动驾驶奖励函数·奖励函数黑客防范·智能驾驶价值对齐
BreezeJuvenile3 小时前
深度学习实验一之图像特征提取和深度学习训练数据标注
人工智能·深度学习
Dev7z3 小时前
舌苔舌象分类图像数据集
人工智能·分类·数据挖掘
万俟淋曦3 小时前
【论文速递】2025年第30周(Jul-20-26)(Robotics/Embodied AI/LLM)
人工智能·深度学习·ai·机器人·论文·robotics·具身智能
高洁013 小时前
大模型-高效优化技术全景解析:微调 量化 剪枝 梯度裁剪与蒸馏 下
人工智能·python·深度学习·神经网络·知识图谱
CoookeCola3 小时前
MovieNet(A holistic dataset for movie understanding) :面向电影理解的多模态综合数据集与工具链
数据仓库·人工智能·目标检测·计算机视觉·数据挖掘
张艾拉 Fun AI Everyday4 小时前
Gartner 2025年新兴技术成熟度曲线
人工智能
菜鸟‍4 小时前
【论文学习】大语言模型(LLM)论文
论文阅读·人工智能·学习
默 语4 小时前
AI驱动软件测试全流程自动化:从理论到实践的深度探索
运维·人工智能·驱动开发·ai·自动化·ai技术·测试全流程