67 自注意力_by《李沐:动手学深度学习v2》pytorch版

系列文章目录


文章目录


自注意力

给定一个由词元组成的输入序列 x 1 , ... , x n \mathbf{x}_1, \ldots, \mathbf{x}_n x1,...,xn,其中任意 x i ∈ R d \mathbf{x}_i \in \mathbb{R}^d xi∈Rd( 1 ≤ i ≤ n 1 \leq i \leq n 1≤i≤n)。自注意力池化层将 x i x_i xi当作key,value,query来对序列抽取特征(该序列的自注意力输出为一个长度相同的序列)得到 y 1 , ... , y n \mathbf{y}_1, \ldots, \mathbf{y}_n y1,...,yn,其中:

y i = f ( x i , ( x 1 , x 1 ) , ... , ( x n , x n ) ) ∈ R d \mathbf{y}_i = f(\mathbf{x}_i, (\mathbf{x}_1, \mathbf{x}_1), \ldots, (\mathbf{x}_n, \mathbf{x}_n)) \in \mathbb{R}^d yi=f(xi,(x1,x1),...,(xn,xn))∈Rd

相当于给定一个序列,我可以根据序列中一个词,给定一个输出。

位置编码

🏷subsec_positional-encoding

在处理词元序列时,循环神经网络是逐个的重复地处理词元的,而自注意力则因为并行计算而放弃了顺序操作。根CNN/RNN不同,自注意力并没有记录位置信息,我只是计算q和k,其实先q与哪个k都没有区别。而如果在模型中加入位置信息,会减少并行性,或者计算量。

为了使用序列的顺序信息,通过在输入表示中添加位置编码 (positional encoding)来注入绝对的或相对的位置信息到输入里。

位置编码可以通过学习得到也可以直接固定得到。接下来描述的是基于正弦函数和余弦函数的固定位置编码

:cite:Vaswani.Shazeer.Parmar.ea.2017

假设输入表示 X ∈ R n × d \mathbf{X} \in \mathbb{R}^{n \times d} X∈Rn×d包含一个序列中 n n n个词元的 d d d维嵌入表示。那么位置编码使用相同形状的位置嵌入矩阵 P ∈ R n × d \mathbf{P} \in \mathbb{R}^{n \times d} P∈Rn×d输出 X + P \mathbf{X} + \mathbf{P} X+P,来作为自编码输入,矩阵第 i i i行、第 2 j 2j 2j列和 2 j + 1 2j+1 2j+1列上的元素为:

p i , 2 j = sin ⁡ ( i 1000 0 2 j / d ) , p i , 2 j + 1 = cos ⁡ ( i 1000 0 2 j / d ) . \begin{aligned} p_{i, 2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right),\\p_{i, 2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right).\end{aligned} pi,2jpi,2j+1=sin(100002j/di),=cos(100002j/di).

:eqlabel:eq_positional-encoding-def

知道下面类的作用是生成 P P P即可

python 复制代码
class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X) # 防止模型对P太过于敏感

总结

  1. 自注意力池化层将x当做key,value,query来对序列抽取特征
  2. 完全并行、最长序列为1、但对长序列计算复杂度高
  3. 位置编码在输入中加入位置信息,使得自注意力能够记忆位置信息

小结

  • 在自注意力中,查询、键和值都来自同一组输入。
  • 卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。
  • 为了使用序列的顺序信息,可以通过在输入表示中添加位置编码,来注入绝对的或相对的位置信息。

练习

  1. 假设设计一个深度架构,通过堆叠基于位置编码的自注意力层来表示序列。可能会存在什么问题?
  2. 请设计一种可学习的位置编码方法。
相关推荐
桃花键神36 分钟前
AI可信论坛亮点:合合信息分享视觉内容安全技术前沿
人工智能
野蛮的大西瓜1 小时前
开源呼叫中心中,如何将ASR与IVR菜单结合,实现动态的IVR交互
人工智能·机器人·自动化·音视频·信息与通信
CountingStars6191 小时前
目标检测常用评估指标(metrics)
人工智能·目标检测·目标跟踪
tangjunjun-owen2 小时前
第四节:GLM-4v-9b模型的tokenizer源码解读
人工智能·glm-4v-9b·多模态大模型教程
冰蓝蓝2 小时前
深度学习中的注意力机制:解锁智能模型的新视角
人工智能·深度学习
橙子小哥的代码世界2 小时前
【计算机视觉基础CV-图像分类】01- 从历史源头到深度时代:一文读懂计算机视觉的进化脉络、核心任务与产业蓝图
人工智能·计算机视觉
新加坡内哥谈技术2 小时前
苏黎世联邦理工学院与加州大学伯克利分校推出MaxInfoRL:平衡内在与外在探索的全新强化学习框架
大数据·人工智能·语言模型
fanstuck3 小时前
Prompt提示工程上手指南(七)Prompt编写实战-基于智能客服问答系统下的Prompt编写
人工智能·数据挖掘·openai
lovelin+v175030409663 小时前
安全性升级:API接口在零信任架构下的安全防护策略
大数据·数据库·人工智能·爬虫·数据分析
wydxry3 小时前
LoRA(Low-Rank Adaptation)模型微调
深度学习