一、Transformer 核心概念
Transformer 是 2017 年提出的基于自注意力机制的序列建模架构,核心是抛弃 RNN/CNN 的顺序依赖,用注意力同时捕捉全局信息,整体分为 Encoder(编码器,负责理解输入) 和 Decoder(解码器,负责生成输出) 两大部分。
关键核心模块包括:
- 位置编码:弥补 Transformer 无顺序感知的缺陷,给每个位置的 token 加位置信息;
- 缩放点积自注意力(Scaled Dot-Product Attention):核心中的核心,计算[查询(Q)- 键(K)- 值(V)]的关联度;
- 多头注意力(Multi-Head Attention):拆分注意力为多个子空间,捕捉不同维度的关联;
- 前馈网络(Feed Forward Network):对注意力输出做非线性变换;
- 残差连接 + 层归一化:防止梯度消失,稳定训练。
二、核心数学公式
1. 缩放点积自注意力
符号解读:
- Q(Query):查询矩阵,维度 [batch,seqlen,dk],代表「当前 token 要找什么」;
- K(Key):键矩阵,维度 [batch,seqlen,dk],代表「所有 token 能提供什么」;
- V(Value):值矩阵,维度 [batch,seqlen,dv],代表「所有 token 的实际内容」;
:Q/K 的维度(通常取 64),
是缩放因子, 防止
过大时内积值溢出,导致 softmax 饱和;
- QKT:计算 Q 和 K 的相似度;
- softmax:将相似度归一化为 0-1 的权重;
- 最终输出:权重乘以 V,得到加权后的 token 内容。
2. 多头注意力
符号解读:
- h:头数,将 Q/K/V 拆分为 h 个子矩阵;
:第 i 个头的投影矩阵,用于拆分维度;
- Concat:拼接所有头的输出;
:最终投影矩阵,将拼接后的维度还原。
3. 位置编码
Transformer 本身无法感知序列顺序,因此给每个位置的每个维度 i 加位置信息,使用正余弦交替的方法捕捉不同频率的位置特征,兼容长序列。
三、代码解释(PyTorch 版)
模块一:导入核心库
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
模块二:位置编码
python
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
# 初始化位置编码矩阵,shape [max_len, d_model]
pe = torch.zeros(max_len, d_model)
# 生成位置序列 [max_len, 1]
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# 计算分母项,避免重复计算
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
# 偶数维度用sin,奇数维度用cos
pe[:, 0::2] = torch.sin(pos * div_term)
pe[:, 1::2] = torch.cos(pos * div_term)
# 增加batch维度,shape [1, max_len, d_model]
pe = pe.unsqueeze(0)
# 注册为非参数的缓冲区(不参与训练)
self.register_buffer('pe', pe)
def forward(self, x):
# x: [batch, seq_len, d_model]
# 给输入添加位置编码(只取对应序列长度的部分)
x = x + self.pe[:, :x.size(1), :]
return x
模块三:缩放点积自注意力
python
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super().__init__()
def forward(self, q, k, v, mask=None):
# q/k/v: [batch, head, seq_len, d_k]
d_k = q.size(-1)
# 1. 计算Q*K^T / sqrt(d_k),shape [batch, head, seq_len, seq_len]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
# 2. 应用mask(比如Decoder的掩码,防止看未来token)
if mask is not None:
# mask为True的位置设为负无穷,softmax后权重为0
scores = scores.masked_fill(mask == 0, -1e9)
# 3. softmax计算权重
attn_weights = F.softmax(scores, dim=-1)
# 4. 权重乘以V,得到最终输出
output = torch.matmul(attn_weights, v)
return output, attn_weights
模块四:多头注意力模块
python
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads # 头数
self.d_model = d_model # 总维度
self.d_k = d_model // n_heads # 每个头的维度
# 定义Q/K/V的投影矩阵(将d_model拆分为n_heads*d_k)
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
# 最终输出的投影矩阵
self.w_o = nn.Linear(d_model, d_model)
# 缩放点积注意力
self.attention = ScaledDotProductAttention()
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 1. 投影并拆分多头:[batch, seq_len, d_model] → [batch, n_heads, seq_len, d_k]
q = self.w_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
k = self.w_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
v = self.w_v(v).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# 2. 计算缩放点积注意力
attn_output, attn_weights = self.attention(q, k, v, mask)
# 3. 拼接多头:[batch, n_heads, seq_len, d_k] → [batch, seq_len, d_model]
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 4. 最终投影
output = self.w_o(attn_output)
return output, attn_weights
模块五:前馈网络
python
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=2048):
super().__init__()
# 两层全连接:d_model → d_ff → d_model
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
# x: [batch, seq_len, d_model]
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
模块六:Encoder层
python
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.mha = MultiHeadAttention(d_model, n_heads) # 多头自注意力
self.ffn = FeedForward(d_model, d_ff) # 前馈网络
# 层归一化
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# dropout
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 1. 多头自注意力 + 残差连接 + 层归一化
attn_output, _ = self.mha(x, x, x, mask) # 自注意力:Q=K=V=x
attn_output = self.dropout1(attn_output)
x = self.norm1(x + attn_output) # 残差连接
# 2. 前馈网络 + 残差连接 + 层归一化
ffn_output = self.ffn(x)
ffn_output = self.dropout2(ffn_output)
x = self.norm2(x + ffn_output)
return x
模块七:Transformer
python
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_len=5000, dropout=0.1):
super().__init__()
self.d_model = d_model
# 词嵌入层:将token转为d_model维向量
self.embedding = nn.Embedding(vocab_size, d_model)
# 位置编码层
self.pos_encoding = PositionalEncoding(d_model, max_len)
# 多个Encoder层堆叠
self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# x: [batch, seq_len](输入的token序列)
batch_size, seq_len = x.size()
# 1. 词嵌入 + 位置编码(乘以sqrt(d_model)是论文中的小技巧,平衡嵌入和位置编码的幅度)
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
x = self.dropout(x)
# 2. 逐层通过Encoder
for layer in self.layers:
x = layer(x, mask)
return x
模块八:测试
python
if __name__ == "__main__":
# 超参数设置(和论文一致)
vocab_size = 10000 # 词汇表大小
d_model = 512 # 总维度
n_heads = 8 # 多头数
n_layers = 6 # Encoder层数
d_ff = 2048 # 前馈网络中间维度
# 创建模型
model = TransformerEncoder(vocab_size, d_model, n_heads, n_layers, d_ff)
# 生成测试输入:batch_size=2,seq_len=10的随机token序列
test_input = torch.randint(0, vocab_size, (2, 10))
# 前向传播
output = model(test_input)
print(f"输入shape: {test_input.shape}") # torch.Size([2, 10])
print(f"输出shape: {output.shape}") # torch.Size([2, 10, 512])
print("Transformer Encoder 运行成功!")
位置编码(无依赖)→ 缩放点积注意力(无依赖)→ 多头注意力(依赖缩放点积)→ 前馈网络(无依赖)→ 单层Encoder(依赖多头+前馈)→ 整体Encoder(依赖位置编码+多层单层Encoder)。
把 "无意义的整数 token" 变成 "有语义的向量",且向量能体现:
- 单个 token 本身的含义------通过词嵌入;
- token 在序列中的位置------通过位置编码;
- 该 token 与上下文其他 token 的关联------通过注意力机制。
运行结果
输入shape: torch.Size([2, 10])
输出shape: torch.Size([2, 10, 512])
四、总结
- 核心本质:Transformer 用自注意力机制替代 RNN 的顺序计算,能并行处理序列,同时捕捉全局关联;
- 关键模块:位置编码补顺序、多头注意力分维度捕捉关联、残差 + 层归一化稳训练;
- 数学核心:缩放点积注意力是核心公式,通过 Q/K/V 的矩阵运算实现加权求和,多头是对该公式的扩展。
感谢大家的观看,transform在文本分类和情感分析的应用比较广泛,上面的代码只是简单的解释方法原理。如果大家需要具体的实例讲解,欢迎大家评论区留言!