自注意力机制中的gen_nopeek_mask()函数

"no-peek"掩码通常用于在自注意力机制中,确保模型在生成序列时只能注意到当前位置之前的信息,而不能"窥视"未来的信息

python 复制代码
def gen_nopeek_mask(length):    
    mask = (torch.triu(torch.ones(length, length)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask.to(device)
  1. torch.triu(torch.ones(length, length)) == 1: 创建一个大小为 (length, length) 的上三角矩阵,其中上三角的元素为1,下三角的元素为0。

  2. .transpose(0, 1): 将矩阵进行转置,得到对角线上方的区域。

  3. mask = mask.float(): 将布尔类型的矩阵转换为浮点数类型。

  4. .masked_fill(mask == 0, float('-inf')): 将矩阵中值为0的位置用负无穷(-∞)填充。这样,在计算注意力权重时,这些位置的值经过 softmax 函数后将趋近于零,表示模型在这些位置不应该关注。

  5. .masked_fill(mask == 1, float(0.0)): 将矩阵中值为1的位置用0填充。这样,在计算注意力权重时,这些位置的值经过 softmax 函数后将保持为1,表示模型在这些位置应该关注。

最终,mask 是一个上三角矩阵,其中对角线及其以下的元素为负无穷,而对角线以上的元素为0。这样的矩阵在自注意力机制中被用作掩码,确保模型在生成每个位置时只关注之前的位置,而不会使用未来的信息。

让我们使用一个具体的长度来演示 gen_nopeek_mask 函数,比如 length = 4。以下是运行这个函数的示例:

python 复制代码
import torch

def gen_nopeek_mask(length):
    mask = (torch.triu(torch.ones(length, length)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# 生成长度为 4 的 nopeek mask
mask_example = gen_nopeek_mask(4)
print(mask_example)
复制代码
运行这个示例,将得到一个 4x4 的矩阵,其中包含了上三角区域以及对角线以下的部分:
bash 复制代码
tensor([[ 0., -inf, -inf, -inf],
        [ 0.,  0., -inf, -inf],
        [ 0.,  0.,  0., -inf],
        [ 0.,  0.,  0.,  0.]])

这个矩阵是一个示例的 "no-peek" 掩码。在这个掩码中,对角线以下和对角线上的元素被设置为负无穷和零,以确保在自注意力机制中,模型只能关注当前位置之前的信息。这种掩码通常在 Transformer 模型中的解码器中使用。

将矩阵中值为0的位置用负无穷(-∞)填充。这样,在计算注意力权重时,这些位置的值经过 softmax 函数后将趋近于0,表示模型在这些位置不应该关注

将矩阵中值为1的位置用0填充。这样,在计算注意力权重时,这些位置的值经过 softmax 函数后将保持为1,表示模型在这些位置应该关注

相关推荐
思绪无限9 分钟前
YOLOv5至YOLOv12升级:常见车型识别系统的设计与实现(完整代码+界面+数据集项目)
人工智能·深度学习·yolo·目标检测·目标跟踪·yolov12·yolo全家桶
lkforce10 分钟前
MiniMind学习笔记(零)--基础概念
人工智能·算法·机器学习·token·分词器·minimind·词汇表
Ndmzi13 分钟前
每次启动claude 后powershell字体颜色就自动修改了,退出后也不会恢复原状,这是什么原因?
人工智能
Baihai_IDP13 分钟前
以 Nano-vLLM 为例,深入理解 LLM 推理引擎(Part 2)
人工智能·面试·llm
刘~浪地球14 分钟前
当AI开始“制造“:智能工厂是提升效率还是取代工人?
人工智能·制造
BFT白芙堂15 分钟前
基于 AR 阻抗可视化的 Franka Research3 机械臂遥操作设计与应用
人工智能·深度学习·机器学习·机器人·ar·franka
踩着两条虫19 分钟前
VTJ.PRO 新手入门:从环境搭建到 AI 生成首个 Vue3 应用
前端·javascript·数据库·vue.js·人工智能·低代码
2013编程爱好者21 分钟前
【AI】豆包+千问下载以及使用指南
人工智能·千问·豆包
山科智能信息处理实验室22 分钟前
(ITES 2025)教育推荐系统综述:主流技术、应用场景与未来趋势
人工智能
OneThingAI25 分钟前
网心技术 | Claude Managed Agents 让 Harness 变成服务
人工智能·claude·onethingai·网心科技