【Transformer系列(3)】cross attention(交叉注意力)

一、cross attention和self-attention的不同

Cross attentionself-attention是在自然语言处理中常用的两种注意力机制。

Self-attention是一种自我关注机制,用于计算输入序列中每个元素与其他元素的关联程度。具体来说,对于给定的输入序列,self-attention机制将每个元素与序列中的所有元素计算关联度,并根据关联度对序列中的每个元素进行加权求和。这种机制使模型能够捕捉到输入序列中不同元素之间的关系,从而更好地理解输入的上下文信息。

Cross attention是在两个不同的输入序列之间计算关联度和加权求和的机制。具体来说,给定两个输入序列,cross attention机制将一个序列中的每个元素与另一个序列中的所有元素计算关联度,并根据关联度对两个序列中的每个元素进行加权求和。这样的机制使模型能够建立不同序列之间的关联关系,并将两个序列的信息融合起来。

因此,self-attentioncross attention的主要区别在于计算关联度和加权求和的对象不同。self-attention用于单一输入序列,用于捕捉序列内元素之间的关系;而cross attention用于两个不同输入序列之间,用于建立不同序列之间的关联关系。

二、代码实现

这个代码实现了一个简单的交叉注意力模块,它接受两个输入x1和x2,并计算它们之间的交叉注意力。在forward方法中,我们首先通过线性变换将输入进行映射,然后计算注意力权重,最后使用注意力权重加权求和得到输出结果。注意力权重使用softmax函数进行归一化处理。

python 复制代码
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(CrossAttention, self).__init__()
        self.linear_q = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.linear_k = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.linear_v = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.linear_out = nn.Linear(hidden_dim, hidden_dim, bias=False)
    
    def forward(self, x1, x2):
        q = self.linear_q(x1)  # query
        k = self.linear_k(x2)  # key
        v = self.linear_v(x2)  # value
        
        # 计算注意力权重
        attn_weights = torch.matmul(q, k.transpose(-2, -1))
        attn_weights = self.softmax(attn_weights)
        
        # 使用注意力权重加权求和
        attn_output = torch.matmul(attn_weights, v)
        
        # 输出结果
        output = self.linear_out(attn_output)
        return output

# 示例输入
x1 = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
x2 = torch.tensor([[7, 8, 9], [10, 11, 12]], dtype=torch.float32)

# 创建交叉注意力模型
cross_attention = CrossAttention(hidden_dim=3)

# 前向传播计算结果
output = cross_attention(x1, x2)
print(output)
 
相关推荐
飞哥数智坊21 分钟前
先理需求再写代码:新版 Cursor 用 Plan Mode 落地费曼学习法
人工智能·ai编程·cursor
abcd_zjq21 分钟前
【2025最新】【win10】vs2026+qt6.9+opencv(cmake编译opencv_contrib拓展模
人工智能·qt·opencv·计算机视觉·visual studio
Voyager_424 分钟前
图像处理踩坑:浮点数误差导致的缩放尺寸异常与解决办法
数据结构·图像处理·人工智能·python·算法
知行力26 分钟前
【GitHub每日速递 251011】无需注册!本地开源AI应用构建器Dyad,跨平台速下载!
人工智能·开源·github
jie*26 分钟前
小杰深度学习(ten)——视觉-经典神经网络——LetNet
人工智能·python·深度学习·神经网络·计算机网络·数据分析
xwz小王子29 分钟前
Nature Machine Intelligence丨多模态大型语言模型中的视觉认知
人工智能·语言模型·自然语言处理
冰糖猕猴桃38 分钟前
【AI】深入 LangChain 生态:核心包架构解析
人工智能·ai·架构·langchain
松果财经1 小时前
千亿级赛道,Robobus 赛道中标新加坡自动驾驶巴士项目的“确定性机会”
人工智能·机器学习·自动驾驶
TMT星球1 小时前
滴滴自动驾驶张博:坚持负责任的科技创新,积极探索新型就业空间
人工智能·科技·自动驾驶
Blossom.1181 小时前
用一颗MCU跑通7B大模型:RISC-V+SRAM极致量化实战
人工智能·python·单片机·嵌入式硬件·opencv·机器学习·risc-v