python
复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange, repeat
from torch.nn import MultiheadAttention
class SelfAttention(nn.Module):
def __init__(self, emb_dim):
super(SelfAttention, self).__init__()
# 定义三个权重矩阵(通过线性变换的方式)
self.emb_dim = emb_dim
self.Wq = nn.Linear(emb_dim, emb_dim, bias=False)
self.Wk = nn.Linear(emb_dim, emb_dim, bias=False)
self.Wv = nn.Linear(emb_dim, emb_dim, bias=False)
self.fc = nn.Linear(emb_dim, emb_dim)
def forward(self, x, pad_mask=None):
# 得到输入的Q, K, V的值
# [batch_szie, seq_len, emb_dim] = [3, 5, 512]
Q = self.Wq(x)
K = self.Wk(x)
V = self.Wv(x)
# 缩放点积公式
# torch.bmm批量矩阵乘法 batch matrix multiplication
# (3,512) * (512,3)
print(Q.shape, K.shape, V.shape, K.transpose(1, 2).shape)
att_weights = torch.bmm(Q, K.transpose(1, 2))
# [batch_szie, seq_len, seq_len] = [3, 5, 5]
print("att_weights1", att_weights)
att_weights = att_weights / math.sqrt(self.emb_dim)
print("att_weights2", att_weights,att_weights.shape)
# padding mask 部分处理,因为x是 3*5,而 pad_mask的时候是5*5,俩矩阵无法相乘,故先注释掉
# if pad_mask is not None:
# att_weights = att_weights.masked_fill(pad_mask, -1e9)
att_weights = F.softmax(att_weights, dim=-1)
print("att_weights3", att_weights,att_weights.shape)
# [batch_szie, seq_len, emb_dim] = [3, 5, 512]
# (5,5) * (5, 512)
output = torch.bmm(att_weights, V)
print("output, att_weights的 shape",output.shape, att_weights.shape)
output = self.fc(output)
# torch.Size([3, 5, 512]) torch.Size([3, 5, 5])
print("output, att_weights的 shape",output.shape, att_weights.shape)
return output, att_weights
class MultiHeadAttention(nn.Module):
def __init__(self, emb_dim, num_heads, att_dropout=0.0):
super(MultiHeadAttention, self).__init__()
self.emb_dim = emb_dim
self.num_heads = num_heads
self.att_dropout = att_dropout # 此处是参考文章作者处理padding mask处用到
# embedding 需要是头数的倍数
assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"
self.depth = emb_dim // num_heads
self.Wq = nn.Linear(emb_dim, emb_dim, bias=False)
self.Wk = nn.Linear(emb_dim, emb_dim, bias=False)
self.Wv = nn.Linear(emb_dim, emb_dim, bias=False)
self.fc = nn.Linear(emb_dim, emb_dim)
def forward(self, x, pad_mask=None):
# [batch_szie, seq_len, emb_dim] = [3, 5, 512]
batch_size = x.size(0)
# [batch_szie, seq_len, emb_dim] = [3, 5, 512]
Q = self.Wq(x)
K = self.Wk(x)
V = self.Wv(x)
print(Q.shape, K.shape, V.shape, K.transpose(1, 2).shape)
# 分头 [batch_szie, num_heads, seq_len, depth] = [3, 8, 5, 512/8=64]
Q = Q.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
print(Q.shape, K.shape, V.shape, K.transpose(1, 2).shape)
# [batch_szie, num_heads, seq_len, seq_len] = [3, 8, 5, 5]
att_weights = torch.matmul(Q, K.transpose(-2, -1))
print("att_weights",att_weights.shape)
att_weights = att_weights / math.sqrt(self.depth)
# if pad_mask is not None:
# # 因为是多头,所以mask矩阵维度要扩充到4维 [batch_size, seq_len, seq_len] -> [batch_size, nums_head, seq_len, seq_len]
# pad_mask = pad_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
# att_weights = att_weights.masked_fill(pad_mask, -1e9)
att_weights = F.softmax(att_weights, dim=-1)
print("att_weights2",att_weights.shape)
# # 自己的多头注意力效果没有torch的好,我猜是因为它的dropout给了att权重,而不是fc
# if self.att_dropout > 0.0:
# att_weights = F.dropout(att_weights, p=self.att_dropout)
# [batch_szie, num_heads, seq_len, depth] = [3, 8, 5, 64]
output = torch.matmul(att_weights, V)
print("output1",output.shape)
# 不同头的结果拼接 [batch_szie, seq_len, emb_dim] = [3, 5, 512]
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.emb_dim)
print("output2",output.shape)
output = self.fc(output)
return output, att_weights
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange, repeat
from torch.nn import MultiheadAttention
if __name__ == '__main__':
batch_size = 3
seq_len = 5
emb_dim = 512
# 本例子则词表大小为 301
vocab_size = 301
input_ids = torch.tensor([[100, 200, 300, 300, 0],
[22, 33, 44, 0, 0],
[66, 55, 66, 30, 0]], dtype=torch.long)
# 逻辑矩阵pad_mask:将填充位置标记为True,其他位置标记为False
# pad_mask = input_ids.eq(0)
# print("pad_mask",pad_mask.shape)
inputs = nn.Embedding(vocab_size, embedding_dim=emb_dim)(input_ids) # [batch_szie, seq_len, emb_dim] = [3, 5, 512]
self_att = SelfAttention(emb_dim=emb_dim)
self_att(inputs, pad_mask=pad_mask)
#self_att(inputs, pad_mask=pad_mask)
multi_att = MultiHeadAttention(emb_dim=emb_dim, num_heads=8)
multi_att(inputs, pad_mask=pad_mask)
#multi_att(inputs, pad_mask=pad_mask)