基于BERT和GPT2的实现来理解Transformer的结构和原理

Transformer

核心就是编码器和解码器,简单理解:编码器就是特征提取,解码器就是特征还原。

Transformer 完整架构

Transformer最初是一个Encoder-Decoder架构,用于机器翻译任务:

复制代码
输入序列 → [Encoder] → 编码表示 → [Decoder] → 输出序列

1. 原始Transformer结构

复制代码
┌─────────────────────────────────────────────┐
│                 Transformer                  │
├─────────────────────┬───────────────────────┤
│      Encoder        │       Decoder         │
├─────────────────────┼───────────────────────┤
│  Multi-Head         │  Masked Multi-Head    │
│  Self-Attention     │  Self-Attention       │
│        ↓            │         ↓             │
│  Add & Norm         │  Add & Norm           │
│        ↓            │         ↓             │
│  Feed Forward       │  Multi-Head           │
│        ↓            │  Cross-Attention      │
│  Add & Norm         │         ↓             │
│        ↓            │  Add & Norm           │
│   (重复N次)         │         ↓             │
│                     │  Feed Forward         │
│                     │         ↓             │
│                     │  Add & Norm           │
│                     │    (重复N次)          │
└─────────────────────┴───────────────────────┘

2. 核心组件详解

2.1 自注意力机制(Self-Attention)
python 复制代码
# 核心公式:Attention(Q,K,V) = softmax(QK^T/√d_k)V

# BERT中的使用(双向注意力)
attention_mask = data["attention_mask"]  # [1,1,1,0,0] 标记真实token
# 可以看到所有位置的信息

# GPT2中的使用(因果注意力)
# 使用下三角mask,只能看到当前位置之前的信息
2.2 多头注意力(Multi-Head Attention)
复制代码
Multi-Head = Concat(head_1, head_2, ..., head_h)W^O
其中 head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

每个头关注不同的语义信息:

  • 头1:可能关注语法关系
  • 头2:可能关注语义相似性
  • 头3:可能关注位置关系
2.3 位置编码(Positional Encoding)

Transformer没有循环结构,需要位置信息:

python 复制代码
# BERT的位置编码
position_ids = torch.arange(seq_length)
position_embeddings = self.position_embeddings(position_ids)

# 原始Transformer使用正弦位置编码
PE(pos,2i) = sin(pos/10000^(2i/d_model))
PE(pos,2i+1) = cos(pos/10000^(2i/d_model))

3. BERT:只用Encoder

BERT使用了Transformer的Encoder部分,实现双向理解:

python 复制代码
# demo_5/net.py 的实现
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # BERT是12层Encoder堆叠
        self.fc = torch.nn.Linear(768, 2)  # 768是隐藏维度
    
    def forward(self, input_ids, attention_mask, token_type_ids):
        # BERT的三个输入
        # input_ids: token的ID [batch, seq_len]
        # attention_mask: 标记padding [batch, seq_len]
        # token_type_ids: 区分句子A/B [batch, seq_len]
        
        with torch.no_grad():
            out = pretrained(input_ids, attention_mask, token_type_ids)
        
        # 取[CLS]的表示做分类
        out = self.fc(out.last_hidden_state[:,0])
        return out

BERT的特点

  • 双向注意力:每个位置都能看到全文
  • MLM预训练:随机mask 15%的token进行预测
  • NSP任务:判断两个句子是否相邻

4. GPT2:只用Decoder

GPT2使用了Transformer的Decoder部分(去掉Cross-Attention):

python 复制代码
# demo_8/train.py 的实现
model = AutoModelForCausalLM.from_pretrained(...)  # 因果语言模型

def collate_fn(data):
    data = tokenizer.batch_encode_plus(data, ...)
    # 关键:标签就是输入向右偏移一位
    data['labels'] = data['input_ids'].clone()
    return data

GPT2的特点

  • 单向注意力:使用因果mask,只能看到之前的token
  • 自回归生成:逐个token生成
  • 统一架构:12/24/48层Decoder堆叠

5. 注意力掩码对比

python 复制代码
# BERT的注意力掩码(可以看到所有位置)
attention_mask = [
    [1, 1, 1, 1],
    [1, 1, 1, 1],
    [1, 1, 1, 1],
    [1, 1, 1, 1]
]

# GPT2的因果掩码(只能看到之前的位置)
causal_mask = [
    [1, 0, 0, 0],
    [1, 1, 0, 0],
    [1, 1, 1, 0],
    [1, 1, 1, 1]
]

6. 前馈网络(Feed Forward)

两个模型都使用相同的FFN结构:

python 复制代码
FFN(x) = max(0, xW1 + b1)W2 + b2
# 通常:d_model=768, d_ff=3072

7. 层归一化和残差连接

python 复制代码
# 每个子层都有
output = LayerNorm(x + Sublayer(x))

总结对比

组件 原始Transformer BERT GPT2
架构 Encoder-Decoder Encoder only Decoder only
层数 6+6 12/24 12/24/48
注意力 双向+单向 双向 单向(因果)
预训练 监督翻译 MLM+NSP 语言建模
应用 序列到序列 理解任务 生成任务

实际应用示例

BERT处理流程

复制代码
输入: "这个产品[MASK]好用" 
→ Tokenize: [101, 2110, 782, 103, 1962, 102]
→ 12层Encoder双向编码
→ 输出: 每个位置的768维表示
→ 预测[MASK]: "很"

GPT2生成流程

复制代码
输入: "今天天气"
→ Tokenize: [791, 1921, 1921, 2698]
→ 12层Decoder单向编码
→ 预测下一个: "很"
→ 继续预测: "好"
→ 最终: "今天天气很好"

Transformer的革命性在于完全基于注意力机制,抛弃了RNN/CNN,实现了并行计算和长距离依赖建模。BERT和GPT2分别展示了其在理解和生成任务上的强大能力。

相关推荐
九年义务漏网鲨鱼1 小时前
【大模型学习 | MINIGPT-4原理】
人工智能·深度学习·学习·语言模型·多模态
元宇宙时间1 小时前
Playfun即将开启大型Web3线上活动,打造沉浸式GameFi体验生态
人工智能·去中心化·区块链
开发者工具分享1 小时前
文本音频违规识别工具排行榜(12选)
人工智能·音视频
产品经理独孤虾2 小时前
人工智能大模型如何助力电商产品经理打造高效的商品工业属性画像
人工智能·机器学习·ai·大模型·产品经理·商品画像·商品工业属性
老任与码2 小时前
Spring AI Alibaba(1)——基本使用
java·人工智能·后端·springaialibaba
蹦蹦跳跳真可爱5892 小时前
Python----OpenCV(图像増强——高通滤波(索贝尔算子、沙尔算子、拉普拉斯算子),图像浮雕与特效处理)
人工智能·python·opencv·计算机视觉
雷羿 LexChien2 小时前
从 Prompt 管理到人格稳定:探索 Cursor AI 编辑器如何赋能 Prompt 工程与人格风格设计(上)
人工智能·python·llm·编辑器·prompt
两棵雪松3 小时前
如何通过向量化技术比较两段文本是否相似?
人工智能
heart000_13 小时前
128K 长文本处理实战:腾讯混元 + 云函数 SCF 构建 PDF 摘要生成器
人工智能·自然语言处理·pdf
敲键盘的小夜猫3 小时前
LLM复杂记忆存储-多会话隔离案例实战
人工智能·python·langchain