NLP:Transformer各子模块作用(特别分享1)

本文目录:

  • [一、Transformer 整体架构图](#一、Transformer 整体架构图)
  • [二、 输入层各子模块](#二、 输入层各子模块)
    • [**(一)输入嵌入(Input Embedding)**](#(一)输入嵌入(Input Embedding))
    • [**(二)位置编码(Positional Encoding)**](#(二)位置编码(Positional Encoding))
  • 三、编码器各子模块
    • [**(一)多头自注意力机制(Multi-Head Self-Attention)**](#(一)多头自注意力机制(Multi-Head Self-Attention))
    • [**(二)前馈神经网络(Feed Forward Network)**](#(二)前馈神经网络(Feed Forward Network))
  • [四、 解码器各子模块](#四、 解码器各子模块)
    • [**(一)掩码多头注意力(Masked Multi-Head Attention)**](#(一)掩码多头注意力(Masked Multi-Head Attention))
    • [**(二)编码器-解码器注意力(Encoder-Decoder Attention)**](#(二)编码器-解码器注意力(Encoder-Decoder Attention))
  • [五、Transformer 各模块协同工作流程](#五、Transformer 各模块协同工作流程)
  • 六、各模块的核心作用总结
  • [七、 实际应用中的变体](#七、 实际应用中的变体)
  • 八、总结

前言:Transformer 是深度学习领域的革命性架构,彻底改变了NLP的发展方向。前面分享了Transformer的大概构建思路,本文特别分享Transformer的各子模块作用。

一、Transformer 整体架构图

首先,我们来看Transformer的整体结构,它主要由输入层、编码器、解码器和输出层 四部分组成:

复制代码
Input → Input Embedding → Positional Encoding → Encoder Stack → Decoder Stack → Output

二、 输入层各子模块

(一)输入嵌入(Input Embedding)

python 复制代码
def input_embedding(input_tokens):
    """
    将离散的token转换为连续的向量表示
    """
    # 作用:将词汇映射到高维空间,捕获语义信息
    # 实现:查找表(lookup table)或神经网络
    return embedding_matrix[input_tokens]

# 示例:将"apple"映射为[0.2, -0.5, 0.8, ...]的256维向量

(二)位置编码(Positional Encoding)

python 复制代码
def positional_encoding(embedding):
    """
    为序列添加位置信息(因为Transformer没有RNN的循环结构)
    """
    # 使用正弦和余弦函数生成位置信息
    # 公式:PE(pos, 2i) = sin(pos/10000^(2i/d_model))
    #       PE(pos, 2i+1) = cos(pos/10000^(2i/d_model))
    return embedding + position_matrix

# 为什么需要:自注意力机制本身没有位置概念,需要额外添加

三、编码器各子模块

经典的Transformer架构中的Encoder模块包含6个Encoder Block. * 每个Encoder Block包含两个子模块, 分别是多头自注意力层, 和前馈全连接层。

(一)多头自注意力机制(Multi-Head Self-Attention)

python 复制代码
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        """
        d_model: 模型维度(如512)
        num_heads: 注意力头数(如8)
        """
        # 核心公式:Attention(Q, K, V) = softmax(QK^T/√d_k)V
        # 多头:将Q、K、V投影到多个子空间,捕获不同方面的信息
        
    def forward(self, x):
        # 1. 线性投影得到Q、K、V
        q = self.wq(x)  # Query:当前要关注的词
        k = self.wk(x)  # Key:被关注的词  
        v = self.wv(x)  # Value:实际的信息内容
        
        # 2. 计算注意力权重
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        attention_weights = F.softmax(scores, dim=-1)
        
        # 3. 加权求和
        output = torch.matmul(attention_weights, v)
        return output

注意力机制的作用

  • 捕获长距离依赖:直接计算任意两个位置的关系
  • 并行计算:比RNN更高效
  • 可解释性:注意力权重显示模型关注点

(二)前馈神经网络(Feed Forward Network)

python 复制代码
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)  # 扩展维度(如2048)
        self.linear2 = nn.Linear(d_ff, d_model)  # 恢复维度
        self.relu = nn.ReLU()
    
    def forward(self, x):
        # 作用:提供非线性变换,增强模型表达能力
        return self.linear2(self.relu(self.linear1(x)))
## **(二) 残差连接和层归一化(Add & Norm)**
```python
def residual_connection(x, sublayer):
    """
    残差连接:解决深层网络梯度消失问题
    """
    return x + sublayer(x)  # 原始输入 + 子层输出

def layer_norm(x):
    """
    层归一化:稳定训练过程,加速收敛
    """
    return (x - x.mean()) / (x.std() + eps) * gamma + beta

四、 解码器各子模块

经典的Transformer架构中的Decoder模块包含6个Decoder Block. * 每个Decoder Block包含3个子模块, 分别是多头自注意力层, Encoder-Decoder Attention层, 和前馈全连接层。

解码器与编码器结构类似,但注意力层有一些关键区别,如下:

(一)掩码多头注意力(Masked Multi-Head Attention)

python 复制代码
class MaskedMultiHeadAttention(nn.Module):
    def forward(self, x):
        # 在解码时,不能让当前词看到后面的词
        # 使用上三角掩码(upper triangular mask)
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        scores = scores.masked_fill(mask, float('-inf'))  # 将未来位置设为负无穷
        return super().forward(x)

(二)编码器-解码器注意力(Encoder-Decoder Attention)

python 复制代码
class EncoderDecoderAttention(nn.Module):
    def forward(self, decoder_x, encoder_output):
        # Query来自解码器,Key和Value来自编码器输出
        q = self.wq(decoder_x)  # Query:当前要生成的词
        k = self.wk(encoder_output)  # Key:编码器的信息
        v = self.wv(encoder_output)  # Value:编码器的信息
        
        # 这让解码器可以关注输入序列的相关部分
        return attention(q, k, v)

五、Transformer 各模块协同工作流程

python 复制代码
def transformer_forward(input_seq, target_seq):
    # 编码器流程
    enc_output = input_embedding(input_seq) + positional_encoding
    for _ in range(N):  # 6个编码器层
        enc_output = encoder_layer(enc_output)
    
    # 解码器流程  
    dec_output = input_embedding(target_seq) + positional_encoding
    for _ in range(N):  # 6个解码器层
        dec_output = decoder_layer(dec_output, enc_output)
    
    return final_output

六、各模块的核心作用总结

模块 英文名 主要作用 为什么重要
输入嵌入 Input Embedding 将离散token转为连续向量 提供语义表示
位置编码 Positional Encoding 添加序列位置信息 弥补无序列结构的缺陷
自注意力 Self-Attention 计算词与词之间的关联度 捕获长距离依赖关系
多头注意力 Multi-Head Attention 从多个角度计算注意力 捕获不同类型的语义关系
前馈网络 Feed Forward 非线性变换 增强模型表达能力
残差连接 Residual Connection 直连通道 解决梯度消失问题
层归一化 Layer Normalization 标准化激活值 稳定训练过程
掩码注意力 Masked Attention 防止信息泄露 保证自回归生成的性质
编码器-解码器注意力 Encoder-Decoder Attention 连接输入和输出 实现序列到序列的转换

七、 实际应用中的变体

(一)BERT(仅编码器)

python 复制代码
# 用于理解任务:文本分类、NER、情感分析
model = TransformerEncoderOnly()
output = model(input_text)  # 得到每个token的表示

(二)GPT(仅解码器)

python 复制代码
# 用于生成任务:文本生成、对话、创作
model = TransformerDecoderOnly()
output = model.generate(prompt)  # 自回归生成文本

(三)原始Transformer(编码器-解码器)

python 复制代码
# 用于序列到序列任务:翻译、摘要、问答
model = TransformerEncoderDecoder()
output = model.translate(english_text)  # 英译中

八、总结

Transformer的成功在于:

  1. 自注意力机制:彻底解决长距离依赖问题
  2. 并行计算:大幅提升训练效率
  3. 模块化设计:各司其职,协同工作
  4. 可扩展性:易于堆叠更多层数

每个子模块都承担着关键角色,共同构成了这个革命性的架构,为后来的BERT、GPT等模型奠定了坚实基础!

今天的分享到此结束。

相关推荐
数巨小码人8 分钟前
AI+数据库:国内DBA职业发展与国产化转型实践
数据库·人工智能·ai·dba
黑客影儿26 分钟前
使用UE5开发2.5D开放世界战略养成类游戏的硬件配置指南
开发语言·c++·人工智能·游戏·智能手机·ue5·游戏引擎
Coovally AI模型快速验证36 分钟前
YOLOv8-SMOT:基于切片辅助训练与自适应运动关联的无人机视角小目标实时追踪框架
人工智能·深度学习·yolo·计算机视觉·目标跟踪·无人机
新智元1 小时前
刚刚,英伟达新模型上线!4B 推理狂飙 53 倍,全新注意力架构超越 Mamba 2
人工智能·openai
新智元1 小时前
北大数学家终结 50 年猜想!一只蝴蝶翅膀,竟难倒菲尔兹奖得主
人工智能·openai
vivo互联网技术1 小时前
EMNLP 2025|vivo 等提出 DiMo-GUI:模态分治+动态聚焦,GUI 智能体推理时扩展的新范式
前端·人工智能·agent
机器之心1 小时前
热议!DeepSeek V3.1惊现神秘「极」字Bug,模型故障了?
人工智能·openai
wan5555cn1 小时前
AI 时代“驯导师”职业发展方向探究
大数据·人工智能·笔记·深度学习
Java中文社群1 小时前
超简单!手把手教你玩转ClaudeCode,无魔法不会员!
人工智能·程序员
算家计算1 小时前
算力暴增!英伟达发布新一代机器人超级计算机,巨量算力驱动物理AI革命
人工智能·云计算·nvidia