my-attention
考虑别的token对当前token的语义影响
第一种情况, 维度缩减
输入x= [4x6]
dk=3
wq = [6x3]
wk = [6x3]
wv = [6x3]
q = x@ wq = [4x6]*[6x3] = [4x3]
k = x@ wk = [4x6]*[6x3] = [4x3]
r=q@k.T = [4x3]*[3x4]=[4x4]
缩放
r = r/sqrt(dk)=[4x4]
a = softmax®=[4x4]
v = x@ wv = [4x6]*[6x3] = [4x3]
out = a@v = [4x4] * [4x3] = [4x3]
第二种情况, 维度不缩减
输入x= [4x6]
输出维度为6
dk=6
随机生成qkv
wq = [6x6]
wk = [6x6]
wv = [6x6]
q = x@ wq = [4x6]*[6x6] = [4x6]
k = x@ wk = [4x6]*[6x6] = [4x6]
r=q@k.T = [4x6]*[6x4]=[4x4]
缩放
r = r/sqrt(dk)=[4x4]
归一化
a = softmax®=[4x4]
原始值增加权重
v = x@ wv = [4x6]*[6x6] = [4x6]
out = a@v = [4x4] * [4x6] = [4x6]
保证输出结果的维度和要求要一致
下面是用代码实现了一下自注意力机制
import math
import torch
from torch import nn
x = torch.randn(16, 64, 512)
d_model = 512
h_num = 8
class Self_Attention(nn.Module):
def __init__(self, d_model, h_num):
# 调用父类构造函数
super(Self_Attention, self).__init__()
self.d_model = d_model
self.h_num = h_num
self.softmax = nn.Softmax(dim=-1)
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def forward(self, x):
B, L, D = x.shape
h_d = self.d_model // self.h_num
q, k, v = self.w_q(x), self.w_k(x), self.w_v(x)
q = q.view(B, L, self.h_num, h_d).transpose(1, 2)
k = k.view(B, L, self.h_num, h_d).transpose(1, 2)
v = v.view(B, L, self.h_num, h_d).transpose(1, 2)
r = q @ k.transpose(2, 3) / math.sqrt(h_d)
mask = torch.tril(torch.ones(L, L, dtype=bool))
r = r.masked_fill(~mask, -10000)
a = self.softmax(r)
o = a @ v
o = o.transpose(1, 2).contiguous().view(B, L, self.d_model)
return self.w_o(o)
attention = Self_Attention(d_model, h_num)
y = attention(x)
print(y.shape)
print(y)