一、输入部分:词嵌入与位置编码
输入部分是Transformer处理原始文本的第一步,负责将离散的文本符号转化为包含语义和位置信息的连续向量。
1. 词嵌入(embeddings
类)
-
作用:将文本中的每个词(用索引表示)映射到高维向量空间,捕捉词的语义信息。
-
核心代码解析:
pythonclass embeddings(nn.Module): def __init__(self, vocab_size, d_model): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) # 词嵌入层 def forward(self, x): # 乘以缩放系数√d_model,控制嵌入向量的方差 return self.embedding(x) * math.sqrt(self.d_model)
-
关键细节:
-
输入
x
是词索引张量(形状:[batch_size, seq_len]
),输出是词嵌入向量(形状:[batch_size, seq_len, d_model]
)。 -
乘以
math.sqrt(d_model)
的原因:补偿Xavier初始化的非正态分布特性,使嵌入向量保持合理的方差,避免后续计算中数值过大或过小。
-
2. 位置编码(positional_encoding
类)
-
作用:Transformer没有循环结构,需通过位置编码向模型注入词的位置信息,使模型感知词在序列中的顺序。
-
核心代码解析:
pythonclass positional_encoding(nn.Module): def __init__(self, d_model, dropout, max_len=100): super().__init__() self.droupout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) # 位置编码矩阵 position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # 位置索引 _2i = torch.arange(0, d_model, 2).float() # 偶数维度索引 # 偶数维度用正弦函数,奇数维度用余弦函数 pe[:, 0::2] = torch.sin(position / (10000 **(_2i / d_model))) pe[:, 1::2] = torch.cos(position / (10000** (_2i / d_model))) pe = pe.unsqueeze(0) # 增加batch维度 self.register_buffer('pe', pe) # 非参数化缓冲区,不参与训练 def forward(self, x): x = x + self.pe[:, :x.size(1)] # 与词嵌入相加(广播机制) return self.droupout(x)
-
关键细节:
-
位置编码公式:对于位置
pos
和维度i
,偶数i
用sin(pos/10000^(i/d_model))
,奇数i
用cos(pos/10000^(i/d_model))
。 -
优势:能表示任意长度的序列(通过外推),且相邻位置的编码具有相似性。
-
输出形状与词嵌入相同(
[batch_size, seq_len, d_model]
),与词嵌入向量逐元素相加后经dropout输出。
-
二、核心机制:注意力机制
注意力机制是Transformer的核心,用于捕捉序列中不同词之间的依赖关系(如"猫"和"它"的指代关系)。
1. 基础注意力计算(attention
函数)
-
作用:通过query(查询)、key(键)、value(值)计算注意力分布,输出加权求和的向量。
-
核心代码解析:
pythondef attention(query, key, value, mask=None, dropout=None): d_k = query.shape[-1] # 每个头的维度 # 计算注意力分数:(Q*K^T)/√d_k(缩放点积) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # 掩码处理(如屏蔽填充词或未来词) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # 掩码位置置为负无穷 # 计算注意力权重(softmax归一化) p_attn = torch.softmax(scores, dim=-1) # dropout防止过拟合 if dropout is not None: p_attn = dropout(p_attn) # 加权求和(注意力权重 * value) return torch.matmul(p_attn, value), p_attn
-
关键细节:
-
缩放因子
√d_k
的作用:当d_k
较大时,点积结果可能过大,导致softmax梯度消失,缩放后可稳定数值范围。 -
掩码(
mask
):用于屏蔽无效信息(如编码器中屏蔽填充词,解码器中屏蔽未来词),确保模型只关注有效位置。
-
2. 多头注意力(multi_head_attn
类)
-
作用:将注意力机制分为多个"头"(head),并行计算不同子空间的注意力,捕捉更丰富的依赖关系。
-
核心代码解析:
pythonclass multi_head_attn(nn.Module): def __init__(self, d_model, n_head, dropout=0.1): super().__init__() assert d_model % n_head == 0 # d_model必须能被头数整除 self.n_head = n_head # 头数 self.d_k = d_model // n_head # 每个头的维度 self.linears = clones(nn.Linear(d_model, d_model), 4) # 4个线性层(Q、K、V投影+输出投影) self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, mask=None): if mask is not None: mask = mask.unsqueeze(0) # 扩展掩码维度以适配多头 batch_size = query.size(0) # 1. 线性投影并拆分多头:[batch, seq_len, d_model] → [batch, n_head, seq_len, d_k] query, key, value = [ model(x).reshape(batch_size, -1, self.n_head, self.d_k).transpose(1, 2) for model, x in zip(self.linears, [query, key, value]) ] # 2. 计算多头注意力 attn, self.attn_weights = attention(query, key, value, mask=mask, dropout=self.dropout) # 3. 合并多头:[batch, n_head, seq_len, d_k] → [batch, seq_len, d_model] attn = attn.transpose(1, 2).reshape(batch_size, -1, self.n_head * self.d_k) # 4. 输出投影 return self.linears[-1](attn)
-
关键细节:
-
多头拆分:将
d_model
维度拆分为n_head
个d_k
维度(d_model = n_head * d_k
),每个头独立计算注意力。 -
优势:并行学习不同的注意力模式(如语法依赖、语义关联),提升模型表达能力。
-
三、编码器部分
编码器负责对输入序列进行特征提取,由N
个相同的编码器层堆叠而成。
1. 前馈网络(FeedForward
类)
-
作用:对注意力输出进行非线性变换,增强模型对复杂模式的拟合能力。
-
核心代码解析:
pythonclass FeedForward(nn.Module): def __init__(self, d_model, dff, dropout=0.1): super().__init__() self.linear1 = nn.Linear(d_model, dff) # 升维:d_model → dff self.linear2 = nn.Linear(dff, d_model) # 降维:dff → d_model self.dropout = nn.Dropout(dropout) def forward(self, x): # 线性变换→ReLU激活→dropout→线性变换 return self.linear2(self.dropout(torch.relu(self.linear1(x))))
-
关键细节:
- 输入输出维度均为
d_model
,中间通过dff
(通常为4*d_model
)升维,引入非线性(ReLU)后再降维。
- 输入输出维度均为
2. 层归一化(layer_norm
类)
-
作用:对每一层的输出进行归一化,使数据分布稳定,加速训练收敛。
-
核心代码解析:
pythonclass layer_norm(nn.Module): def __init__(self, size, eps=1e-6): super().__init__() self.a = nn.Parameter(torch.ones(size)) # 缩放参数(可学习) self.b = nn.Parameter(torch.zeros(size)) # 偏移参数(可学习) self.eps = eps # 防止除零 def forward(self, x): mean = x.mean(dim=-1, keepdim=True) # 沿最后一维(特征维度)计算均值 std = x.std(dim=-1, keepdim=True) # 沿最后一维计算标准差 return self.a * (x - mean) / (std + self.eps) + self.b # 归一化+缩放偏移
3. 子层连接(sub_layer_conncetion
类)
-
作用:将注意力层/前馈网络与残差连接、层归一化结合,缓解深层网络的梯度消失问题。
-
核心代码解析:
pythonclass sub_layer_conncetion(nn.Module): def __init__(self, size, dropout=0.1): super().__init__() self.norm = layer_norm(size) # 层归一化 self.dropout = nn.Dropout(dropout) # dropout def forward(self, x, sub_layer): # 残差连接:x + 子层输出(子层输入先归一化) return x + self.dropout(sub_layer(self.norm(x)))
-
流程 :输入
x
先经层归一化,再送入子层(注意力或前馈网络),子层输出经dropout后与原始x
残差相加。
4. 编码器层(encoder_layer
类)
-
作用:编码器的基本单元,包含"多头自注意力"和"前馈网络"两个子层。
-
核心代码解析:
pythonclass encoder_layer(nn.Module): def __init__(self, size, self_attn, feed_forward, dropout): super().__init__() self.self_attn = self_attn # 多头自注意力 self.feed_forward = feed_forward # 前馈网络 self.sub_layer = clones(sub_layer_conncetion(size, dropout), 2) # 2个子层连接 def forward(self, x, mask): # 第1个子层:多头自注意力(输入x既是Q、K,也是V) x = self.sub_layer[0](x, lambda x: self.self_attn(x, x, x, mask)) # 第2个子层:前馈网络 x = self.sub_layer[1](x, lambda x: self.feed_forward(x)) return x
5. 编码器(encoder
类)
-
作用 :堆叠
N
个编码器层(通常N=6
),对输入序列进行深度编码。 -
核心代码解析:
pythonclass encoder(nn.Module): def __init__(self, layer, N): super().__init__() self.layers = clones(layer, N) # 克隆N个编码器层 self.norm = layer_norm(layer.size) # 最终层归一化 def forward(self, x, mask): for layer in self.layers: x = layer(x, mask) # 依次通过每个编码器层 return self.norm(x) # 最终归一化输出
四、解码器部分
解码器负责根据编码器的输出(memory
)和目标序列,生成输出序列,由N
个相同的解码器层堆叠而成。
1. 解码器层(DecoderLayer
类)
-
作用:解码器的基本单元,包含3个子层:"解码器自注意力""编码器-解码器注意力""前馈网络"。
-
核心代码解析:
pythonclass DecoderLayer(nn.Module): def __init__(self, size, self_attn, src_attn, feed_forward, dropout): super().__init__() self.self_attn = self_attn # 解码器自注意力(目标序列内部) self.src_attn = src_attn # 编码器-解码器注意力(关联输入和目标) self.feed_forward = feed_forward # 前馈网络 self.sub_layers = clones(sub_layer_conncetion(size, dropout), 3) # 3个子层连接 def forward(self, x, memory, source_mask, target_mask): m = memory # 编码器输出 # 第1个子层:解码器自注意力(带目标掩码,防止关注未来词) x = self.sub_layers[0](x, lambda x: self.self_attn(x, x, x, target_mask)) # 第2个子层:编码器-解码器注意力(Q=解码器输出,K=V=编码器输出) x = self.sub_layers[1](x, lambda x: self.src_attn(x, m, m, source_mask)) # 第3个子层:前馈网络 x = self.sub_layers[2](x, lambda x: self.feed_forward(x)) return x
2. 解码器(Decoder
类)
-
作用 :堆叠
N
个解码器层(通常N=6
),生成目标序列的特征表示。 -
核心代码解析:
pythonclass Decoder(nn.Module): def __init__(self, layer, N): super().__init__() self.layers = clones(layer, N) # 克隆N个解码器层 self.norm = layer_norm(layer.size) # 最终层归一化 def forward(self, x, memory, source_mask, target_mask): for layer in self.layers: x = layer(x, memory, source_mask, target_mask) # 依次通过每个解码器层 return self.norm(x) # 最终归一化输出
五、生成器(Generator
类)
-
作用:将解码器的输出转化为词表上的概率分布,生成最终的输出序列。
-
核心代码解析:
pythonclass Generator(nn.Module): def __init__(self, d_model, vocab_size): super().__init__() self.linear = nn.Linear(d_model, vocab_size) # 映射到词表大小 def forward(self, x): # log_softmax便于计算交叉熵损失 return torch.log_softmax(self.linear(x), dim=-1)
-
输出 :形状为
[batch_size, seq_len, vocab_size]
,每个位置对应词表中所有词的对数概率。
六、关键工具函数
-
clones
函数 :克隆N
个相同的模块(如编码器层、线性层),确保参数独立但结构相同。pythondef clones(module, N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
-
掩码函数:
subsequent_mask(size)
:生成下三角掩码,用于解码器自注意力,防止模型关注未来的词(如翻译时"我吃"不能关注"饭"来预测"吃")。
七、整体流程总结
-
输入处理:文本→词嵌入(+缩放)→+位置编码→输入向量。
-
编码器 :输入向量经
N
个编码器层(多头自注意力+前馈网络)→memory
(编码后的输入特征)。 -
解码器 :目标序列经词嵌入+位置编码后,与
memory
一起输入N
个解码器层(解码器自注意力+编码器-解码器注意力+前馈网络)→解码特征。 -
生成器:解码特征→词表概率分布→输出序列。
通过上述模块的协作,Transformer能够高效捕捉序列中的长距离依赖,在机器翻译、文本摘要等任务中表现优异。