Transformer模型:Decoder的self-attention mask实现

前言

这是对Transformer模型Word Embedding、Postion Embedding、Encoder self-attention mask、intra-attention mask内容的续篇。

视频链接:20、Transformer模型Decoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili

文章链接:Transformer模型:WordEmbedding实现-CSDN博客

Transformer模型:Postion Embedding实现-CSDN博客

Transformer模型:Encoder的self-attention mask实现-CSDN博客

Transformer模型:intra-attention mask实现-CSDN博客


正文

首先介绍一下Deoder的self-attention mask,它与前面的两个mask不一样地方在于Decoder是生成一个单词之后,将改单词作为输入给到Decoder中继续生成下一个,也就是相当于下三角矩阵,一次多一个,直到完成整个预测。

先生成一个下三角矩阵:

python 复制代码
tri_matrix = [torch.tril(torch.ones(L, L)) for L in tgt_len]

这里生成的两个下三角矩阵的维度是不一样的,首先要统一维度:

python 复制代码
valid_decoder_tri_matrix = [F.pad(torch.tril(torch.ones(L, L)), (0, max_tgt_seg_len-L, 0, max_tgt_seg_len-L)) for L in tgt_len]

然后就是将它转为1个3维的张量形式,过程跟先前类似,这里就不一步步拆解了:

python 复制代码
valid_decoder_tri_matrix = torch.cat([torch.unsqueeze(F.pad(torch.tril(torch.ones(L, L)), (0, max_tgt_seg_len-L, 0, max_tgt_seg_len-L)),0) for L in tgt_len])

后续掩码过程还是跟前两篇一样,这里也不多解释了:

python 复制代码
invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix
mask_decoder_self_attention = invalid_decoder_tri_matrix.to(torch.bool)
score2 = torch.randn(batch_size, max_tgt_seg_len, max_tgt_seg_len)
mask_score3 = score2.masked_fill(mask_decoder_self_attention, -1e9)
prob3 = F.softmax(mask_score3, -1)

代码

python 复制代码
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# 句子数
batch_size = 2

# 单词表大小
max_num_src_words = 10
max_num_tgt_words = 10

# 序列的最大长度
max_src_seg_len = 12
max_tgt_seg_len = 12
max_position_len = 12

# 模型的维度
model_dim = 8

# 生成固定长度的序列
src_len = torch.Tensor([11, 9]).to(torch.int32)
tgt_len = torch.Tensor([10, 11]).to(torch.int32)

# 单词索引构成的句子
src_seq = torch.cat(
    [torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max_src_seg_len - L)), 0) for L in src_len])
tgt_seq = torch.cat(
    [torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)), (0, max_tgt_seg_len - L)), 0) for L in tgt_len])

# Part1:构造Word Embedding
src_embedding_table = nn.Embedding(max_num_src_words + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words + 1, model_dim)
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)

# 构造Pos序列跟i序列
pos_mat = torch.arange(max_position_len).reshape((-1, 1))
i_mat = torch.pow(10000, torch.arange(0, 8, 2) / model_dim)

# Part2:构造Position Embedding
pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[:, 0::2] = torch.sin(pos_mat / i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat / i_mat)

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)

# 构建位置索引
src_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len), 0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len), 0) for _ in tgt_len]).to(torch.int32)

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)

# Part3:构造encoder self-attention mask
valid_encoder_pos = torch.unsqueeze(
    torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_src_seg_len - L)), 0) for L in src_len]), 2)
valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_encoder_pos_matrix = 1 - torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)
score = torch.randn(batch_size, max_src_seg_len, max_src_seg_len)
mask_score1 = score.masked_fill(mask_encoder_self_attention, -1e9)
prob1 = F.softmax(mask_score1, -1)

# Part4:构造intra-attention mask
valid_encoder_pos = torch.unsqueeze(
    torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_src_seg_len - L)), 0) for L in src_len]), 2)
valid_decoder_pos = torch.unsqueeze(
    torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_tgt_seg_len - L)), 0) for L in tgt_len]), 2)

valid_cross_pos_matrix = torch.bmm(valid_decoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_cross_pos_matrix = 1 - valid_cross_pos_matrix
mask_cross_attention = invalid_cross_pos_matrix.to(torch.bool)
mask_score2 = score.masked_fill(mask_cross_attention, -1e9)
prob2 = F.softmax(mask_score2, -1)

# Part5:构造Decoder self-attention mask
valid_decoder_tri_matrix = torch.cat([torch.unsqueeze(F.pad(torch.tril(torch.ones(L, L)), (0, max_tgt_seg_len-L, 0, max_tgt_seg_len-L)),0) for L in tgt_len])
invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix
mask_decoder_self_attention = invalid_decoder_tri_matrix.to(torch.bool)
score2 = torch.randn(batch_size, max_tgt_seg_len, max_tgt_seg_len)
mask_score3 = score2.masked_fill(mask_decoder_self_attention, -1e9)
prob3 = F.softmax(mask_score3, -1)
相关推荐
大写-凌祁5 小时前
零基础入门深度学习:从理论到实战,GitHub+开源资源全指南(2025最新版)
人工智能·深度学习·开源·github
焦耳加热5 小时前
阿德莱德大学Nat. Commun.:盐模板策略实现废弃塑料到单原子催化剂的高值转化,推动环境与能源催化应用
人工智能·算法·机器学习·能源·材料工程
深空数字孪生5 小时前
储能调峰新实践:智慧能源平台如何保障风电消纳与电网稳定?
大数据·人工智能·物联网
wan5555cn5 小时前
多张图片生成视频模型技术深度解析
人工智能·笔记·深度学习·算法·音视频
格林威6 小时前
机器视觉检测的光源基础知识及光源选型
人工智能·深度学习·数码相机·yolo·计算机视觉·视觉检测
今天也要学习吖7 小时前
谷歌nano banana官方Prompt模板发布,解锁六大图像生成风格
人工智能·学习·ai·prompt·nano banana·谷歌ai
Hello123网站7 小时前
glean-企业级AI搜索和知识发现平台
人工智能·产品运营·ai工具
AKAMAI7 小时前
Queue-it 为数十亿用户增强在线体验
人工智能·云原生·云计算
索迪迈科技7 小时前
INDEMIND亮相2025科技创变者大会,以机器人空间智能技术解锁具身智能新边界
人工智能·机器人·扫地机器人·空间智能·陪伴机器人
栒U7 小时前
一文从零部署vLLM+qwen0.5b(mac本地版,不可以实操GPU单元)
人工智能·macos·vllm