

第一种注意力机制
python
# 注意力机制
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attn(nn.Module):
def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
"""初始化函数中的参数有5个
query_size代表query的最后一维大小
key_size代表key的最后一维大小, value_size1代表value的导数第二维大小
value = (1, value_size1, value_size2)
value_size2代表value的倒数第一维大小, output_size输出的最后一维大小
"""
super(Attn, self).__init__()
# 将以下参数传入类中
self.query_size = query_size
self.key_size = key_size
self.value_size1 = value_size1
self.value_size2 = value_size2
self.output_size = output_size
# 初始化注意力机制实现第一步中需要的线性层
self.attn = nn.Linear(self.query_size + self.key_size, value_size1)
# 初始化注意力机制实现第三步中需要的线性层
self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)
def forward(self, Q, K, V):
"""forward函数的输入参数有三个
分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的
张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量
"""
# 第一步, 按照计算规则进行计算,
# 我们采用常见的第一种计算规则
# 将Q,K进行纵轴拼接, 做一次线性变化, 最后使用softmax处理获得结果
attn_weights = F.softmax(self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)
print(' Q、K进行softmax后注意力权重长这样子\n', attn_weights)
# 然后进行第一步的后半部分, 将得到的权重矩阵与V做矩阵乘法计算,
# 当二者都是三维张量且第一维代表为batch条数时, 则做bmm运算
attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)
# 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法,
# 需要将Q与第一步的计算结果再进行拼接
output = torch.cat((Q[0], attn_applied[0]), 1)
# 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出
# 因为要保证输出也是3维张量, 因此使用unsqueeze(0)扩展维度
output = self.attn_combine(output).unsqueeze(0)
return output, attn_weights
if __name__ == '__main__':
query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
# 批次,行,列
Q = torch.randn(1, 1, 32)
print(' Q长这样子\n', Q)
K = torch.randn(1, 1, 32)
print(' K长这样子\n', K)
V = torch.randn(1, 32, 64)
print(' V长这样子\n', V.shape)
print(' V长这样子\n', V)
print('**************************************************************************************')
print(Q[0])
print(K[0])
print(' Q、K拼接后长这样子\n', torch.cat((Q[0], K[0])))
print('**************************************************************************************')
out, attn_weights = attn(Q, K, V)
print(' 最终输出结果\n', out.shape)
print(out)
print(' 注意力权重\n', attn_weights.shape)
print(attn_weights)
运行结果
python
Q长这样子
tensor([[[-0.0843, 0.6108, -0.5214, 0.4358, 1.6302, -0.6159, 1.6340,
0.1276, 1.5854, -0.2922, -0.9621, 0.1989, -0.0558, 1.9234,
-0.5138, -0.7876, 1.9724, -0.0659, 0.5300, 1.1414, 1.1585,
0.9155, 0.0557, -0.7387, -0.5724, -0.0478, 0.4301, 1.2947,
-0.4314, -0.0663, 0.3610, 0.6614]]])
K长这样子
tensor([[[-1.6176, -0.2296, 0.0839, 0.1775, 0.3062, -0.2145, -0.6811,
-1.3397, -0.4235, 0.4637, -1.3447, -0.5441, -0.9798, 0.8265,
-0.2740, 0.9446, -2.4202, 1.1822, 1.8531, 2.0389, -0.4581,
-0.7546, 2.1168, 2.1271, 0.3378, 1.4806, -1.2704, 0.0628,
-1.2798, 0.0615, -0.0730, -0.8597]]])
V长这样子
torch.Size([1, 32, 64])
V长这样子
tensor([[[ 1.0952, 0.3832, 0.1141, ..., -1.3869, -0.0160, -1.3580],
[ 0.3251, 0.3406, 0.1589, ..., -0.8902, 2.0466, -0.5664],
[-0.6364, -1.0243, 0.1915, ..., 0.6893, -0.8892, 0.2788],
...,
[ 0.3980, 1.6673, 0.4893, ..., -0.7628, -0.0612, -0.2004],
[ 0.2605, 0.6287, -2.1606, ..., -0.8923, -0.4310, 1.8570],
[-0.2593, -1.3517, -0.4209, ..., 0.7520, 0.6580, -0.9260]]])
**************************************************************************************
tensor([[-0.0843, 0.6108, -0.5214, 0.4358, 1.6302, -0.6159, 1.6340, 0.1276,
1.5854, -0.2922, -0.9621, 0.1989, -0.0558, 1.9234, -0.5138, -0.7876,
1.9724, -0.0659, 0.5300, 1.1414, 1.1585, 0.9155, 0.0557, -0.7387,
-0.5724, -0.0478, 0.4301, 1.2947, -0.4314, -0.0663, 0.3610, 0.6614]])
tensor([[-1.6176, -0.2296, 0.0839, 0.1775, 0.3062, -0.2145, -0.6811, -1.3397,
-0.4235, 0.4637, -1.3447, -0.5441, -0.9798, 0.8265, -0.2740, 0.9446,
-2.4202, 1.1822, 1.8531, 2.0389, -0.4581, -0.7546, 2.1168, 2.1271,
0.3378, 1.4806, -1.2704, 0.0628, -1.2798, 0.0615, -0.0730, -0.8597]])
Q、K拼接后长这样子
tensor([[-0.0843, 0.6108, -0.5214, 0.4358, 1.6302, -0.6159, 1.6340, 0.1276,
1.5854, -0.2922, -0.9621, 0.1989, -0.0558, 1.9234, -0.5138, -0.7876,
1.9724, -0.0659, 0.5300, 1.1414, 1.1585, 0.9155, 0.0557, -0.7387,
-0.5724, -0.0478, 0.4301, 1.2947, -0.4314, -0.0663, 0.3610, 0.6614],
[-1.6176, -0.2296, 0.0839, 0.1775, 0.3062, -0.2145, -0.6811, -1.3397,
-0.4235, 0.4637, -1.3447, -0.5441, -0.9798, 0.8265, -0.2740, 0.9446,
-2.4202, 1.1822, 1.8531, 2.0389, -0.4581, -0.7546, 2.1168, 2.1271,
0.3378, 1.4806, -1.2704, 0.0628, -1.2798, 0.0615, -0.0730, -0.8597]])
**************************************************************************************
Q、K进行softmax后注意力权重长这样子
tensor([[0.0566, 0.0220, 0.0399, 0.0829, 0.0233, 0.0175, 0.0365, 0.0313, 0.0234,
0.0187, 0.0387, 0.0832, 0.0358, 0.0130, 0.0148, 0.0245, 0.0242, 0.0210,
0.0087, 0.0236, 0.0509, 0.0279, 0.0473, 0.0297, 0.0309, 0.0546, 0.0165,
0.0305, 0.0257, 0.0106, 0.0164, 0.0193]], grad_fn=<SoftmaxBackward0>)
最终输出结果
torch.Size([1, 1, 64])
tensor([[[-1.5883e-01, -9.7250e-02, -1.4577e-01, 6.2508e-02, -3.0917e-01,
-4.8471e-01, 1.2058e-01, -3.9673e-01, 4.7531e-01, 2.4023e-01,
-4.5470e-01, 9.8248e-02, -1.7717e-01, 3.3285e-01, 5.4367e-01,
-1.0387e-01, 1.0913e-01, 1.9735e-01, 3.9441e-01, -4.1193e-01,
5.5962e-02, -3.7915e-01, -1.1829e-01, -1.2722e-01, 1.2517e-01,
4.2707e-01, 1.6100e-01, -3.4799e-02, -1.5643e-01, -2.1065e-02,
-1.9389e-02, 8.9914e-05, -2.5389e-01, -1.1194e-01, -2.6804e-01,
6.9662e-01, -3.6186e-01, 6.3613e-01, 1.2927e-01, -1.0210e+00,
-9.3159e-01, -4.4763e-01, -3.8813e-01, -2.8905e-01, 5.0221e-01,
-2.9630e-01, 1.9712e-01, 3.4796e-01, 2.0145e-01, 2.1066e-01,
4.6304e-01, 3.5566e-01, 3.7207e-01, 2.1636e-01, 9.2869e-02,
-3.1811e-01, -4.5739e-01, -4.8703e-01, -4.9259e-02, 3.0813e-01,
-4.4769e-01, 2.3227e-01, 9.7959e-02, -3.2980e-02]]],
grad_fn=<UnsqueezeBackward0>)
注意力权重
torch.Size([1, 32])
tensor([[0.0566, 0.0220, 0.0399, 0.0829, 0.0233, 0.0175, 0.0365, 0.0313, 0.0234,
0.0187, 0.0387, 0.0832, 0.0358, 0.0130, 0.0148, 0.0245, 0.0242, 0.0210,
0.0087, 0.0236, 0.0509, 0.0279, 0.0473, 0.0297, 0.0309, 0.0546, 0.0165,
0.0305, 0.0257, 0.0106, 0.0164, 0.0193]], grad_fn=<SoftmaxBackward0>)
Process finished with exit code 0
第二种注意力机制
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attn(nn.Module):
def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
"""初始化函数中的参数有5个
query_size代表query的最后一维大小
key_size代表key的最后一维大小, value_size1代表value的导数第二维大小
value = (1, value_size1, value_size2)
value_size2代表value的倒数第一维大小, output_size输出的最后一维大小
"""
super(Attn, self).__init__()
# 将以下参数传入类中
self.query_size = query_size
self.key_size = key_size
self.value_size1 = value_size1
self.value_size2 = value_size2
self.output_size = output_size
# 初始化注意力机制实现第一步中需要的线性层
self.attn = nn.Linear(self.query_size + self.key_size, value_size1)
# 初始化注意力机制实现第三步中需要的线性层
self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)
def forward(self, Q, K, V):
"""forward函数的输入参数有三个
分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的
张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量
"""
# 按照公式计算注意力权重
# 将Q,K进行纵轴拼接
combined = torch.cat((Q[0], K[0]), 1)
# 做一次线性变化
linear_result = self.attn(combined)
# 使用tanh函数激活
tanh_result = torch.tanh(linear_result)
# 进行内部求和
sum_result = torch.sum(tanh_result, dim=1, keepdim=True)
# 使用softmax处理获得结果
attn_weights = F.softmax(sum_result, dim=1)
print(' Q、K进行softmax后注意力权重长这样子\n', attn_weights)
# 将得到的权重矩阵与V做张量乘法
attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)
# 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法,
# 需要将Q与第一步的计算结果再进行拼接
output = torch.cat((Q[0], attn_applied[0]), 1)
# 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出
# 因为要保证输出也是3维张量, 因此使用unsqueeze(0)扩展维度
output = self.attn_combine(output).unsqueeze(0)
return output, attn_weights
if __name__ == '__main__':
query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
# 批次,行,列
Q = torch.randn(1, 1, 32)
print(' Q长这样子\n', Q)
K = torch.randn(1, 1, 32)
print(' K长这样子\n', K)
V = torch.randn(1, 32, 64)
print(' V长这样子\n', V.shape)
print(' V长这样子\n', V)
print('**************************************************************************************')
print(Q[0])
print(K[0])
print(' Q、K拼接后长这样子\n', torch.cat((Q[0], K[0])))
print('**************************************************************************************')
out, attn_weights = attn(Q, K, V)
print(' 最终输出结果\n', out.shape)
print(out)
print(' 注意力权重\n', attn_weights.shape)
print(attn_weights)
第三种注意力机制(缩放点积注意力)
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attn(nn.Module):
def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
"""初始化函数中的参数有5个
query_size代表query的最后一维大小
key_size代表key的最后一维大小, value_size1代表value的导数第二维大小
value = (1, value_size1, value_size2)
value_size2代表value的倒数第一维大小, output_size输出的最后一维大小
"""
super(Attn, self).__init__()
# 将以下参数传入类中
self.query_size = query_size
self.key_size = key_size
self.value_size1 = value_size1
self.value_size2 = value_size2
self.output_size = output_size
# 这里原代码初始化的线性层在新公式中不再使用,可删除相关初始化,不过保留也不影响后续计算逻辑正确性
# self.attn = nn.Linear(self.query_size + self.key_size, value_size1)
# self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)
def forward(self, Q, K, V):
"""forward函数的输入参数有三个
分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的
张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量
"""
# 计算缩放点积注意力权重
# 转置K,将其形状从(batch_size, seq_length, key_size)变为(batch_size, key_size, seq_length)
K_transposed = K.transpose(1, 2)
# 计算Q与K的转置的点积
dot_product = torch.bmm(Q, K_transposed)
# 除以缩放系数,这里缩放系数为键向量维度的平方根
scaling_factor = torch.sqrt(torch.tensor(self.key_size, dtype=torch.float))
scaled_dot_product = dot_product / scaling_factor
# 使用softmax处理获得注意力权重
attn_weights = F.softmax(scaled_dot_product, dim=2)
print(' Q、K进行softmax后注意力权重长这样子\n', attn_weights)
# 将得到的权重矩阵与V做张量乘法
attn_applied = torch.bmm(attn_weights, V)
return attn_applied, attn_weights
if __name__ == '__main__':
query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
# 批次,行,列
Q = torch.randn(1, 1, 32)
print(' Q长这样子\n', Q)
K = torch.randn(1, 1, 32)
print(' K长这样子\n', K)
V = torch.randn(1, 32, 64)
print(' V长这样子\n', V.shape)
print(' V长这样子\n', V)
print('**************************************************************************************')
print(Q[0])
print(K[0])
print(' Q、K拼接后长这样子\n', torch.cat((Q[0], K[0])))
print('**************************************************************************************')
out, attn_weights = attn(Q, K, V)
print(' 最终输出结果\n', out.shape)
print(out)
print(' 注意力权重\n', attn_weights.shape)
print(attn_weights)