1. 引言:为什么是 Transformer?
想象一下,你正在一个嘈杂的咖啡馆里和朋友聊天。虽然周围有很多声音,但你的大脑能神奇地"聚焦"在朋友的话语上,忽略背景噪音。这种"选择性关注"的能力,就是 注意力机制(Attention Mechanism) 的核心思想。
2017年,Google的研究团队在论文《Attention Is All You Need》中提出了 Transformer 模型。它彻底抛弃了传统的循环神经网络(RNN),仅依靠注意力机制来处理序列数据(如文本、时间序列)。这一创新不仅让模型训练速度大幅提升,更在机器翻译、文本生成、语音识别等领域取得了突破性进展,成为当今深度学习领域的基石。
本文将带你深入浅出地理解 Transformer,并用一个股票价格预测的实战项目,手把手教你用 PyTorch 实现一个简化版的 Transformer 模型。
2. Transformer 核心思想:用"注意力"取代"循环"
2.1 传统 RNN 的困境
在 Transformer 出现之前,处理序列数据(如一句话、一段股价历史)的主流模型是循环神经网络(RNN)。RNN 像一条"传送带",信息必须按顺序一步步传递。这导致两个问题:
- 计算慢:无法并行处理,训练耗时。
- 长距离依赖弱:信息在长序列中传递时容易丢失或变形(梯度消失/爆炸)。
2.2 Transformer 的解决方案:全局注意力
Transformer 的答案很简单:为什么不一次性看完整个序列,再决定每个位置应该关注哪里?
这就像你读一篇文章时,不是逐字逐句线性阅读,而是先快速扫一眼全文,找到关键段落和它们之间的联系,再深入理解。Transformer 通过 自注意力(Self-Attention) 机制实现了这一点。
核心公式(缩放点积注意力):
Attention(Q, K, V) = softmax( (Q · K^T) / √d_k ) · V
这个公式看似复杂,但我们可以用生活中的例子来理解它的三个核心角色:
- Q (Query - 查询) :好比 "我想知道什么?"。例如,在预测明天股价时,模型会问:"基于过去30天的数据,哪些天的情况对预测明天最重要?"
- K (Key - 键) :好比 "我有什么标签?"。序列中的每一天(一个位置)都自带一些特征标签(如开盘价、成交量),K 就是这些标签的表示。
- V (Value - 值) :好比 "我的实际内容是什么?"。这是每个位置所包含的真实信息。
计算过程通俗解释:
- 匹配(Q·K^T):计算"当前疑问"与"每个历史位置的标签"的匹配程度(相似度)。
- 缩放与归一化(除以√d_k, softmax):对匹配分数进行缩放(防止数值过大)并转化为权重(所有权重和为1)。这步得到了一个"注意力分布",表示对于当前预测,应该给历史每一天分配多少注意力。
- 聚合(·V):用上一步得到的权重,对每个位置的"实际内容(V)"进行加权求和,得到最终的上下文表示。
多头注意力(Multi-Head Attention):为了让模型同时关注不同方面的信息(例如,既关注短期波动趋势,也关注长期周期),Transformer 将 Q、K、V 投影到多个不同的"子空间"(即多个头),并行计算注意力,最后把结果拼接起来。这就像让多个专家从不同角度分析同一份数据。
2.3 位置编码:给无位置的模型注入"顺序感"
注意力机制是"位置无关"的,打乱输入顺序,输出权重关系不变。但序列的顺序显然很重要("猫追老鼠"和"老鼠追猫"意思相反)。为此,Transformer 引入了 位置编码(Positional Encoding)。
它使用一组固定的正弦和余弦函数,为序列中的每个位置生成一个独一无二的"位置向量",然后把这个向量加到输入数据中。这样,模型在计算注意力时,就能间接感知到位置信息。
位置编码的数学原理可视化:
#mermaid-svg-vLCdqZkzC50peaVF{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-vLCdqZkzC50peaVF .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-vLCdqZkzC50peaVF .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-vLCdqZkzC50peaVF .error-icon{fill:#552222;}#mermaid-svg-vLCdqZkzC50peaVF .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-vLCdqZkzC50peaVF .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-vLCdqZkzC50peaVF .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-vLCdqZkzC50peaVF .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-vLCdqZkzC50peaVF .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-vLCdqZkzC50peaVF .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-vLCdqZkzC50peaVF .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-vLCdqZkzC50peaVF .marker{fill:#333333;stroke:#333333;}#mermaid-svg-vLCdqZkzC50peaVF .marker.cross{stroke:#333333;}#mermaid-svg-vLCdqZkzC50peaVF svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-vLCdqZkzC50peaVF p{margin:0;}#mermaid-svg-vLCdqZkzC50peaVF .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-vLCdqZkzC50peaVF .cluster-label text{fill:#333;}#mermaid-svg-vLCdqZkzC50peaVF .cluster-label span{color:#333;}#mermaid-svg-vLCdqZkzC50peaVF .cluster-label span p{background-color:transparent;}#mermaid-svg-vLCdqZkzC50peaVF .label text,#mermaid-svg-vLCdqZkzC50peaVF span{fill:#333;color:#333;}#mermaid-svg-vLCdqZkzC50peaVF .node rect,#mermaid-svg-vLCdqZkzC50peaVF .node circle,#mermaid-svg-vLCdqZkzC50peaVF .node ellipse,#mermaid-svg-vLCdqZkzC50peaVF .node polygon,#mermaid-svg-vLCdqZkzC50peaVF .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-vLCdqZkzC50peaVF .rough-node .label text,#mermaid-svg-vLCdqZkzC50peaVF .node .label text,#mermaid-svg-vLCdqZkzC50peaVF .image-shape .label,#mermaid-svg-vLCdqZkzC50peaVF .icon-shape .label{text-anchor:middle;}#mermaid-svg-vLCdqZkzC50peaVF .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-vLCdqZkzC50peaVF .rough-node .label,#mermaid-svg-vLCdqZkzC50peaVF .node .label,#mermaid-svg-vLCdqZkzC50peaVF .image-shape .label,#mermaid-svg-vLCdqZkzC50peaVF .icon-shape .label{text-align:center;}#mermaid-svg-vLCdqZkzC50peaVF .node.clickable{cursor:pointer;}#mermaid-svg-vLCdqZkzC50peaVF .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-vLCdqZkzC50peaVF .arrowheadPath{fill:#333333;}#mermaid-svg-vLCdqZkzC50peaVF .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-vLCdqZkzC50peaVF .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-vLCdqZkzC50peaVF .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-vLCdqZkzC50peaVF .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-vLCdqZkzC50peaVF .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-vLCdqZkzC50peaVF .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-vLCdqZkzC50peaVF .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-vLCdqZkzC50peaVF .cluster text{fill:#333;}#mermaid-svg-vLCdqZkzC50peaVF .cluster span{color:#333;}#mermaid-svg-vLCdqZkzC50peaVF div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-vLCdqZkzC50peaVF .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-vLCdqZkzC50peaVF rect.text{fill:none;stroke-width:0;}#mermaid-svg-vLCdqZkzC50peaVF .icon-shape,#mermaid-svg-vLCdqZkzC50peaVF .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-vLCdqZkzC50peaVF .icon-shape p,#mermaid-svg-vLCdqZkzC50peaVF .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-vLCdqZkzC50peaVF .icon-shape .label rect,#mermaid-svg-vLCdqZkzC50peaVF .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-vLCdqZkzC50peaVF .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-vLCdqZkzC50peaVF .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-vLCdqZkzC50peaVF :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 位置编码生成过程
示例:不同位置的编码
相似
相似
差异大
位置 1
编码向量 1
位置 2
编码向量 2
位置 3
编码向量 3
位置 100
编码向量 100
位置索引 pos
位置向量
pos, pos, ..., pos
维度索引 i
频率计算
10000^(2i/d_model)
正弦函数 sin()
余弦函数 cos()
除法 pos / freq
pos / freq
pos / freq
偶数维度 PE(pos, 2i)
奇数维度 PE(pos, 2i+1)
组合成位置编码向量
加到输入嵌入中
输入嵌入
带位置信息的输入
3. 项目实战:用 Transformer 预测股票价格
接下来,我们将把理论付诸实践,构建一个用于时间序列预测的 Transformer 模型,并以上证指数收盘价预测为例。
3.1 项目概述与数据准备
目标:使用过去 N 天(例如30天)的股票多项特征(开盘价、最高价、最低价、收盘价、成交量),来预测第 N+1 天的收盘价。
数据来源 :我们使用 akshare 库获取国内股票数据(以上证指数 sh000001 为例)。这是一个免费、强大的国内金融数据接口库。
核心步骤:
- 数据下载与预处理:下载历史数据,提取特征列,并进行归一化处理。
- 构建时间窗口 :将连续的时序数据转化为
(样本数, 时间步长, 特征数)的格式。 - 定义模型 :构建基于
TransformerEncoder的预测模型。 - 训练与评估:划分训练集和测试集,训练模型并评估其预测效果。
- 可视化结果:绘制预测曲线与损失曲线。
3.2 代码拆解与讲解
下面,我们将结合提供的代码,分模块讲解关键实现。
模块一:全局配置与数据预处理 (LOOK_BACK, preprocess_data)
python
LOOK_BACK = 30 # 使用过去30天的数据来预测下一天
FEATURE_COLS = ["open", "high", "low", "close", "volume"] # 使用的特征
def preprocess_data(df):
# 1. 列名映射与数据清洗
# 2. 特征归一化 (使用 MinMaxScaler,按列独立归一化)
# 3. 构建时间窗口样本 (X, y)
# X.shape: (样本数, LOOK_BACK, 特征数)
# y.shape: (样本数, ), 对应第 N+1 天的收盘价
LOOK_BACK:这是 Transformer 的"记忆长度"。模型在预测时,只能看到过去这LOOK_BACK步的信息。这个值需要根据数据特性调整。- 归一化:不同特征(如价格和成交量)数值尺度差异巨大,归一化到 0,1 区间有助于模型稳定训练。
- 构建样本:这是时间序列预测的标准操作。例如,用第1-30天数据预测第31天,用第2-31天数据预测第32天,以此类推。
模块二:模型核心------位置编码 (PositionalEncoding)
python
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度用sin
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度用cos
self.register_buffer('pe', pe) # 注册为不参与训练的缓冲区
def forward(self, x):
x = x + self.pe[:x.size(0), :] # 将位置编码加到输入上
return self.dropout(x)
- 正弦余弦函数:使用不同频率的正弦和余弦函数,可以生成唯一且平滑的位置编码,并能外推到比训练时更长的序列。
register_buffer:将位置编码矩阵注册为"缓冲区",这意味着它是模型的一部分,会随模型保存和加载,但不参与梯度更新(因为位置编码是固定的先验知识)。
模块三:时间序列 Transformer 模型 (TimeSeriesTransformer)
这是本项目最核心的类。
python
class TimeSeriesTransformer(nn.Module):
def __init__(self, input_dim, d_model, nhead, num_layers, dim_feedforward, dropout):
super().__init__()
# 1. 输入投影层:将原始特征维度映射到模型内部维度 d_model
self.input_projection = nn.Linear(input_dim, d_model)
# 2. 位置编码层
self.pos_encoder = PositionalEncoding(d_model, dropout)
# 3. Transformer编码器层(核心)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead, # 多头注意力的头数
dim_feedforward=dim_feedforward, # 前馈网络中间层维度
dropout=dropout,
activation='relu',
batch_first=False # PyTorch Transformer 默认 (seq, batch, feature)
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# 4. 输出层:将最终的上下文表示映射为预测值(一个标量)
self.fc_out = nn.Linear(d_model, 1)
def forward(self, x):
# x: (batch_size, seq_len, input_dim)
x = x.permute(1, 0, 2) # 调整为 (seq_len, batch_size, input_dim)
x = self.input_projection(x) # -> (seq_len, batch_size, d_model)
x = self.pos_encoder(x) # 注入位置信息
x = self.transformer_encoder(x) # -> (seq_len, batch_size, d_model)
# 取最后一个时间步的输出作为预测依据
x = x[-1, :, :] # -> (batch_size, d_model)
out = self.fc_out(x) # -> (batch_size, 1)
return out
模型数据流与核心思想图解:
#mermaid-svg-lyGiA9TzSh2VD26V{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-lyGiA9TzSh2VD26V .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-lyGiA9TzSh2VD26V .error-icon{fill:#552222;}#mermaid-svg-lyGiA9TzSh2VD26V .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-lyGiA9TzSh2VD26V .marker{fill:#333333;stroke:#333333;}#mermaid-svg-lyGiA9TzSh2VD26V .marker.cross{stroke:#333333;}#mermaid-svg-lyGiA9TzSh2VD26V svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-lyGiA9TzSh2VD26V p{margin:0;}#mermaid-svg-lyGiA9TzSh2VD26V .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-lyGiA9TzSh2VD26V .cluster-label text{fill:#333;}#mermaid-svg-lyGiA9TzSh2VD26V .cluster-label span{color:#333;}#mermaid-svg-lyGiA9TzSh2VD26V .cluster-label span p{background-color:transparent;}#mermaid-svg-lyGiA9TzSh2VD26V .label text,#mermaid-svg-lyGiA9TzSh2VD26V span{fill:#333;color:#333;}#mermaid-svg-lyGiA9TzSh2VD26V .node rect,#mermaid-svg-lyGiA9TzSh2VD26V .node circle,#mermaid-svg-lyGiA9TzSh2VD26V .node ellipse,#mermaid-svg-lyGiA9TzSh2VD26V .node polygon,#mermaid-svg-lyGiA9TzSh2VD26V .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-lyGiA9TzSh2VD26V .rough-node .label text,#mermaid-svg-lyGiA9TzSh2VD26V .node .label text,#mermaid-svg-lyGiA9TzSh2VD26V .image-shape .label,#mermaid-svg-lyGiA9TzSh2VD26V .icon-shape .label{text-anchor:middle;}#mermaid-svg-lyGiA9TzSh2VD26V .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-lyGiA9TzSh2VD26V .rough-node .label,#mermaid-svg-lyGiA9TzSh2VD26V .node .label,#mermaid-svg-lyGiA9TzSh2VD26V .image-shape .label,#mermaid-svg-lyGiA9TzSh2VD26V .icon-shape .label{text-align:center;}#mermaid-svg-lyGiA9TzSh2VD26V .node.clickable{cursor:pointer;}#mermaid-svg-lyGiA9TzSh2VD26V .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-lyGiA9TzSh2VD26V .arrowheadPath{fill:#333333;}#mermaid-svg-lyGiA9TzSh2VD26V .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-lyGiA9TzSh2VD26V .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-lyGiA9TzSh2VD26V .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-lyGiA9TzSh2VD26V .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-lyGiA9TzSh2VD26V .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-lyGiA9TzSh2VD26V .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-lyGiA9TzSh2VD26V .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-lyGiA9TzSh2VD26V .cluster text{fill:#333;}#mermaid-svg-lyGiA9TzSh2VD26V .cluster span{color:#333;}#mermaid-svg-lyGiA9TzSh2VD26V div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-lyGiA9TzSh2VD26V .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-lyGiA9TzSh2VD26V rect.text{fill:none;stroke-width:0;}#mermaid-svg-lyGiA9TzSh2VD26V .icon-shape,#mermaid-svg-lyGiA9TzSh2VD26V .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-lyGiA9TzSh2VD26V .icon-shape p,#mermaid-svg-lyGiA9TzSh2VD26V .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-lyGiA9TzSh2VD26V .icon-shape .label rect,#mermaid-svg-lyGiA9TzSh2VD26V .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-lyGiA9TzSh2VD26V .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-lyGiA9TzSh2VD26V .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-lyGiA9TzSh2VD26V :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 编码器内部
多头自注意力
加残差 & 层归一化
前馈网络
加残差 & 层归一化
输入: (批次, 30, 5)
维度变换与投影
(30, 批次, 64)
- 位置编码
取最后一个时间步
(批次, 64)
全连接层
输出预测值
(批次, 1)
关键点解析:
input_projection:原始特征(如5维)通常与模型内部维度 (d_model) 不匹配。这个线性层将其映射到统一的d_model维空间,方便后续计算。batch_first=False:PyTorch 的 Transformer 模块默认期望输入形状为(序列长度, 批次大小, 特征维度)。所以我们用permute调整了维度。TransformerEncoder:这是 PyTorch 封装好的模块,内部包含了多头自注意力层和前馈神经网络层,以及残差连接和层归一化。我们只需要堆叠num_layers层即可。- 取最后一个时间步 :在时间序列预测中,我们通常用过去
LOOK_BACK天的信息来预测下一天。经过 Transformer 编码后,序列中每个位置都包含了全局上下文信息。我们取最后一个位置(即最近一天)的表示来进行预测,这是一种常见且有效的做法。
模块四:训练与评估循环 (train_model, evaluate_model)
训练过程是标准的 PyTorch 流程:前向传播、计算损失(均方误差 MSE)、反向传播、优化器步进。我们使用了 Adam 优化器和学习率调度器 (StepLR)。
一个重要的技巧:梯度裁剪 (clip_grad_norm_)
python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Transformer 模型有时会遇到梯度爆炸问题。梯度裁剪将梯度的范数限制在一个阈值内,可以显著提高训练稳定性。
4. 运行结果与模型对比
运行完整的代码后,你会得到两张图:
- 预测结果对比图:显示训练集真实值、测试集真实值以及模型在测试集上的预测值。可以直观看到模型是否捕捉到了价格趋势。
- 训练损失曲线:显示训练集和测试集的损失随训练轮次(Epoch)下降的过程,用于判断模型是否过拟合或欠拟合。
Transformer 与传统序列模型对比表:
| 特性 | RNN | GRU | LSTM | Transformer |
|---|---|---|---|---|
| 核心机制 | 简单循环 | 门控循环单元 | 长短时记忆网络 | 自注意力 |
| 并行计算 | ❌ (顺序) | ❌ (顺序) | ❌ (顺序) | ✅ (全局) |
| 长距离依赖 | 弱 | 中 | 强 | 强 |
| 位置感知 | 内置顺序 | 内置顺序 | 内置顺序 | 需位置编码 |
| 训练速度 | 慢 | 中 | 慢 | 快 |
| 诞生年份 | 1986 | 2014 | 1997 | 2017 |
Transformer 的优势总结:
- 并行化:极大加速训练。
- 全局视野:每个输出位置都能直接关注到输入序列的所有位置,擅长捕捉长距离依赖。
- 可解释性:注意力权重可以可视化,看到模型在关注什么。
在本项目中的局限:
- 这是一个简化版,未使用解码器(Decoder),适合单变量预测。
- 股价预测是极其复杂的任务,受众多因素影响。本例主要用于演示 Transformer 在时序数据上的应用,切勿直接用于实际投资。
5. 总结与拓展
通过这个项目,我们不仅理解了 Transformer 的核心思想------自注意力机制,还亲手实现了一个用于时间序列预测的 Transformer 模型。你可以尝试:
- 调整超参数 :如
LOOK_BACK(历史窗口)、d_model(模型维度)、nhead(注意力头数)、num_layers(编码器层数),观察模型性能变化。 - 添加更多特征:除了价格和成交量,还可以加入技术指标(如均线、RSI、MACD)。
- 尝试更复杂的架构:例如 Transformer + 卷积层,或引入时间特征编码。
- 更换任务:将代码稍作修改,可用于其他时序预测任务,如天气预测、电力负荷预测、销量预测等。
Transformer 的设计是优雅而强大的,它证明了"注意力"确实可以成为你所需要的全部。希望这篇结合了理论、生活比喻和实战代码的文章,能帮助你顺利踏入 Transformer 的大门。
附录:完整可运行代码
以下代码整合了上述所有模块,复制到 PyTorch 环境中即可运行。请确保已安装所需库:pip install akshare
python
# -*- coding: utf-8 -*-
"""
================================================================================
Transformer 初学者教程 ------ 使用自注意力机制预测股票价格
================================================================================
数据集:通过 akshare(国内金融数据接口)自动下载「上证指数」历史数据
框架: PyTorch
模型: TransformerEncoder(自注意力 + 前馈网络,无循环/无卷积)
任务: 用前 N 天的多项特征(开高低收量),预测第 N+1 天的收盘价
═══════════════════════════════════════════════════════════════════
Transformer 核心思想(Vaswani et al., 2017 "Attention Is All You Need")
核心公式 ------ 缩放点积注意力(Scaled Dot-Product Attention):
Q · K^T
Attention(Q, K, V) = softmax( ─────── ) · V
√d_k
通俗理解:
Q (Query): "我想查什么?" ------ 当前位置想关注什么
K (Key): "我有什么标签?" ------ 每个位置的特征标识
V (Value): "我的内容是什么?" ------ 每个位置的实际信息
softmax(QK^T/√d_k): 相似度矩阵,表示"当前位置应该关注每个历史位置多少"
多头注意力(Multi-Head Attention):
把 Q/K/V 拆成 h 个"头",每个头独立做注意力,最后拼接。
→ 不同头可以关注不同的模式(短期趋势、长期趋势、周期等)
Transformer vs RNN 的关键区别:
┌──────────────┬─────────────────────┬─────────────────────┐
│ 特性 │ RNN 家族 │ Transformer │
├──────────────┼─────────────────────┼─────────────────────┤
│ 处理方式 │ 逐步递归(串行) │ 并行注意力(并行) │
│ 长距离依赖 │ 逐步传递(可能丢失) │ 直接关注任意位置 │
│ 训练速度 │ 慢(不能并行) │ 快(可并行) │
│ 位置信息 │ 天然内置 │ 需要位置编码注入 │
│ 核心机制 │ 循环/门控 │ 自注意力 │
└──────────────┴─────────────────────┴─────────────────────┘
运行方式:
pip install torch numpy pandas scikit-learn matplotlib akshare
python rnn_transformer_demo.py
================================================================================
"""
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler
# ============================================================================
# 0. 全局配置
# ============================================================================
LOOK_BACK = 30 # Transformer 擅长处理长序列,可以设大一些
BATCH_SIZE = 32
D_MODEL = 64 # 词向量/嵌入维度(整个 Transformer 的特征维度)
NHEAD = 4 # 多头注意力的"头"数(d_model 必须能被 nhead 整除)
NUM_LAYERS = 2 # Transformer Encoder 层数
DIM_FEEDFORWARD = 128 # 前馈网络中间层维度(通常是 d_model × 2~4)
DROPOUT = 0.1 # Dropout 比例
LEARNING_RATE = 0.001
EPOCHS = 80
TRAIN_RATIO = 0.8
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 使用多项特征(Transformer 擅长捕捉特征间关系)
FEATURE_COLS = ["open", "high", "low", "close", "volume"]
NUM_FEATURES = len(FEATURE_COLS) # 5
print(f"🔧 使用设备: {DEVICE}")
print(f"📦 PyTorch 版本: {torch.__version__}")
# ============================================================================
# 1. 下载数据 ------ akshare(国内源)
# ============================================================================
def download_data():
print("\n📥 正在下载上证指数历史数据...")
try:
import akshare as ak
df = ak.stock_zh_index_daily(symbol="sh000001")
print(f" ✅ 下载成功!共 {len(df)} 条记录")
print(f" 📅 数据范围: {df['date'].iloc[0]} ~ {df['date'].iloc[-1]}")
return df
except ImportError:
print(" ❌ akshare 未安装,请运行: pip install akshare")
raise
except Exception as e:
print(f" ⚠️ 主接口失败 ({e}),尝试备用接口...")
try:
import akshare as ak
df = ak.index_zh_a_hist(symbol="000001", period="daily")
print(f" ✅ 备用接口下载成功!共 {len(df)} 条记录")
return df
except Exception as e2:
print(f" ❌ 所有接口均失败: {e2}")
raise
# ============================================================================
# 2. 数据预处理
# ============================================================================
def preprocess_data(df):
print("\n🔧 数据预处理中...")
# --- 2.1 列名映射 ---
col_mapping = {}
for target in FEATURE_COLS:
for col in df.columns:
if col.lower() == target or target in col.lower():
col_mapping[target] = col
break
used_keys = [k for k in FEATURE_COLS if k in col_mapping]
data = df[[col_mapping[k] for k in used_keys]].values.astype(np.float32)
data = np.nan_to_num(data, nan=0.0)
num_features = data.shape[1]
print(f" 📊 多特征数据形状: {data.shape}")
# --- 2.2 每列独立归一化 ---
scalers = []
data_scaled = np.zeros_like(data)
for c in range(num_features):
scaler = MinMaxScaler(feature_range=(0, 1))
data_scaled[:, c] = scaler.fit_transform(data[:, c].reshape(-1, 1)).ravel()
scalers.append(scaler)
close_idx = used_keys.index("close") if "close" in used_keys else 3
close_scaler = scalers[close_idx]
# --- 2.3 构造时序样本 ---
X, y = [], []
for i in range(LOOK_BACK, len(data_scaled)):
X.append(data_scaled[i - LOOK_BACK:i, :])
y.append(data_scaled[i, close_idx])
X = np.array(X) # (样本数, LOOK_BACK, 特征数)
y = np.array(y) # (样本数,)
print(f" ✅ X 形状: {X.shape}(样本={X.shape[0]}, 时间步={X.shape[1]}, 特征={X.shape[2]})")
print(f" ✅ y 形状: {y.shape}")
print(f" 📋 使用特征: {used_keys}")
return X, y, close_scaler
# ============================================================================
# 3. 划分训练集 / 测试集
# ============================================================================
def split_data(X, y):
split_idx = int(len(X) * TRAIN_RATIO)
X_train, X_test = X[:split_idx], X[split_idx:]
y_train, y_test = y[:split_idx], y[split_idx:]
print(f"\n📂 训练集: {len(X_train)} 条 | 测试集: {len(X_test)} 条")
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
X_test_t = torch.tensor(X_test, dtype=torch.float32)
y_test_t = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)
train_loader = DataLoader(TensorDataset(X_train_t, y_train_t),
batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(TensorDataset(X_test_t, y_test_t),
batch_size=BATCH_SIZE, shuffle=False)
return train_loader, test_loader, X_train_t, y_train_t, X_test_t, y_test_t
# ============================================================================
# 4. ★ 位置编码(Positional Encoding)
# ============================================================================
class PositionalEncoding(nn.Module):
"""
Transformer 没有循环结构,不知道位置的先后顺序。
位置编码就是给每个位置注入一个唯一的"位置信号"。
使用正弦/余弦函数的原因:
- 不同位置有不同编码(唯一性)
- 相邻位置编码相似(平滑性)
- 可以外推到训练时未见过的更长序列
公式:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
↑
偶数列用 sin,奇数列用 cos
"""
def __init__(self, d_model, max_len=5000, dropout=0.1):
"""
参数:
d_model: 嵌入维度(和 Transformer 的 d_model 一致)
max_len: 最大序列长度(预计算这么多位置)
dropout: 加在编码后的 dropout
"""
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# ------ 预计算位置编码矩阵 (max_len, d_model) ------
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
# div_term: 10000^(2i/d_model) 的分母
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数列 sin
pe[:, 1::2] = torch.cos(position * div_term) # 奇数列 cos
pe = pe.unsqueeze(1) # (max_len, 1, d_model) 适配 batch 维度
# 注册为 buffer(不参与训练的参数,但随模型保存/加载)
self.register_buffer("pe", pe)
def forward(self, x):
"""
x: (seq_len, batch_size, d_model)
返回: x + 位置编码 (seq_len, batch_size, d_model)
"""
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
# ============================================================================
# 5. ★ 定义 Transformer 模型
# ============================================================================
class TimeSeriesTransformer(nn.Module):
"""
基于 TransformerEncoder 的时间序列预测模型。
结构(数据流):
Input (batch, seq_len, features)
│
▼
线性投影 → (batch, seq_len, d_model) ← 将原始特征映射到 d_model 维
│
▼
+ 位置编码 ← 注入位置信息
│
▼
TransformerEncoder × N 层 ← 自注意力 + 前馈网络
│ ┌─────────────────────────┐
│ │ Multi-Head Attention │ ← "我应该关注序列中哪些位置?"
│ │ + Add & Norm │
│ │ Feed-Forward Network │ ← "提取更复杂的特征"
│ │ + Add & Norm │
│ └─────────────────────────┘
│ × NUM_LAYERS 层
▼
取最后一个位置的输出 → (batch, d_model)
│
▼
FC → (batch, 1) ← 输出预测的收盘价
"""
def __init__(self, input_dim=NUM_FEATURES, d_model=D_MODEL,
nhead=NHEAD, num_layers=NUM_LAYERS,
dim_feedforward=DIM_FEEDFORWARD, dropout=DROPOUT,
output_size=1):
super(TimeSeriesTransformer, self).__init__()
self.d_model = d_model
self.input_dim = input_dim
# ① 输入投影:将原始特征数映射到 d_model 维度
self.input_projection = nn.Linear(input_dim, d_model)
# ② 位置编码
self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
# ③ ★ 核心:Transformer Encoder 层
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, # 特征维度
nhead=nhead, # 注意力头数
dim_feedforward=dim_feedforward, # FFN 中间层维度
dropout=dropout,
activation="relu", # FFN 激活函数
batch_first=False, # 输入格式 (seq, batch, feature)
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers, # 堆叠层数
)
# ④ 输出层
self.fc_out = nn.Linear(d_model, output_size)
def forward(self, x):
"""
参数:
x: (batch, seq_len, input_dim) 例如 (32, 30, 5)
返回:
out: (batch, output_size) 预测的收盘价
"""
# ── 调整形状:Transformer 期望 (seq_len, batch, features) ──
x = x.permute(1, 0, 2) # (look_back, batch, num_features)
# ── 投影到 d_model 维度 ──
x = self.input_projection(x) # (seq_len, batch, d_model)
# ── 注入位置编码 ──
x = self.pos_encoder(x) # (seq_len, batch, d_model)
# ── ★ Transformer Encoder(自注意力核心)──
# 这里每个位置会计算与所有其他位置的注意力权重
# 输出的 x 每个位置已经融合了全局上下文信息
x = self.transformer_encoder(x) # (seq_len, batch, d_model)
# ── 取最后一个时间步的输出做预测 ──
# (也可以取平均池化,但取最后位置是最常见的做法)
x = x[-1, :, :] # (batch, d_model)
# ── 全连接输出 ──
out = self.fc_out(x) # (batch, 1)
return out
# ============================================================================
# 6. 训练
# ============================================================================
def train_model(model, train_loader, criterion, optimizer):
model.train()
total_loss = 0
for X_batch, y_batch in train_loader:
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
predictions = model(X_batch)
loss = criterion(predictions, y_batch)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
# ============================================================================
# 7. 评估
# ============================================================================
def evaluate_model(model, test_loader, criterion):
model.eval()
total_loss = 0
with torch.no_grad():
for X_batch, y_batch in test_loader:
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
predictions = model(X_batch)
loss = criterion(predictions, y_batch)
total_loss += loss.item()
return total_loss / len(test_loader)
# ============================================================================
# 8. 可视化
# ============================================================================
def plot_results(y_real, y_pred, scaler, train_size):
y_real_inv = scaler.inverse_transform(y_real.reshape(-1, 1))
y_pred_inv = scaler.inverse_transform(y_pred.reshape(-1, 1))
plt.figure(figsize=(14, 6))
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
x_axis = range(len(y_real_inv))
plt.plot(x_axis[:train_size], y_real_inv[:train_size],
color="blue", alpha=0.4, linewidth=1, label="训练集(真实值)")
plt.plot(x_axis[train_size:], y_real_inv[train_size:],
color="green", alpha=0.7, linewidth=1.5, label="测试集(真实值)")
plt.plot(x_axis[train_size:], y_pred_inv[train_size:],
color="red", alpha=0.8, linewidth=1.5, label="测试集(Transformer 预测值)")
plt.axvline(x=train_size, color="gray", linestyle="--", alpha=0.7)
plt.xlabel("时间(天)", fontsize=12)
plt.ylabel("收盘价", fontsize=12)
plt.title(f"Transformer 股票价格预测 --- 上证指数\n"
f"(Look Back={LOOK_BACK}, d_model={D_MODEL}, heads={NHEAD}, layers={NUM_LAYERS})",
fontsize=14)
plt.legend(loc="upper left")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("transformer_prediction_result.png", dpi=150)
plt.show()
print("\n📊 图表已保存为 transformer_prediction_result.png")
def plot_training_curve(train_losses, test_losses):
plt.figure(figsize=(10, 5))
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
plt.plot(train_losses, label="训练损失", color="blue", linewidth=1.5)
plt.plot(test_losses, label="测试损失", color="orange", linewidth=1.5)
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Loss (MSE)", fontsize=12)
plt.title("Transformer 训练过程 --- Loss 曲线", fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("transformer_loss_curve.png", dpi=150)
plt.show()
print("📊 损失曲线已保存为 transformer_loss_curve.png")
# ============================================================================
# 9. Transformer 原理讲解
# ============================================================================
def print_transformer_explanation():
print("\n" + "=" * 60)
print(" 📖 Transformer 核心原理 ------ Attention Is All You Need")
print("=" * 60)
print(f"""
Transformer 用"自注意力"取代了 RNN 的"循环"。
一句话理解:每个位置直接看序列中所有位置,而不是一步步传递。
┌─────────────────────────────────────────────────────────┐
│ Scaled Dot-Product Attention │
│ │
│ Q (Query) ← 当前位置的线性投影 │
│ K (Key) ← 所有位置的线性投影(用于匹配) │
│ V (Value) ← 所有位置的线性投影(用于聚合) │
│ │
│ Q·K^T │
│ Attention = softmax( ──────── ) · V │
│ √d_k │
│ ↑ │
│ 除以 √d_k 防止点积过大导致 softmax 梯度消失 │
└─────────────────────────────────────────────────────────┘
★ 自注意力的优势:
1. 并行计算:不依赖上一时间步,整条序列同时算
2. 长距离依赖:不管位置相隔多远,直接建立联系
3. 可解释性:注意力权重可以可视化"模型在关注哪里"
★ 为什么还需要位置编码?
Transformer 本身是"位置无关"的------打乱输入顺序,输出也会对应打乱。
位置编码给每个位置注入唯一信号,让模型知道"谁在前谁在后"。
""")
print("=" * 60)
# ============================================================================
# 10. 主函数
# ============================================================================
def main():
print("=" * 60)
print(" 🧠 Transformer 初学者教程 --- Attention Is All You Need")
print("=" * 60)
print_transformer_explanation()
# Step 1: 下载数据
df = download_data()
# Step 2: 预处理
X, y, close_scaler = preprocess_data(df)
# Step 3: 划分
train_loader, test_loader, X_train, y_train, X_test, y_test = split_data(X, y)
# Step 4: 创建模型
model = TimeSeriesTransformer(
input_dim=NUM_FEATURES,
d_model=D_MODEL,
nhead=NHEAD,
num_layers=NUM_LAYERS,
dim_feedforward=DIM_FEEDFORWARD,
dropout=DROPOUT,
).to(DEVICE)
print(f"\n🏗️ 模型结构:\n{model}")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"📐 总参数量: {total_params:,} | 可训练: {trainable_params:,}")
print(f"📐 d_model={D_MODEL}, heads={NHEAD}, layers={NUM_LAYERS}, ff_dim={DIM_FEEDFORWARD}")
# Step 5 & 6: 训练 + 评估
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# Transformer 训练通常使用学习率预热(warmup),这里简化处理
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
train_losses, test_losses = [], []
print(f"\n🚀 开始训练(共 {EPOCHS} 轮)...")
print(f" 💡 Transformer 比 RNN 并行度高,训练速度通常更快")
for epoch in range(1, EPOCHS + 1):
train_loss = train_model(model, train_loader, criterion, optimizer)
test_loss = evaluate_model(model, test_loader, criterion)
scheduler.step()
train_losses.append(train_loss)
test_losses.append(test_loss)
if epoch % 10 == 0 or epoch == 1:
lr_now = scheduler.get_last_lr()[0]
print(f" Epoch [{epoch:3d}/{EPOCHS}] "
f"Train Loss: {train_loss:.6f} | Test Loss: {test_loss:.6f} "
f"| LR: {lr_now:.6f}")
print(f"\n✅ 训练完成!最终 Test Loss: {test_losses[-1]:.6f}")
# Step 7: 可视化
model.eval()
with torch.no_grad():
y_pred_train = model(X_train.to(DEVICE)).cpu().numpy()
y_pred_test = model(X_test.to(DEVICE)).cpu().numpy()
y_all_real = np.concatenate([y_train.numpy(), y_test.numpy()])
y_all_pred = np.concatenate([y_pred_train, y_pred_test])
plot_results(y_all_real, y_all_pred, close_scaler, len(y_train))
plot_training_curve(train_losses, test_losses)
# 小结
print("\n" + "=" * 60)
print("🎉 程序运行完毕!")
print(f"\n📝 全系列模型对比总结:")
print(f" ┌──────────────┬──────────┬──────────┬──────────┬──────────────┬──────────────┐")
print(f" │ 特性 │ RNN │ GRU │ LSTM │ ConvLSTM │ Transformer │")
print(f" ├──────────────┼──────────┼──────────┼──────────┼──────────────┼──────────────┤")
print(f" │ 核心机制 │ 循环 │ 门控 │ 三门控 │ 卷积+门控 │ 自注意力 │")
print(f" │ 并行计算 │ ❌ │ ❌ │ ❌ │ ❌ │ ✅ │")
print(f" │ 长距离依赖 │ 弱 │ 中 │ 强 │ 中 │ 强 │")
print(f" │ 位置感知 │ 内置 │ 内置 │ 内置 │ 内置 │ 需编码 │")
print(f" │ 训练速度 │ 慢 │ 中 │ 慢 │ 慢 │ 快 │")
print(f" │ 提出年份 │ 1986 │ 2014 │ 1997 │ 2015 │ 2017 │")
print(f" └──────────────┴──────────┴──────────┴──────────┴──────────────┴──────────────┘")
print(f"\n🔧 可调超参数:")
print(f" LOOK_BACK={LOOK_BACK} d_model={D_MODEL} nhead={NHEAD} "
f"num_layers={NUM_LAYERS} ff_dim={DIM_FEEDFORWARD} lr={LEARNING_RATE}")
print("=" * 60)
if __name__ == "__main__":
main()