【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码------交叉注意力(Cross-Attention)

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码------交叉注意力(Cross-Attention)


文章目录


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:

可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

1. 交叉注意力的起源与提出

交叉注意力(Cross-Attention)是在深度学习中提出的一种重要注意力机制,用于在多个输入之间建立关联,主要用于多模态任务中(如图像和文本、视频和音频的联合处理)。

与常规的自注意力机制不同,交叉注意力专注于从两个不同的输入特征空间中提取和结合关键信息。这种机制最初在自然语言处理和计算机视觉的融合任务中得到应用,例如在多模态Transformer、机器翻译和图像-文本任务(如CLIP、DALL·E、VQA等)中。

  • 提出背景:交叉注意力通常用于处理两种不同类型的数据,通过这种机制,一个输入可以对另一个输入进行查询,捕捉和增强跨模态之间的关联。相比自注意力(仅在同一个输入中找到相关性),交叉注意力能够有效地捕捉多模态数据的交互信息。

2. 交叉注意力的原理

交叉注意力的核心思想是将一个输入(例如图像)作为查询(Query),另一个输入(例如文本)作为键(Key)和值(Value),通过注意力机制让查询能够从键和值中选择和关注相关信息。

交叉注意力的步骤:

  • 查询、键、值的生成: 假设有两个不同的输入数据 X1 和 X2,分别生成对应的 Query、Key 和 Value 矩阵。对于 X1,我们可以生成 Query 矩阵,而对于 X2,则可以生成 Key 和 Value 矩阵。
  • 注意力计算 : 与自注意力类似,交叉注意力通过计算 Query 和 Key 的相似性来获得注意力权重:

    其中 Q 来自 X1,而 K 和 V 来自 X2 。通过这种计算,Query 可以从X2 中提取与其最相关的信息,这种机制实现了两个输入数据之间的特征融合和信息传递。
  • 权重与输出 : 计算出的注意力权重应用到 X2的 Value 矩阵上,得到 X1在
    X2上的相关信息。这种机制实现了两个输入数据之间的特征融合和信息传递。

3. 交叉注意力的数学表示

假设有两个输入特征 X 1 ∈ R T 1 × d X_1∈R^{T_1×d} X1∈RT1×d和 X 2 ∈ R T 2 × d X_2∈R^{T_2×d} X2∈RT2×d,其中 T 1 T_1 T1和 T 2 T_2 T2分别表示两个输入的长度(如序列长度或特征维度), d d d 表示特征维度。

Query、Key 和 Value 的生成:

  • 对于 X 1 X_1 X1:生成查询矩阵 Q = W q X 1 Q=W_qX_1 Q=WqX1
  • 对于 X 2 X_2 X2:生成键矩阵 K = W k X 2 K=W_kX_2 K=WkX2和值矩阵 V = W v X 2 V=W_vX_2 V=WvX2

注意力计算:

其中, W q W_q Wq, W k W_k Wk, W v W_v Wv ∈ R d × d ∈R^{d×d} ∈Rd×d是线性变换矩阵, d d d 是键的维度。

结果输出: 注意力权重应用于 V V V 后的结果,即:

4. 交叉注意力的应用场景与发展

交叉注意力在以下场景中得到广泛应用:

  • 多模态学习:交叉注意力在视觉和语言任务中的多模态联合建模中尤为常见,如图像与文本的对齐(CLIP)、视觉问答(VQA)和跨模态生成任务(如DALL·E)。
  • 机器翻译:交叉注意力在Transformer中的"解码器"部分用于让生成的序列(目标语言)参考源语言的表示,这大大提高了翻译质量。
  • Transformer架构的扩展:在诸如BERT、GPT等基于Transformer的模型中,交叉注意力也被用于各种任务,例如文本生成、序列到序列任务等。

发展过程中,交叉注意力机制已经被改进和扩展。例如,层次化交叉注意力(Hierarchical Cross-Attention)通过在不同层次上融合多模态信息,进一步提升了模型在多模态任务上的性能。

5. 代码实现

下面是一个基于PyTorch的交叉注意力机制的简单实现,用于展示如何在两个不同的输入(例如图像和文本)之间计算交叉注意力。

csharp 复制代码
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        
        assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"

        # 线性变换,用于生成 Q, K, V 矩阵
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)

        # 输出的线性变换
        self.out_proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
        
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x1, x2):
        # x1 是 Query,x2 是 Key 和 Value
        B, T1, C = x1.shape  # x1 的形状: [batch_size, seq_len1, dim]
        _, T2, _ = x2.shape  # x2 的形状: [batch_size, seq_len2, dim]

        # 生成 Q, K, V 矩阵
        Q = self.q_proj(x1).view(B, T1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x2).view(B, T2, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x2).view(B, T2, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力得分
        attn_scores = (Q @ K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = self.softmax(attn_scores)  # 注意力权重
        attn_weights = self.dropout(attn_weights)  # dropout 防止过拟合

        # 使用注意力权重加权值矩阵
        attn_output = attn_weights @ V
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T1, C)

        # 输出线性变换
        output = self.out_proj(attn_output)
        return output

# 测试交叉注意力机制
if __name__ == "__main__":
    B, T1, T2, C = 2, 10, 20, 64  # batch_size, seq_len1, seq_len2, channels
    x1 = torch.randn(B, T1, C)  # Query 输入
    x2 = torch.randn(B, T2, C)  # Key 和 Value 输入

    cross_attn = CrossAttention(dim=C, num_heads=4)
    output = cross_attn(x1, x2)
    
    print("输出形状:", output.shape)  # 输出应该为 [batch_size, seq_len1, channels]

6. 代码解释

CrossAttention 类:该类实现了交叉注意力机制,允许将两个不同的输入(x1x2)进行交叉信息融合。

  • q_proj, k_proj, v_proj:三个线性层,用于将输入分别映射到 Query、Key 和 Value 空间。
  • num_headshead_dim:定义了多头注意力机制的头数和每个头的维度。

forward 函数:实现前向传播过程。

  • Q, K, V:分别从 x1x2 中生成 Query、Key 和 Value 矩阵,形状为 [batch_size, num_heads, seq_len, head_dim]
  • attn_scores:计算 Query 和 Key 的点积,得到注意力得分。
  • attn_weights:通过 softmax 对得分进行归一化,得到注意力权重。
  • attn_output:利用注意力权重对 Value 矩阵进行加权求和,得到最终的注意力输出。

测试部分:随机生成两个输入张量 x1x2,并测试交叉注意力的输出形状,确保与预期一致。

7. 总结

交叉注意力在多模态学习中起到了至关重要的作用,能够有效融合不同类型的数据,使得模型可以同时处理图像、文本等多种信息。通过捕捉模态之间的相关性,交叉注意力为多模态任务中的特征融合提供了强大的工具。

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:

可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

相关推荐
UMS攸信技术1 小时前
汽车电子行业数字化转型的实践与探索——以盈趣汽车电子为例
人工智能·汽车
测试老哥1 小时前
Python+Selenium+Pytest+POM自动化测试框架封装(完整版)
自动化测试·软件测试·python·selenium·测试工具·职场和发展·测试用例
ws2019071 小时前
聚焦汽车智能化与电动化︱AUTO TECH 2025 华南展,以展带会,已全面启动,与您相约11月广州!
大数据·人工智能·汽车
Ws_1 小时前
蓝桥杯 python day01 第一题
开发语言·python·蓝桥杯
神雕大侠mu2 小时前
函数式接口与回调函数实践
开发语言·python
堇舟2 小时前
斯皮尔曼相关(Spearman correlation)系数
人工智能·算法·机器学习
爱写代码的小朋友2 小时前
使用 OpenCV 进行人脸检测
人工智能·opencv·计算机视觉
Cici_ovo3 小时前
摄像头点击器常见问题——摄像头视窗打开慢
人工智能·单片机·嵌入式硬件·物联网·计算机视觉·硬件工程
萧鼎3 小时前
【Python】高效数据处理:使用Dask处理大规模数据
开发语言·python
互联网杂货铺3 小时前
Python测试框架—pytest详解
自动化测试·软件测试·python·测试工具·测试用例·pytest·1024程序员节