【Transformer系列(3)】cross attention(交叉注意力)

一、cross attention和self-attention的不同

Cross attentionself-attention是在自然语言处理中常用的两种注意力机制。

Self-attention是一种自我关注机制,用于计算输入序列中每个元素与其他元素的关联程度。具体来说,对于给定的输入序列,self-attention机制将每个元素与序列中的所有元素计算关联度,并根据关联度对序列中的每个元素进行加权求和。这种机制使模型能够捕捉到输入序列中不同元素之间的关系,从而更好地理解输入的上下文信息。

Cross attention是在两个不同的输入序列之间计算关联度和加权求和的机制。具体来说,给定两个输入序列,cross attention机制将一个序列中的每个元素与另一个序列中的所有元素计算关联度,并根据关联度对两个序列中的每个元素进行加权求和。这样的机制使模型能够建立不同序列之间的关联关系,并将两个序列的信息融合起来。

因此,self-attentioncross attention的主要区别在于计算关联度和加权求和的对象不同。self-attention用于单一输入序列,用于捕捉序列内元素之间的关系;而cross attention用于两个不同输入序列之间,用于建立不同序列之间的关联关系。

二、代码实现

这个代码实现了一个简单的交叉注意力模块,它接受两个输入x1和x2,并计算它们之间的交叉注意力。在forward方法中,我们首先通过线性变换将输入进行映射,然后计算注意力权重,最后使用注意力权重加权求和得到输出结果。注意力权重使用softmax函数进行归一化处理。

python 复制代码
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(CrossAttention, self).__init__()
        self.linear_q = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.linear_k = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.linear_v = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.linear_out = nn.Linear(hidden_dim, hidden_dim, bias=False)
    
    def forward(self, x1, x2):
        q = self.linear_q(x1)  # query
        k = self.linear_k(x2)  # key
        v = self.linear_v(x2)  # value
        
        # 计算注意力权重
        attn_weights = torch.matmul(q, k.transpose(-2, -1))
        attn_weights = self.softmax(attn_weights)
        
        # 使用注意力权重加权求和
        attn_output = torch.matmul(attn_weights, v)
        
        # 输出结果
        output = self.linear_out(attn_output)
        return output

# 示例输入
x1 = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
x2 = torch.tensor([[7, 8, 9], [10, 11, 12]], dtype=torch.float32)

# 创建交叉注意力模型
cross_attention = CrossAttention(hidden_dim=3)

# 前向传播计算结果
output = cross_attention(x1, x2)
print(output)
 
相关推荐
那个村的李富贵14 小时前
光影魔术师:CANN加速实时图像风格迁移,让每张照片秒变大师画作
人工智能·aigc·cann
腾讯云开发者16 小时前
“痛点”到“通点”!一份让 AI 真正落地产生真金白银的实战指南
人工智能
CareyWYR16 小时前
每周AI论文速递(260202-260206)
人工智能
hopsky17 小时前
大模型生成PPT的技术原理
人工智能
禁默17 小时前
打通 AI 与信号处理的“任督二脉”:Ascend SIP Boost 加速库深度实战
人工智能·信号处理·cann
心疼你的一切18 小时前
昇腾CANN实战落地:从智慧城市到AIGC,解锁五大行业AI应用的算力密码
数据仓库·人工智能·深度学习·aigc·智慧城市·cann
AI绘画哇哒哒18 小时前
【干货收藏】深度解析AI Agent框架:设计原理+主流选型+项目实操,一站式学习指南
人工智能·学习·ai·程序员·大模型·产品经理·转行
数据分析能量站18 小时前
Clawdbot(现名Moltbot)-现状分析
人工智能
那个村的李富贵18 小时前
CANN加速下的AIGC“即时翻译”:AI语音克隆与实时变声实战
人工智能·算法·aigc·cann
二十雨辰18 小时前
[python]-AI大模型
开发语言·人工智能·python