如何用熵正则化控制注意力分数的分布

先写一个CrossAttention模块,

复制代码
# input: Q(B, L, d), KV(B, N, d)
# output: (B, L, dim)
# 0<alpha<=ln(N), alpha越接近0, 注意力分数越逼近one-hot分布
class CrossAttention(layers.Layer):
    def __init__(self, num_head, dim, alpha,**kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha
        self.num_head = num_head
        self.dim = dim
        self.layernorm = layers.LayerNormalization()

    def build(self, input_shape):
        self.qdk = self.add_weight(name='query_dense_kernel', shape=[input_shape[0][-1], self.num_head, self.dim])
        self.kdk = self.add_weight(name='key_dense_kernel', shape=[input_shape[1][-1], self.num_head, self.dim])
        self.vdk = self.add_weight(name='value_dense_kernel', shape=[input_shape[1][-1], self.num_head, self.dim])
        self.odk = self.add_weight(name='output_dense_kernel', shape=[self.num_head, self.dim, self.dim])
        self.odb = self.add_weight(name='output_dense_bais', shape=[self.dim])

    def call(self, inputs, *args, **kwargs):
        Q, KV = inputs
        query = tf.einsum("abc, cde->abde", Q, self.qdk)
        key = tf.einsum("abc, cde->abde", KV, self.kdk)
        value = tf.einsum("abc, cde->abde", KV, self.vdk)
        query = tf.multiply(query, 1.0 / tf.math.sqrt(float(self.dim)))
        attention_scorces = tf.math.softmax(tf.einsum("abcd, aecd->acbe", query, key))

        self.add_loss(
            tf.reduce_mean((-tf.reduce_sum(attention_scorces * tf.math.log(attention_scorces + 1e-07), axis=-1) - self.alpha)**2))

        attention_output = tf.einsum("abcd, aceb->aecd", value, attention_scorces)
        output = tf.einsum("abcd, cdd->abd", attention_output, self.odk) + self.odb
        return self.layernorm(output + Q), attention_scorces

损失函数包含两种类型:prediction loss和regularization loss。

regularization loss需要add_loss方法进行添加,add_loss方法添加的损失值可以通过model.losses进行访问,返回一个集合,集合每个元素对应一个正则损失。
regularization loss被add_loss方法添加后,需要被tf.GradientTape()的作用域包含

复制代码
def fit(x, y, epochs, model):
    optimizer = tf.keras.optimizers.Adam()
    for epoch in range(epochs):
        print("\nStart of epoch %d" % (epoch,))
        with tf.GradientTape() as tape:
            logits = model(x, training=True)[0]
            # Compute the loss value for this minibatch.
            loss_value = tf.keras.losses.binary_crossentropy(y, logits)
            print(model.losses)
            loss_value += sum(model.losses)
        grads = tape.gradient(loss_value, model.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        print(
            "attention scores entropy loss: %s" % (sum(model.losses))
        )
        print("loss" % loss_value)

接下来,设置一个简单的任务和数据看看熵正则化的效果,

复制代码
class model(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.CA = CrossAttention(3, 16, 0.2)

    def build(self, input_shape):
        self.k = self.add_weight(name="predict_kernel", shape=[input_shape[0][-2], 16, 2])

    def call(self, inputs, *args, **kwargs):
        x, scores = self.CA(inputs)

        return tf.math.sigmoid(tf.einsum("abc, bcd->ad", x, self.k)), scores


if __name__ == '__main__':
    Q = tf.random.uniform((1, 4, 16))
    KV = tf.random.uniform((1, 6, 16))
    labels = tf.constant([[0., 1]])

    model = model()
    fit((Q, KV), labels, epochs=1000, model=model)

    print(model((Q, KV)))

模型训练好后,打印注意力分数的分布情况,可以发现每一行注意力分数都接近one-hot分布。

一个概率分布的信息熵最小值为0,最大值为。最小值对应熵最小的one-hot分布,最大值对应熵最大的均匀分布。在这里设置的熵正则化损失函数为,通过调整alpha的大小,可以控制注意力分数逼近one-hot分布的程度。

复制代码
<tf.Tensor: shape=(1, 3, 4, 6), dtype=float32, numpy=
array([[[[3.27099487e-03, 2.51237601e-02, 9.59103048e-01, 6.17671618e-03, 4.41661570e-03, 1.90874794e-03],
         [3.13497148e-03, 2.45431308e-02, 9.60317731e-01, 5.98660251e-03, 4.21733223e-03, 1.80015271e-03],
         [3.40689556e-03, 2.57629622e-02, 9.57871556e-01, 6.44017057e-03, 4.52079810e-03, 1.99752697e-03],
         [1.10661890e-03, 1.27607975e-02, 9.81542170e-01, 2.43383530e-03, 1.57478068e-03, 5.81745524e-04]],

        [[3.10348310e-02, 8.87579299e-05, 4.77730093e-04, 1.46521628e-03, 9.52835977e-01, 1.40975416e-02],
         [2.63910089e-02, 5.77273713e-05, 3.39534425e-04, 1.06126023e-03, 9.60522711e-01, 1.16277682e-02],
         [3.12446002e-02, 9.48462111e-05, 5.04475611e-04, 1.46811537e-03, 9.52523112e-01, 1.41648324e-02],
         [1.50452955e-02, 1.21340790e-05, 9.36727956e-05, 3.67841218e-04, 9.78706181e-01, 5.77491429e-03]],

        [[2.43717595e-03, 2.99123675e-02, 9.57738578e-01, 8.62001721e-03, 6.63190091e-04, 6.28605427e-04],
         [2.37493636e-03, 2.91979928e-02, 9.58939075e-01, 8.24017916e-03, 6.44378248e-04, 6.03430963e-04],
         [2.84763589e-03, 3.26210111e-02, 9.53170657e-01, 9.78391431e-03, 8.07495147e-04, 7.69288861e-04],
         [1.05220091e-03, 1.88548081e-02, 9.75066125e-01, 4.56135161e-03, 2.38224253e-04, 2.27281591e-04]]]],
      dtype=float32)>
相关推荐
Ka1Yan6 小时前
[算法] 双指针:本质是“分治思维“——从基础原理到实战的深度解析
java·开发语言·数据结构·算法·面试
Bling_Bling_17 小时前
Vue2 与 Vue3 路由钩子的区别及用法详解
开发语言·前端·vue
蒋星熠7 小时前
Redis 7.0 高性能缓存架构设计与优化
数据库·redis·分布式·python·缓存·docker·微服务
smilejingwei7 小时前
数据分析编程第六步:大数据运算
java·大数据·开发语言·数据分析·编程·esprocspl
☆璇7 小时前
【C++】C++的IO流
开发语言·c++
雷达学弱狗8 小时前
python反转字符串
开发语言·python
励志不掉头发的内向程序员8 小时前
STL库——stack/queue(类函数学习)
开发语言·c++·学习
努力努力再努力wz8 小时前
【c++进阶系列】:万字详解异常
java·linux·运维·服务器·开发语言·c++
Peter_Deng.8 小时前
C语言 - 输出参数详解:从简单示例到 alloc_chrdev_region
c语言·开发语言