基于pytorch从0开始实现transformer

动手学深度学习------从0开始实现Transformer

原文链接 : zh-v2.d2l.ai/chapter_att...

简介

本文参考《动手学深度学习》中的内容,基于pytorch从0开始实现完整的transformer模型代码,并对于代码中重要部分添加了详细的注释,以此来进一步加深对于代码以及模型结构的理解。阅读本文之前,需要对于transformer模型架构有一定了解。

transformer架构图

带掩码的mask操作

在计算注意力分数的时候,并非所有KV都会被纳入到注意力汇聚中,因为有的文本序列可能是没有意义的token 为了仅将有意义的词元作为值来获取注意力汇聚, 可以指定一个有效序列长度

一般来说,对于一个形状为(batch_size, query_size, key_size)的注意力分数来说,第三个维度中的值就代表着某一个查询在不同KV上的注意力分数 。我们一般利用形如(batch_size,query_size)或者(batch_size,)的valid_lens来进行mask操作

python 复制代码
import math
import torch
from torch import nn
from d2l import torch as d2l
import torch.nn.functional as F



# 掩蔽mask操作


# 在一般的Transformer实现中,掩码机制通常是在计算注意力分数(QKV的点积注意力)之后应用的。
# 在自注意力机制中,注意力分数计算是在Q(查询)、K(键)和V(数值)之间进行的,然后通过点积操作来计算注意力分数。
# 这种顺序确保了在计算注意力分数时,模型可以访问所有键和值,然后通过掩码来限制查询关注的范围。这是Transformer中实现注意力机制的典型顺序
def masked_softmax(X, valid_lens):
    # X:3d张量 valid_lens:1d或者2d张量

    if valid_lens is None:
        return F.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0

        # X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        X = X.reshape(-1, shape[-1])
        for i in range(len(valid_lens)):
            X[i, valid_lens[i]:] = -1e6

    return F.softmax(X.reshape(shape), dim=-1)


print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])))

#tensor([[[0.4523, 0.5477, 0.0000, 0.0000],
#        [0.4412, 0.5588, 0.0000, 0.0000]],
#
#       [[0.3466, 0.3465, 0.3069, 0.0000],
#       [0.3222, 0.4589, 0.2189, 0.0000]]])

缩放点积注意力

使用点积可以得到计算效率更高的评分函数,但是点积操作要求查询和键具有相同的长度d

例如基于n个查询和m个键-值对计算注意力 Q(n,d) K(m,d) V(m,v)

python 复制代码
class DotProductAttention(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # queries(batch_size, 查询个数, d)
    # keys(batch_size, 键值对个数, d)
    # values(batch_size, 键值对个数, 值的维度)
    # valid_lens(batch_size,) or (batch_size, 查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # torch.bmm()用于执行批量矩阵乘法
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        # scores:(batch_size, 查询个数, 键值对个数)
        attention_weight = masked_softmax(scores, valid_lens)
        # attention_weight(batch_size, 查询个数, 键值对个数)
        # 最后结果:(batch_size, 查询个数, 值的维度)
        return torch.bmm(self.dropout(attention_weight), values)

多头注意力机制

允许注意力机制组合使用查询、键和值的不同子空间表示(对应不同是注意力头),可能是有益的与其只用一个注意力汇聚,倒不如通过线性投影去学习h组不同的(q,k,v) 最后将h组注意力汇聚结合在一起,并再通过一个可学习的线性投影进行变换,以得到最终的输出

在此处的多头注意力代码中,是通过将隐藏层中的维度大小进行分割实现的。

原始的注意力头中的Q:(batch_size, 查询or键值对数量(两者相同), num_hiddens)

多头注意力机制的注意力头中的Q:(batch_size, 查询or键值对数量(两者相同), new_num_hiddens * num_heads) 其中num_heads表示注意力头的数量

为了可以实现并行计算,我们可以将num_heads这个维度放到第一个维度中去,即变为(batch_size * num_heads, 查询or键值对数量(两者相同), new_num_hiddens)

放在batch_size这个维度上之后,便可以将不同的head看做是不同批次中的注意力头,可以实现并行计算

最终结果是:(batch_size * num_heads, 查询or键值对数量(两者相同), new_num_hiddens) 经过简单变化,再将多头注意力分数汇聚到一个头中 ,即(batch_size, 查询or键值对数量(两者相同), num_hiddens)

python 复制代码
# QKV一开始的维度都是(batch_size, Q/K/V_size, d)
# 首先会经过线性层将维度d=>num_hiddens

class MultiHeadAttention(nn.Module):
    """多头注意力机制"""

    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        
        # 维度变换
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        
        # 多头注意力输出汇聚
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # q,k,v的形状
        # (batch_size, 查询or键值对数量(两者相同), num_hiddens)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = valid_lens.repeat_interleave(
                repeats=self.num_heads, dim=0
            )
        # 由于使用多头注意力机制,所以valid_lens的第一个维度(batch_size)也要变成
        # batch_size * num_heads

        output = self.attention(queries, keys, values, valid_lens)

        output_concat = transpose_output(output, self.num_heads)

        return self.W_o(output_concat)


# 为了实现多头注意力的并行计算,需要写两个转置函数
def transpose_qkv(X, num_heads):
    # 输入X的形状:(batch_size,查询或者"键-值"对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者"键-值"对的个数,num_heads,num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size,num_heads,查询或者"键-值"对的个数,num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者"键-值"对的个数,num_hiddens/num_heads)
    # 相当于把不同的head视作不同的batch中的数据,实现了并行计算的效果

    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
    # X的形状:(batch_size * num_heads, 查询或者"键-值"对的个数, num_hiddens / num_heads)
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    # X形状:(batch_size, num_heads, 查询或者"键-值"对的个数, num_hiddens / num_heads)
    X = X.permute(0, 2, 1, 3)
    # X形状:(batch_size, 查询或者"键-值"对的个数, num_heads, num_hiddens / num_heads)
    return X.reshape(X.shape[0], X.shape[1], -1)
    # X形状:(batch_size, 查询或者"键-值"对的个数, num_hiddens)


num_hiddens = 100
num_heads = 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens,
                               num_heads, 0)
print(attention.eval())

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
print(attention(X, Y, Y, valid_lens).shape)


# 总结一下:初始的qkv维度都是(batch_size, 查询数量(一般和键值对数量相同), 维度d)
# 由于需要使用多头注意力机制,所以需要经过一个线性层,将其维度变为hidden_size(其实一般和d也是一样的)
# 经过线性变换后,需要将qkv拆分成num_heads个小的qkv
# 拆分的过程其实就是改变张量形状的过程(batch_size, 查询数量, num_hiddens)=>(batch_size*num_head, 查询数量, num_hiddens / num_heads)
# 将num_head放在batch_size这一维度,巧妙地实现了并行计算的效果,相当于把每个头都当成了一个batch
# 最终输出结果的形状为(batch_size * num_head, 查询数量, num_hiddens / num_heads)
# 最终把num_heads这个维度放到最后一个维度上去,得到输出(batch_size, 查询数量, num_hiddens)
# 最终需要汇聚多个头的注意力输出(其实已经汇聚了,但是还是要经过一个线性层),利用线性层将最后的num_hiddens这个维度
# 转换为指定的维度大小


# 自注意力机制
# 具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。
# 由于查询、键和值来自同一组输入,因此被称为 自注意力(self-attention)

位置编码

原论文中使用的是基于正弦和余弦 的位置编码,位置编码在输入X的基础上,加上一个位置编码信息P,即X------>X + P。具体编码公式为:

python 复制代码
# 位置编码


class PositionalEncoding(nn.Module):
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)

        self.P = torch.zeros(1, max_len, num_hiddens)

        X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1)
        X = X / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)

        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, X.shape[1], :].to(X.device)
        return self.dropout(X)

基于位置的前馈网络

此处代码比较简单,主要是让X通过两个线性层和一个激活函数层,变换过程中X的第三个维度大小变化:ffn_num_input->ffn_num_output

python 复制代码
# 输入X:(batch_size, 时间步数(序列长度), 特征维度)
# 输出Y:(batch_size, 时间步数(序列长度), ffn_num_outputs)

class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

残差连接与层规范化(layer normalization)

首先进行残差连接,然后再进行层归一化 操作。层归一化操作可以直接使用nn.LayerNorm(norm_shape)函数实现。norm_shape表示被归一化的维度大小。例如X的形状为(batch_size, seq_length, num_hiddens),则若norm_shape=[num_hiddens],则表示在X的第三个维度上进行层归一化,若norm_shape=[seq_length, num_hiddens],则表示在第二和第三个维度上进行归一化

python 复制代码
# 残差连接与layer normalization(层规范化)
# 层规范化是基于一个批次内的特征维度进行规范化, 批量规范化是基于不同批次的相同特征进行规范化

class AddNorm(nn.Module):
    """残差连接后进行层规范化"""

    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        # X是前一层网络的输入, Y是前一层网路的输出
        # 先执行残差连接, 然后再进行层规范化,output = ln(X + Y)
        # print(X.shape, Y.shape)
        return self.ln(self.dropout(Y) + X)

编码块

EncoderBlock类包含两个子层:1.多头自注意力2.基于位置的前馈网络 ,这两个子层都使用了残差连接 和紧随的层归一化,直接将前面定义的方法拿来用即可。相比于解码器简单很多

python 复制代码
class EncoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        # 此处的ffn_num_input,ffn_num_hiddens,qkv_size,num_hiddens应该都是相同
        # norm_shape可能是一个二维张量
        self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens,
                                            num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        # return self.attention(X, X, X, valid_lens)
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))


# X = torch.ones((2, 100, 24))
# valid_lens = torch.tensor([3, 2])
encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)


# encoder_blk.eval()
# print(encoder_blk(X, valid_lens).shape)

编码器

num_layer 个编码块组成,在经过编码器之前,需要先经过归一化和位置编码。首先利用nn.Embedding()函数将输入转为张量表示,然后进行归一化,最后再加上位置编码作为编码块的输入。

python 复制代码
# 下面实现Transformer中的编码器(由num_layers个EncoderBlock组成)
# 由于位置编码使用的是固定位置编码,且范围为[-1,1],所以要对输入嵌入进行缩放
# 然后再和位置编码进行相加

class TransformerEncoder(nn.Module):
    """transformer编码器"""

    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens

        # nn.Embedding:嵌入层通常用于将离散的整数或类别型数据(如单词、标签等)映射到连续的低维向量空间中
        # num_hiddens:嵌入层的输出维度,它表示每个输入类别或词汇将被映射到一个多少维度的向量空间中
        self.embedding = nn.Embedding(vocab_size, num_hiddens)

        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()

        for i in range(num_layers):
            self.blks.add_module("block" + str(i),
                                 EncoderBlock(key_size, query_size, value_size,
                                              num_hiddens, norm_shape, ffn_num_input,
                                              ffn_num_hiddens, num_heads, dropout, use_bias))

    def forward(self, X, valid_lens, *args):
        # 要对输入X进行归一化处理,然后再和位置编码相加
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        # self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            # self.attention_weights[i] = blk.attention.attention.attention_weights
        return X


encoder = TransformerEncoder(
    200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
encoder.eval()
# print(encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape)

解码器

在DecoderBlock类中实现的每个层包含了三个子层:1.解码器自注意力层 2.编码-解码交叉注意力层 3.基于位置的前馈网络层

训练 阶段,解码器的输入和编码器一样,是整个序列,在计算注意力分数时候,需要注意使用掩码操作,即某个token的注意力输出只取决于其前面的token,即计算注意力分数时,KV只会来自前面的tokens

预测阶段,输出序列的token是逐个生成的,故在任何解码器时间步中,只有生成的词元才能用于解码器的自注意力计算中,所以不需要进行掩码

python 复制代码
# 解码器实现

# 与编码器类似,上述三个子层后也会有残差连接+归一化层

# 值得注意的是,在解码器的自注意力中
# 由于在预测阶段,输出序列的token是逐个生成的,故在任何解码器时间步中,只有生成的词元才能用于解码器的自注意力计算中
# 在第一个注意力层中,qkv都来自上一个子层的输出
# 为了在解码器中保留"自回归的属性",其掩蔽自注意力设定了参数dec_valid_lens,以便任何查询都只会与解码器中所有已经生成
# 词元的位置进行注意力计算

class DecoderBlock(nn.Module):
    def __init__(self, key_Size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, i, **kwargs):
        # 参数i表示这是当前解码器中的第i个块
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = MultiHeadAttention(key_Size, query_size, value_size, num_hiddens,
                                             num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(key_Size, query_size, value_size, num_hiddens,
                                             num_heads, dropout)

        self.addnorm2 = AddNorm(norm_shape, dropout)

        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)

        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        # state[0]:编码器的输出?
        enc_outputs, enc_valid_lens = state[0], state[1]
        # 此处与编码器很不一样!!!
        # 训练阶段,输出序列的全部词元都在同一时间处理
        # 训练的时候根本不需要管state[2][self.i]

        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), dim=1)
        # state[2][self.i]包含着直到当前时间步第i个块解码的输出表示
        # 具体来说,假设我们需要预测5个token长度的
        # 当我们已经预测了2个token时,对于第一个decoder块来说
        # 其输入即为[<sos>,token1,token2]
        # 但是我们传入的X是token2,其网络块内部保存着state[2][i](对于第一块来说,就是[<sos>,token1])
        # 所以此处需要  key_values = torch.cat((state[2][self.i], X), dim=1)
        # 而对于训练来说,可以直接并行计算(训练的时候,需要注意掩码),所以无需用到state[2][i]这个变量

        state[2][self.i] = key_values

        # 如果是训练阶段
        if self.training:
            batch_size, num_steps, _ = X.shape  # 分别将X.shape的三个维度的大小赋值给变量

            # 由此可见X是当前编码器的输入,需要根据X的实际长度去设置dec_valid_lens
            dec_valid_lens = torch.arange(
                1, num_steps + 1, device=X.device
            ).repeat(batch_size, 1)

        else:
            dec_valid_lens = None

        # 自注意力
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)

        # 交叉注意力
        # Q来自解码器上一个网络块的输出,KV来自编码器的输出
        Y2 = self.attention2(X, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)

        return self.addnorm3(Z, self.ffn(Z)), state


decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0)
decoder_blk.eval()
X = torch.ones((2, 100, 24))
state = [encoder_blk(X, None), None, [None]]

解码器

与编码器中的类似

todo

python 复制代码
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block" + str(i),
                                 DecoderBlock(key_size, query_size, value_size, num_hiddens,
                                              norm_shape, ffn_num_input, ffn_num_hiddens,
                                              num_heads, dropout, i))
        #  编码器的最后一层需要通过一个线性层将维度映射到vocab_size,以实现单词预测
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens, *args):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))

        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)

        return self.dense(X), state
相关推荐
LNTON羚通3 小时前
摄像机视频分析软件下载LiteAIServer视频智能分析平台玩手机打电话检测算法技术的实现
算法·目标检测·音视频·监控·视频监控
哭泣的眼泪4085 小时前
解析粗糙度仪在工业制造及材料科学和建筑工程领域的重要性
python·算法·django·virtualenv·pygame
Microsoft Word6 小时前
c++基础语法
开发语言·c++·算法
天才在此6 小时前
汽车加油行驶问题-动态规划算法(已在洛谷AC)
算法·动态规划
莫叫石榴姐7 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
茶猫_8 小时前
力扣面试题 - 25 二进制数转字符串
c语言·算法·leetcode·职场和发展
肥猪猪爸10 小时前
使用卡尔曼滤波器估计pybullet中的机器人位置
数据结构·人工智能·python·算法·机器人·卡尔曼滤波·pybullet
readmancynn10 小时前
二分基本实现
数据结构·算法
萝卜兽编程10 小时前
优先级队列
c++·算法
盼海10 小时前
排序算法(四)--快速排序
数据结构·算法·排序算法