时空反向传播 (STBP) 算法

时空反向传播 (STBP) 算法深度解析

1. 核心挑战:SNN 为什么难训练?

传统的深度学习(ANN)使用反向传播(BP)算法,依赖链式法则计算梯度。SNN 想要使用 BP,面临一个数学上的"死胡同":

  • 脉冲的不可导性: SNN 中的神经元发放脉冲是一个二值的阶跃函数(Step Function)。
    • O u t p u t = 1 Output = 1 Output=1 if u > V t h u > V_{th} u>Vth else 0 0 0
  • 梯度消失或爆炸: 阶跃函数的导数在阈值处是无穷大(狄拉克 δ \delta δ 函数),在其他地方是 0。这意味着梯度要么无法传播,要么爆炸,导致网络无法通过标准的梯度下降来更新权重。

STBP 的解决思路: 将 SNN 在时间维度 上展开,将其视为一个特殊的循环神经网络(RNN),并在反向传播时引入代理梯度(Surrogate Gradient)

2. STBP 的"时空"含义

STBP 的精髓在于它同时在两个维度上计算梯度的传播:

  1. 空间维度 (Spatial Domain):
    • 就像传统的 CNN/MLP 一样,误差从输出层向输入层,逐层(Layer-by-Layer)反向传播。
    • 这解决了"如何调整权重以提取特征"的问题。
  2. 时间维度 (Temporal Domain):
    • 由于 LIF 神经元有膜电位泄漏(Leakage)和累积特性,当前时刻的状态依赖于上一时刻。误差需要从 t t t 时刻向 t − 1 t-1 t−1 时刻传播。
    • 这解决了"如何利用历史信息"的问题。

3. 算法数学推导 (结合 Paper 1)

A. 前向传播 (LIF 动力学)

在离散时间步 t t t,LIF 神经元 i i i 的行为被建模为:

  1. 预突触输入: x i t = ∑ j w i j s j t − 1 x_i^t = \sum_j w_{ij} s_j^{t-1} xit=∑jwijsjt−1
  2. 膜电位更新: u i t = u i t − 1 ⋅ τ d e c a y + x i t + b i u_i^t = u_i^{t-1} \cdot \tau_{decay} + x_i^t + b_i uit=uit−1⋅τdecay+xit+bi (若上一步未发放脉冲)
  3. 脉冲发放: s i t = g ( u i t − V t h ) s_i^t = g(u_i^t - V_{th}) sit=g(uit−Vth),其中 g ( x ) g(x) g(x) 是海维赛德阶跃函数。

B. 反向传播 (链式法则)

我们要计算损失函数 L L L 对权重 W W W 的梯度 ∂ L ∂ W \frac{\partial L}{\partial W} ∂W∂L。根据链式法则,总梯度是空间梯度和时间梯度的总和。

∂ L ∂ u i t = ∂ L ∂ s i t ∂ s i t ∂ u i t ⏟ 空间传播 + ∂ L ∂ u i t + 1 ∂ u i t + 1 ∂ u i t ⏟ 时间传播 \frac{\partial L}{\partial u_i^t} = \underbrace{\frac{\partial L}{\partial s_i^t} \frac{\partial s_i^t}{\partial u_i^t}}{\text{空间传播}} + \underbrace{\frac{\partial L}{\partial u_i^{t+1}} \frac{\partial u_i^{t+1}}{\partial u_i^t}}{\text{时间传播}} ∂uit∂L=空间传播 ∂sit∂L∂uit∂sit+时间传播 ∂uit+1∂L∂uit∂uit+1

这里有两个关键项:

  1. 时间依赖项 ∂ u i t + 1 ∂ u i t \frac{\partial u_i^{t+1}}{\partial u_i^t} ∂uit∂uit+1**:**

    这对应于神经元的泄漏因子(decay factor)。

    ∂ u i t + 1 ∂ u i t ≈ τ d e c a y (忽略复位影响的简化) \frac{\partial u_i^{t+1}}{\partial u_i^t} \approx \tau_{decay} \text{ (忽略复位影响的简化)} ∂uit∂uit+1≈τdecay (忽略复位影响的简化)

  2. 脉冲导数项 ∂ s i t ∂ u i t \frac{\partial s_i^t}{\partial u_i^t} ∂uit∂sit (Crucial!):

    这是 s s s 对 u u u 求导。由于 s s s 是阶跃函数,直接求导不可行。STBP 在这里引入了"代理梯度"

C. 代理梯度 (Surrogate Gradient)

前向传播 时,仍然使用阶跃函数以保持 SNN 的二值特性;但在反向传播 计算梯度时,使用一个平滑的可导函数 h ( u ) h(u) h(u) 来近似阶跃函数。

Paper 1 中选择的代理梯度函数是一个类高斯函数(类似于概率密度函数):

h ( u ) = 1 2 π e − ( u − V t h ) 2 h(u) = \frac{1}{\sqrt{2\pi}} e^{-(u - V_{th})^2} h(u)=2π 1e−(u−Vth)2

  • 这意味着:当膜电位 u u u 接近阈值 V t h V_{th} Vth 时,我们认为它"很有可能"发放脉冲,因此给予较大的梯度;当 u u u 远离阈值时,梯度衰减。这使得梯度可以顺滑地传回网络。

4. Paper 1 的改进:硬件感知 STBP

Paper 1 并没有止步于标准的 STBP,而是针对 低功耗 ASIC 设计 修改了损失函数,使其训练出的网络天生适合硬件。

损失函数设计

L o s s t o t a l = L o s s M S E + λ s L o s s R a t e + λ w L o s s W e i g h t Loss_{total} = Loss_{MSE} + \lambda_s Loss_{Rate} + \lambda_w Loss_{Weight} Losstotal=LossMSE+λsLossRate+λwLossWeight

  1. L o s s M S E Loss_{MSE} LossMSE (准确率项):

    传统的分类误差(均方误差),让网络输出正确的分类。

  2. L o s s R a t e Loss_{Rate} LossRate (发放率正则化项):

    λ s ∑ ∑ ∣ ∣ s i t ∣ ∣ 2 2 \lambda_s \sum \sum ||s_i^t||_2^2 λs∑∑∣∣sit∣∣22

    • 目的: 惩罚脉冲的发放。
    • 硬件意义: 芯片采用了"脉冲驱动(Spike-Driven)"架构,功耗与脉冲数量成正比。强制网络变得稀疏(Firing Rate 从 25% 压到 15%),直接降低了芯片的动态功耗。
  3. L o s s W e i g h t Loss_{Weight} LossWeight (权重正则化项):

    λ w ∑ ∣ ∣ w ∣ ∣ 1 \lambda_w \sum ||w||_1 λw∑∣∣w∣∣1

    • 目的: L1 正则化,使权重趋向于 0。
    • 硬件意义: 配合剪枝(Pruning),将接近 0 的权重直接移除。这减少了存储需求(SRAM),并允许使用稀疏存储格式。

5. 算法流程图解

mermind 复制代码
graph TD
    subgraph Time_Step_t ["时间步 t"]
        Pre_Spike_t(输入脉冲 S_in) -->|x W| Mem_Pot_t(膜电位 u_t)
        Mem_Pot_t -->|Step Function| Spike_Out_t(输出脉冲 S_out)
    end
    
    subgraph Time_Step_t_plus_1 ["时间步 t+1"]
        Mem_Pot_t -.->|Leakage| Mem_Pot_t1(膜电位 u_t+1)
    end
    
    subgraph Backpropagation ["反向传播 (STBP)"]
        Err_Out(输出误差) -->|空间梯度| Grad_S_t(dS/du)
        Grad_S_t --"代理梯度 h(u)"--> Grad_U_t(du)
        Err_Next(t+1时刻误差) -->|时间梯度| Grad_U_t
    end

6. 总结

STBP 算法通过引入时间维度的展开代理梯度近似,打通了 SNN 训练的数学链路。

在 Paper 1 中,该算法不仅仅是为了训练一个能用的网络,更是通过修改损失函数 ,充当了硬件/算法协同设计 (Co-design) 的桥梁:它"逼迫"神经网络学出一种稀疏的、低发放率的、权重简单的形态,从而完美契合其 ASIC 芯片的低功耗特性。

相关推荐
CoovallyAIHub17 小时前
Moonshine:比 Whisper 快 100 倍的端侧语音识别神器,Star 6.6K!
深度学习·算法·计算机视觉
CoovallyAIHub18 小时前
速度暴涨10倍、成本暴降6倍!Mercury 2用扩散取代自回归,重新定义LLM推理速度
深度学习·算法·计算机视觉
CoovallyAIHub18 小时前
实时视觉AI智能体框架来了!Vision Agents 狂揽7K Star,延迟低至30ms,YOLO+Gemini实时联动!
算法·架构·github
CoovallyAIHub18 小时前
开源:YOLO最强对手?D-FINE目标检测与实例分割框架深度解析
人工智能·算法·github
CoovallyAIHub19 小时前
OpenClaw:从“19万星标”到“行业封杀”,这只“赛博龙虾”究竟触动了谁的神经?
算法·架构·github
刀法如飞19 小时前
程序员必须知道的核心算法思想
算法·编程开发·算法思想
徐小夕20 小时前
pxcharts Ultra V2.3更新:多维表一键导出 PDF,渲染兼容性拉满!
vue.js·算法·github
CoovallyAIHub21 小时前
OpenClaw一脚踩碎传统CV?机器终于不再只是看世界
深度学习·算法·计算机视觉
CoovallyAIHub21 小时前
仅凭单目相机实现3D锥桶定位?UNet-RKNet破解自动驾驶锥桶检测难题
深度学习·算法·计算机视觉