主要组成部分:
1. 定义注意力层:
定义一个Attention_Layer类,接受两个参数:hidden_dim(隐藏层维度)和is_bi_rnn(是否是双向RNN)。
2. 定义前向传播:
定义了注意力层的前向传播过程,包括计算注意力权重和输出。
3. 数据准备
生成一个随机的数据集,包含3个句子,每个句子10个词,每个词128个特征。
4. 实例化注意力层:
实例化一个注意力层,接受两个参数:hidden_dim(隐藏层维度)和is_bi_rnn(是否是双向RNN)。
5. 前向传播
将数据传递给注意力层的前向传播方法。
6. 分析结果
获取第一个句子的注意力权重。
7. 可视化注意力权重
使用matplotlib库可视化了注意力权重。
python
**主要函数和类:**
Attention_Layer类:定义了注意力层的结构和前向传播过程。
forward方法:定义了注意力层的前向传播过程。
torch.from_numpy函数:将numpy数组转换为PyTorch张量。
matplotlib库:用于可视化注意力权重。
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# 定义注意力层
class Attention_Layer(nn.Module):
def __init__(self, hidden_dim, is_bi_rnn):
super(Attention_Layer,self).__init__()
self.hidden_dim = hidden_dim
self.is_bi_rnn = is_bi_rnn
if is_bi_rnn:
self.Q_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)
self.K_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)
self.V_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)
else:
self.Q_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)
self.K_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)
self.V_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)
def forward(self, inputs, lens):
# 获取输入的大小
size = inputs.size()
Q = self.Q_linear(inputs)
K = self.K_linear(inputs).permute(0, 2, 1)
V = self.V_linear(inputs)
max_len = max(lens)
sentence_lengths = torch.Tensor(lens)
mask = torch.arange(sentence_lengths.max().item())[None, :] < sentence_lengths[:, None]
mask = mask.unsqueeze(dim = 1)
mask = mask.expand(size[0], max_len, max_len)
padding_num = torch.ones_like(mask)
padding_num = -2**31 * padding_num.float()
alpha = torch.matmul(Q, K)
alpha = torch.where(mask, alpha, padding_num)
alpha = F.softmax(alpha, dim = 2)
out = torch.matmul(alpha, V)
return out
# 准备数据
data = np.random.rand(3, 10, 128) # 3个句子,每个句子10个词,每个词128个特征
lens = [7, 10, 4] # 每个句子的长度
# 实例化注意力层
hidden_dim = 64
is_bi_rnn = True
att_L = Attention_Layer(hidden_dim, is_bi_rnn)
# 前向传播
att_out = att_L(torch.from_numpy(data).float(), lens)
# 分析结果
attention_weights = att_out[0, :, :].detach().numpy() # 获取第一个句子的注意力权重
# 可视化注意力权重
plt.imshow(attention_weights, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.show()
data:image/s3,"s3://crabby-images/c0cd9/c0cd96a6fe4f37f84a702ea9a46c54a703d3ef05" alt=""