注意力机制中的三种掩码技术及其PyTorch实现

在深度学习中,特别是处理序列数据时,注意力机制是一种非常关键的技术,广泛应用于各种先进的神经网络架构中,如Transformer模型。为了确保模型能够正确处理序列数据,掩码技术发挥了重要作用。本文将介绍三种常见的掩码技术:填充掩码(Padding Mask)、序列掩码(Sequence Mask)和前瞻掩码(Look-ahead Mask),并提供相应的PyTorch代码实现。

1. 填充掩码(Padding Mask)

目的

填充掩码的主要目的是确保模型在处理填充的输入数据时,不会将这些无关的数据当作有效数据处理。在序列处理中,由于不同序列的长度可能不同,通常需要对较短的序列进行填充,以保证所有序列长度一致,便于批处理。然而,这些填充的部分并不包含实际信息,因此应该在模型处理时忽略。

PyTorch实现

python 复制代码
import torch

def create_padding_mask(seq, pad_token=0):
    mask = (seq == pad_token).unsqueeze(1).unsqueeze(2)
    return mask  # (batch_size, 1, 1, seq_len)

# 示例使用
seq = torch.tensor([[7, 6, 0, 0], [1, 2, 3, 0]])
padding_mask = create_padding_mask(seq)
print(padding_mask)
2. 序列掩码(Sequence Mask)

目的

序列掩码用于更广泛地控制模型应该关注的数据部分,不仅可以指示填充位置,还可以用于其他类型的掩蔽需求。例如,在序列到序列的任务中,可能需要隐藏未来信息,以确保模型在解码时不会"窥视"到未来信息。

PyTorch实现

python 复制代码
def create_sequence_mask(seq):
    seq_len = seq.size(1)
    mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1)
    return mask  # (seq_len, seq_len)

# 示例使用
seq_len = 4
sequence_mask = create_sequence_mask(torch.zeros(seq_len, seq_len))
print(sequence_mask)
3. 前瞻掩码(Look-ahead Mask)

目的

前瞻掩码主要用于自回归模型中,以确保模型在生成序列时不会"看到"未来的符号。这保证了在给定位置的预测仅依赖于该位置之前的符号,维护了生成过程的时序正确性。

PyTorch实现

python 复制代码
def create_look_ahead_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask  # (seq_len, seq_len)

# 示例使用
look_ahead_mask = create_look_ahead_mask(4)
print(look_ahead_mask)
掩码在注意力机制中的应用

在注意力机制中,掩码被用来修改注意力得分,以确保模型在计算注意力权重时能够正确地考虑哪些部分应该被忽略。以下是一个使用掩码进行缩放点积注意力计算的示例:

python 复制代码
import torch.nn.functional as F

def scaled_dot_product_attention(q, k, v, mask=None):
    matmul_qk = torch.matmul(q, k.transpose(-2, -1))
    dk = q.size()[-1]
    scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(dk, dtype=torch.float32))
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)
    attention_weights = F.softmax(scaled_attention_logits, dim=-1)
    output = torch.matmul(attention_weights, v)
    return output, attention_weights

# 示例使用
d_model = 512
batch_size = 2
seq_len = 4
q = torch.rand((batch_size, seq_len, d_model))
k = torch.rand((batch_size, seq_len, d_model))
v = torch.rand((batch_size, seq_len, d_model))
mask = create_look_ahead_mask(seq_len)
attention_output, attention_weights = scaled_dot_product_attention(q, k, v, mask)
print(attention_output)
相关推荐
黎燃2 分钟前
当 YOLO 遇见编剧:用自然语言生成技术把“目标检测”写成“目标剧情”
人工智能
算家计算3 分钟前
AI教母李飞飞团队发布最新空间智能模型!一张图生成无限3D世界,元宇宙越来越近了
人工智能·资讯
掘金一周6 分钟前
Flutter Riverpod 3.0 发布,大规模重构下的全新状态管理框架 | 掘金一周 9.18
前端·人工智能·后端
CoovallyAIHub22 分钟前
开源的消逝与新生:从 TensorFlow 的落幕到开源生态的蜕变
pytorch·深度学习·llm
用户51914958484535 分钟前
C#记录类型与集合的深度解析:从默认实现到自定义比较器
人工智能·aigc
IT_陈寒4 小时前
React 18实战:7个被低估的Hooks技巧让你的开发效率提升50%
前端·人工智能·后端
数据智能老司机5 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
逛逛GitHub5 小时前
飞书多维表“独立”了!功能强大的超出想象。
人工智能·github·产品
机器之心5 小时前
刚刚,DeepSeek-R1论文登上Nature封面,通讯作者梁文锋
人工智能·openai
数据智能老司机6 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言