【Transformer系列(3)】cross attention(交叉注意力)

一、cross attention和self-attention的不同

Cross attentionself-attention是在自然语言处理中常用的两种注意力机制。

Self-attention是一种自我关注机制,用于计算输入序列中每个元素与其他元素的关联程度。具体来说,对于给定的输入序列,self-attention机制将每个元素与序列中的所有元素计算关联度,并根据关联度对序列中的每个元素进行加权求和。这种机制使模型能够捕捉到输入序列中不同元素之间的关系,从而更好地理解输入的上下文信息。

Cross attention是在两个不同的输入序列之间计算关联度和加权求和的机制。具体来说,给定两个输入序列,cross attention机制将一个序列中的每个元素与另一个序列中的所有元素计算关联度,并根据关联度对两个序列中的每个元素进行加权求和。这样的机制使模型能够建立不同序列之间的关联关系,并将两个序列的信息融合起来。

因此,self-attentioncross attention的主要区别在于计算关联度和加权求和的对象不同。self-attention用于单一输入序列,用于捕捉序列内元素之间的关系;而cross attention用于两个不同输入序列之间,用于建立不同序列之间的关联关系。

二、代码实现

这个代码实现了一个简单的交叉注意力模块,它接受两个输入x1和x2,并计算它们之间的交叉注意力。在forward方法中,我们首先通过线性变换将输入进行映射,然后计算注意力权重,最后使用注意力权重加权求和得到输出结果。注意力权重使用softmax函数进行归一化处理。

python 复制代码
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(CrossAttention, self).__init__()
        self.linear_q = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.linear_k = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.linear_v = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.linear_out = nn.Linear(hidden_dim, hidden_dim, bias=False)
    
    def forward(self, x1, x2):
        q = self.linear_q(x1)  # query
        k = self.linear_k(x2)  # key
        v = self.linear_v(x2)  # value
        
        # 计算注意力权重
        attn_weights = torch.matmul(q, k.transpose(-2, -1))
        attn_weights = self.softmax(attn_weights)
        
        # 使用注意力权重加权求和
        attn_output = torch.matmul(attn_weights, v)
        
        # 输出结果
        output = self.linear_out(attn_output)
        return output

# 示例输入
x1 = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
x2 = torch.tensor([[7, 8, 9], [10, 11, 12]], dtype=torch.float32)

# 创建交叉注意力模型
cross_attention = CrossAttention(hidden_dim=3)

# 前向传播计算结果
output = cross_attention(x1, x2)
print(output)
 
相关推荐
明似水几秒前
Perplexity AI:对话式搜索引擎的革新者与未来认知操作系统
人工智能·搜索引擎
heart000_114 分钟前
从实验室到生产线:机器学习模型部署的七大陷阱及PyTorch Serving避坑指南
人工智能·pytorch·机器学习
锈儿海老师33 分钟前
关于平凡AI 提示词造就世界最强ast-grep 规则这件事
前端·javascript·人工智能
一点.点34 分钟前
李沐动手深度学习(pycharm中运行笔记)——12.权重衰退
pytorch·笔记·深度学习·pycharm
抱抱宝38 分钟前
Transformer:现代自然语言处理的革命性架构
深度学习·自然语言处理·transformer
腾讯云开发者1 小时前
腾讯云架构师技术沙龙 · 长沙站圆满落幕,共话AI驱动下的技术架构与前沿应用
人工智能
学习OK呀2 小时前
MCP 的相关实操学习
人工智能·后端
明似水2 小时前
点点(小红书AI搜索):生活场景的智能搜索助手
人工智能·生活
ZzzZ314159262 小时前
七天速成数字图像处理之七(颜色图像处理基础)
图像处理·人工智能·深度学习·计算机视觉·数学建模
明似水2 小时前
通义千问(Qwen):阿里云打造的全能AI大模型
人工智能·阿里云·云计算