注意力机制
import torch
import torch.nn as nn
class Attention(nn.Module):
def __init__(self, hidden_dim):
super(Attention, self).__init__()
self.attention = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, encoder_outputs):
# encoder_outputs shape: (batch_size, sequence_length, hidden_dim)
attn_weights = self.attention(encoder_outputs) # (batch_size, sequence_length, 1)
attn_weights = torch.softmax(attn_weights, dim=1) # (batch_size, sequence_length, 1)
context = torch.sum(attn_weights * encoder_outputs, dim=1) # (batch_size, hidden_dim)
return context, attn_weights
# 示例用法
batch_size = 2
sequence_length = 5
hidden_dim = 10
encoder_outputs = torch.randn(batch_size, sequence_length, hidden_dim)
attention_layer = Attention(hidden_dim)
context, attn_weights = attention_layer(encoder_outputs)
print("Context:", context)
print("Attention Weights:", attn_weights)
自注意力机制
import torch
import torch.nn as nn
class MultiHeadSelfAttention(nn.Module):
def init(self, embed_dim, num_heads):
super(MultiHeadSelfAttention, self).init()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
self.o_proj = nn.Linear(embed_dim, embed_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, seq_length, embed_dim = x.size()
qkv = self.qkv_proj(x) # (batch_size, seq_length, embed_dim * 3)
qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
qkv = qkv.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_length, 3 * head_dim)
q, k, v = qkv.chunk(3, dim=-1) # Each has shape (batch_size, num_heads, seq_length, head_dim)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) # Scaled dot-product
attn_weights = self.softmax(attn_weights) # (batch_size, num_heads, seq_length, seq_length)
attn_output = torch.matmul(attn_weights, v) # (batch_size, num_heads, seq_length, head_dim)
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim)
output = self.o_proj(attn_output)
return output, attn_weights
示例用法
batch_size = 2
seq_length = 5
embed_dim = 16
num_heads = 4
x = torch.randn(batch_size, seq_length, embed_dim)
self_attention_layer = MultiHeadSelfAttention(embed_dim, num_heads)
output, attn_weights = self_attention_layer(x)
print("Output:", output)
print("Attention Weights:", attn_weights)