白话transformer(三):Q K V矩阵代码演示

在前面文章讲解了QKV矩阵的原理,属于比较主观的解释,下面用简单的代码再过一遍加深下印象。

B站视频

白话transformer(三)

1、生成数据

我们呢就使用一个句子来做一个测试,

python 复制代码
text1 = "我喜欢的水果是橙子和苹果"
text2 = "相比苹果我更加喜欢国产的华为"

比如我们有两个句子,里面都有苹果这个词。我们用text1来走下流程

1.1 创建词嵌入

我们使用spacy进行词嵌入生成,代码很简单

python 复制代码
nlp = spacy.load('zh_core_web_sm')
doc = nlp(text1)

我们为了简单一点只取前10个维度,实际上spacy默认的词嵌入维度是很高的,我们只是用前十个来过一下流程。

python 复制代码
emd_dim = 10

dics = {}
for token in doc:
    dics[token.text] = token.vector[:emd_dim]
X = pd.DataFrame(dics)

这样我们就得到了第一个句子中所有词的embedding表示

2、初始化 W q W_q Wq, W k W_k Wk, W v W_v Wv

具体的内容可以查看之前的文章Bert基础(一)--自注意力机制

为了创建查询矩阵、键矩阵和值矩阵,我们需要先创建另外三个权重矩阵,分别为 W Q 、 W K 、 W V W^Q 、W^K、W^V WQ、WK、WV。用矩阵X分别乘以矩阵 W Q 、 W K 、 W V W^Q 、W^K、W^V WQ、WK、WV,就可以依次创建出查询矩阵Q、键矩阵K和值矩阵V。

python 复制代码
d_k = 6       # QKV向量的维度

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^{T}}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

d_k是指公式中的d_k

python 复制代码
Wq = np.random.randn(emd_dim, d_k)

Wq矩阵的格式,就是10*6

  • 10:是指词嵌入的维度
  • 6:d_k,Q的维度

Wk, Wv,同样

3、计算QKV

Q = X * Wq

python 复制代码
np.dot(X.T, Wq)

这样就得到了查询矩阵Q,Q其实可以理解为每个词需要查询的内容。

同样可以计算K和V矩阵

4、相似矩阵

计算公式为:
X W Q ∗ ( W K X ) T XW^Q *(W^KX )^T XWQ∗(WKX)T

其实就是我们计算好的Q和K
Q K T Q K^T QKT

直接点乘就可以得到每个词和每个词的相似性:

5、点积缩放

python 复制代码
Q@K.T/ np.sqrt(d_k)

6、Soft Max

我们自己遍历计算一下即可

python 复制代码
# 计算Softmax
for i in range(len(df_QK)):
    exp_v = np.exp(df_QK.iloc[i])
    softmax = exp_v / np.sum(exp_v)
    df_QK.iloc[i] = softmax

现在就得到了最后的相似性矩阵

7、attention

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^{T}}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

根据公示直接将前面计算的结构点乘V

相关推荐
安徽必海微马春梅_6688A2 分钟前
A实验:大鼠脑定位仪 小鼠脑定位仪 大动物定位仪 小动物脑定位仪 资料说明。
人工智能·深度学习
机器学习之心13 分钟前
198种组合算法+优化TCN-Transformer+SHAP分析+新数据预测+多输出!深度学习可解释分析,强烈安利,粉丝必备!
深度学习·算法·transformer·shap分析·新数据预测
一瞬祈望20 分钟前
⭐ 深度学习入门体系(第 15 篇): 从 RNN 到 LSTM:为什么深度网络需要“记忆能力”?
rnn·深度学习·lstm
LeeeX!21 分钟前
基于YOLO11实现明厨亮灶系统实时检测【多场景数据+模型训练、推理、导出】
深度学习·算法·目标检测·数据集·明厨亮灶
知乎的哥廷根数学学派26 分钟前
基于高阶统计量引导的小波自适应块阈值地震信号降噪算法(MATLAB)
网络·人工智能·pytorch·深度学习·算法·机器学习·matlab
墨北小七27 分钟前
CNN深度学习模型在小说创作领域的应用
人工智能·深度学习·cnn
Yeats_Liao32 分钟前
昇腾910B与DeepSeek:国产算力与开源模型的架构适配分析
人工智能·python·深度学习·神经网络·机器学习·架构·开源
子午41 分钟前
【2026原创】昆虫识别系统~Python+深度学习+卷积算法+模型训练+人工智能
人工智能·python·深度学习
李泽辉_43 分钟前
深度学习算法学习(六):深度学习-处理文本:神经网络处理文本、Embedding层
深度学习·学习·算法
高洁0143 分钟前
AI智能体搭建(1)
人工智能·深度学习·机器学习·transformer·知识图谱