【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码------交叉注意力(Cross-Attention)

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码------交叉注意力(Cross-Attention)


文章目录


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:

可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

1. 交叉注意力的起源与提出

交叉注意力(Cross-Attention)是在深度学习中提出的一种重要注意力机制,用于在多个输入之间建立关联,主要用于多模态任务中(如图像和文本、视频和音频的联合处理)。

与常规的自注意力机制不同,交叉注意力专注于从两个不同的输入特征空间中提取和结合关键信息。这种机制最初在自然语言处理和计算机视觉的融合任务中得到应用,例如在多模态Transformer、机器翻译和图像-文本任务(如CLIP、DALL·E、VQA等)中。

  • 提出背景:交叉注意力通常用于处理两种不同类型的数据,通过这种机制,一个输入可以对另一个输入进行查询,捕捉和增强跨模态之间的关联。相比自注意力(仅在同一个输入中找到相关性),交叉注意力能够有效地捕捉多模态数据的交互信息。

2. 交叉注意力的原理

交叉注意力的核心思想是将一个输入(例如图像)作为查询(Query),另一个输入(例如文本)作为键(Key)和值(Value),通过注意力机制让查询能够从键和值中选择和关注相关信息。

交叉注意力的步骤:

  • 查询、键、值的生成: 假设有两个不同的输入数据 X1 和 X2,分别生成对应的 Query、Key 和 Value 矩阵。对于 X1,我们可以生成 Query 矩阵,而对于 X2,则可以生成 Key 和 Value 矩阵。
  • 注意力计算 : 与自注意力类似,交叉注意力通过计算 Query 和 Key 的相似性来获得注意力权重:

    其中 Q 来自 X1,而 K 和 V 来自 X2 。通过这种计算,Query 可以从X2 中提取与其最相关的信息,这种机制实现了两个输入数据之间的特征融合和信息传递。
  • 权重与输出 : 计算出的注意力权重应用到 X2的 Value 矩阵上,得到 X1在
    X2上的相关信息。这种机制实现了两个输入数据之间的特征融合和信息传递。

3. 交叉注意力的数学表示

假设有两个输入特征 X 1 ∈ R T 1 × d X_1∈R^{T_1×d} X1∈RT1×d和 X 2 ∈ R T 2 × d X_2∈R^{T_2×d} X2∈RT2×d,其中 T 1 T_1 T1和 T 2 T_2 T2分别表示两个输入的长度(如序列长度或特征维度), d d d 表示特征维度。

Query、Key 和 Value 的生成:

  • 对于 X 1 X_1 X1:生成查询矩阵 Q = W q X 1 Q=W_qX_1 Q=WqX1
  • 对于 X 2 X_2 X2:生成键矩阵 K = W k X 2 K=W_kX_2 K=WkX2和值矩阵 V = W v X 2 V=W_vX_2 V=WvX2

注意力计算:

其中, W q W_q Wq, W k W_k Wk, W v W_v Wv ∈ R d × d ∈R^{d×d} ∈Rd×d是线性变换矩阵, d d d 是键的维度。

结果输出: 注意力权重应用于 V V V 后的结果,即:

4. 交叉注意力的应用场景与发展

交叉注意力在以下场景中得到广泛应用:

  • 多模态学习:交叉注意力在视觉和语言任务中的多模态联合建模中尤为常见,如图像与文本的对齐(CLIP)、视觉问答(VQA)和跨模态生成任务(如DALL·E)。
  • 机器翻译:交叉注意力在Transformer中的"解码器"部分用于让生成的序列(目标语言)参考源语言的表示,这大大提高了翻译质量。
  • Transformer架构的扩展:在诸如BERT、GPT等基于Transformer的模型中,交叉注意力也被用于各种任务,例如文本生成、序列到序列任务等。

发展过程中,交叉注意力机制已经被改进和扩展。例如,层次化交叉注意力(Hierarchical Cross-Attention)通过在不同层次上融合多模态信息,进一步提升了模型在多模态任务上的性能。

5. 代码实现

下面是一个基于PyTorch的交叉注意力机制的简单实现,用于展示如何在两个不同的输入(例如图像和文本)之间计算交叉注意力。

csharp 复制代码
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        
        assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"

        # 线性变换,用于生成 Q, K, V 矩阵
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)

        # 输出的线性变换
        self.out_proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
        
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x1, x2):
        # x1 是 Query,x2 是 Key 和 Value
        B, T1, C = x1.shape  # x1 的形状: [batch_size, seq_len1, dim]
        _, T2, _ = x2.shape  # x2 的形状: [batch_size, seq_len2, dim]

        # 生成 Q, K, V 矩阵
        Q = self.q_proj(x1).view(B, T1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x2).view(B, T2, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x2).view(B, T2, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力得分
        attn_scores = (Q @ K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = self.softmax(attn_scores)  # 注意力权重
        attn_weights = self.dropout(attn_weights)  # dropout 防止过拟合

        # 使用注意力权重加权值矩阵
        attn_output = attn_weights @ V
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T1, C)

        # 输出线性变换
        output = self.out_proj(attn_output)
        return output

# 测试交叉注意力机制
if __name__ == "__main__":
    B, T1, T2, C = 2, 10, 20, 64  # batch_size, seq_len1, seq_len2, channels
    x1 = torch.randn(B, T1, C)  # Query 输入
    x2 = torch.randn(B, T2, C)  # Key 和 Value 输入

    cross_attn = CrossAttention(dim=C, num_heads=4)
    output = cross_attn(x1, x2)
    
    print("输出形状:", output.shape)  # 输出应该为 [batch_size, seq_len1, channels]

6. 代码解释

CrossAttention 类:该类实现了交叉注意力机制,允许将两个不同的输入(x1x2)进行交叉信息融合。

  • q_proj, k_proj, v_proj:三个线性层,用于将输入分别映射到 Query、Key 和 Value 空间。
  • num_headshead_dim:定义了多头注意力机制的头数和每个头的维度。

forward 函数:实现前向传播过程。

  • Q, K, V:分别从 x1x2 中生成 Query、Key 和 Value 矩阵,形状为 [batch_size, num_heads, seq_len, head_dim]
  • attn_scores:计算 Query 和 Key 的点积,得到注意力得分。
  • attn_weights:通过 softmax 对得分进行归一化,得到注意力权重。
  • attn_output:利用注意力权重对 Value 矩阵进行加权求和,得到最终的注意力输出。

测试部分:随机生成两个输入张量 x1x2,并测试交叉注意力的输出形状,确保与预期一致。

7. 总结

交叉注意力在多模态学习中起到了至关重要的作用,能够有效融合不同类型的数据,使得模型可以同时处理图像、文本等多种信息。通过捕捉模态之间的相关性,交叉注意力为多模态任务中的特征融合提供了强大的工具。

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:

可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

相关推荐
通信.萌新33 分钟前
OpenCV边沿检测(Python版)
人工智能·python·opencv
ARM+FPGA+AI工业主板定制专家35 分钟前
基于RK3576/RK3588+FPGA+AI深度学习的轨道异物检测技术研究
人工智能·深度学习
赛丽曼38 分钟前
机器学习-分类算法评估标准
人工智能·机器学习·分类
Bran_Liu39 分钟前
【LeetCode 刷题】字符串-字符串匹配(KMP)
python·算法·leetcode
伟贤AI之路40 分钟前
从音频到 PDF:AI 全流程打造完美英文绘本教案
人工智能
weixin_3077791342 分钟前
分析一个深度学习项目并设计算法和用PyTorch实现的方法和步骤
人工智能·pytorch·python
helianying551 小时前
云原生架构下的AI智能编排:ScriptEcho赋能前端开发
前端·人工智能·云原生·架构
池央1 小时前
StyleGAN - 基于样式的生成对抗网络
人工智能·神经网络·生成对抗网络
Channing Lewis2 小时前
flask实现重启后需要重新输入用户名而避免浏览器使用之前已经记录的用户名
后端·python·flask
Channing Lewis2 小时前
如何在 Flask 中实现用户认证?
后端·python·flask