注意力机制中的三种掩码技术及其PyTorch实现

在深度学习中,特别是处理序列数据时,注意力机制是一种非常关键的技术,广泛应用于各种先进的神经网络架构中,如Transformer模型。为了确保模型能够正确处理序列数据,掩码技术发挥了重要作用。本文将介绍三种常见的掩码技术:填充掩码(Padding Mask)、序列掩码(Sequence Mask)和前瞻掩码(Look-ahead Mask),并提供相应的PyTorch代码实现。

1. 填充掩码(Padding Mask)

目的

填充掩码的主要目的是确保模型在处理填充的输入数据时,不会将这些无关的数据当作有效数据处理。在序列处理中,由于不同序列的长度可能不同,通常需要对较短的序列进行填充,以保证所有序列长度一致,便于批处理。然而,这些填充的部分并不包含实际信息,因此应该在模型处理时忽略。

PyTorch实现

python 复制代码
import torch

def create_padding_mask(seq, pad_token=0):
    mask = (seq == pad_token).unsqueeze(1).unsqueeze(2)
    return mask  # (batch_size, 1, 1, seq_len)

# 示例使用
seq = torch.tensor([[7, 6, 0, 0], [1, 2, 3, 0]])
padding_mask = create_padding_mask(seq)
print(padding_mask)
2. 序列掩码(Sequence Mask)

目的

序列掩码用于更广泛地控制模型应该关注的数据部分,不仅可以指示填充位置,还可以用于其他类型的掩蔽需求。例如,在序列到序列的任务中,可能需要隐藏未来信息,以确保模型在解码时不会"窥视"到未来信息。

PyTorch实现

python 复制代码
def create_sequence_mask(seq):
    seq_len = seq.size(1)
    mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1)
    return mask  # (seq_len, seq_len)

# 示例使用
seq_len = 4
sequence_mask = create_sequence_mask(torch.zeros(seq_len, seq_len))
print(sequence_mask)
3. 前瞻掩码(Look-ahead Mask)

目的

前瞻掩码主要用于自回归模型中,以确保模型在生成序列时不会"看到"未来的符号。这保证了在给定位置的预测仅依赖于该位置之前的符号,维护了生成过程的时序正确性。

PyTorch实现

python 复制代码
def create_look_ahead_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask  # (seq_len, seq_len)

# 示例使用
look_ahead_mask = create_look_ahead_mask(4)
print(look_ahead_mask)
掩码在注意力机制中的应用

在注意力机制中,掩码被用来修改注意力得分,以确保模型在计算注意力权重时能够正确地考虑哪些部分应该被忽略。以下是一个使用掩码进行缩放点积注意力计算的示例:

python 复制代码
import torch.nn.functional as F

def scaled_dot_product_attention(q, k, v, mask=None):
    matmul_qk = torch.matmul(q, k.transpose(-2, -1))
    dk = q.size()[-1]
    scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(dk, dtype=torch.float32))
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)
    attention_weights = F.softmax(scaled_attention_logits, dim=-1)
    output = torch.matmul(attention_weights, v)
    return output, attention_weights

# 示例使用
d_model = 512
batch_size = 2
seq_len = 4
q = torch.rand((batch_size, seq_len, d_model))
k = torch.rand((batch_size, seq_len, d_model))
v = torch.rand((batch_size, seq_len, d_model))
mask = create_look_ahead_mask(seq_len)
attention_output, attention_weights = scaled_dot_product_attention(q, k, v, mask)
print(attention_output)
相关推荐
阿里云大数据AI技术1 分钟前
PAI Physical AI Notebook 详解 1:基于 Isaac 仿真的操作动作数据扩增与模仿学习
人工智能
该用户已不存在3 分钟前
Vibe Coding 入门指南:从想法到产品的完整路径
前端·人工智能·后端
一只鹿鹿鹿4 分钟前
系统安全设计方案书(Word)
开发语言·人工智能·web安全·需求分析·软件系统
Likeadust4 分钟前
视频直播点播平台EasyDSS:助力现代农业驶入数字科技“快车道”
人工智能·科技·音视频
南阳木子5 分钟前
GEO:AI 时代流量新入口,四川嗨它科技如何树立行业标杆? (2025年10月最新版)
人工智能·科技
oe10197 分钟前
好文与笔记分享 A Survey of Context Engineering for Large Language Models(中)
人工智能·笔记·语言模型·agent开发
铁锹少年10 分钟前
当多进程遇上异步:一次 Celery 与 Async SQLAlchemy 的边界冲突
分布式·后端·python·架构·fastapi
梨轻巧12 分钟前
pyside6常用控件:QCheckBox() 单个复选框、多个复选框、三态模式
python
寒秋丶17 分钟前
Milvus:集合(Collections)操作详解(三)
数据库·人工智能·python·ai·ai编程·milvus·向量数据库
寒秋丶19 分钟前
Milvus:Schema详解(四)
数据库·人工智能·python·ai·ai编程·milvus·向量数据库