注意力机制中的三种掩码技术及其PyTorch实现

在深度学习中,特别是处理序列数据时,注意力机制是一种非常关键的技术,广泛应用于各种先进的神经网络架构中,如Transformer模型。为了确保模型能够正确处理序列数据,掩码技术发挥了重要作用。本文将介绍三种常见的掩码技术:填充掩码(Padding Mask)、序列掩码(Sequence Mask)和前瞻掩码(Look-ahead Mask),并提供相应的PyTorch代码实现。

1. 填充掩码(Padding Mask)

目的

填充掩码的主要目的是确保模型在处理填充的输入数据时,不会将这些无关的数据当作有效数据处理。在序列处理中,由于不同序列的长度可能不同,通常需要对较短的序列进行填充,以保证所有序列长度一致,便于批处理。然而,这些填充的部分并不包含实际信息,因此应该在模型处理时忽略。

PyTorch实现

python 复制代码
import torch

def create_padding_mask(seq, pad_token=0):
    mask = (seq == pad_token).unsqueeze(1).unsqueeze(2)
    return mask  # (batch_size, 1, 1, seq_len)

# 示例使用
seq = torch.tensor([[7, 6, 0, 0], [1, 2, 3, 0]])
padding_mask = create_padding_mask(seq)
print(padding_mask)
2. 序列掩码(Sequence Mask)

目的

序列掩码用于更广泛地控制模型应该关注的数据部分,不仅可以指示填充位置,还可以用于其他类型的掩蔽需求。例如,在序列到序列的任务中,可能需要隐藏未来信息,以确保模型在解码时不会"窥视"到未来信息。

PyTorch实现

python 复制代码
def create_sequence_mask(seq):
    seq_len = seq.size(1)
    mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1)
    return mask  # (seq_len, seq_len)

# 示例使用
seq_len = 4
sequence_mask = create_sequence_mask(torch.zeros(seq_len, seq_len))
print(sequence_mask)
3. 前瞻掩码(Look-ahead Mask)

目的

前瞻掩码主要用于自回归模型中,以确保模型在生成序列时不会"看到"未来的符号。这保证了在给定位置的预测仅依赖于该位置之前的符号,维护了生成过程的时序正确性。

PyTorch实现

python 复制代码
def create_look_ahead_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask  # (seq_len, seq_len)

# 示例使用
look_ahead_mask = create_look_ahead_mask(4)
print(look_ahead_mask)
掩码在注意力机制中的应用

在注意力机制中,掩码被用来修改注意力得分,以确保模型在计算注意力权重时能够正确地考虑哪些部分应该被忽略。以下是一个使用掩码进行缩放点积注意力计算的示例:

python 复制代码
import torch.nn.functional as F

def scaled_dot_product_attention(q, k, v, mask=None):
    matmul_qk = torch.matmul(q, k.transpose(-2, -1))
    dk = q.size()[-1]
    scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(dk, dtype=torch.float32))
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)
    attention_weights = F.softmax(scaled_attention_logits, dim=-1)
    output = torch.matmul(attention_weights, v)
    return output, attention_weights

# 示例使用
d_model = 512
batch_size = 2
seq_len = 4
q = torch.rand((batch_size, seq_len, d_model))
k = torch.rand((batch_size, seq_len, d_model))
v = torch.rand((batch_size, seq_len, d_model))
mask = create_look_ahead_mask(seq_len)
attention_output, attention_weights = scaled_dot_product_attention(q, k, v, mask)
print(attention_output)
相关推荐
知识鱼丸17 分钟前
自定义数据集 使用scikit-learn中svm的包实现svm分类
人工智能
说私域40 分钟前
基于开源AI智能名片2 + 1链动模式S2B2C商城小程序视角下的个人IP人设构建研究
人工智能·小程序·开源
山海青风1 小时前
OpenAI 实战进阶教程 - 第七节: 与数据库集成 - 生成 SQL 查询与优化
数据库·人工智能·python·sql
Chatopera 研发团队2 小时前
计算图 Compute Graph 和自动求导 Autograd | PyTorch 深度学习实战
人工智能·pytorch·深度学习
纠结哥_Shrek2 小时前
pytorch实现半监督学习
人工智能·pytorch·学习
白白糖2 小时前
深度学习 Pytorch 基础网络手动搭建与快速实现
人工智能·pytorch·深度学习
AI浩2 小时前
【Block总结】HWD,小波下采样,适用分类、分割、目标检测等任务|即插即用
人工智能·目标检测·分类
ZWZhangYu2 小时前
【实践案例】基于大语言模型的海龟汤游戏
人工智能·游戏·语言模型
AI浩2 小时前
【Block总结】Shuffle Attention,新型的Shuffle注意力|即插即用
人工智能·深度学习
三月七(爱看动漫的程序员)3 小时前
模型/O功能之提示词模板
java·前端·javascript·人工智能·语言模型·langchain·prompt