TensorFlow中的掩码传递

复制代码
from keras import layers
import tensorflow as tf

# The number of lines in the vocabulary.txt
vocabulary_size = 15


def positional_encoding(length, depth):
    depth = depth / 2.
    positions = tf.range(length, dtype=tf.float32)[:, tf.newaxis]  # (seq, 1)
    depths = tf.range(depth, dtype=tf.float32)[tf.newaxis, :] / depth  # (1, depth)
    angle_rates = 1. / (10000 ** depths)  # (1, depth)
    angle_rads = positions * angle_rates  # (pos, depth)

    pos_encoding = tf.concat(
        [tf.sin(angle_rads), tf.cos(angle_rads)],
        axis=-1)
    return tf.cast(pos_encoding, dtype=tf.float32)


class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, model_dim=512):
        super().__init__()
        self.model_dim = model_dim
        self.Text_Embedding = layers.Embedding(input_dim=vocabulary_size,
                                               output_dim=model_dim,
                                               mask_zero=True)

    def build(self, input_shape):
        self.pos_encoding = positional_encoding(length=input_shape[-1], depth=self.model_dim)

    # By overriding the compute_mask method,
    # it is possible to ensure that the mask of the Embedding layer
    # is passed to all subsequent MultiHeadAttention layers.
    # If this method is not overridden,
    # the mask will be blocked by the positional_encoding method.
    def compute_mask(self, *args, **kwargs):
        return self.Text_Embedding.compute_mask(*args, **kwargs)

    def call(self, x):
        x = self.Text_Embedding(inputs)
        x *= tf.math.sqrt(tf.cast(self.model_dim, tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :, :]
        return x


# The number of multi-head self-attention layers is 12,
# with 8 attention heads per layer,
# and the dimension of all latent layers is 512.
class Text_Tokenizer(layers.Layer):
    def __init__(self, model_dim=36, num_layers=2):
        super().__init__()
        self.num_layers = num_layers
        self.model_dim = model_dim
        self.Text_Embedding = PositionalEmbedding()
        self.MHAs = [layers.MultiHeadAttention(num_heads=4, key_dim=model_dim) for _ in range(num_layers)]

    # input : (B, L=77)
    # output : (B, L=77, 512)
    def call(self, inputs, *args, **kwargs):
        # (B, L=77, 512)
        x = self.Text_Embedding(inputs)
        for i in range(self.num_layers):
            print(x._keras_mask)
            x = self.MHAs[i](query=x,
                             value=x,
                             key=x)

        return x


# forward propagation
if __name__ == '__main__':
    inputs = tf.constant([[3, 14, 2, 0, 0, 0]])
    model = Text_Tokenizer()
    # _keras_mask是tensor的潜在属性
    print(model(inputs)._keras_mask)

上述定义了一个自定义层,该层中包含了Embedding,Position_encodering,以及2个自注意力层。当Embedding启用掩码机制,如果在PositionalEmbedding类中不去重写compute_mask方法,则Embedding的掩码将被Position_encodering方法给阻断无法传递给后续的自注意力层,从而导致自注意力层在计算注意力分数时关注到一些无意义的位置。
._keras_mask是tensor的潜在属性,存储着tensor对应得mask。如果PositionalEmbedding类中重写了compute_mask方法,则Text_Tokenizer类中每个子层的输出都将拥有_keras_mask属性,Embedding得掩码会一步步传递下去,可以通过_keras_mask属性打印每层输出所携带得掩码,甚至能打印模型最终输出所携带得掩码。
总结:如果模型中某一层采用了掩码机制,且希望后续层都需要遵循同样得掩码,则都应该显示重写compute_mask方法,从而能使得掩码能够传递。keras.layers下实现得网络层都默认能实现掩码传递。而自己自定义的层一定要根据是否需要掩码来判断是否要重写compute_mask方法,如果你自定义了一个层,并希望后续的层能沿用该层的掩码,则一定要重写compute_mask方法。如果你自定义的层中包含的全是keras.layers下的实现层,则可以不需要重写compute_mask方法,因为keras.layers下的实现层都默认支持掩码传递。

相关推荐
爱学习的小鱼gogo7 分钟前
pyhton 螺旋矩阵(指针-矩阵-中等)含源码(二十六)
python·算法·矩阵·指针·经验·二维数组·逆序
文火冰糖的硅基工坊11 分钟前
[嵌入式系统-146]:五次工业革命对应的机器人形态的演进、主要功能的演进以及操作系统的演进
前端·网络·人工智能·嵌入式硬件·机器人
猫头虎17 分钟前
openAI发布的AI浏览器:什么是Atlas?(含 ChatGPT 浏览功能)macOS 离线下载安装Atlas完整教程
人工智能·macos·chatgpt·langchain·prompt·aigc·agi
老六哥_AI助理指南22 分钟前
为什么AI会改变单片机的未来?
人工智能·单片机·嵌入式硬件
SEO_juper33 分钟前
2026 AI可见性:构建未来-proof策略的顶级工具
人工智能·搜索引擎·百度·工具·数字营销
sivdead36 分钟前
当前智能体的几种形式
人工智能·后端·agent
AIGC_北苏37 分钟前
大语言模型,一个巨大的矩阵
人工智能·语言模型·矩阵
算家计算1 小时前
DeepSeek-OCR本地部署教程:DeepSeek突破性开创上下文光学压缩,10倍效率重构文本处理范式
人工智能·开源·deepseek
言之。1 小时前
Andrej Karpathy 演讲【PyTorch at Tesla】
人工智能·pytorch·python
算家计算1 小时前
快手推出“工具+模型+平台”AI编程生态!大厂挤占AI赛道,中小企业如何突围?
人工智能·ai编程·资讯