注意力机制:让AI拥有黄金七秒记忆的魔法--(注意力机制中的Q、K、V)

注意力机制:让AI拥有"黄金七秒记忆"的魔法--(注意力机制中的Q、K、V)

在注意⼒机制中,查询(Query)、键(Key)和值(Value)是三个关键部分。

■ 查询(Query):是指当前需要处理的信息。模型根据查询向量在输⼊序列中查找相关信息。

■ 键(Key):是指来⾃输⼊序列的⼀组表示。它们⽤于根据查询向量计算注意⼒权重。注意⼒权重反映了不同位置的输⼊数据与查询的相关性。

■ 值(Value):是指来⾃输⼊序列的⼀组表示。它们⽤于根据注意⼒权重计算加权和,得到最终的注意⼒输出向量,其包含了与查询最相关的输⼊信息。

用下面栗子打一个比方:

py 复制代码
import torch # 导入 torch
import torch.nn.functional as F # 导入 nn.functional
# 1. 创建两个张量 x1 和 x2
x1 = torch.randn(2, 3, 4) # 形状 (batch_size, seq_len1, feature_dim)
x2 = torch.randn(2, 5, 4) # 形状 (batch_size, seq_len2, feature_dim)
# 2. 计算原始权重
raw_weights = torch.bmm(x1, x2.transpose(1, 2)) # 形状 (batch_size, seq_len1, seq_len2)
# 3. 用 softmax 函数对原始权重进行归一化
attn_weights = F.softmax(raw_weights, dim=2) # 形状 (batch_size, seq_len1, seq_len2)
# 4. 将注意力权重与 x2 相乘,计算加权和
attn_output = torch.bmm(attn_weights, x2)  # 形状 (batch_size, seq_len1, feature_dim)

我们可以将x1视为查询(Query,Q )向量,将x2视为键(Key,K )和值(Value,V )向量。这是因为我们直接使⽤x1和x2的点积作为相似度得分,并将权重应⽤于x2本身来计算加权信息。所以,在这个简化示例中,Q 对应于x1,KV都对应于x2。

然⽽,在Transformer中,QKV通常是从相同的输⼊序列经过不同的线性变换得到的不同向量。

py 复制代码
import torch
import torch.nn.functional as F
#1. 创建 Query、Key 和 Value 张量
q = torch.randn(2, 3, 4) # 形状 (batch_size, seq_len1, feature_dim)
k = torch.randn(2, 4, 4) # 形状 (batch_size, seq_len2, feature_dim)
v = torch.randn(2, 4, 4) # 形状 (batch_size, seq_len2, feature_dim)
# 2. 计算点积,得到原始权重,形状为 (batch_size, seq_len1, seq_len2)
raw_weights = torch.bmm(q, k.transpose(1, 2))
# 3. 将原始权重进行缩放(可选),形状仍为 (batch_size, seq_len1, seq_len2)
scaling_factor = q.size(-1) ** 0.5
scaled_weights = raw_weights / scaling_factor
# 4. 应用 softmax 函数,使结果的值在 0 和 1 之间,且每一行的和为 1
attn_weights = F.softmax(scaled_weights, dim=-1) # 形状仍为 (batch_size, seq_len1, seq_len2)
# 5. 与 Value 相乘,得到注意力分布的加权和 , 形状为 (batch_size, seq_len1, feature_dim)
attn_output = torch.bmm(attn_weights, v)

KV的维度是否需完全相同呢?

在缩放点积注意⼒中,KV 向量的维度不⼀定需要完全相同。在这种注意⼒机制中,KV 的序列⻓度维度(在这⾥是第2维)应该相同,因为它们描述了同⼀个序列的不同部分。然⽽,它们的特征(或隐藏层)维度(在这⾥是第3维)可以不同。V向量的第⼆个维度则决定了最终输出张量的特征维度,这个维度可以根据具体任务和模型设计进⾏调整。

K 向量的序列⻓度维度(在这⾥是第2维)和Q 向量的序列⻓度维度可以不同,因为它们可以来⾃不同的输⼊序列,但是,K 向量的特征维度(在这⾥是第3维)需要与Q向量的特征维度相同,因为它们之间要计算点积。

在实践中,KV的各个维度通常是相同的,因为它们通常来⾃同⼀个输⼊序列并经过不同的线性变换。

在注意力机制中,k (Key)和 v(Value)的初始值并不是随机产生的,而是由输入数据经过各自的线性变换得到的。具体来说:

来源相同但变换不同:

  • 假设我们有一个输入序列的表示矩阵 X(例如编码器的输出或者词嵌入),

  • 我们通过三个不同的线性层(也就是不同的权重矩阵)分别计算 Query、Key 和 Value:

    • q = X W q q=XW_q q=XWq
    • k = X W k k=XW_k k=XWk
    • v = X W v v= XW_v v=XWv
  • 这里, W q W_q Wq、 W k W_k Wk 和 W v W_v Wv 是模型在训练过程中学习到的参数矩阵。

  • 确定方式

    这些矩阵 W k W_k Wk 和 W v W_v Wv 在模型设计时就被定义好,并在训练过程中通过反向传播进行更新。

  • 作用不同

    • Key (k) :通过 W k W_k Wk 得到,用来与 Query 进行匹配,计算注意力分数,决定输入中哪些部分对当前 Query 最重要。
    • Value (v) :通过 W v W_v Wv得到,它携带的是具体的信息内容,最终会根据注意力分数被加权求和,形成输出的上下文向量。
  • k 与 v 的初始值都源自相同的输入 X,但它们经过了各自独立的线性变换,参数 W k W_k Wk 和 W v W_v Wv 决定了它们具体的数值和表示。

  • 这两个过程是在训练过程中自动学习并调整的,确保模型能够有效地捕捉和利用输入信息。

这样,通过学习到的权重矩阵,模型可以从输入中抽取出适合进行匹配(Key)和传递信息(Value)的表示。

现在,重写缩放点积注意⼒的计算过程,如下所述。

(1)计算Q 向量和K向量的点积。

(2)将点积结果除以缩放因⼦(Q向量特征维度的平⽅根)。

(3)应⽤softmax函数得到注意⼒权重。

(4)使⽤注意⼒权重对V向量进⾏加权求和。

这个过程的图示如下⻚图所示:

具体到编码器-解码器注意⼒来说,可以这样理解QKV向量。

Q 向量代表了解码器在当前时间步的表示,⽤于和K 向量进⾏匹配,以计算注意⼒权重。Q 向量通常是解码器隐藏状态的线性变换。

K 向量是编码器输出的⼀种表示,⽤于和Q 向量进⾏匹配,以确定哪些编码器输出对于当前解码器时间步来说最相关。K 向量通常是编码器隐藏状态的线性变换。

V 向量是编码器输出的另⼀种表示,⽤于计算加权求和,⽣成注意⼒上下⽂向量。注意⼒权重会作⽤在V 向量上,以便在解码过程中关注输⼊序列中的特定部分。V 向量通常也是编码器隐藏状态的线性变换。

在刚才的编码器-解码器注意⼒示例中,直接使⽤了编码器隐藏状态和解码器隐藏状态来计算注意⼒。这⾥的QKV 向量并没有显式地表示出来(⽽且,此处KV是同⼀个向量),但它们的概念仍然隐含在实现中:

■ 编码器隐藏状态(encoder_hidden_states)充当了KV向量的⻆⾊。

■ 解码器隐藏状态(decoder_hidden_states)充当了Q向量的⻆⾊。

我们计算Q 向量(解码器隐藏状态)与K 向量(编码器隐藏状态)之间的点积来得到注意⼒权重,然后⽤这些权重对V向量(编码器隐藏状态)进⾏加权求和,得到上下⽂向量。

当然了,在⼀些更复杂的注意⼒机制(如Transformer中的多头⾃注意⼒机制)中,QKV 向量通常会更明确地表示出来,因为我们需要通过使⽤不同的线性层将相同的输⼊序列显式地映射到不同的QKV向量空间。

V 向量表示值,⽤于计算加权信息。通过将注意⼒权重应⽤于V 向量,我们可以获取输⼊序列中与Q 向量相关的信息。它们(QKV)其实都是输⼊序列,有时是编码器输⼊序列,有时是解码器输⼊序列,有时是神经⽹络中的隐藏状态(也来⾃输⼊序列)的线性表示,也都是序列的"嵌⼊向量"。

相关推荐
TG:@yunlaoda360 云老大2 小时前
腾讯WAIC发布“1+3+N”AI全景图:混元3D世界模型开源,具身智能平台Tairos亮相
人工智能·3d·开源·腾讯云
这张生成的图像能检测吗2 小时前
(论文速读)Fast3R:在一个向前通道中实现1000+图像的3D重建
人工智能·深度学习·计算机视觉·3d重建
兴趣使然黄小黄5 小时前
【AI-agent】LangChain开发智能体工具流程
人工智能·microsoft·langchain
出门吃三碗饭5 小时前
Transformer前世今生——使用pytorch实现多头注意力(八)
人工智能·深度学习·transformer
l1t5 小时前
利用DeepSeek改写SQLite版本的二进制位数独求解SQL
数据库·人工智能·sql·sqlite
说私域6 小时前
开源AI智能名片链动2+1模式S2B2C商城小程序FAQ设计及其意义探究
人工智能·小程序
开利网络6 小时前
合规底线:健康产品营销的红线与避坑指南
大数据·前端·人工智能·云计算·1024程序员节
非著名架构师6 小时前
量化“天气风险”:金融与保险机构如何利用气候大数据实现精准定价与投资决策
大数据·人工智能·新能源风光提高精度·疾风气象大模型4.0
巫婆理发2227 小时前
评估指标+数据不匹配+贝叶斯最优误差(分析方差和偏差)+迁移学习+多任务学习+端到端深度学习
深度学习·学习·迁移学习
熙梦数字化7 小时前
2025汽车零部件行业数字化转型落地方案
大数据·人工智能·汽车