多种注意力机制(文本->残差->视频)

1.初代自我注意机制(多头注意力机制)

1.1原理

总体架构

上图是 Self-Attention 的结构,在计算的时候需要用到矩阵Q(查询),K(键值),V(值)。在实际中,Self-Attention 接收的是输入(单词的表示向量x组成的矩阵X) 或者上一个 Encoder block 的输出。而Q,K,V正是通过 Self-Attention 的输入进行线性变换得到的。

Q, K, V 的计算 这个X无疑是由不同单词(词向量)组成的

如下图无疑是计算出了Q K V

多头注意力的输出

公式中计算矩阵QK 每一行向量的内积,为了防止内积过大,因此除以 的平方根。Q 乘以K的转置后,得到的矩阵行列数都为 n,n 为句子单词数,这个矩阵可以表示单词之间的 attention 强度

换一话来说 对图像而言 是不是patch 对视频而言 是不是帧)!!!

之后,使用 Softmax 计算每一个单词对于其他单词的 attention 系数,公式中的 Softmax 是对矩阵的每一行进行 Softmax,即每一行的和都变为 1.

得到 Softmax 矩阵之后可以和V 相乘,得到最终的输出Z

就拿结果而言我们是不是得到了一个带有位置信息 和对所有样本有注意力系数的词向量 (可以一个patch 也可以是一个帧)

多头注意力

在上一步,我们已经知道怎么通过 Self-Attention 计算得到输出矩阵 Z,而 Multi-Head Attention 是由多个 Self-Attention 组合形成的,下图是论文中 Multi-Head Attention 的结构图。

从上图可以看到 Multi-Head Attention 包含多个 Self-Attention 层,首先将输入X 分别传递到 h 个不同的 Self-Attention 中,计算得到 h 个输出矩阵Z。下图是 h=8 时候的情况,此时会得到 8 个输出矩阵

得到 8 个输出矩阵 Z1 到 Z8 之后,Multi-Head Attention 将它们拼接在一起 (Concat) ,然后传入一个Linear层,得到 Multi-Head Attent

最终的输出Z

掩码注意力就是看不见未来的情况

参考这一篇transform

代码实现

调用api

python 复制代码
nn.MultiheadAttention(d_model, n_head)

 def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

2.残差注意力机制(图像识别用的多)

工作过程

  • 特征提取:首先,输入数据经过一系列的卷积层和池化层等操作进行特征提取,得到初始的特征图。

  • 注意力计算:将初始特征图输入到注意力模块中,通过一系列的计算生成注意力掩码。这个过程通常包括对特征图进行卷积操作、激活函数处理以及归一化等步骤,以得到每个位置的注意力权重。

  • 特征加权:将注意力掩码与初始特征图相乘,使得特征图中重要的部分得到增强,不重要的部分得到抑制。

  • 残差融合:将经过特征加权后的特征图与原始输入特征图通过残差连接相加,得到最终的输出特征图。这个输出特征图既包含了经过注意力机制筛选后的重要特征,又保留了原始输入中的信息,从而提高了特征的质量和模型的性能。

代码实现

python 复制代码
class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = nn.LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

3.跨帧注意力块(视频用的多)

跨帧融合注意力(CFA)

CFA 的作用是允许帧之间直接交换信息。具体来说:

  • 每个帧都有一个 消息令牌(message token),用于抽象当前帧的视觉信息。

  • 这些消息令牌会参与到全局的时间依赖性建模中。

  • 消息令牌的计算方式如下:

为t帧在l层的消息令牌 w是一个线性矩阵,是前一层的特征输出。

所有帧的消息令牌会参与到跨帧融合注意力中

CFA 是一个自注意力模块,它允许消息令牌之间相互作用,从而捕捉帧之间的时间依赖性

帧内融合注意力(IFA)

IFA 的作用是在帧内传播全局的时间信息,增强帧的表示能力

  • 其中,是帧特征的更新表示,m ˉt (l)​ 是消息令牌的更新表示。

  • 消息令牌在每个块中仅用于信息交换,不会传递到下一个块。

CFA

首先一个数据由(b,t,c.h,w)->(bt,c,h,w)->(wxh,bt,d)(你得了解这个视频处理)

python 复制代码
 l, bt, d = x.size()

然后为了得到时间 一个知识点 三维是图像块输入 四维是帧输入

python 复制代码
b = bt // self.T
x = x.view(l, b, self.T, d) 

初始化消息令牌(vit的csl)

python 复制代码
msg_token = self.message_fc(x[0,:,:,:]) # # 使用线性变换初始化消息令牌 使用线性变换初始化消#息令牌 这个代表帧在不同批次和步长
msg_token = msg_token.view(b, self.T, 1, d) # 调整形状为 (b, T, 1, d)
msg_token = msg_token.permute(1,2,0,3).view(self.T, b, d)

实现跨帧传输(t不同无疑就是 不同帧)

python 复制代码
msg_token = msg_token + self.drop_path(self.message_attn(self.message_ln(msg_token),self.message_ln(msg_token),self.message_ln(msg_token),need_weights=False)[0])

IFA:帧内

python 复制代码
 x = x.view(l+1, -1, d)
        x = x + self.drop_path(self.attention(self.ln_1(x)))
        x = x[:l,:,:]

4.补充一个知识点

python 复制代码
MultiHeadAttention

传入一个输入的多头注意力机制:这种形式可以让模型关注序列内部的上下文信息,捕捉序列中不同位置之间的依赖关系。例如在文本处理中,一个句子中的每个词都可以通过自注意力机制与其他词进行交互,从而更好地理解整个句子的语义。

传入三个输入的多头注意力机制:形式允许模型在不同的表示子空间中并行地计算注意力,从而捕捉输入序列之间的不同类型的依赖关系。通过分别传入 Q、K、V,模型可以灵活地根据查询去关注键值对中的不同部分,提高模型的表达能力。

相关推荐
飞哥数智坊13 分钟前
Anthropic 出手,大幅降低 MCP 的 Token 消耗
人工智能·claude·mcp
贝多财经16 分钟前
千里科技报考港股上市:高度依赖吉利,AI智驾转型收入仍为零
大数据·人工智能·科技
嵌入式-老费27 分钟前
自己动手写深度学习框架(pytorch入门)
人工智能·pytorch·深度学习
irisMoon0636 分钟前
yolov5单目测距+速度测量+目标跟踪
人工智能·yolo·目标跟踪
Linux猿39 分钟前
365科技简报 2025年11月13日 星期四
人工智能·科技简报
终端域名1 小时前
当今前沿科技:脑机共生界面(脑机接口)深度解析
人工智能·智能电视
汗流浃背了吧,老弟!1 小时前
预训练语言模型(Pre-trained Language Model, PLM)介绍
深度学习·语言模型·自然语言处理
化作星辰1 小时前
深度学习_神经网络激活函数
人工智能·深度学习·神经网络
陈天伟教授2 小时前
人工智能技术- 语音语言- 03 ChatGPT 对话、写诗、写小说
人工智能·chatgpt
llilian_162 小时前
智能数字式毫秒计在实际生活场景中的应用 数字式毫秒计 智能毫秒计
大数据·网络·人工智能