Transformer 深度学习详解:从“注意力”到股票预测实战(代码拿来就能用)

1. 引言:为什么是 Transformer?

想象一下,你正在一个嘈杂的咖啡馆里和朋友聊天。虽然周围有很多声音,但你的大脑能神奇地"聚焦"在朋友的话语上,忽略背景噪音。这种"选择性关注"的能力,就是 注意力机制(Attention Mechanism) 的核心思想。

2017年,Google的研究团队在论文《Attention Is All You Need》中提出了 Transformer 模型。它彻底抛弃了传统的循环神经网络(RNN),仅依靠注意力机制来处理序列数据(如文本、时间序列)。这一创新不仅让模型训练速度大幅提升,更在机器翻译、文本生成、语音识别等领域取得了突破性进展,成为当今深度学习领域的基石。

本文将带你深入浅出地理解 Transformer,并用一个股票价格预测的实战项目,手把手教你用 PyTorch 实现一个简化版的 Transformer 模型。

2. Transformer 核心思想:用"注意力"取代"循环"

2.1 传统 RNN 的困境

在 Transformer 出现之前,处理序列数据(如一句话、一段股价历史)的主流模型是循环神经网络(RNN)。RNN 像一条"传送带",信息必须按顺序一步步传递。这导致两个问题:

  1. 计算慢:无法并行处理,训练耗时。
  2. 长距离依赖弱:信息在长序列中传递时容易丢失或变形(梯度消失/爆炸)。

2.2 Transformer 的解决方案:全局注意力

Transformer 的答案很简单:为什么不一次性看完整个序列,再决定每个位置应该关注哪里?

这就像你读一篇文章时,不是逐字逐句线性阅读,而是先快速扫一眼全文,找到关键段落和它们之间的联系,再深入理解。Transformer 通过 自注意力(Self-Attention) 机制实现了这一点。

核心公式(缩放点积注意力):

复制代码
Attention(Q, K, V) = softmax( (Q · K^T) / √d_k ) · V

这个公式看似复杂,但我们可以用生活中的例子来理解它的三个核心角色:

  • Q (Query - 查询) :好比 "我想知道什么?"。例如,在预测明天股价时,模型会问:"基于过去30天的数据,哪些天的情况对预测明天最重要?"
  • K (Key - 键) :好比 "我有什么标签?"。序列中的每一天(一个位置)都自带一些特征标签(如开盘价、成交量),K 就是这些标签的表示。
  • V (Value - 值) :好比 "我的实际内容是什么?"。这是每个位置所包含的真实信息。

计算过程通俗解释:

  1. 匹配(Q·K^T):计算"当前疑问"与"每个历史位置的标签"的匹配程度(相似度)。
  2. 缩放与归一化(除以√d_k, softmax):对匹配分数进行缩放(防止数值过大)并转化为权重(所有权重和为1)。这步得到了一个"注意力分布",表示对于当前预测,应该给历史每一天分配多少注意力。
  3. 聚合(·V):用上一步得到的权重,对每个位置的"实际内容(V)"进行加权求和,得到最终的上下文表示。

多头注意力(Multi-Head Attention):为了让模型同时关注不同方面的信息(例如,既关注短期波动趋势,也关注长期周期),Transformer 将 Q、K、V 投影到多个不同的"子空间"(即多个头),并行计算注意力,最后把结果拼接起来。这就像让多个专家从不同角度分析同一份数据。

2.3 位置编码:给无位置的模型注入"顺序感"

注意力机制是"位置无关"的,打乱输入顺序,输出权重关系不变。但序列的顺序显然很重要("猫追老鼠"和"老鼠追猫"意思相反)。为此,Transformer 引入了 位置编码(Positional Encoding)

它使用一组固定的正弦和余弦函数,为序列中的每个位置生成一个独一无二的"位置向量",然后把这个向量加到输入数据中。这样,模型在计算注意力时,就能间接感知到位置信息。

位置编码的数学原理可视化:
#mermaid-svg-vLCdqZkzC50peaVF{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-vLCdqZkzC50peaVF .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-vLCdqZkzC50peaVF .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-vLCdqZkzC50peaVF .error-icon{fill:#552222;}#mermaid-svg-vLCdqZkzC50peaVF .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-vLCdqZkzC50peaVF .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-vLCdqZkzC50peaVF .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-vLCdqZkzC50peaVF .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-vLCdqZkzC50peaVF .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-vLCdqZkzC50peaVF .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-vLCdqZkzC50peaVF .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-vLCdqZkzC50peaVF .marker{fill:#333333;stroke:#333333;}#mermaid-svg-vLCdqZkzC50peaVF .marker.cross{stroke:#333333;}#mermaid-svg-vLCdqZkzC50peaVF svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-vLCdqZkzC50peaVF p{margin:0;}#mermaid-svg-vLCdqZkzC50peaVF .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-vLCdqZkzC50peaVF .cluster-label text{fill:#333;}#mermaid-svg-vLCdqZkzC50peaVF .cluster-label span{color:#333;}#mermaid-svg-vLCdqZkzC50peaVF .cluster-label span p{background-color:transparent;}#mermaid-svg-vLCdqZkzC50peaVF .label text,#mermaid-svg-vLCdqZkzC50peaVF span{fill:#333;color:#333;}#mermaid-svg-vLCdqZkzC50peaVF .node rect,#mermaid-svg-vLCdqZkzC50peaVF .node circle,#mermaid-svg-vLCdqZkzC50peaVF .node ellipse,#mermaid-svg-vLCdqZkzC50peaVF .node polygon,#mermaid-svg-vLCdqZkzC50peaVF .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-vLCdqZkzC50peaVF .rough-node .label text,#mermaid-svg-vLCdqZkzC50peaVF .node .label text,#mermaid-svg-vLCdqZkzC50peaVF .image-shape .label,#mermaid-svg-vLCdqZkzC50peaVF .icon-shape .label{text-anchor:middle;}#mermaid-svg-vLCdqZkzC50peaVF .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-vLCdqZkzC50peaVF .rough-node .label,#mermaid-svg-vLCdqZkzC50peaVF .node .label,#mermaid-svg-vLCdqZkzC50peaVF .image-shape .label,#mermaid-svg-vLCdqZkzC50peaVF .icon-shape .label{text-align:center;}#mermaid-svg-vLCdqZkzC50peaVF .node.clickable{cursor:pointer;}#mermaid-svg-vLCdqZkzC50peaVF .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-vLCdqZkzC50peaVF .arrowheadPath{fill:#333333;}#mermaid-svg-vLCdqZkzC50peaVF .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-vLCdqZkzC50peaVF .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-vLCdqZkzC50peaVF .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-vLCdqZkzC50peaVF .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-vLCdqZkzC50peaVF .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-vLCdqZkzC50peaVF .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-vLCdqZkzC50peaVF .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-vLCdqZkzC50peaVF .cluster text{fill:#333;}#mermaid-svg-vLCdqZkzC50peaVF .cluster span{color:#333;}#mermaid-svg-vLCdqZkzC50peaVF div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-vLCdqZkzC50peaVF .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-vLCdqZkzC50peaVF rect.text{fill:none;stroke-width:0;}#mermaid-svg-vLCdqZkzC50peaVF .icon-shape,#mermaid-svg-vLCdqZkzC50peaVF .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-vLCdqZkzC50peaVF .icon-shape p,#mermaid-svg-vLCdqZkzC50peaVF .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-vLCdqZkzC50peaVF .icon-shape .label rect,#mermaid-svg-vLCdqZkzC50peaVF .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-vLCdqZkzC50peaVF .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-vLCdqZkzC50peaVF .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-vLCdqZkzC50peaVF :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 位置编码生成过程
示例:不同位置的编码
相似
相似
差异大
位置 1
编码向量 1
位置 2
编码向量 2
位置 3
编码向量 3
位置 100
编码向量 100
位置索引 pos
位置向量

pos, pos, ..., pos

维度索引 i
频率计算

10000^(2i/d_model)
正弦函数 sin()
余弦函数 cos()
除法 pos / freq
pos / freq
pos / freq
偶数维度 PE(pos, 2i)
奇数维度 PE(pos, 2i+1)
组合成位置编码向量
加到输入嵌入中
输入嵌入
带位置信息的输入

3. 项目实战:用 Transformer 预测股票价格

接下来,我们将把理论付诸实践,构建一个用于时间序列预测的 Transformer 模型,并以上证指数收盘价预测为例。

3.1 项目概述与数据准备

目标:使用过去 N 天(例如30天)的股票多项特征(开盘价、最高价、最低价、收盘价、成交量),来预测第 N+1 天的收盘价。

数据来源 :我们使用 akshare 库获取国内股票数据(以上证指数 sh000001 为例)。这是一个免费、强大的国内金融数据接口库。

核心步骤

  1. 数据下载与预处理:下载历史数据,提取特征列,并进行归一化处理。
  2. 构建时间窗口 :将连续的时序数据转化为 (样本数, 时间步长, 特征数) 的格式。
  3. 定义模型 :构建基于 TransformerEncoder 的预测模型。
  4. 训练与评估:划分训练集和测试集,训练模型并评估其预测效果。
  5. 可视化结果:绘制预测曲线与损失曲线。

3.2 代码拆解与讲解

下面,我们将结合提供的代码,分模块讲解关键实现。

模块一:全局配置与数据预处理 (LOOK_BACK, preprocess_data)
python 复制代码
LOOK_BACK = 30  # 使用过去30天的数据来预测下一天
FEATURE_COLS = ["open", "high", "low", "close", "volume"] # 使用的特征

def preprocess_data(df):
    # 1. 列名映射与数据清洗
    # 2. 特征归一化 (使用 MinMaxScaler,按列独立归一化)
    # 3. 构建时间窗口样本 (X, y)
    # X.shape: (样本数, LOOK_BACK, 特征数)
    # y.shape: (样本数, ), 对应第 N+1 天的收盘价
  • LOOK_BACK :这是 Transformer 的"记忆长度"。模型在预测时,只能看到过去这 LOOK_BACK 步的信息。这个值需要根据数据特性调整。
  • 归一化:不同特征(如价格和成交量)数值尺度差异巨大,归一化到 0,1 区间有助于模型稳定训练。
  • 构建样本:这是时间序列预测的标准操作。例如,用第1-30天数据预测第31天,用第2-31天数据预测第32天,以此类推。
模块二:模型核心------位置编码 (PositionalEncoding)
python 复制代码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度用sin
        pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度用cos
        self.register_buffer('pe', pe) # 注册为不参与训练的缓冲区

    def forward(self, x):
        x = x + self.pe[:x.size(0), :] # 将位置编码加到输入上
        return self.dropout(x)
  • 正弦余弦函数:使用不同频率的正弦和余弦函数,可以生成唯一且平滑的位置编码,并能外推到比训练时更长的序列。
  • register_buffer:将位置编码矩阵注册为"缓冲区",这意味着它是模型的一部分,会随模型保存和加载,但不参与梯度更新(因为位置编码是固定的先验知识)。
模块三:时间序列 Transformer 模型 (TimeSeriesTransformer)

这是本项目最核心的类。

python 复制代码
class TimeSeriesTransformer(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_layers, dim_feedforward, dropout):
        super().__init__()
        # 1. 输入投影层:将原始特征维度映射到模型内部维度 d_model
        self.input_projection = nn.Linear(input_dim, d_model)
        # 2. 位置编码层
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        # 3. Transformer编码器层(核心)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead, # 多头注意力的头数
            dim_feedforward=dim_feedforward, # 前馈网络中间层维度
            dropout=dropout,
            activation='relu',
            batch_first=False # PyTorch Transformer 默认 (seq, batch, feature)
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        # 4. 输出层:将最终的上下文表示映射为预测值(一个标量)
        self.fc_out = nn.Linear(d_model, 1)

    def forward(self, x):
        # x: (batch_size, seq_len, input_dim)
        x = x.permute(1, 0, 2) # 调整为 (seq_len, batch_size, input_dim)
        x = self.input_projection(x) # -> (seq_len, batch_size, d_model)
        x = self.pos_encoder(x)      # 注入位置信息
        x = self.transformer_encoder(x) # -> (seq_len, batch_size, d_model)
        # 取最后一个时间步的输出作为预测依据
        x = x[-1, :, :] # -> (batch_size, d_model)
        out = self.fc_out(x) # -> (batch_size, 1)
        return out

模型数据流与核心思想图解:
#mermaid-svg-lyGiA9TzSh2VD26V{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-lyGiA9TzSh2VD26V .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-lyGiA9TzSh2VD26V .error-icon{fill:#552222;}#mermaid-svg-lyGiA9TzSh2VD26V .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-lyGiA9TzSh2VD26V .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-lyGiA9TzSh2VD26V .marker{fill:#333333;stroke:#333333;}#mermaid-svg-lyGiA9TzSh2VD26V .marker.cross{stroke:#333333;}#mermaid-svg-lyGiA9TzSh2VD26V svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-lyGiA9TzSh2VD26V p{margin:0;}#mermaid-svg-lyGiA9TzSh2VD26V .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-lyGiA9TzSh2VD26V .cluster-label text{fill:#333;}#mermaid-svg-lyGiA9TzSh2VD26V .cluster-label span{color:#333;}#mermaid-svg-lyGiA9TzSh2VD26V .cluster-label span p{background-color:transparent;}#mermaid-svg-lyGiA9TzSh2VD26V .label text,#mermaid-svg-lyGiA9TzSh2VD26V span{fill:#333;color:#333;}#mermaid-svg-lyGiA9TzSh2VD26V .node rect,#mermaid-svg-lyGiA9TzSh2VD26V .node circle,#mermaid-svg-lyGiA9TzSh2VD26V .node ellipse,#mermaid-svg-lyGiA9TzSh2VD26V .node polygon,#mermaid-svg-lyGiA9TzSh2VD26V .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-lyGiA9TzSh2VD26V .rough-node .label text,#mermaid-svg-lyGiA9TzSh2VD26V .node .label text,#mermaid-svg-lyGiA9TzSh2VD26V .image-shape .label,#mermaid-svg-lyGiA9TzSh2VD26V .icon-shape .label{text-anchor:middle;}#mermaid-svg-lyGiA9TzSh2VD26V .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-lyGiA9TzSh2VD26V .rough-node .label,#mermaid-svg-lyGiA9TzSh2VD26V .node .label,#mermaid-svg-lyGiA9TzSh2VD26V .image-shape .label,#mermaid-svg-lyGiA9TzSh2VD26V .icon-shape .label{text-align:center;}#mermaid-svg-lyGiA9TzSh2VD26V .node.clickable{cursor:pointer;}#mermaid-svg-lyGiA9TzSh2VD26V .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-lyGiA9TzSh2VD26V .arrowheadPath{fill:#333333;}#mermaid-svg-lyGiA9TzSh2VD26V .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-lyGiA9TzSh2VD26V .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-lyGiA9TzSh2VD26V .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-lyGiA9TzSh2VD26V .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-lyGiA9TzSh2VD26V .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-lyGiA9TzSh2VD26V .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-lyGiA9TzSh2VD26V .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-lyGiA9TzSh2VD26V .cluster text{fill:#333;}#mermaid-svg-lyGiA9TzSh2VD26V .cluster span{color:#333;}#mermaid-svg-lyGiA9TzSh2VD26V div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-lyGiA9TzSh2VD26V .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-lyGiA9TzSh2VD26V rect.text{fill:none;stroke-width:0;}#mermaid-svg-lyGiA9TzSh2VD26V .icon-shape,#mermaid-svg-lyGiA9TzSh2VD26V .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-lyGiA9TzSh2VD26V .icon-shape p,#mermaid-svg-lyGiA9TzSh2VD26V .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-lyGiA9TzSh2VD26V .icon-shape .label rect,#mermaid-svg-lyGiA9TzSh2VD26V .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-lyGiA9TzSh2VD26V .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-lyGiA9TzSh2VD26V .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-lyGiA9TzSh2VD26V :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 编码器内部
多头自注意力
加残差 & 层归一化
前馈网络
加残差 & 层归一化
输入: (批次, 30, 5)
维度变换与投影
(30, 批次, 64)

  • 位置编码
    取最后一个时间步

(批次, 64)
全连接层
输出预测值

(批次, 1)

关键点解析:

  1. input_projection :原始特征(如5维)通常与模型内部维度 (d_model) 不匹配。这个线性层将其映射到统一的 d_model 维空间,方便后续计算。
  2. batch_first=False :PyTorch 的 Transformer 模块默认期望输入形状为 (序列长度, 批次大小, 特征维度)。所以我们用 permute 调整了维度。
  3. TransformerEncoder :这是 PyTorch 封装好的模块,内部包含了多头自注意力层和前馈神经网络层,以及残差连接和层归一化。我们只需要堆叠 num_layers 层即可。
  4. 取最后一个时间步 :在时间序列预测中,我们通常用过去 LOOK_BACK 天的信息来预测下一天。经过 Transformer 编码后,序列中每个位置都包含了全局上下文信息。我们取最后一个位置(即最近一天)的表示来进行预测,这是一种常见且有效的做法。
模块四:训练与评估循环 (train_model, evaluate_model)

训练过程是标准的 PyTorch 流程:前向传播、计算损失(均方误差 MSE)、反向传播、优化器步进。我们使用了 Adam 优化器和学习率调度器 (StepLR)。

一个重要的技巧:梯度裁剪 (clip_grad_norm_)

python 复制代码
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Transformer 模型有时会遇到梯度爆炸问题。梯度裁剪将梯度的范数限制在一个阈值内,可以显著提高训练稳定性。

4. 运行结果与模型对比

运行完整的代码后,你会得到两张图:

  1. 预测结果对比图:显示训练集真实值、测试集真实值以及模型在测试集上的预测值。可以直观看到模型是否捕捉到了价格趋势。
  2. 训练损失曲线:显示训练集和测试集的损失随训练轮次(Epoch)下降的过程,用于判断模型是否过拟合或欠拟合。

Transformer 与传统序列模型对比表:

特性 RNN GRU LSTM Transformer
核心机制 简单循环 门控循环单元 长短时记忆网络 自注意力
并行计算 ❌ (顺序) ❌ (顺序) ❌ (顺序) ✅ (全局)
长距离依赖
位置感知 内置顺序 内置顺序 内置顺序 需位置编码
训练速度
诞生年份 1986 2014 1997 2017

Transformer 的优势总结:

  • 并行化:极大加速训练。
  • 全局视野:每个输出位置都能直接关注到输入序列的所有位置,擅长捕捉长距离依赖。
  • 可解释性:注意力权重可以可视化,看到模型在关注什么。

在本项目中的局限:

  • 这是一个简化版,未使用解码器(Decoder),适合单变量预测。
  • 股价预测是极其复杂的任务,受众多因素影响。本例主要用于演示 Transformer 在时序数据上的应用,切勿直接用于实际投资。

5. 总结与拓展

通过这个项目,我们不仅理解了 Transformer 的核心思想------自注意力机制,还亲手实现了一个用于时间序列预测的 Transformer 模型。你可以尝试:

  1. 调整超参数 :如 LOOK_BACK(历史窗口)、d_model(模型维度)、nhead(注意力头数)、num_layers(编码器层数),观察模型性能变化。
  2. 添加更多特征:除了价格和成交量,还可以加入技术指标(如均线、RSI、MACD)。
  3. 尝试更复杂的架构:例如 Transformer + 卷积层,或引入时间特征编码。
  4. 更换任务:将代码稍作修改,可用于其他时序预测任务,如天气预测、电力负荷预测、销量预测等。

Transformer 的设计是优雅而强大的,它证明了"注意力"确实可以成为你所需要的全部。希望这篇结合了理论、生活比喻和实战代码的文章,能帮助你顺利踏入 Transformer 的大门。


附录:完整可运行代码

以下代码整合了上述所有模块,复制到 PyTorch 环境中即可运行。请确保已安装所需库:pip install akshare

python 复制代码
# -*- coding: utf-8 -*-
"""
================================================================================
 Transformer 初学者教程 ------ 使用自注意力机制预测股票价格
================================================================================
 数据集:通过 akshare(国内金融数据接口)自动下载「上证指数」历史数据
 框架:  PyTorch
 模型:  TransformerEncoder(自注意力 + 前馈网络,无循环/无卷积)
 任务:  用前 N 天的多项特征(开高低收量),预测第 N+1 天的收盘价

 ═══════════════════════════════════════════════════════════════════
  Transformer 核心思想(Vaswani et al., 2017 "Attention Is All You Need")

  核心公式 ------ 缩放点积注意力(Scaled Dot-Product Attention):

                Q · K^T
    Attention(Q, K, V) = softmax( ─────── ) · V
                                 √d_k

  通俗理解:
    Q (Query):  "我想查什么?" ------ 当前位置想关注什么
    K (Key):    "我有什么标签?" ------ 每个位置的特征标识
    V (Value):  "我的内容是什么?" ------ 每个位置的实际信息
    softmax(QK^T/√d_k):  相似度矩阵,表示"当前位置应该关注每个历史位置多少"

  多头注意力(Multi-Head Attention):
    把 Q/K/V 拆成 h 个"头",每个头独立做注意力,最后拼接。
    → 不同头可以关注不同的模式(短期趋势、长期趋势、周期等)

  Transformer vs RNN 的关键区别:
    ┌──────────────┬─────────────────────┬─────────────────────┐
    │     特性     │     RNN 家族         │    Transformer      │
    ├──────────────┼─────────────────────┼─────────────────────┤
    │ 处理方式     │ 逐步递归(串行)      │ 并行注意力(并行)   │
    │ 长距离依赖   │ 逐步传递(可能丢失)  │ 直接关注任意位置     │
    │ 训练速度     │ 慢(不能并行)        │ 快(可并行)         │
    │ 位置信息     │ 天然内置              │ 需要位置编码注入     │
    │ 核心机制     │ 循环/门控             │ 自注意力             │
    └──────────────┴─────────────────────┴─────────────────────┘

 运行方式:
     pip install torch numpy pandas scikit-learn matplotlib akshare
     python rnn_transformer_demo.py
================================================================================
"""

import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler

# ============================================================================
# 0. 全局配置
# ============================================================================
LOOK_BACK = 30            # Transformer 擅长处理长序列,可以设大一些
BATCH_SIZE = 32
D_MODEL = 64              # 词向量/嵌入维度(整个 Transformer 的特征维度)
NHEAD = 4                 # 多头注意力的"头"数(d_model 必须能被 nhead 整除)
NUM_LAYERS = 2            # Transformer Encoder 层数
DIM_FEEDFORWARD = 128     # 前馈网络中间层维度(通常是 d_model × 2~4)
DROPOUT = 0.1             # Dropout 比例
LEARNING_RATE = 0.001
EPOCHS = 80
TRAIN_RATIO = 0.8
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 使用多项特征(Transformer 擅长捕捉特征间关系)
FEATURE_COLS = ["open", "high", "low", "close", "volume"]
NUM_FEATURES = len(FEATURE_COLS)  # 5

print(f"🔧 使用设备: {DEVICE}")
print(f"📦 PyTorch 版本: {torch.__version__}")


# ============================================================================
# 1. 下载数据 ------ akshare(国内源)
# ============================================================================
def download_data():
    print("\n📥 正在下载上证指数历史数据...")
    try:
        import akshare as ak
        df = ak.stock_zh_index_daily(symbol="sh000001")
        print(f"   ✅ 下载成功!共 {len(df)} 条记录")
        print(f"   📅 数据范围: {df['date'].iloc[0]} ~ {df['date'].iloc[-1]}")
        return df
    except ImportError:
        print("   ❌ akshare 未安装,请运行: pip install akshare")
        raise
    except Exception as e:
        print(f"   ⚠️ 主接口失败 ({e}),尝试备用接口...")
        try:
            import akshare as ak
            df = ak.index_zh_a_hist(symbol="000001", period="daily")
            print(f"   ✅ 备用接口下载成功!共 {len(df)} 条记录")
            return df
        except Exception as e2:
            print(f"   ❌ 所有接口均失败: {e2}")
            raise


# ============================================================================
# 2. 数据预处理
# ============================================================================
def preprocess_data(df):
    print("\n🔧 数据预处理中...")

    # --- 2.1 列名映射 ---
    col_mapping = {}
    for target in FEATURE_COLS:
        for col in df.columns:
            if col.lower() == target or target in col.lower():
                col_mapping[target] = col
                break

    used_keys = [k for k in FEATURE_COLS if k in col_mapping]
    data = df[[col_mapping[k] for k in used_keys]].values.astype(np.float32)
    data = np.nan_to_num(data, nan=0.0)

    num_features = data.shape[1]
    print(f"   📊 多特征数据形状: {data.shape}")

    # --- 2.2 每列独立归一化 ---
    scalers = []
    data_scaled = np.zeros_like(data)
    for c in range(num_features):
        scaler = MinMaxScaler(feature_range=(0, 1))
        data_scaled[:, c] = scaler.fit_transform(data[:, c].reshape(-1, 1)).ravel()
        scalers.append(scaler)

    close_idx = used_keys.index("close") if "close" in used_keys else 3
    close_scaler = scalers[close_idx]

    # --- 2.3 构造时序样本 ---
    X, y = [], []
    for i in range(LOOK_BACK, len(data_scaled)):
        X.append(data_scaled[i - LOOK_BACK:i, :])
        y.append(data_scaled[i, close_idx])

    X = np.array(X)  # (样本数, LOOK_BACK, 特征数)
    y = np.array(y)  # (样本数,)

    print(f"   ✅ X 形状: {X.shape}(样本={X.shape[0]}, 时间步={X.shape[1]}, 特征={X.shape[2]})")
    print(f"   ✅ y 形状: {y.shape}")
    print(f"   📋 使用特征: {used_keys}")

    return X, y, close_scaler


# ============================================================================
# 3. 划分训练集 / 测试集
# ============================================================================
def split_data(X, y):
    split_idx = int(len(X) * TRAIN_RATIO)
    X_train, X_test = X[:split_idx], X[split_idx:]
    y_train, y_test = y[:split_idx], y[split_idx:]

    print(f"\n📂 训练集: {len(X_train)} 条  |  测试集: {len(X_test)} 条")

    X_train_t = torch.tensor(X_train, dtype=torch.float32)
    y_train_t = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
    X_test_t = torch.tensor(X_test, dtype=torch.float32)
    y_test_t = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

    train_loader = DataLoader(TensorDataset(X_train_t, y_train_t),
                              batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(TensorDataset(X_test_t, y_test_t),
                             batch_size=BATCH_SIZE, shuffle=False)

    return train_loader, test_loader, X_train_t, y_train_t, X_test_t, y_test_t


# ============================================================================
# 4. ★ 位置编码(Positional Encoding)
# ============================================================================
class PositionalEncoding(nn.Module):
    """
    Transformer 没有循环结构,不知道位置的先后顺序。
    位置编码就是给每个位置注入一个唯一的"位置信号"。

    使用正弦/余弦函数的原因:
      - 不同位置有不同编码(唯一性)
      - 相邻位置编码相似(平滑性)
      - 可以外推到训练时未见过的更长序列

    公式:
      PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
      PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
                       ↑
              偶数列用 sin,奇数列用 cos
    """

    def __init__(self, d_model, max_len=5000, dropout=0.1):
        """
        参数:
            d_model: 嵌入维度(和 Transformer 的 d_model 一致)
            max_len: 最大序列长度(预计算这么多位置)
            dropout:  加在编码后的 dropout
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # ------ 预计算位置编码矩阵 (max_len, d_model) ------
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        # div_term: 10000^(2i/d_model) 的分母
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)   # 偶数列 sin
        pe[:, 1::2] = torch.cos(position * div_term)   # 奇数列 cos
        pe = pe.unsqueeze(1)  # (max_len, 1, d_model) 适配 batch 维度

        # 注册为 buffer(不参与训练的参数,但随模型保存/加载)
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        x: (seq_len, batch_size, d_model)
        返回: x + 位置编码 (seq_len, batch_size, d_model)
        """
        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)


# ============================================================================
# 5. ★ 定义 Transformer 模型
# ============================================================================
class TimeSeriesTransformer(nn.Module):
    """
    基于 TransformerEncoder 的时间序列预测模型。

    结构(数据流):
      Input (batch, seq_len, features)
         │
         ▼
      线性投影 → (batch, seq_len, d_model)    ← 将原始特征映射到 d_model 维
         │
         ▼
      + 位置编码                               ← 注入位置信息
         │
         ▼
      TransformerEncoder × N 层                ← 自注意力 + 前馈网络
         │  ┌─────────────────────────┐
         │  │  Multi-Head Attention   │ ← "我应该关注序列中哪些位置?"
         │  │        + Add & Norm     │
         │  │  Feed-Forward Network   │ ← "提取更复杂的特征"
         │  │        + Add & Norm     │
         │  └─────────────────────────┘
         │        × NUM_LAYERS 层
         ▼
      取最后一个位置的输出 → (batch, d_model)
         │
         ▼
      FC → (batch, 1)                          ← 输出预测的收盘价
    """

    def __init__(self, input_dim=NUM_FEATURES, d_model=D_MODEL,
                 nhead=NHEAD, num_layers=NUM_LAYERS,
                 dim_feedforward=DIM_FEEDFORWARD, dropout=DROPOUT,
                 output_size=1):
        super(TimeSeriesTransformer, self).__init__()

        self.d_model = d_model
        self.input_dim = input_dim

        # ① 输入投影:将原始特征数映射到 d_model 维度
        self.input_projection = nn.Linear(input_dim, d_model)

        # ② 位置编码
        self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)

        # ③ ★ 核心:Transformer Encoder 层
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,              # 特征维度
            nhead=nhead,                  # 注意力头数
            dim_feedforward=dim_feedforward,  # FFN 中间层维度
            dropout=dropout,
            activation="relu",            # FFN 激活函数
            batch_first=False,            # 输入格式 (seq, batch, feature)
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,        # 堆叠层数
        )

        # ④ 输出层
        self.fc_out = nn.Linear(d_model, output_size)

    def forward(self, x):
        """
        参数:
            x: (batch, seq_len, input_dim)  例如 (32, 30, 5)

        返回:
            out: (batch, output_size)      预测的收盘价
        """
        # ── 调整形状:Transformer 期望 (seq_len, batch, features) ──
        x = x.permute(1, 0, 2)   # (look_back, batch, num_features)

        # ── 投影到 d_model 维度 ──
        x = self.input_projection(x)  # (seq_len, batch, d_model)

        # ── 注入位置编码 ──
        x = self.pos_encoder(x)       # (seq_len, batch, d_model)

        # ── ★ Transformer Encoder(自注意力核心)──
        # 这里每个位置会计算与所有其他位置的注意力权重
        # 输出的 x 每个位置已经融合了全局上下文信息
        x = self.transformer_encoder(x)  # (seq_len, batch, d_model)

        # ── 取最后一个时间步的输出做预测 ──
        # (也可以取平均池化,但取最后位置是最常见的做法)
        x = x[-1, :, :]                 # (batch, d_model)

        # ── 全连接输出 ──
        out = self.fc_out(x)            # (batch, 1)
        return out


# ============================================================================
# 6. 训练
# ============================================================================
def train_model(model, train_loader, criterion, optimizer):
    model.train()
    total_loss = 0

    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)

        predictions = model(X_batch)
        loss = criterion(predictions, y_batch)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)


# ============================================================================
# 7. 评估
# ============================================================================
def evaluate_model(model, test_loader, criterion):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
            predictions = model(X_batch)
            loss = criterion(predictions, y_batch)
            total_loss += loss.item()

    return total_loss / len(test_loader)


# ============================================================================
# 8. 可视化
# ============================================================================
def plot_results(y_real, y_pred, scaler, train_size):
    y_real_inv = scaler.inverse_transform(y_real.reshape(-1, 1))
    y_pred_inv = scaler.inverse_transform(y_pred.reshape(-1, 1))

    plt.figure(figsize=(14, 6))
    plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
    plt.rcParams["axes.unicode_minus"] = False

    x_axis = range(len(y_real_inv))
    plt.plot(x_axis[:train_size], y_real_inv[:train_size],
             color="blue", alpha=0.4, linewidth=1, label="训练集(真实值)")
    plt.plot(x_axis[train_size:], y_real_inv[train_size:],
             color="green", alpha=0.7, linewidth=1.5, label="测试集(真实值)")
    plt.plot(x_axis[train_size:], y_pred_inv[train_size:],
             color="red", alpha=0.8, linewidth=1.5, label="测试集(Transformer 预测值)")

    plt.axvline(x=train_size, color="gray", linestyle="--", alpha=0.7)
    plt.xlabel("时间(天)", fontsize=12)
    plt.ylabel("收盘价", fontsize=12)
    plt.title(f"Transformer 股票价格预测 --- 上证指数\n"
              f"(Look Back={LOOK_BACK}, d_model={D_MODEL}, heads={NHEAD}, layers={NUM_LAYERS})",
              fontsize=14)
    plt.legend(loc="upper left")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("transformer_prediction_result.png", dpi=150)
    plt.show()
    print("\n📊 图表已保存为 transformer_prediction_result.png")


def plot_training_curve(train_losses, test_losses):
    plt.figure(figsize=(10, 5))
    plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
    plt.rcParams["axes.unicode_minus"] = False

    plt.plot(train_losses, label="训练损失", color="blue", linewidth=1.5)
    plt.plot(test_losses, label="测试损失", color="orange", linewidth=1.5)
    plt.xlabel("Epoch", fontsize=12)
    plt.ylabel("Loss (MSE)", fontsize=12)
    plt.title("Transformer 训练过程 --- Loss 曲线", fontsize=14)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("transformer_loss_curve.png", dpi=150)
    plt.show()
    print("📊 损失曲线已保存为 transformer_loss_curve.png")


# ============================================================================
# 9. Transformer 原理讲解
# ============================================================================
def print_transformer_explanation():
    print("\n" + "=" * 60)
    print("   📖 Transformer 核心原理 ------ Attention Is All You Need")
    print("=" * 60)
    print(f"""
   Transformer 用"自注意力"取代了 RNN 的"循环"。
   一句话理解:每个位置直接看序列中所有位置,而不是一步步传递。

   ┌─────────────────────────────────────────────────────────┐
   │              Scaled Dot-Product Attention                │
   │                                                          │
   │   Q (Query)  ← 当前位置的线性投影                         │
   │   K (Key)    ← 所有位置的线性投影(用于匹配)              │
   │   V (Value)  ← 所有位置的线性投影(用于聚合)              │
   │                                                          │
   │                          Q·K^T                            │
   │   Attention = softmax( ──────── ) · V                    │
   │                          √d_k                             │
   │                     ↑                                    │
   │              除以 √d_k 防止点积过大导致 softmax 梯度消失    │
   └─────────────────────────────────────────────────────────┘

   ★ 自注意力的优势:
     1. 并行计算:不依赖上一时间步,整条序列同时算
     2. 长距离依赖:不管位置相隔多远,直接建立联系
     3. 可解释性:注意力权重可以可视化"模型在关注哪里"

   ★ 为什么还需要位置编码?
     Transformer 本身是"位置无关"的------打乱输入顺序,输出也会对应打乱。
     位置编码给每个位置注入唯一信号,让模型知道"谁在前谁在后"。
""")
    print("=" * 60)


# ============================================================================
# 10. 主函数
# ============================================================================
def main():
    print("=" * 60)
    print("   🧠 Transformer 初学者教程 --- Attention Is All You Need")
    print("=" * 60)

    print_transformer_explanation()

    # Step 1: 下载数据
    df = download_data()

    # Step 2: 预处理
    X, y, close_scaler = preprocess_data(df)

    # Step 3: 划分
    train_loader, test_loader, X_train, y_train, X_test, y_test = split_data(X, y)

    # Step 4: 创建模型
    model = TimeSeriesTransformer(
        input_dim=NUM_FEATURES,
        d_model=D_MODEL,
        nhead=NHEAD,
        num_layers=NUM_LAYERS,
        dim_feedforward=DIM_FEEDFORWARD,
        dropout=DROPOUT,
    ).to(DEVICE)

    print(f"\n🏗️ 模型结构:\n{model}")
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"📐 总参数量: {total_params:,}  |  可训练: {trainable_params:,}")
    print(f"📐 d_model={D_MODEL}, heads={NHEAD}, layers={NUM_LAYERS}, ff_dim={DIM_FEEDFORWARD}")

    # Step 5 & 6: 训练 + 评估
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Transformer 训练通常使用学习率预热(warmup),这里简化处理
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

    train_losses, test_losses = [], []

    print(f"\n🚀 开始训练(共 {EPOCHS} 轮)...")
    print(f"   💡 Transformer 比 RNN 并行度高,训练速度通常更快")
    for epoch in range(1, EPOCHS + 1):
        train_loss = train_model(model, train_loader, criterion, optimizer)
        test_loss = evaluate_model(model, test_loader, criterion)
        scheduler.step()

        train_losses.append(train_loss)
        test_losses.append(test_loss)

        if epoch % 10 == 0 or epoch == 1:
            lr_now = scheduler.get_last_lr()[0]
            print(f"   Epoch [{epoch:3d}/{EPOCHS}]  "
                  f"Train Loss: {train_loss:.6f}  |  Test Loss: {test_loss:.6f}  "
                  f"|  LR: {lr_now:.6f}")

    print(f"\n✅ 训练完成!最终 Test Loss: {test_losses[-1]:.6f}")

    # Step 7: 可视化
    model.eval()
    with torch.no_grad():
        y_pred_train = model(X_train.to(DEVICE)).cpu().numpy()
        y_pred_test = model(X_test.to(DEVICE)).cpu().numpy()

    y_all_real = np.concatenate([y_train.numpy(), y_test.numpy()])
    y_all_pred = np.concatenate([y_pred_train, y_pred_test])

    plot_results(y_all_real, y_all_pred, close_scaler, len(y_train))
    plot_training_curve(train_losses, test_losses)

    # 小结
    print("\n" + "=" * 60)
    print("🎉 程序运行完毕!")
    print(f"\n📝 全系列模型对比总结:")
    print(f"   ┌──────────────┬──────────┬──────────┬──────────┬──────────────┬──────────────┐")
    print(f"   │    特性      │   RNN    │   GRU    │   LSTM   │  ConvLSTM    │ Transformer  │")
    print(f"   ├──────────────┼──────────┼──────────┼──────────┼──────────────┼──────────────┤")
    print(f"   │ 核心机制     │   循环   │  门控    │  三门控  │ 卷积+门控   │  自注意力    │")
    print(f"   │ 并行计算     │   ❌     │   ❌     │   ❌     │   ❌         │   ✅         │")
    print(f"   │ 长距离依赖   │   弱     │   中     │   强     │   中         │   强         │")
    print(f"   │ 位置感知     │  内置    │  内置    │  内置    │  内置        │  需编码      │")
    print(f"   │ 训练速度     │   慢     │   中     │   慢     │   慢         │   快         │")
    print(f"   │ 提出年份     │  1986    │  2014    │  1997    │  2015        │  2017        │")
    print(f"   └──────────────┴──────────┴──────────┴──────────┴──────────────┴──────────────┘")
    print(f"\n🔧 可调超参数:")
    print(f"   LOOK_BACK={LOOK_BACK}  d_model={D_MODEL}  nhead={NHEAD}  "
          f"num_layers={NUM_LAYERS}  ff_dim={DIM_FEEDFORWARD}  lr={LEARNING_RATE}")
    print("=" * 60)


if __name__ == "__main__":
    main()