LSTM、GRU 与 Transformer网络模型参数计算

参数计算公式对比

模型类型 参数计算公式 关键组成部分
LSTM 4 × (embed_dim × hidden_size + hidden_size² + hidden_size) 4个门控结构
GRU 3 × (embed_dim × hidden_size + hidden_size² + hidden_size) 3个门控结构
Transformer (Encoder) 12 × embed_dim² + 9 × embed_dim × ff_dim + 14 × embed_dim 多头注意力 + FFN
Transformer (Decoder) 14 × embed_dim² + 9 × embed_dim × ff_dim + 15 × embed_dim 多头注意力 + FFN + 掩码注意力

详细参数计算解析

1. LSTM 参数计算

LSTM 单元包含 4 个门控结构(输入门、遗忘门、候选单元、输出门)

Python

复制代码
LSTM_params = 4 × (input_size × hidden_size +   # Wi, Wf, Wc, Wo
                   hidden_size × hidden_size +  # Ui, Uf, Uc, Uo
                   hidden_size)                 # bi, bf, bc, bo

简化公式LSTM_params ≈ 4 × hidden_size × (input_size + hidden_size + 1)

2. GRU 参数计算

GRU 单元包含 3 个门控结构(更新门、重置门、候选门)

复制代码
GRU_params = 3 × (input_size × hidden_size +   # Wz, Wr, Wh
                  hidden_size × hidden_size +   # Uz, Ur, Uh
                  hidden_size)                 # bz, br, bh

简化公式GRU_params ≈ 3 × hidden_size × (input_size + hidden_size + 1)

3. Transformer 参数计算

Transformer 由多层堆叠,每层包含:

  • 多头注意力机制(Multi-Head Attention)
  • 前馈神经网络(Feed-Forward Network)
  • 层归一化(LayerNorm)
  • 残差连接(Skip Connections)
单层参数分解:
复制代码
# 多头注意力层
QKV_proj = 3 × embed_dim × embed_dim  # Wq, Wk, Wv
output_proj = embed_dim × embed_dim   # Wo
attention_params = 4 × embed_dim²

# 前馈神经网络
FFN_params = 2 × (embed_dim × ff_dim + ff_dim × embed_dim) + (ff_dim + embed_dim)
           = 2 × embed_dim × ff_dim + 2 × ff_dim × embed_dim + ff_dim + embed_dim
           = 4 × embed_dim × ff_dim + ff_dim + embed_dim

# 层归一化 (2个)
LayerNorm_params = 2 × 2 × embed_dim  # 每个LN有gamma和beta参数

# 总单层参数
Encoder_layer = attention_params + FFN_params + LayerNorm_params
              = 4×embed_dim² + (4×embed_dim×ff_dim + ff_dim + embed_dim) + 4×embed_dim

完整 Transformer 参数公式

对于 N 层 Transformer:

其中:

  • d = embed_dim (嵌入维度)
  • d_ff = ff_dim (前馈网络隐藏层维度)
  • Embedding = vocab_size × embed_dim (词嵌入参数)

参数对比示例

假设配置:

  • 嵌入维度 (embed_dim) = 512
  • 隐藏层维度 (hidden_size) = 512
  • FFN 维度 (ff_dim) = 2048
  • 词表大小 (vocab_size) = 50000
  • LSTM/GRU 层数 = 1
  • Transformer 层数 = 6

参数计算结果:

模型 参数计算 总量 占比
LSTM 4 × (512×512 + 512² + 512) = 4×(262,144 + 262,144 + 512) = 2,100,352 2.10M 基准
GRU 3 × (512×512 + 512² + 512) = 3×(262,144 + 262,144 + 512) = 1,574,400 1.57M 75%
Transformer Encoder 6×(4×512² + 4×512×2048 + 2048 + 5×512) + 50000×512 = 6×(1,048,576 + 4,194,304 + 2048 + 2,560) + 25,600,000 = 6×5,247,488 + 25,600,000 = **57,084,928** 57.1M 27.2倍
Embedding层 50000×512 = 25,600,000 25.6M -

参数计算工具函数

复制代码
def calculate_params(model_type, embed_dim, hidden_size=None, 
                     ff_dim=None, num_layers=1, vocab_size=None):
    params = 0
    
    if model_type == "LSTM":
        # LSTM参数计算
        params = 4 * (embed_dim * hidden_size + hidden_size**2 + hidden_size)
    
    elif model_type == "GRU":
        # GRU参数计算
        params = 3 * (embed_dim * hidden_size + hidden_size**2 + hidden_size)
    
    elif model_type == "Transformer-Encoder":
        # Transformer编码器参数计算
        per_layer = (4 * embed_dim**2) + (4 * embed_dim * ff_dim) + ff_dim + (5 * embed_dim)
        encoder_params = num_layers * per_layer
        embedding_params = vocab_size * embed_dim
        params = encoder_params + embedding_params
    
    elif model_type == "Transformer-Decoder":
        # Transformer解码器参数计算
        per_layer = (8 * embed_dim**2) + (4 * embed_dim * ff_dim) + ff_dim + (6 * embed_dim)
        decoder_params = num_layers * per_layer
        embedding_params = vocab_size * embed_dim
        params = decoder_params + embedding_params
    
    return params

# 示例使用
lstm_params = calculate_params("LSTM", embed_dim=512, hidden_size=512)
transformer_params = calculate_params("Transformer-Encoder", embed_dim=512, 
                                     ff_dim=2048, num_layers=6, vocab_size=50000)
相关推荐
简简单单做算法6 小时前
基于LSTM深度学习网络的视频类型分类算法matlab仿真
深度学习·matlab·分类·lstm·视频类型分类
机器学习之心17 小时前
三种深度学习模型(GRU、CNN-GRU、贝叶斯优化的CNN-GRU/BO-CNN-GRU)对北半球光伏数据进行时间序列预测
gru·cnn-gru·贝叶斯优化的cnn-gru
王上上1 天前
【论文阅读51】-CNN-LSTM-安全系数和失效概率预测
论文阅读·cnn·lstm
叫我:松哥1 天前
优秀案例:基于python django的智能家居销售数据采集和分析系统设计与实现,使用混合推荐算法和LSTM算法情感分析
爬虫·python·算法·django·lstm·智能家居·推荐算法
王小王-1231 天前
基于Transform、ARIMA、LSTM、Prophet的药品销量预测分析
lstm·arima·transform·prophet·药品销量预测·时序建模预测
老鱼说AI2 天前
Transformer Masked loss原理精讲及其PyTorch逐行实现
人工智能·pytorch·python·深度学习·transformer
lucky_lyovo2 天前
循环神经网络--LSTM模型
rnn·机器学习·lstm
李加号pluuuus2 天前
【论文阅读+复现】LayoutDM: Transformer-based Diffusion Model for Layout Generation
论文阅读·深度学习·transformer
9呀2 天前
【人工智能99问】长短期记忆网络(LSTM)的结构和原理是什么?(12/99)
人工智能·rnn·lstm
叫我:松哥3 天前
基于python的微博评论和博文文本分析,包括LDA+聚类+词频分析+lstm热度预测,数据量10000条
python·机器学习·数据挖掘·数据分析·lstm·聚类