动手学深度学习------从0开始实现Transformer
原文链接 : zh-v2.d2l.ai/chapter_att...
简介
本文参考《动手学深度学习》中的内容,基于pytorch从0开始实现完整的transformer模型代码,并对于代码中重要部分添加了详细的注释,以此来进一步加深对于代码以及模型结构的理解。阅读本文之前,需要对于transformer模型架构有一定了解。
transformer架构图
带掩码的mask操作
在计算注意力分数的时候,并非所有KV都会被纳入到注意力汇聚中,因为有的文本序列可能是没有意义的token 为了仅将有意义的词元作为值来获取注意力汇聚, 可以指定一个有效序列长度
一般来说,对于一个形状为(batch_size, query_size, key_size)
的注意力分数来说,第三个维度中的值就代表着某一个查询在不同KV上的注意力分数 。我们一般利用形如(batch_size,query_size)
或者(batch_size,)
的valid_lens来进行mask操作
python
import math
import torch
from torch import nn
from d2l import torch as d2l
import torch.nn.functional as F
# 掩蔽mask操作
# 在一般的Transformer实现中,掩码机制通常是在计算注意力分数(QKV的点积注意力)之后应用的。
# 在自注意力机制中,注意力分数计算是在Q(查询)、K(键)和V(数值)之间进行的,然后通过点积操作来计算注意力分数。
# 这种顺序确保了在计算注意力分数时,模型可以访问所有键和值,然后通过掩码来限制查询关注的范围。这是Transformer中实现注意力机制的典型顺序
def masked_softmax(X, valid_lens):
# X:3d张量 valid_lens:1d或者2d张量
if valid_lens is None:
return F.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
# X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
X = X.reshape(-1, shape[-1])
for i in range(len(valid_lens)):
X[i, valid_lens[i]:] = -1e6
return F.softmax(X.reshape(shape), dim=-1)
print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])))
#tensor([[[0.4523, 0.5477, 0.0000, 0.0000],
# [0.4412, 0.5588, 0.0000, 0.0000]],
#
# [[0.3466, 0.3465, 0.3069, 0.0000],
# [0.3222, 0.4589, 0.2189, 0.0000]]])
缩放点积注意力
使用点积可以得到计算效率更高的评分函数,但是点积操作要求查询和键具有相同的长度d
例如基于n个查询和m个键-值对计算注意力 Q(n,d) K(m,d) V(m,v)
python
class DotProductAttention(nn.Module):
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
# queries(batch_size, 查询个数, d)
# keys(batch_size, 键值对个数, d)
# values(batch_size, 键值对个数, 值的维度)
# valid_lens(batch_size,) or (batch_size, 查询的个数)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# torch.bmm()用于执行批量矩阵乘法
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
# scores:(batch_size, 查询个数, 键值对个数)
attention_weight = masked_softmax(scores, valid_lens)
# attention_weight(batch_size, 查询个数, 键值对个数)
# 最后结果:(batch_size, 查询个数, 值的维度)
return torch.bmm(self.dropout(attention_weight), values)
多头注意力机制
允许注意力机制组合使用查询、键和值的不同子空间表示(对应不同是注意力头),可能是有益的与其只用一个注意力汇聚,倒不如通过线性投影去学习h组不同的(q,k,v) 最后将h组注意力汇聚结合在一起,并再通过一个可学习的线性投影进行变换,以得到最终的输出
在此处的多头注意力代码中,是通过将隐藏层中的维度大小进行分割实现的。
原始的注意力头中的Q:(batch_size, 查询or键值对数量(两者相同), num_hiddens)
多头注意力机制的注意力头中的Q:(batch_size, 查询or键值对数量(两者相同), new_num_hiddens * num_heads)
其中num_heads表示注意力头的数量
为了可以实现并行计算,我们可以将num_heads这个维度放到第一个维度中去,即变为(batch_size * num_heads, 查询or键值对数量(两者相同), new_num_hiddens)
放在batch_size这个维度上之后,便可以将不同的head看做是不同批次中的注意力头,可以实现并行计算
最终结果是:(batch_size * num_heads, 查询or键值对数量(两者相同), new_num_hiddens)
经过简单变化,再将多头注意力分数汇聚到一个头中 ,即(batch_size, 查询or键值对数量(两者相同), num_hiddens)
python
# QKV一开始的维度都是(batch_size, Q/K/V_size, d)
# 首先会经过线性层将维度d=>num_hiddens
class MultiHeadAttention(nn.Module):
"""多头注意力机制"""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = DotProductAttention(dropout)
# 维度变换
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
# 多头注意力输出汇聚
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# q,k,v的形状
# (batch_size, 查询or键值对数量(两者相同), num_hiddens)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
valid_lens = valid_lens.repeat_interleave(
repeats=self.num_heads, dim=0
)
# 由于使用多头注意力机制,所以valid_lens的第一个维度(batch_size)也要变成
# batch_size * num_heads
output = self.attention(queries, keys, values, valid_lens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
# 为了实现多头注意力的并行计算,需要写两个转置函数
def transpose_qkv(X, num_heads):
# 输入X的形状:(batch_size,查询或者"键-值"对的个数,num_hiddens)
# 输出X的形状:(batch_size,查询或者"键-值"对的个数,num_heads,num_hiddens/num_heads)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# 输出X的形状:(batch_size,num_heads,查询或者"键-值"对的个数,num_hiddens/num_heads)
X = X.permute(0, 2, 1, 3)
# 最终输出的形状:(batch_size*num_heads,查询或者"键-值"对的个数,num_hiddens/num_heads)
# 相当于把不同的head视作不同的batch中的数据,实现了并行计算的效果
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):
# X的形状:(batch_size * num_heads, 查询或者"键-值"对的个数, num_hiddens / num_heads)
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
# X形状:(batch_size, num_heads, 查询或者"键-值"对的个数, num_hiddens / num_heads)
X = X.permute(0, 2, 1, 3)
# X形状:(batch_size, 查询或者"键-值"对的个数, num_heads, num_hiddens / num_heads)
return X.reshape(X.shape[0], X.shape[1], -1)
# X形状:(batch_size, 查询或者"键-值"对的个数, num_hiddens)
num_hiddens = 100
num_heads = 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens,
num_heads, 0)
print(attention.eval())
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
print(attention(X, Y, Y, valid_lens).shape)
# 总结一下:初始的qkv维度都是(batch_size, 查询数量(一般和键值对数量相同), 维度d)
# 由于需要使用多头注意力机制,所以需要经过一个线性层,将其维度变为hidden_size(其实一般和d也是一样的)
# 经过线性变换后,需要将qkv拆分成num_heads个小的qkv
# 拆分的过程其实就是改变张量形状的过程(batch_size, 查询数量, num_hiddens)=>(batch_size*num_head, 查询数量, num_hiddens / num_heads)
# 将num_head放在batch_size这一维度,巧妙地实现了并行计算的效果,相当于把每个头都当成了一个batch
# 最终输出结果的形状为(batch_size * num_head, 查询数量, num_hiddens / num_heads)
# 最终把num_heads这个维度放到最后一个维度上去,得到输出(batch_size, 查询数量, num_hiddens)
# 最终需要汇聚多个头的注意力输出(其实已经汇聚了,但是还是要经过一个线性层),利用线性层将最后的num_hiddens这个维度
# 转换为指定的维度大小
# 自注意力机制
# 具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。
# 由于查询、键和值来自同一组输入,因此被称为 自注意力(self-attention)
位置编码
原论文中使用的是基于正弦和余弦 的位置编码,位置编码在输入X的基础上,加上一个位置编码信息P,即X------>X + P
。具体编码公式为:
python
# 位置编码
class PositionalEncoding(nn.Module):
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
self.P = torch.zeros(1, max_len, num_hiddens)
X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1)
X = X / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, X.shape[1], :].to(X.device)
return self.dropout(X)
基于位置的前馈网络
此处代码比较简单,主要是让X通过两个线性层和一个激活函数层,变换过程中X的第三个维度大小变化:ffn_num_input->ffn_num_output
python
# 输入X:(batch_size, 时间步数(序列长度), 特征维度)
# 输出Y:(batch_size, 时间步数(序列长度), ffn_num_outputs)
class PositionWiseFFN(nn.Module):
def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
super(PositionWiseFFN, self).__init__(**kwargs)
self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
self.relu = nn.ReLU()
self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
def forward(self, X):
return self.dense2(self.relu(self.dense1(X)))
残差连接与层规范化(layer normalization)
首先进行残差连接,然后再进行层归一化 操作。层归一化操作可以直接使用nn.LayerNorm(norm_shape)
函数实现。norm_shape表示被归一化的维度大小。例如X的形状为(batch_size, seq_length, num_hiddens)
,则若norm_shape=[num_hiddens]
,则表示在X的第三个维度上进行层归一化,若norm_shape=[seq_length, num_hiddens]
,则表示在第二和第三个维度上进行归一化
python
# 残差连接与layer normalization(层规范化)
# 层规范化是基于一个批次内的特征维度进行规范化, 批量规范化是基于不同批次的相同特征进行规范化
class AddNorm(nn.Module):
"""残差连接后进行层规范化"""
def __init__(self, normalized_shape, dropout, **kwargs):
super(AddNorm, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
self.ln = nn.LayerNorm(normalized_shape)
def forward(self, X, Y):
# X是前一层网络的输入, Y是前一层网路的输出
# 先执行残差连接, 然后再进行层规范化,output = ln(X + Y)
# print(X.shape, Y.shape)
return self.ln(self.dropout(Y) + X)
编码块
EncoderBlock类包含两个子层:1.多头自注意力 和 2.基于位置的前馈网络 ,这两个子层都使用了残差连接 和紧随的层归一化,直接将前面定义的方法拿来用即可。相比于解码器简单很多
python
class EncoderBlock(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
dropout, use_bias=False, **kwargs):
super(EncoderBlock, self).__init__(**kwargs)
# 此处的ffn_num_input,ffn_num_hiddens,qkv_size,num_hiddens应该都是相同
# norm_shape可能是一个二维张量
self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens,
num_heads, dropout, use_bias)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(norm_shape, dropout)
def forward(self, X, valid_lens):
# return self.attention(X, X, X, valid_lens)
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))
# X = torch.ones((2, 100, 24))
# valid_lens = torch.tensor([3, 2])
encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
# encoder_blk.eval()
# print(encoder_blk(X, valid_lens).shape)
编码器
由num_layer 个编码块组成,在经过编码器之前,需要先经过归一化和位置编码。首先利用nn.Embedding()函数将输入转为张量表示,然后进行归一化,最后再加上位置编码作为编码块的输入。
python
# 下面实现Transformer中的编码器(由num_layers个EncoderBlock组成)
# 由于位置编码使用的是固定位置编码,且范围为[-1,1],所以要对输入嵌入进行缩放
# 然后再和位置编码进行相加
class TransformerEncoder(nn.Module):
"""transformer编码器"""
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, num_layers, dropout, use_bias=False, **kwargs):
super(TransformerEncoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
# nn.Embedding:嵌入层通常用于将离散的整数或类别型数据(如单词、标签等)映射到连续的低维向量空间中
# num_hiddens:嵌入层的输出维度,它表示每个输入类别或词汇将被映射到一个多少维度的向量空间中
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block" + str(i),
EncoderBlock(key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_input,
ffn_num_hiddens, num_heads, dropout, use_bias))
def forward(self, X, valid_lens, *args):
# 要对输入X进行归一化处理,然后再和位置编码相加
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
# self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
# self.attention_weights[i] = blk.attention.attention.attention_weights
return X
encoder = TransformerEncoder(
200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
encoder.eval()
# print(encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape)
解码器
在DecoderBlock类中实现的每个层包含了三个子层:1.解码器自注意力层 2.编码-解码交叉注意力层 3.基于位置的前馈网络层
在训练 阶段,解码器的输入和编码器一样,是整个序列,在计算注意力分数时候,需要注意使用掩码操作,即某个token的注意力输出只取决于其前面的token,即计算注意力分数时,KV只会来自前面的tokens
在预测阶段,输出序列的token是逐个生成的,故在任何解码器时间步中,只有生成的词元才能用于解码器的自注意力计算中,所以不需要进行掩码
python
# 解码器实现
# 与编码器类似,上述三个子层后也会有残差连接+归一化层
# 值得注意的是,在解码器的自注意力中
# 由于在预测阶段,输出序列的token是逐个生成的,故在任何解码器时间步中,只有生成的词元才能用于解码器的自注意力计算中
# 在第一个注意力层中,qkv都来自上一个子层的输出
# 为了在解码器中保留"自回归的属性",其掩蔽自注意力设定了参数dec_valid_lens,以便任何查询都只会与解码器中所有已经生成
# 词元的位置进行注意力计算
class DecoderBlock(nn.Module):
def __init__(self, key_Size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
dropout, i, **kwargs):
# 参数i表示这是当前解码器中的第i个块
super(DecoderBlock, self).__init__(**kwargs)
self.i = i
self.attention1 = MultiHeadAttention(key_Size, query_size, value_size, num_hiddens,
num_heads, dropout)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.attention2 = MultiHeadAttention(key_Size, query_size, value_size, num_hiddens,
num_heads, dropout)
self.addnorm2 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(norm_shape, dropout)
def forward(self, X, state):
# state[0]:编码器的输出?
enc_outputs, enc_valid_lens = state[0], state[1]
# 此处与编码器很不一样!!!
# 训练阶段,输出序列的全部词元都在同一时间处理
# 训练的时候根本不需要管state[2][self.i]
if state[2][self.i] is None:
key_values = X
else:
key_values = torch.cat((state[2][self.i], X), dim=1)
# state[2][self.i]包含着直到当前时间步第i个块解码的输出表示
# 具体来说,假设我们需要预测5个token长度的
# 当我们已经预测了2个token时,对于第一个decoder块来说
# 其输入即为[<sos>,token1,token2]
# 但是我们传入的X是token2,其网络块内部保存着state[2][i](对于第一块来说,就是[<sos>,token1])
# 所以此处需要 key_values = torch.cat((state[2][self.i], X), dim=1)
# 而对于训练来说,可以直接并行计算(训练的时候,需要注意掩码),所以无需用到state[2][i]这个变量
state[2][self.i] = key_values
# 如果是训练阶段
if self.training:
batch_size, num_steps, _ = X.shape # 分别将X.shape的三个维度的大小赋值给变量
# 由此可见X是当前编码器的输入,需要根据X的实际长度去设置dec_valid_lens
dec_valid_lens = torch.arange(
1, num_steps + 1, device=X.device
).repeat(batch_size, 1)
else:
dec_valid_lens = None
# 自注意力
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
# 交叉注意力
# Q来自解码器上一个网络块的输出,KV来自编码器的输出
Y2 = self.attention2(X, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state
decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0)
decoder_blk.eval()
X = torch.ones((2, 100, 24))
state = [encoder_blk(X, None), None, [None]]
解码器
与编码器中的类似
todo
python
class TransformerDecoder(nn.Module):
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, num_layers, dropout, **kwargs):
super(TransformerDecoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block" + str(i),
DecoderBlock(key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens,
num_heads, dropout, i))
# 编码器的最后一层需要通过一个线性层将维度映射到vocab_size,以实现单词预测
self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self, enc_outputs, enc_valid_lens, *args):
return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
def forward(self, X, state):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
return self.dense(X), state