【SegRNN 源码理解】图示理解 forward的过程

复制代码
输入: x [16, 60, 7]
(16个批次,每个60个时间步,每步7个特征)
            │
            ▼
┌──────────────────────────────┐
│      RevIN 标准化 + 维度置换     │
│   x = revinLayer(x, 'norm')   │
│      .permute(0, 2, 1)       │
└──────────────┬───────────────┘
               │
               ▼
          x [16, 7, 60]
(16个批次,7个特征,每个特征60个时间步)
            │
            ▼
┌──────────────────────────────┐
│         重塑为分段格式          │
│ x.reshape(-1, seg_num_x, seg_len) │
└──────────────┬───────────────┘
               │
               ▼
          x [112, 5, 12]
(112个序列=16批次×7特征,每个分5段,每段12步)
            │
            ▼
┌──────────────────────────────┐
│          段值嵌入             │
│     x = valueEmbedding(x)     │
│  (Linear: 12 → 512 + ReLU)    │
└──────────────┬───────────────┘
               │
               ▼
          x [112, 5, 512]
(112个序列,5个段,每段表示为512维向量)
            │
            ▼
┌──────────────────────────────┐
│           GRU 编码            │
│      _, hn = self.rnn(x)      │
└──────────────┬───────────────┘
               │
               ▼
          hn [1, 112, 512]
(1层GRU,112个序列的最终隐藏状态)
            │
            ▼
┌────────────────┬─────────────┐
│    RMF 解码     │    PMF 解码   │
└────────┬───────┴──────┬──────┘
         │              │
┌────────▼───────┐  ┌───▼──────────┐
│  循环多步预测    │  │  并行多步预测   │
│(逐段自回归预测)  │  │(一次性预测所有段)│
└────────┬───────┘  └───┬──────────┘
         │              │
         │       ┌──────▼─────────┐
         │       │  位置和通道嵌入   │
         │       │   组合成条件     │
         │       └──────┬─────────┘
         │              │
         │       ┌──────▼─────────┐
         │       │ pos_emb [224, 1, 512] │
         │       │ (224=16×7×2)   │
         │       └──────┬─────────┘
         │              │
         │       ┌──────▼─────────┐
         │       │   条件GRU解码    │
         │       │ _, hy = rnn(pos_emb, hn) │
         │       └──────┬─────────┘
         │              │
         │       ┌──────▼─────────┐
         │       │ hy [1, 224, 512] │
         │       └──────┬─────────┘
         │              │
┌────────▼───────┐ ┌────▼──────────┐
│ 预测 + 堆叠各段  │ │     预测       │
└────────┬───────┘ │ y = predict(hy) │
         │         └────┬───────────┘
         │              │
         └──────┬───────┘
                │
                ▼
           y [16, 7, 24]
 (预测结果: 16个批次,7个特征,预测24个时间步)
                │
                ▼
┌───────────────────────────────┐
│      维度置换 + RevIN反标准化    │
│   y = revinLayer(y.permute(0, 2, 1), 'denorm')   │
└───────────────┬───────────────┘
                │
                ▼
       最终输出 y [16, 24, 7]
(16个批次,预测未来24个时间步,每步7个特征)

SegRNN 前向传播的关键处理阶段解释

1. 数据预处理和视角转换

  • 输入形状[16, 60, 7] (批次, 时间步, 特征)
  • RevIN标准化:对每个特征序列进行可逆实例标准化,消除分布偏移
  • 维度置换 :将视角从"时间优先"转为"特征优先" → [16, 7, 60]
  • 作用:为每个特征序列建立独立处理路径

2. 序列分段和嵌入

  • 重塑和分段[16, 7, 60][112, 5, 12]
    • 112 = 16(批次) × 7(特征)
    • 每个序列分为5段,每段12个时间步
  • 线性嵌入:将每个12维的段映射到512维空间
  • 作用:引入层次化表示,捕获不同时间尺度的模式

3. GRU序列编码

  • 序列编码:GRU处理5个段之间的时间依赖关系
  • 输出隐藏状态hn [1, 112, 512]
  • 作用:将历史信息压缩到一个固定长度向量

4. 解码阶段(PMF模式)

  • 位置和通道嵌入组合
    • 位置嵌入:[2, 256] → 预测的2个段
    • 通道嵌入:[7, 256] → 7个不同特征
    • 组合后:[224, 1, 512] (224 = 16×7×2)
  • 条件GRU解码
    • 输入:位置嵌入序列
    • 条件:复制并重塑的编码器隐藏状态
    • 输出:hy [1, 224, 512]
  • 预测 :线性层 [224, 512][224, 12] → 重塑为 [16, 7, 24]
  • 作用:基于历史信息和位置/通道条件,一次生成所有预测段

5. 输出处理

  • 维度置换[16, 7, 24][16, 24, 7]
  • RevIN反标准化:将标准化的数据转换回原始分布
  • 最终输出[16, 24, 7] (16个批次,预测24个时间步,每步7个特征)

模型设计亮点

  1. 层次化架构:通过分段+RNN的两级架构,有效处理不同时间尺度
  2. 特征独立处理:每个特征有独立的处理路径,减少干扰
  3. 可逆标准化:RevIN处理分布偏移,保留原始分布特性
  4. 条件生成:位置嵌入+通道嵌入提供细粒度控制的条件生成

这种设计使SegRNN能高效处理长序列预测问题,特别是在多变量时间序列中,不同特征具有不同时间模式的情况。

相关推荐
vx_biyesheji00019 小时前
Python 全国城市租房洞察系统 Django框架 Requests爬虫 可视化 房子 房源 大数据 大模型 计算机毕业设计源码(建议收藏)✅
爬虫·python·机器学习·django·flask·课程设计·旅游
湘美书院--湘美谈教育9 小时前
湘美谈教育湘美书院网文研究:人工智能与微型小说选集
人工智能·深度学习·神经网络·机器学习·ai写作
zh路西法11 小时前
【宇树机器人强化学习】(七):复杂地形的生成与训练
python·深度学习·机器学习·机器人
OpenBayes贝式计算14 小时前
教程上新丨基于 GPU 部署 OpenClaw,轻松接入飞书/Discord 等社交软件
人工智能·深度学习·机器学习
Master_oid14 小时前
机器学习35:元学习的应用
人工智能·学习·机器学习
Echo_NGC223714 小时前
【卷积神经网络 CNN】一文讲透卷积神经网络CNN的核心概念与演进历程
人工智能·深度学习·神经网络·目标检测·机器学习·自然语言处理·cnn
郑同学zxc14 小时前
机器学习19-tensorflow4.2
人工智能·机器学习
LSssT.15 小时前
【02】线性回归:机器学习的入门第一课
人工智能·机器学习·线性回归
vx_biyesheji000116 小时前
计算机毕业设计:Python多源新闻数据智能舆情挖掘平台 Flask框架 爬虫 SnowNLP ARIMA 可视化 数据分析 大数据(建议收藏)✅
爬虫·python·机器学习·数据分析·django·flask·课程设计
忧郁的橙子.17 小时前
08-QLora微调&GGUF模型转换、Qwen打包部署 ollama 运行
人工智能·深度学习·机器学习·qlora·打包部署 ollama