注意力机制中的三种掩码技术及其PyTorch实现

在深度学习中,特别是处理序列数据时,注意力机制是一种非常关键的技术,广泛应用于各种先进的神经网络架构中,如Transformer模型。为了确保模型能够正确处理序列数据,掩码技术发挥了重要作用。本文将介绍三种常见的掩码技术:填充掩码(Padding Mask)、序列掩码(Sequence Mask)和前瞻掩码(Look-ahead Mask),并提供相应的PyTorch代码实现。

1. 填充掩码(Padding Mask)

目的

填充掩码的主要目的是确保模型在处理填充的输入数据时,不会将这些无关的数据当作有效数据处理。在序列处理中,由于不同序列的长度可能不同,通常需要对较短的序列进行填充,以保证所有序列长度一致,便于批处理。然而,这些填充的部分并不包含实际信息,因此应该在模型处理时忽略。

PyTorch实现

python 复制代码
import torch

def create_padding_mask(seq, pad_token=0):
    mask = (seq == pad_token).unsqueeze(1).unsqueeze(2)
    return mask  # (batch_size, 1, 1, seq_len)

# 示例使用
seq = torch.tensor([[7, 6, 0, 0], [1, 2, 3, 0]])
padding_mask = create_padding_mask(seq)
print(padding_mask)
2. 序列掩码(Sequence Mask)

目的

序列掩码用于更广泛地控制模型应该关注的数据部分,不仅可以指示填充位置,还可以用于其他类型的掩蔽需求。例如,在序列到序列的任务中,可能需要隐藏未来信息,以确保模型在解码时不会"窥视"到未来信息。

PyTorch实现

python 复制代码
def create_sequence_mask(seq):
    seq_len = seq.size(1)
    mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1)
    return mask  # (seq_len, seq_len)

# 示例使用
seq_len = 4
sequence_mask = create_sequence_mask(torch.zeros(seq_len, seq_len))
print(sequence_mask)
3. 前瞻掩码(Look-ahead Mask)

目的

前瞻掩码主要用于自回归模型中,以确保模型在生成序列时不会"看到"未来的符号。这保证了在给定位置的预测仅依赖于该位置之前的符号,维护了生成过程的时序正确性。

PyTorch实现

python 复制代码
def create_look_ahead_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask  # (seq_len, seq_len)

# 示例使用
look_ahead_mask = create_look_ahead_mask(4)
print(look_ahead_mask)
掩码在注意力机制中的应用

在注意力机制中,掩码被用来修改注意力得分,以确保模型在计算注意力权重时能够正确地考虑哪些部分应该被忽略。以下是一个使用掩码进行缩放点积注意力计算的示例:

python 复制代码
import torch.nn.functional as F

def scaled_dot_product_attention(q, k, v, mask=None):
    matmul_qk = torch.matmul(q, k.transpose(-2, -1))
    dk = q.size()[-1]
    scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(dk, dtype=torch.float32))
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)
    attention_weights = F.softmax(scaled_attention_logits, dim=-1)
    output = torch.matmul(attention_weights, v)
    return output, attention_weights

# 示例使用
d_model = 512
batch_size = 2
seq_len = 4
q = torch.rand((batch_size, seq_len, d_model))
k = torch.rand((batch_size, seq_len, d_model))
v = torch.rand((batch_size, seq_len, d_model))
mask = create_look_ahead_mask(seq_len)
attention_output, attention_weights = scaled_dot_product_attention(q, k, v, mask)
print(attention_output)
相关推荐
Jerryhut11 分钟前
Bev感知特征空间算法
人工智能
xian_wwq21 分钟前
【学习笔记】基于人工智能的火电机组全局性能一体化优化研究
人工智能·笔记·学习·火电
B站计算机毕业设计之家25 分钟前
基于大数据热门旅游景点数据分析可视化平台 数据大屏 Flask框架 Echarts可视化大屏
大数据·爬虫·python·机器学习·数据分析·spark·旅游
春风LiuK34 分钟前
虚实无界:VRAR如何重塑课堂与突破研究边界
人工智能·程序人生
周纠纠1 小时前
附录1:中文切词
python
Cricyta Sevina1 小时前
Java Collection 集合进阶知识笔记
java·笔记·python·collection集合
歌_顿1 小时前
Embedding 模型word2vec/glove/fasttext/elmo/doc2vec/infersent学习总结
人工智能·算法
胡萝卜3.01 小时前
深入C++可调用对象:从function包装到bind参数适配的技术实现
开发语言·c++·人工智能·机器学习·bind·function·包装器
Echo_NGC22371 小时前
【KL 散度】深入理解 Kullback-Leibler Divergence:AI 如何衡量“像不像”的问题
人工智能·算法·机器学习·散度·kl