【机器学习】21. Transformer: 最通俗易懂讲解

1. 结构:Encoder and Decoeder

  • Encoder Component:结构相同的编码器堆栈
    • Self-attention Layer: 帮助查看输入句子中的其他单词,因为它编码了一个特定的单词。
  • Decoding component: 由相同数量的编码器组成的一组解码器。
    • Encoder-decoder attention layer:帮助将注意力集中在输入句子的相关部分

2. Encoder 第一个输入的embedding

和传统的NLP网络一样,文本数据都是通过embedding转换成向量进行操作

3. self-attention 层

假设输入是这个句子,自注意力机制能够帮助it连接到animal

"The animal didn't cross the street because it was too tried."

self-attention 实现

  1. 给每行一个encoder的输入向量,创建 query 向量,key向量,value向量
    通过将embedding乘以我们在训练过程中训练的三个矩阵 得到,Q,K,V矩阵
    三个矩阵是什么?是右侧的W(Q),W(K)和W(V)
  2. Q * K计算分数
    取Q向量与我们要评分的单词的K向量的点积 Q*K
  3. 除以8
    为什么是8?作者是这么做的。64是key vector的大小,因为embedding是512,把key vector取了64,开平方得到了8.
  4. softmax得到概率

    这些分数决定了编码当前位置的词, 即Thinking的时候, 对所有位置的词分别有多少的注意力. 很明显, 在上图的例子中, 当前位置的词Thinking对自己有最高的注意力0.88
  5. value * softmax的结果
    保留我们想要关注的单词的价值,并淹没不相关的单词。
  6. 把最后加权过的V向量相加
    在这个位置产生自注意层的输出(对于第一个单词)

4. 多头机制


一个embedding有多个输出。

因为前馈神经网络接受的是1个矩阵(每个词的一个向量), 所以我们需要有一种方法把10个矩阵整合为一个矩阵

输出左右拼接,再有一个W向量,二者相乘得到最终结果向量。

回顾之前翻译的案例,可以得到一个注意头是animal,另一个是tire

5. Position Encoding(和embedding相加)

前面的embedding并没有位置信息

Position Encoding

  • 提供嵌入向量之间有意义的距离,一旦投影到Q/K/V向量。
  • 确定每个单词的位置。


最后相加的结果传递给self-attention层

6. Residuals

跟ResNet的操作一样,输入和输出相加作为下一层的输入


7. Decoder

顶部编码器的输出被转化为一组注意力向量K和V

  • 用于"编码器-解码器注意"层解码器。
  • 关注输入序列中的适当位置

每一步的输出在下一个时间步中被馈入底层解码器。

8. 最后的Linear and Softmax layer

假设我们的模型认识5000个唯一的英文单词, 那么logits向量的维度就是5000,跟其他神经网络一样,这里是概率。

9. 训练概述

首先,真实的标签会用one-hot转换成向量。

之后使用损失函数

该模型为每个单词生成一个的概率。

  • 使用反向传播使输出更接近实际输出。
  • 比较两种概率分布:
    • Cross-Entropy
    • Kullback-Leiber Divergence

10. 选择Transformer的理由

优点:

更好的远程连接

更容易并行化

在实践中,可以使其比RNN更深(层次更多)

缺点:

注意力计算在技术上是O(n^2) 【n是最长的句子】

实现起来有点复杂(位置编码等)

优点似乎远远大于缺点,transformer在许多情况下比rnn和LSTMs工作得更好

可以说是过去十年中最重要的序列建模改进之一。

相关推荐
不去幼儿园3 分钟前
【MARL】深入理解多智能体近端策略优化(MAPPO)算法与调参
人工智能·python·算法·机器学习·强化学习
想成为高手4999 分钟前
生成式AI在教育技术中的应用:变革与创新
人工智能·aigc
YSGZJJ1 小时前
股指期货的套保策略如何精准选择和规避风险?
人工智能·区块链
无脑敲代码,bug漫天飞1 小时前
COR 损失函数
人工智能·机器学习
HPC_fac130520678162 小时前
以科学计算为切入点:剖析英伟达服务器过热难题
服务器·人工智能·深度学习·机器学习·计算机视觉·数据挖掘·gpu算力
小陈phd5 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
Guofu_Liao6 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
秀儿还能再秀8 小时前
神经网络(系统性学习三):多层感知机(MLP)
神经网络·学习笔记·mlp·多层感知机
wxl7812279 小时前
如何使用本地大模型做数据分析
python·数据挖掘·数据分析·代码解释器
ZHOU_WUYI9 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt