注意力机制中的三种掩码技术及其PyTorch实现

在深度学习中,特别是处理序列数据时,注意力机制是一种非常关键的技术,广泛应用于各种先进的神经网络架构中,如Transformer模型。为了确保模型能够正确处理序列数据,掩码技术发挥了重要作用。本文将介绍三种常见的掩码技术:填充掩码(Padding Mask)、序列掩码(Sequence Mask)和前瞻掩码(Look-ahead Mask),并提供相应的PyTorch代码实现。

1. 填充掩码(Padding Mask)

目的

填充掩码的主要目的是确保模型在处理填充的输入数据时,不会将这些无关的数据当作有效数据处理。在序列处理中,由于不同序列的长度可能不同,通常需要对较短的序列进行填充,以保证所有序列长度一致,便于批处理。然而,这些填充的部分并不包含实际信息,因此应该在模型处理时忽略。

PyTorch实现

python 复制代码
import torch

def create_padding_mask(seq, pad_token=0):
    mask = (seq == pad_token).unsqueeze(1).unsqueeze(2)
    return mask  # (batch_size, 1, 1, seq_len)

# 示例使用
seq = torch.tensor([[7, 6, 0, 0], [1, 2, 3, 0]])
padding_mask = create_padding_mask(seq)
print(padding_mask)
2. 序列掩码(Sequence Mask)

目的

序列掩码用于更广泛地控制模型应该关注的数据部分,不仅可以指示填充位置,还可以用于其他类型的掩蔽需求。例如,在序列到序列的任务中,可能需要隐藏未来信息,以确保模型在解码时不会"窥视"到未来信息。

PyTorch实现

python 复制代码
def create_sequence_mask(seq):
    seq_len = seq.size(1)
    mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1)
    return mask  # (seq_len, seq_len)

# 示例使用
seq_len = 4
sequence_mask = create_sequence_mask(torch.zeros(seq_len, seq_len))
print(sequence_mask)
3. 前瞻掩码(Look-ahead Mask)

目的

前瞻掩码主要用于自回归模型中,以确保模型在生成序列时不会"看到"未来的符号。这保证了在给定位置的预测仅依赖于该位置之前的符号,维护了生成过程的时序正确性。

PyTorch实现

python 复制代码
def create_look_ahead_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask  # (seq_len, seq_len)

# 示例使用
look_ahead_mask = create_look_ahead_mask(4)
print(look_ahead_mask)
掩码在注意力机制中的应用

在注意力机制中,掩码被用来修改注意力得分,以确保模型在计算注意力权重时能够正确地考虑哪些部分应该被忽略。以下是一个使用掩码进行缩放点积注意力计算的示例:

python 复制代码
import torch.nn.functional as F

def scaled_dot_product_attention(q, k, v, mask=None):
    matmul_qk = torch.matmul(q, k.transpose(-2, -1))
    dk = q.size()[-1]
    scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(dk, dtype=torch.float32))
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)
    attention_weights = F.softmax(scaled_attention_logits, dim=-1)
    output = torch.matmul(attention_weights, v)
    return output, attention_weights

# 示例使用
d_model = 512
batch_size = 2
seq_len = 4
q = torch.rand((batch_size, seq_len, d_model))
k = torch.rand((batch_size, seq_len, d_model))
v = torch.rand((batch_size, seq_len, d_model))
mask = create_look_ahead_mask(seq_len)
attention_output, attention_weights = scaled_dot_product_attention(q, k, v, mask)
print(attention_output)
相关推荐
互联网全栈架构11 分钟前
遨游Spring AI:第一盘菜Hello World
java·人工智能·后端·spring
m0_4652157912 分钟前
大语言模型解析
人工智能·语言模型·自然语言处理
张较瘦_1 小时前
[论文阅读] 人工智能+软件工程 | 结对编程中的知识转移新图景
人工智能·软件工程·结对编程
小Q小Q2 小时前
cmake编译LASzip和LAStools
人工智能·计算机视觉
yzx9910132 小时前
基于 Q-Learning 算法和 CNN 的强化学习实现方案
人工智能·算法·cnn
token-go2 小时前
[特殊字符] 革命性AI提示词优化平台正式开源!
人工智能·开源
cooldream20093 小时前
华为云Flexus+DeepSeek征文|基于华为云Flexus X和DeepSeek-R1打造个人知识库问答系统
人工智能·华为云·dify
老胖闲聊6 小时前
Python Copilot【代码辅助工具】 简介
开发语言·python·copilot
Blossom.1186 小时前
使用Python和Scikit-Learn实现机器学习模型调优
开发语言·人工智能·python·深度学习·目标检测·机器学习·scikit-learn
曹勖之7 小时前
基于ROS2,撰写python脚本,根据给定的舵-桨动力学模型实现动力学更新
开发语言·python·机器人·ros2