transformer代码注解

其中代码均来自李沐老师的动手学pytorch中。

python 复制代码
class PositionWiseFFN(nn.Module):
    '''
    ffn_num_inputs 4
    ffn_num_hiddens 4
    ffn_num_outputs 8
    '''
    def __init__(self,ffn_num_inputs,ffn_num_hiddens,ffn_num_outputs):
        super(PositionWiseFFN,self).__init__()
        self.dense1 = nn.Linear(ffn_num_inputs,ffn_num_hiddens)#4*4
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens,ffn_num_outputs)#4*8
    def forward(self,X):
        return self.dense2(self.relu(self.dense1(X)))
positionWiseFFN = PositionWiseFFN(4,4,8)
positionWiseFFN.eval()
positionWiseFFN(torch.ones(size=(2,3,4)))[0]

上面的代码为前馈神经网络结构,其实也就是一个全连接层。

python 复制代码
class AddNorm(nn.Module):
    def __init__(self,normalized_shape,dropout):
        super(AddNorm, self).__init__()
        self.dropout=nn.Dropout(dropout)
        self.layer_norm=nn.LayerNorm(normalized_shape=normalized_shape)
    def forward(self,x,y):
        return self.layer_norm(self.dropout(y)+x)
#比如[3, 4]或torch.Size([3, 4]),则会对网络最后的两维进行归一化,且要求输入数据的最后两维尺寸也是[3, 4]
add_norm = AddNorm(normalized_shape=[3,4],dropout=0.5)
add_norm.eval()
add_norm(torch.ones(size=(2,3,4)),torch.ones(size=(2,3,4)))

这里实现的是残差化和规范化。nn.LayerNorm(normalized_shape=normalized_shape)为layer规范化,其中normalized_shape为[3, 4],对网络最后的两维进行归一化。

python 复制代码
class MultiHeadAttention(nn.Module):
    def __init__(self,query_size,key_size,value_size,num_hiddens,num_heads,dropout,bias=False):
        super(MultiHeadAttention, self).__init__()
        self.num_heads=num_heads
        #用独立学习得到的 ℎ 组不同的线性投影(linear projections)来变换查询、键和值
        self.attention=d2l.torch.DotProductAttention(dropout)
        self.W_q=nn.Linear(query_size,num_hiddens,bias=bias)
        self.W_k=nn.Linear(key_size,num_hiddens,bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
#总之就是:我们的Q,K,V的embedding,怎么拆分成k个头的数据,然后放到一个大头中,一遍算出multi_head的值
#这里是一组QKV乘一组W 直接生成特征大小的结果,在切分成8份,放到batch里等价于并行计算
    def forward(self,queries,keys,values,valid_lens):
        # print('----')
        # print(queries)
        queries = transpose_qkv(self.W_q(queries),self.num_heads)
        # print(queries)
        # print('----')
        keys = transpose_qkv(self.W_k(keys),self.num_heads)
        values = transpose_qkv(self.W_v(values),self.num_heads)
        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(valid_lens,repeats=self.num_heads,dim=0)
        #valid_lens tensor([3, 3, 3, 3, 3, 2, 2, 2, 2, 2])
        # output的形状:(batch_size*num_heads,查询的个数,
        # num_hiddens/num_heads)
        # print(queries.shape)
        # print(keys.shape)
        '''
        queries-->torch.Size([10, 4, 20])
        keys----->torch.Size([10, 6, 20])
        两个批次,每次五个多头注意力,就一共会有十个注意力需要做。得出的矩阵为10*4*6,表示为10次注意力
        每个注意力query和key的矩阵为4*6
              keys  keys  keys  keys  keys  keys
        Query
        Query
        Query
        Query
        在经过mask时 需要将10*4*6的矩阵,转为二维矩阵,就是40*6。
        valid_lens首先会在上面的代码中,扩展至num_heads,然后会在masked_softmax中扩至40大小。
        
        '''
        output = self.attention(queries,keys,values,valid_lens)
        # print('-----')
        # print(output)
        # print('-----')
        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output,self.num_heads)
       # print(output_concat.shape)torch.Size([2, 4, 100])
        return self.W_o(output_concat)
def transpose_qkv(X,num_heads):
    # 2,6,100 2,4,100
    X = X.reshape(X.shape[0],X.shape[1],num_heads,-1)
    # 2,5,6,20  2,5,4,20
    # 输出X的形状: (batch_size,num_heads,查询或者"键-值"对的个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    #最终输出的形状: (batch_size * num_heads,查询或者"键-值"对的个数,num_hiddens/num_heads)
    #10,6,20  10,6,20
    return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X,num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1,num_heads,X.shape[1],X.shape[2])
    X = X.permute(0,2,1,3)
    return X.reshape(X.shape[0],X.shape[1],-1)
#在这里,我们设置head为5,也就是一共有5次self-attention。
num_hiddens,num_heads = 100,5
multiHeadAttention = MultiHeadAttention(num_hiddens,num_hiddens,num_hiddens,num_hiddens,5,0.5)
multiHeadAttention.eval()

batch_size,num_queries = 2,4
num_kvpairs,valid_lens = 6,torch.tensor([3,2])
#2,6,100 批次 句子长度 embedsize
Y = torch.ones(size=(batch_size,num_kvpairs,num_hiddens))
#2,4,100
X = torch.ones(size=(batch_size,num_queries,num_hiddens))
print(multiHeadAttention(X,Y,Y,valid_lens).shape)

首先我们设置head为5。num_hiddens可以理解为query或者key的大小,num_kvpairs表示每次注意力中key的数量,num_queries表示每次注意力中query的数量。value的数量与key的数量一样。另外一种理解就是,将X理解为批次*句子长度(单词的数量)*embedding size。每个单词对应一次查询。随后就是__init__,创建几个全连接层,对query、key、value进行变换,不同注意力的query、key和value,均不一样。

主要是实现图中红色部分。然后会调用forward函数,transpose_qkv函数进行切分,假定原本的输入为2 * 6 * 100,因为大小为两个批次,每个批次需要做五个注意力机制,每个注意力机制的key的数量为6,所以将输入为2 * 6 * 100,转换为10 * 6 * 20。意思就是10次注意力,每个注意力中的key为6个,每个key由20维度的向量表示。query同理。因为我们要并行计算,这样使用torch.bmm可以直接进行计算,计算得出Query和key矩阵。在上面的例子中,计算得出的为10 * 4 * 6大小的矩阵。

在训练时刻的mask中,首先会将结果转变为二维矩阵40 * 6,其中的每一行代表了query与不同key计算的结果,有时候query只能和部分key进行计算,比如:第二个词的query只能计算第一个词与第二个词的key,而之后key需要进行mask。我们会给定一个valid_lens 代表需要保留的计算结果。其中mask部分会调用以下代码:

python 复制代码
 mask = torch.arange((maxlen), dtype=torch.float32,device=X.device)[None, :] < valid_len[:, None]
 X[~mask] = value#value为极小值。

torch.arange((maxlen)会生成从0到5的矩阵,valid_len在之前会经过扩展为1*40大小的矩阵,然后转换为40 * 1的矩阵。最终的mask会变成40 * 6大小矩阵就像以下形式:

True,True,True,False,False

而最后两个False是需要进行mask的,X[~mask] = value将最后两个Fasle,变为负极小值,再经过softmax之后,结果将趋近于0,从而将其mask。然后与value相乘,得出结果为10 * 4* 20矩阵大小的结果,在经过变换,变为2 * 4* 100矩阵,最后再经过最后一次全连接层,然后输出结果。

python 复制代码
class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self,num_hiddens,dropout,max_len=1000):
        super(PositionalEncoding,self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros(size=(1,max_len,num_hiddens))
        X = torch.arange(max_len,dtype=torch.float32).reshape(-1,1)/torch.pow(1000,torch.arange(0,num_hiddens,2,dtype=torch.float32)/num_hiddens)
        self.P[:,:,0::2] = torch.sin(X)
        self.P[:,:,1::2] = torch.cos(X)
    def forward(self,X):
        X = X+self.P[:,:X.shape[1],:].to(X.device)
        return self.dropout(X)

主要实现位置编码,将基于正弦函数和余弦函数的固定位置编码公式进行实现,X就是将公式进行实现,P的大小为批次 * 输入模型单词可能最多的数量 * 每个单词的embedding size。我们可以假设P的大小为1 * 1000 * 32的矩阵,其中1000代表网络一次最多输入1000个词,每个词使用32维度向量表示。生成的X,是一个1000 * 16大小的矩阵,其中每一行的数值均不相同。P中每一行的偶数位置数据是由torch.sin(X)来生成的,奇数位置数据由torch.cos(X)生成。这样位置编码已经提前生成好了,在需要进行位置编码的时候,直接拿取前多少行,就行了。

python 复制代码
class EncoderBlock(nn.Module):
    """transformer编码器块"""
#EncoderBlock(query_size=24, key_size=24, value_size=24, num_hiddens=24, normalized_shape=[100, 24],ffn_num_inputs=24, ffn_num_hiddens=48, num_heads=8, dropout=0.5, use_bias=False)
    def __init__(self,query_size,key_size,value_size,num_hiddens,normalized_shape,ffn_num_inputs,ffn_num_hiddens,num_heads,dropout,use_bias=False):
        super(EncoderBlock,self).__init__()
        self.multihead_attention = MultiHeadAttention(key_size,query_size,value_size,num_hiddens,num_heads,dropout,use_bias)
        self.addnorm1 = AddNorm(normalized_shape,dropout)
        self.ffn = PositionWiseFFN(ffn_num_inputs,ffn_num_hiddens,num_hiddens)
        self.addnorm2 = AddNorm(normalized_shape,dropout)
    def forward(self,X,valid_lens):
        Y = self.addnorm1(X,self.multihead_attention(X,X,X,valid_lens))
        return self.addnorm2(Y,self.ffn(Y))

EncoderBlock对一个encoderBlock进行实现。先后经过,多头注意力机制,残差和规范化,前馈神经网络,残差和规范化,最后将结果输出。

python 复制代码
class TransformerEncoder(d2l.torch.Encoder):
    """transformer编码器"""
    def __init__(self,vocab_size,query_size,key_size,value_size,num_hiddens,normalized_shape,ffn_num_inputs,ffn_num_hiddens,num_heads,num_layers,dropout,use_bias=False):
        super(TransformerEncoder,self).__init__()
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size,num_hiddens)
        self.positionalEncoding = d2l.torch.PositionalEncoding(num_hiddens,dropout)
        self.encoder_blocks = nn.Sequential()
        for i in range(num_layers):
            self.encoder_blocks.add_module(f'encoder_block{i}',
            EncoderBlock(query_size,key_size,value_size,num_hiddens,normalized_shape,ffn_num_inputs,ffn_num_hiddens,num_heads,dropout,use_bias=use_bias))
    def forward(self, X,valid_lens, *args):。
        X = self.positionalEncoding(self.embedding(X)*math.sqrt(self.num_hiddens))
        self.attention_weights = [None]*len(self.encoder_blocks)
        for i,encoder_block in enumerate(self.encoder_blocks):
            X = encoder_block(X,valid_lens)
            self.attention_weights[i] = encoder_block.multihead_attention.attention.attention_weights
        return X

transformer编码器,对encoder进行堆叠,self.embedding(X)*math.sqrt(self.num_hiddens 主要因为embedding的值相对于位置编码比较小,乘以math.sqrt(self.num_hiddens,使得值与位置编码的值,差不多大小。

python 复制代码
class DecoderBlock(nn.Module):
    """解码器中第i个块"""
    #decoder_block = DecoderBlock(24,24,24,24,[100,24],24,48,8,0.5,0,use_bias=False)
    def __init__(self, query_size, key_size, value_size, num_hiddens, normalized_shape, ffn_num_inputs, ffn_num_hiddens,
                 num_heads, dropout, i, use_bias=False):
        super(DecoderBlock, self).__init__()
        self.i = i  # i表示这是第i个DecoderBlock块
        self.mask_multihead_attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens,
                                                                      num_heads, dropout, bias=use_bias)
        self.addnorm1 = AddNorm(normalized_shape, dropout)
        self.mutilhead_attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens,
                                                                 num_heads, dropout, bias=use_bias)
        self.addnorm2 = AddNorm(normalized_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_inputs, ffn_num_hiddens, num_hiddens)
        self.addnorm3 = AddNorm(normalized_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        # 训练阶段,输出序列的所有词元都在同一时间处理,
        # 因此state[2][self.i]初始化为None。
        # 预测阶段,输出序列是通过词元一个接着一个解码的,
        # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示
        # 训练时,由于每次都需要调用init_state函数,因此重新训练一个batch时,state[2]始终是一个None列表,
        # 当测试时,由于每次根据当前时间步的词元预测下一个词元时都不会重新调用init_state()函数,
        # 不会重新初始化state,因此state[2]里面保存的是之前时间步预测出来的词元信息(存的是decoder每层第一个掩码多头注意力state信息)
        if state[2][self.i] is None:
            keys_values = X
        else:
            keys_values = torch.cat([state[2][self.i], X], dim=1)
        state[2][self.i] = keys_values
        if self.training:
            #[2, 100, 24]
            batch_size, num_step, _ = X.shape
            # 训练时执行当前时间步的query时只看它前面的keys,values,不看它后面的keys,values。
            # 因为预测时是从左往右预测的,右边还没有预测出来,因此右侧的keys是没有的,看不到右侧的keys;
            # 训练时预测当前时间步词元能看到后面的目标词元,因此需要dec_valid_lens
            # dec_valid_lens的开头:(batch_size,num_steps),
            # 其中每一行是[1,2,...,num_steps]
            dec_valid_lens = torch.arange(1, num_step + 1, device=X.device).repeat(batch_size, 1)
            print(dec_valid_lens)
        else:
            # 测试时预测当前时间步的词元只能看到之前预测出来的词元,后面还没预测的词元还看不到,因此dec_valid_lens可以不需要
            dec_valid_lens = None
        # 自注意力
        X2 = self.mask_multihead_attention1(X, keys_values, keys_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        # 编码器-解码器注意力。
        # enc_outputs的开头:(batch_size,num_steps,num_hiddens)
        Y2 = self.mutilhead_attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

区分主要在两点,就是训练的时候,会执行if state[2][self.i] is None:,所以第一次多头注意力,输入的key和value,均为本身。而在预测阶段,第一次多头注意力输入的为之前生成

相关推荐
逢生博客18 分钟前
使用 Python 项目管理工具 uv 快速创建 MCP 服务(Cherry Studio、Trae 添加 MCP 服务)
python·sqlite·uv·deepseek·trae·cherry studio·mcp服务
xwz小王子20 分钟前
Nature Communications 面向形状可编程磁性软材料的数据驱动设计方法—基于随机设计探索与神经网络的协同优化框架
深度学习
堕落似梦24 分钟前
Pydantic增强SQLALchemy序列化(FastAPI直接输出SQLALchemy查询集)
python
生信碱移1 小时前
大语言模型时代,单细胞注释也需要集思广益(mLLMCelltype)
人工智能·经验分享·深度学习·语言模型·自然语言处理·数据挖掘·数据可视化
坐吃山猪1 小时前
Python-Agent调用多个Server-FastAPI版本
开发语言·python·fastapi
Bruce-li__2 小时前
使用Django REST Framework快速开发API接口
python·django·sqlite
小兜全糖(xdqt)2 小时前
python 脚本引用django中的数据库model
python·django
Arenaschi2 小时前
SQLite 是什么?
开发语言·网络·python·网络协议·tcp/ip
纪元A梦2 小时前
华为OD机试真题——推荐多样性(2025A卷:200分)Java/python/JavaScript/C++/C语言/GO六种最佳实现
java·javascript·c++·python·华为od·go·华为od机试题
硅谷秋水2 小时前
通过模仿学习实现机器人灵巧操作:综述(上)
人工智能·深度学习·机器学习·计算机视觉·语言模型·机器人