《动手学深度学习 Pytorch版》 10.5 多头注意力

多头注意力(multihead attention):用独立学习得到的 h 组不同的线性投影(linear projections)来变换查询、键和值,然后并行地送到注意力汇聚中。最后,将这 h 个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。

对于 h 个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。

10.5.1 模型

用数学语言描述多头注意力:

h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p \boldsymbol{h}_i=f(\boldsymbol{W}_i^{(q)}\boldsymbol{q},\boldsymbol{W}_i^{(k)}\boldsymbol{k},\boldsymbol{W}_i^{(v)}\boldsymbol{v})\in\R^p hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rp

参数字典:

  • f f f 表示注意力汇聚函数

  • q ∈ R d q \boldsymbol{q}\in\R^{d_q} q∈Rdq、 k ∈ R d k \boldsymbol{k}\in\R^{d_k} k∈Rdk 和 v ∈ R d v \boldsymbol{v}\in\R^{d_v} v∈Rdv 分别是查询、键和值

  • W i ( q ) ∈ R p d × d q \boldsymbol{W}_i^{(q)}\in\R^{p_d\times d_q} Wi(q)∈Rpd×dq、 W i ( k ) ∈ R p k × d k \boldsymbol{W}_i^{(k)}\in\R^{p_k\times d_k} Wi(k)∈Rpk×dk 和 W i ( v ) ∈ R p v × d v \boldsymbol{W}_i^{(v)}\in\R^{p_v\times d_v} Wi(v)∈Rpv×dv 均为可学习参数

多头注意力的输出需要经过另一个线性转换:

y = [ h 1 ⋮ h h ] ∈ R p o y= \begin{bmatrix} \boldsymbol{h}_1\\ \vdots\\ \boldsymbol{h}_h \end{bmatrix} \in\R^{p_o} y= h1⋮hh ∈Rpo

python 复制代码
import math
import torch
from torch import nn
from d2l import torch as d2l

10.5.2 实现

在实现过程中通常选择缩放点积注意力作为每一个注意力头。为了避免计算代价和参数代价的大幅增长,设定 p q = p k = p v = p o / h p_q=p_k=p_v=p_o/h pq=pk=pv=po/h。值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 p q h = p k h = p v h = p o p_qh=p_kh=p_vh=p_o pqh=pkh=pvh=po,则可以并行计算 h 个头。在下面的实现中, p o p_o po 是通过参数 num_hiddens 指定的。

MultiHeadAttention 类将使用下面定义的两个转置函数,transpose_output 函数反转了 transpose_qkv 函数的操作。转来转去是为了避免 for 循环。

python 复制代码
#@save
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size,查询或者"键-值"对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者"键-值"对的个数,num_heads,num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size,num_heads,查询或者"键-值"对的个数,num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者"键-值"对的个数,num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)
python 复制代码
#@save
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状: (batch_size,查询或者"键-值"对的个数,num_hiddens)
        # valid_lens 的形状: (batch_size,) 或 (batch_size,查询的个数)
        # 经过变换后,输出的queries,keys,values 的形状: (batch_size*num_heads,查询或者"键-值"对的个数,num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)
python 复制代码
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
python 复制代码
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])

练习

(1)分别可视化这个实验中的多个头的注意力权重。

python 复制代码
d2l.show_heatmaps(attention.attention.attention_weights.reshape((2, 5, 4, 6)),
                  xlabel='Keys', ylabel='Queries', figsize=(5,5))


(2)假设有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢?

不会,略。

相关推荐
m0_743106461 小时前
【论文笔记】MV-DUSt3R+:两秒重建一个3D场景
论文阅读·深度学习·计算机视觉·3d·几何学
m0_743106461 小时前
【论文笔记】TranSplat:深度refine的camera-required可泛化稀疏方法
论文阅读·深度学习·计算机视觉·3d·几何学
井底哇哇4 小时前
ChatGPT是强人工智能吗?
人工智能·chatgpt
Coovally AI模型快速验证4 小时前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
AI浩5 小时前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer
可为测控5 小时前
图像处理基础(4):高斯滤波器详解
人工智能·算法·计算机视觉
一水鉴天5 小时前
为AI聊天工具添加一个知识系统 之63 详细设计 之4:AI操作系统 之2 智能合约
开发语言·人工智能·python
倔强的石头1065 小时前
解锁辅助驾驶新境界:基于昇腾 AI 异构计算架构 CANN 的应用探秘
人工智能·架构
佛州小李哥6 小时前
Agent群舞,在亚马逊云科技搭建数字营销多代理(Multi-Agent)(下篇)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
IE066 小时前
深度学习系列75:sql大模型工具vanna
深度学习