【Transformer系列(3)】cross attention(交叉注意力)

一、cross attention和self-attention的不同

Cross attentionself-attention是在自然语言处理中常用的两种注意力机制。

Self-attention是一种自我关注机制,用于计算输入序列中每个元素与其他元素的关联程度。具体来说,对于给定的输入序列,self-attention机制将每个元素与序列中的所有元素计算关联度,并根据关联度对序列中的每个元素进行加权求和。这种机制使模型能够捕捉到输入序列中不同元素之间的关系,从而更好地理解输入的上下文信息。

Cross attention是在两个不同的输入序列之间计算关联度和加权求和的机制。具体来说,给定两个输入序列,cross attention机制将一个序列中的每个元素与另一个序列中的所有元素计算关联度,并根据关联度对两个序列中的每个元素进行加权求和。这样的机制使模型能够建立不同序列之间的关联关系,并将两个序列的信息融合起来。

因此,self-attentioncross attention的主要区别在于计算关联度和加权求和的对象不同。self-attention用于单一输入序列,用于捕捉序列内元素之间的关系;而cross attention用于两个不同输入序列之间,用于建立不同序列之间的关联关系。

二、代码实现

这个代码实现了一个简单的交叉注意力模块,它接受两个输入x1和x2,并计算它们之间的交叉注意力。在forward方法中,我们首先通过线性变换将输入进行映射,然后计算注意力权重,最后使用注意力权重加权求和得到输出结果。注意力权重使用softmax函数进行归一化处理。

python 复制代码
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(CrossAttention, self).__init__()
        self.linear_q = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.linear_k = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.linear_v = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.linear_out = nn.Linear(hidden_dim, hidden_dim, bias=False)
    
    def forward(self, x1, x2):
        q = self.linear_q(x1)  # query
        k = self.linear_k(x2)  # key
        v = self.linear_v(x2)  # value
        
        # 计算注意力权重
        attn_weights = torch.matmul(q, k.transpose(-2, -1))
        attn_weights = self.softmax(attn_weights)
        
        # 使用注意力权重加权求和
        attn_output = torch.matmul(attn_weights, v)
        
        # 输出结果
        output = self.linear_out(attn_output)
        return output

# 示例输入
x1 = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
x2 = torch.tensor([[7, 8, 9], [10, 11, 12]], dtype=torch.float32)

# 创建交叉注意力模型
cross_attention = CrossAttention(hidden_dim=3)

# 前向传播计算结果
output = cross_attention(x1, x2)
print(output)
 
相关推荐
正在走向自律4 分钟前
Trae上手指南:AI编程从0到1的奇妙跃迁
人工智能
MILI元宇宙9 分钟前
DeepSeek R1开源模型的技术突破与AI产业格局的重构
人工智能·重构·开源
江苏泊苏系统集成有限公司1 小时前
半导体晶圆制造洁净厂房的微振控制方案-江苏泊苏系统集成有限公司
人工智能·深度学习·目标检测·机器学习·创业创新·制造·远程工作
猿小猴子2 小时前
主流 AI IDE 之一的 Windsurf 介绍
ide·人工智能
智联视频超融合平台3 小时前
无人机+AI视频联网:精准狙击,让‘罪恶之花’无处藏身
人工智能·网络协议·安全·系统安全·音视频·无人机
AiTEN_Robotics3 小时前
智能仓储落地:机器人如何通过自动化减少仓库操作失误?
人工智能·机器人·自动化
江湖有缘4 小时前
华为云Flexus+DeepSeek征文 | 初探华为云ModelArts Studio:部署DeepSeek-V3/R1商用服务的详细步骤
人工智能·华为云·modelarts
Vizio<4 小时前
基于FashionMnist数据集的自监督学习(生成式自监督学习AE算法)
人工智能·笔记·深度学习·神经网络·自监督学习
梅一一4 小时前
5款AI对决:Gemini学术封神,但日常办公我选它
大数据·人工智能·数据可视化
kyle~4 小时前
Pytorch---ImageFolder
人工智能·pytorch·python