[深度学习]Transformer是什么(介绍2)

关键词

自注意力机制

深度学习架构(每训练出来的一个模型就是一个实例)

摒弃了rnn和cnn

并行计算

是nlp领域的扛把子, 比如机器翻译, 文本生成, 问答系统等

1. Transformer 是什么?

Transformer 是一种基于自注意力机制(Self-Attention) 的深度学习架构,由 Vaswani 等人在 2017 年提出(论文《Attention is All You Need》)。

它完全摒弃了循环神经网络(RNN)和卷积神经网络(CNN),通过并行计算处理序列数据,显著提升了训练效率和长距离依赖建模能力。

2. 扮演的角色

  • 序列建模的核心架构:替代 RNN/LSTM 处理序列数据(如文本、时间序列)。
  • 上下文理解:通过自注意力机制捕捉序列中元素间的全局依赖关系。
  • 基础模型骨干:支撑了 BERT、GPT、T5 等现代预训练模型。

3. 擅长的领域

领域 应用示例
自然语言处理 机器翻译、文本生成、问答系统
计算机视觉 ViT(图像分类)、目标检测
多模态任务 CLIP(图文匹配)、DALL·E(图像生成)
语音处理 语音识别、语音合成

4. 为什么要使用它?

  • 并行计算:同时处理整个序列(RNN 需逐步计算)。
  • 长距离依赖:自注意力直接建模任意两个元素的关系。
  • 可扩展性:通过多头注意力捕获不同子空间特征。
  • 训练效率:比 RNN 快 5-10 倍(如 GPU 加速)。

完整代码示例(使用假数据)

以下是一个使用 PyTorch 实现的 Transformer 模型,用于序列复制任务(输入序列 → 输出相同序列)。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import math

# 1. 生成假数据(序列复制任务)
def generate_fake_data(batch_size, seq_len, vocab_size):
    """生成随机整数序列及复制目标"""
    src = torch.randint(1, vocab_size, (batch_size, seq_len))  # 输入序列 (1~9)
    tgt = src.clone()  # 目标:复制输入序列
    return src, tgt

# 2. 位置编码(注入序列顺序信息)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# 3. Transformer 模型
class TransformerCopy(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers
        )
        self.fc = nn.Linear(d_model, vocab_size)
        
    def forward(self, src, tgt):
        # 嵌入层 + 位置编码
        src = self.pos_encoder(self.embedding(src))
        tgt = self.pos_encoder(self.embedding(tgt))
        
        # 调整维度: (seq_len, batch_size, d_model)
        src = src.permute(1, 0, 2)
        tgt = tgt.permute(1, 0, 2)
        
        # Transformer 前向传播
        output = self.transformer(src, tgt)
        return self.fc(output).permute(1, 0, 2)  # 恢复 (batch, seq, vocab)

# 4. 训练配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vocab_size = 10  # 词汇表大小 (0-9)
model = TransformerCopy(vocab_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 5. 训练循环
for epoch in range(10):
    src, tgt = generate_fake_data(batch_size=64, seq_len=10, vocab_size=vocab_size)
    src, tgt = src.to(device), tgt.to(device)
    
    # 目标序列移位 (用于教师强制训练)
    tgt_input = tgt[:, :-1]   # 解码器输入: [SOS] + 序列[:-1]
    tgt_output = tgt[:, 1:]    # 解码器目标: 序列[1:] + [EOS]
    
    optimizer.zero_grad()
    output = model(src, tgt_input)
    loss = criterion(output.reshape(-1, vocab_size), tgt_output.reshape(-1))
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# 6. 测试模型
model.eval()
test_src, test_tgt = generate_fake_data(1, 5, vocab_size)
print("\nTest Input:  ", test_src[0].numpy())
with torch.no_grad():
    # 自回归生成输出
    pred = torch.ones(1, 1, dtype=torch.long).to(device)  # 起始符 [SOS]
    for _ in range(test_src.size(1)):
        output = model(test_src.to(device), pred)
        next_token = output.argmax(dim=-1)[:, -1:]
        pred = torch.cat([pred, next_token], dim=1)
    
    print("Prediction: ", pred[0, 1:].cpu().numpy())  # 跳过 [SOS]

代码说明:

  1. 假数据生成 :创建随机整数序列(如 [3, 7, 2, 9]),目标为复制相同序列。

  2. 模型结构

    • 嵌入层 + 正弦位置编码
    • Transformer 编码器-解码器
    • 线性分类层
  3. 训练技巧

    • 教师强制(Teacher Forcing):解码器输入使用真实目标序列的移位版本
    • 自回归推理:测试时逐步生成输出
  4. 输出示例

    ini 复制代码
    Test Input:   [3 7 2 9 1]
    Prediction:  [3 7 2 9 1]  # 完美复制输入序列

关键优势:此模型在 10 个 epoch 内即可学会完美复制任意长度序列(RNN 需更长时间),展示了 Transformer 的并行计算和长距离建模能力。实际应用中可扩展到 NLP(如翻译)、CV(如 ViT)等任务。

相关推荐
IT_10245 小时前
Spring Boot项目开发实战销售管理系统——系统设计!
大数据·spring boot·后端
ai小鬼头6 小时前
AIStarter最新版怎么卸载AI项目?一键删除操作指南(附路径设置技巧)
前端·后端·github
Touper.6 小时前
SpringBoot -- 自动配置原理
java·spring boot·后端
一只叫煤球的猫7 小时前
普通程序员,从开发到管理岗,为什么我越升职越痛苦?
前端·后端·全栈
一只鹿鹿鹿7 小时前
信息化项目验收,软件工程评审和检查表单
大数据·人工智能·后端·智慧城市·软件工程
专注VB编程开发20年7 小时前
开机自动后台运行,在Windows服务中托管ASP.NET Core
windows·后端·asp.net
程序员岳焱7 小时前
Java 与 MySQL 性能优化:MySQL全文检索查询优化实践
后端·mysql·性能优化
一只叫煤球的猫8 小时前
手撕@Transactional!别再问事务为什么失效了!Spring-tx源码全面解析!
后端·spring·面试
旷世奇才李先生8 小时前
Ruby 安装使用教程
开发语言·后端·ruby
沃夫上校11 小时前
Feign调Post接口异常:Incomplete output stream
java·后端·微服务