从零手写LSTM:从门控原理到PyTorch源码级实现

在学习深度学习框架时,调用nn.LSTM往往只是一行代码,但理解其内部门控机制与矩阵运算才能真正掌握序列建模的本质。本文将从头实现长短期记忆网络(LSTM)的前向传播,深入解析输入门、遗忘门、输出门的计算细节,并与PyTorch官方实现进行对比验证。

一、LSTM基础原理

LSTM(Long Short-Term Memory)通过引入细胞状态(Cell State)和三个门控机制(输入门、遗忘门、输出门)来解决传统RNN的长期依赖问题。

核心公式

LSTM的计算流程可分为四个门控计算两种状态更新

四个门控: i t = σ ( W i i x t + b i i + W h i h t − 1 + b h i ) // 输入门 f t = σ ( W i f x t + b i f + W h f h t − 1 + b h f ) // 遗忘门 g t = tanh ⁡ ( W i g x t + b i g + W h g h t − 1 + b h g ) // 候选细胞状态 o t = σ ( W i o x t + b i o + W h o h t − 1 + b h o ) // 输出门 两种状态更新: c t = f t ⊙ c t − 1 + i t ⊙ g t // 细胞状态更新 h t = o t ⊙ tanh ⁡ ( c t ) // 隐藏状态输出 \begin{aligned} \text{四个门控:} \\ i_t &= \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1} + b_{hi}) \quad \text{// 输入门} \\ f_t &= \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf}) \quad \text{// 遗忘门} \\ g_t &= \tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg}) \quad \text{// 候选细胞状态} \\ o_t &= \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho}) \quad \text{// 输出门} \\ \text{两种状态更新:} \\ c_t &= f_t \odot c_{t-1} + i_t \odot g_t \quad \text{// 细胞状态更新} \\ h_t &= o_t \odot \tanh(c_t) \quad \text{// 隐藏状态输出} \end{aligned} 四个门控:itftgtot两种状态更新:ctht=σ(Wiixt+bii+Whiht−1+bhi)// 输入门=σ(Wifxt+bif+Whfht−1+bhf)// 遗忘门=tanh(Wigxt+big+Whght−1+bhg)// 候选细胞状态=σ(Wioxt+bio+Whoht−1+bho)// 输出门=ft⊙ct−1+it⊙gt// 细胞状态更新=ot⊙tanh(ct)// 隐藏状态输出

其中:

  • h t h_t ht:时刻 t t t 的隐藏状态(Hidden State)
  • c t c_t ct:时刻 t t t 的细胞状态(Cell State),LSTM的核心记忆载体
  • x t x_t xt:时刻 t t t 的输入特征
  • i t , f t , o t i_t, f_t, o_t it,ft,ot:输入门、遗忘门、输出门,通过Sigmoid激活输出0-1之间的值控制信息流动
  • g t g_t gt:候选记忆内容,通过Tanh激活
  • ⊙ \odot ⊙:Hadamard积(逐元素乘积)

门控机制图解

  • 遗忘门 f t f_t ft:决定从细胞状态中丢弃哪些信息(0表示完全遗忘,1表示完全保留)
  • 输入门 i t i_t it:决定哪些新信息存入细胞状态
  • 输出门 o t o_t ot:决定基于细胞状态输出哪些隐藏状态

二、PyTorch LSTM参数详解

在实现之前,先理解PyTorch API的设计逻辑:

核心参数

  • input_size -- 输入 x x x 中预期特征的数量(输入特征维度 H i n H_{in} Hin)
  • hidden_size -- 隐藏状态 h h h 和细胞状态 c c c 中的特征数量( H o u t H_{out} Hout / H c e l l H_{cell} Hcell)
  • num_layers -- 循环层的数量。例如num_layers=2表示堆叠两个LSTM,第二层接收第一层的隐藏状态作为输入。默认值:1
  • bias -- 如果为False,则不使用偏置 b i h b_{ih} bih 和 b h h b_{hh} bhh。默认值:True
  • batch_first -- 如果为True,输入输出张量格式为(batch, seq, feature)而非(seq, batch, feature)。注意:这不适用于隐藏状态和细胞状态。默认值:False
  • dropout -- 如果非零,在除最后一层外的每个LSTM层输出上引入Dropout层。默认值:0
  • bidirectional -- 如果为True,则变为双向LSTM。默认值:False
  • proj_size -- 如果大于0,将使用对应大小的投影LSTM, h t h_t ht维度将从hidden_size变为proj_size。默认值:0

输入:input, (h_0, c_0)

  • input: 张量形状:

    • 无batch输入: ( L , H i n ) (L, H_{in}) (L,Hin)
    • batch_first=False时: ( L , N , H i n ) (L, N, H_{in}) (L,N,Hin),即(seq, batch, feature)
    • batch_first=True时: ( N , L , H i n ) (N, L, H_{in}) (N,L,Hin),即(batch, seq, feature)
  • h_0 : 初始隐藏状态,形状为 ( D × num_layers , N , H o u t ) (D \times \text{num\layers}, N, H{out}) (D×num_layers,N,Hout)(批处理情况)

  • c_0 : 初始细胞状态,形状为 ( D × num_layers , N , H c e l l ) (D \times \text{num\layers}, N, H{cell}) (D×num_layers,N,Hcell)

其中:
N = batch size(批量大小) L = sequence length(序列长度) D = 2 如果 bidirectional=True,否则为 1 H i n = input_size H c e l l = hidden_size H o u t = 如果 proj_size>0 则为 proj_size,否则为 hidden_size \begin{aligned} N &= \text{batch size(批量大小)} \\ L &= \text{sequence length(序列长度)} \\ D &= 2 \text{ 如果 bidirectional=True,否则为 } 1 \\ H_{in} &= \text{input\size} \\ H{cell} &= \text{hidden\size} \\ H{out} &= \text{如果 proj\_size>0 则为 proj\_size,否则为 hidden\_size} \end{aligned} NLDHinHcellHout=batch size(批量大小)=sequence length(序列长度)=2 如果 bidirectional=True,否则为 1=input_size=hidden_size=如果 proj_size>0 则为 proj_size,否则为 hidden_size

输出:output, (h_n, c_n)

  • output : 包含LSTM最后一层每个时间步的输出特征 h t h_t ht,形状与input对应:

    • batch_first=False时: ( L , N , D × H o u t ) (L, N, D \times H_{out}) (L,N,D×Hout)
    • batch_first=True时: ( N , L , D × H o u t ) (N, L, D \times H_{out}) (N,L,D×Hout)
  • h_n : 最终隐藏状态,形状 ( D × num_layers , N , H o u t ) (D \times \text{num\layers}, N, H{out}) (D×num_layers,N,Hout)

  • c_n : 最终细胞状态,形状 ( D × num_layers , N , H c e l l ) (D \times \text{num\layers}, N, H{cell}) (D×num_layers,N,Hcell)

三、手撕LSTM单层实现

首先准备数据,并调用PyTorch官方API作为对比基准:

python 复制代码
import torch
import torch.nn as nn

# 定义常量
batch_size, sequence_len, input_size, hidden_size = 2, 3, 4, 5

# 构造输入input和初始状态c0, h0
input = torch.randn(batch_size, sequence_len, input_size)  # [B, L, Hin]
c0 = torch.randn(batch_size, hidden_size)                  # [B, Hout]
h0 = torch.randn(batch_size, hidden_size)                  # [B, Hout]

# 调用官方API
lstm_layer = nn.LSTM(input_size, hidden_size, batch_first=True)
output, (h_final, c_final) = lstm_layer(
    input, 
    (h0.unsqueeze(0), c0.unsqueeze(0))  # 官方API要求 [num_layers, B, Hout]
)

# 官方API输出维度:
# output: [B, L, Hout]
# h_final: [1, B, Hout]
# c_final: [1, B, Hout]

查看官方参数组织方式

python 复制代码
for k, v in lstm_layer.named_parameters():
    print(k, v.shape)

# 输出:
# weight_ih_l0 torch.Size([20, 4])   # [4*Hout, Hin]
# weight_hh_l0 torch.Size([20, 5])   # [4*Hout, Hout]
# bias_ih_l0 torch.Size([20])        # [4*Hout]
# bias_hh_l0 torch.Size([20])        # [4*Hout]

关键理解:PyTorch将四个门的权重矩阵拼接存储:

  • weight_ih = [ W i i ; W i f ; W i g ; W i o ] [W_{ii}; W_{if}; W_{ig}; W_{io}] [Wii;Wif;Wig;Wio](按行拼接,形状 [ 4 × H o u t , H i n ] [4\times H_{out}, H_{in}] [4×Hout,Hin])
  • weight_hh = [ W h i ; W h f ; W h g ; W h o ] [W_{hi}; W_{hf}; W_{hg}; W_{ho}] [Whi;Whf;Whg;Who](形状 [ 4 × H o u t , H o u t ] [4\times H_{out}, H_{out}] [4×Hout,Hout])
  • bias_ih = [ b i i ; b i f ; b i g ; b i o ] [b_{ii}; b_{if}; b_{ig}; b_{io}] [bii;bif;big;bio](形状 [ 4 × H o u t ] [4\times H_{out}] [4×Hout])
  • bias_hh = [ b h i ; b h f ; b h g ; b h o ] [b_{hi}; b_{hf}; b_{hg}; b_{ho}] [bhi;bhf;bhg;bho](形状 [ 4 × H o u t ] [4\times H_{out}] [4×Hout])

接下来是核心------手动实现LSTM前向传播:

python 复制代码
def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):
    """
    手动实现LSTM前向传播
    
    Args:
        input: [B, L, Hin]
        initial_states: (h0, c0),每个都是 [B, Hout]
        w_ih: [4*Hout, Hin],输入到隐藏层的权重
        w_hh: [4*Hout, Hout],隐藏层到隐藏层的循环权重
        b_ih, b_hh: [4*Hout],偏置项
    """
    h0, c0 = initial_states
    batch_size, sequence_len, input_size = input.shape
    hidden_size = w_ih.shape[0] // 4  # 因为是四个门的权重拼接

    prev_h = h0  # [B, Hout]
    prev_c = c0  # [B, Hout]

    # 为批次并行计算准备权重矩阵:扩展batch维度
    # w_ih: [4*Hout, Hin] -> [B, 4*Hout, Hin]
    batch_w_ih = w_ih.unsqueeze(0).tile(batch_size, 1, 1)
    # w_hh: [4*Hout, Hout] -> [B, 4*Hout, Hout]
    batch_w_hh = w_hh.unsqueeze(0).tile(batch_size, 1, 1)

    # 初始化输出序列 [B, L, Hout]
    output = torch.zeros(batch_size, sequence_len, hidden_size)

    # 遍历时间步
    for t in range(sequence_len):
        x = input[:, t, :]  # 获取当前时刻输入 [B, Hin]
        
        # 计算 w_ih @ x: [B, 4*Hout, Hin] @ [B, Hin, 1] -> [B, 4*Hout, 1]
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(2)).squeeze(-1)  # [B, 4*Hout]
        
        # 计算 w_hh @ h_{t-1}: [B, 4*Hout, Hout] @ [B, Hout, 1] -> [B, 4*Hout]
        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(2)).squeeze(-1)  # [B, 4*Hout]

        # 分别计算四个门(通过切片获取对应权重部分)
        # 输入门 i_t
        i_t = torch.sigmoid(
            w_times_x[:, :hidden_size] + 
            w_times_h_prev[:, :hidden_size] +
            b_ih[:hidden_size].unsqueeze(0).tile(batch_size, 1) +
            b_hh[:hidden_size].unsqueeze(0).tile(batch_size, 1)
        )  # [B, Hout]
        
        # 遗忘门 f_t
        f_t = torch.sigmoid(
            w_times_x[:, hidden_size:2*hidden_size] + 
            w_times_h_prev[:, hidden_size:2*hidden_size] +
            b_ih[hidden_size:2*hidden_size].unsqueeze(0).tile(batch_size, 1) +
            b_hh[hidden_size:2*hidden_size].unsqueeze(0).tile(batch_size, 1)
        )  # [B, Hout]
        
        # 候选细胞状态 g_t
        g_t = torch.tanh(
            w_times_x[:, 2*hidden_size:3*hidden_size] + 
            w_times_h_prev[:, 2*hidden_size:3*hidden_size] +
            b_ih[2*hidden_size:3*hidden_size].unsqueeze(0).tile(batch_size, 1) +
            b_hh[2*hidden_size:3*hidden_size].unsqueeze(0).tile(batch_size, 1)
        )  # [B, Hout]
        
        # 输出门 o_t
        o_t = torch.sigmoid(
            w_times_x[:, 3*hidden_size:] + 
            w_times_h_prev[:, 3*hidden_size:] +
            b_ih[3*hidden_size:].unsqueeze(0).tile(batch_size, 1) +
            b_hh[3*hidden_size:].unsqueeze(0).tile(batch_size, 1)
        )  # [B, Hout]

        # 细胞状态更新:c_t = f_t * c_{t-1} + i_t * g_t
        prev_c = f_t * prev_c + i_t * g_t  # [B, Hout]
        
        # 隐藏状态输出:h_t = o_t * tanh(c_t)
        prev_h = o_t * torch.tanh(prev_c)  # [B, Hout]
        
        # 存储当前时刻输出
        output[:, t, :] = prev_h
    
    return output, (prev_h, prev_c)

四、关键实现细节解析

1. 权重矩阵的批次并行(bmm)

为了实现批次数据的并行计算,需要将权重矩阵扩展到批次维度:

python 复制代码
batch_w_ih = w_ih.unsqueeze(0).tile(batch_size, 1, 1)  # [1, 4*Hout, Hin] -> [B, 4*Hout, Hin]

这样每个样本都使用相同的权重进行独立计算,通过torch.bmm(批次矩阵乘法)实现高效并行。

2. 维度对齐技巧

  • 输入 x t x_t xt 初始形状 [B, Hin],通过unsqueeze(2)变为[B, Hin, 1],便于与权重[B, 4*Hout, Hin]进行矩阵乘法
  • 结果[B, 4*Hout, 1]通过squeeze(-1)压缩回[B, 4*Hout],便于后续门控计算

3. 四权重拼接的切片访问

PyTorch将四个门的权重拼接存储,因此通过切片获取:

python 复制代码
# 第0~Hout行:输入门权重
# 第Hout~2*Hout行:遗忘门权重  
# 第2*Hout~3*Hout行:候选状态权重
# 第3*Hout~4*Hout行:输出门权重

4. 偏置的广播机制

由于PyTorch存储的偏置是1维[4*Hout],需要扩展为[B, Hout]以进行批次加法:

python 复制代码
b_ih[:hidden_size].unsqueeze(0).tile(batch_size, 1)  # [Hout] -> [1, Hout] -> [B, Hout]

五、验证实现正确性

使用官方API的参数测试手写实现:

python 复制代码
# 提取官方LSTM的参数
w_ih = lstm_layer.weight_ih_l0
w_hh = lstm_layer.weight_hh_l0
b_ih = lstm_layer.bias_ih_l0
b_hh = lstm_layer.bias_hh_l0

# 调用自定义实现
custom_output, (custom_h_final, custom_c_final) = lstm_forward(
    input, (h0, c0), w_ih, w_hh, b_ih, b_hh
)

print("PyTorch API nn.LSTM output:")
print(output, h_final.squeeze(0), c_final.squeeze(0))
print("\ncustom lstm_forward() output:")
print(custom_output, custom_h_final, custom_c_final)

# 数值对比(应几乎完全相同,误差在1e-6级别)
print("\nDifference:", torch.abs(output - custom_output).max())

验证结果 :若两者输出一致(误差小于 10 − 6 10^{-6} 10−6),说明我们的数学推导和实现逻辑与PyTorch底层C++实现完全吻合。

六、LSTM的优缺点总结

LSTM的优点

  • 有效解决RNN的长期依赖问题(梯度消失/爆炸),可记忆数十到数百个时间步的信息
  • 通过门控机制精确控制信息的遗忘与更新,适合长序列建模(如文本生成、时间序列预测)
  • 细胞状态 c t c_t ct 的线性传递(累加操作)提供了稳定的梯度传播路径

LSTM的缺点

  • 计算复杂度高:每个时间步需要计算4个门控,参数量是标准RNN的4倍
  • 串行计算无法并行化时间步,推理速度较慢
  • 内存占用大:需要同时保存 h t h_t ht和 c t c_t ct两个状态向量

七、总结

通过手撕LSTM的实现,我们深入理解了:

  1. 门控机制:输入门、遗忘门、输出门如何协同工作,通过Sigmoid控制信息流动比例,Tanh生成候选内容
  2. 双状态设计 :细胞状态 c t c_t ct作为长期记忆载体(线性更新),隐藏状态 h t h_t ht作为短期输出(非线性变换)
  3. 矩阵拼接:理解PyTorch将四组权重拼接存储的工程实践,掌握张量切片与维度对齐技巧
  4. 批次并行 :使用bmm实现多样本并行计算,理解unsqueeze/squeeze在维度对齐中的作用

虽然实际项目中直接使用nn.LSTM即可,但这种底层实现训练有助于理解:

  • 梯度裁剪在LSTM中的必要性
  • 双向LSTM的实现原理(正反向各一个LSTM,输出拼接)
  • 多层堆叠LSTM的实现(上一层的 h t h_t ht作为下一层的 x t x_t xt)
  • 变长序列的PackedSequence处理

完整代码已验证:上述代码可直接运行,输出结果与PyTorch官方API完全一致,证明我们的实现正确复现了LSTM的门控逻辑与矩阵运算流程。


注:对于带投影的LSTM(proj_size > 0),只需在输出时增加 h t = W h r h t h_t = W_{hr}h_t ht=Whrht 的线性投影步骤,将维度从hidden_size映射到proj_size,其余逻辑完全一致。

相关推荐
勾股导航16 小时前
大模型Skill
人工智能·python·机器学习
卷福同学18 小时前
【养虾日记】Openclaw操作浏览器自动化发文
人工智能·后端·算法
春日见19 小时前
如何入门端到端自动驾驶?
linux·人工智能·算法·机器学习·自动驾驶
光锥智能19 小时前
从自动驾驶到 AI 能力体系,元戎启行 GTC 发布基座模型新进展
人工智能
luoganttcc19 小时前
自动驾驶 世界模型 有哪些
人工智能·机器学习·自动驾驶
潘高19 小时前
10分钟教你手撸一个小龙虾(OpenClaw)
人工智能
禁默19 小时前
光学与机器视觉:解锁“机器之眼”的核心密码-《第五届光学与机器视觉国际学术会议(ICOMV 2026)》
人工智能·计算机视觉·光学
深小乐19 小时前
不是DeepSeek V4!这两个神秘的 Hunter 模型竟然来自小米
人工智能
laozhao43220 小时前
科大讯飞中标教育管理应用升级开发项目
大数据·人工智能
rainbow72424420 小时前
AI人才简历评估选型:技术面试、代码评审与项目复盘的综合运用方案
人工智能·面试·职场和发展