STBP推导

STBP推导

STBP的推导容易引起混乱,主要在于论文中上来先分了4个case,而不是先讲清楚梯度是怎么流动的,然后自然而然引出这4种情况。此外,论文公式推导中存在一定的错误。我们**不妨先盯住某个时刻某个特定神经元节点的膜电位 u i t , n u_i^{t,n} uit,n,**看它在前向计算图里怎样继续影响损失 L L L。

原文使用的迭代LIF模型是

x i t + 1 , n = ∑ j = 1 l ( n − 1 ) w i j n o j t + 1 , n − 1 , x_i^{t+1,n}=\sum_{j=1}^{l(n-1)}w_{ij}^{n}o_j^{t+1,n-1}, xit+1,n=j=1∑l(n−1)wijnojt+1,n−1,

u i t + 1 , n = u i t , n f ( o i t , n ) + x i t + 1 , n + b i n , u_i^{t+1,n}=u_i^{t,n}f(o_i^{t,n})+x_i^{t+1,n}+b_i^n, uit+1,n=uit,nf(oit,n)+xit+1,n+bin,

o i t + 1 , n = g ( u i t + 1 , n ) . o_i^{t+1,n}=g(u_i^{t+1,n}). oit+1,n=g(uit+1,n).

最终需要通过时空反向传播更新的是可学习参数,也就是需要:

∂ L ∂ W n , ∂ L ∂ b n . \frac{\partial L}{\partial W^n},\qquad \frac{\partial L}{\partial b^n}. ∂Wn∂L,∂bn∂L.

而 W n , b n W^n,b^n Wn,bn 都是先进入膜电位 u t , n u^{t,n} ut,n,再通过神经元状态影响后面的计算。于是根据链式法则:

∂ L ∂ b i n = ∑ t = 1 T ∂ L ∂ u i t , n ∂ u i t , n ∂ b i n , \frac{\partial L}{\partial b_i^n} =\sum_{t=1}^{T} \frac{\partial L}{\partial u_i^{t,n}} \frac{\partial u_i^{t,n}}{\partial b_i^n}, ∂bin∂L=t=1∑T∂uit,n∂L∂bin∂uit,n,

∂ L ∂ w i j n = ∑ t = 1 T ∂ L ∂ u i t , n ∂ u i t , n ∂ w i j n . \frac{\partial L}{\partial w_{ij}^{n}} =\sum_{t=1}^{T} \frac{\partial L}{\partial u_i^{t,n}} \frac{\partial u_i^{t,n}}{\partial w_{ij}^{n}}. ∂wijn∂L=t=1∑T∂uit,n∂L∂wijn∂uit,n.

根据前向更新式,对于上面两个参数梯度式中各自的第二项,显然有:

∂ u i t , n ∂ b i n = 1 , ∂ u i t , n ∂ w i j n = o j t , n − 1 . \frac{\partial u_i^{t,n}}{\partial b_i^n}=1,\qquad \frac{\partial u_i^{t,n}}{\partial w_{ij}^{n}}=o_j^{t,n-1}. ∂bin∂uit,n=1,∂wijn∂uit,n=ojt,n−1.

所以求参数梯度的关键其实是求出对膜电位的梯度。记

ϵ i t , n = ∂ L ∂ u i t , n . \epsilon_i^{t,n}=\frac{\partial L}{\partial u_i^{t,n}}. ϵit,n=∂uit,n∂L.

只要 ϵ i t , n \epsilon_i^{t,n} ϵit,n 求出来,就有

∂ L ∂ b i n = ∑ t = 1 T ϵ i t , n , ∂ L ∂ w i j n = ∑ t = 1 T ϵ i t , n o j t , n − 1 . \frac{\partial L}{\partial b_i^n} =\sum_{t=1}^{T}\epsilon_i^{t,n}, \qquad \frac{\partial L}{\partial w_{ij}^{n}} =\sum_{t=1}^{T}\epsilon_i^{t,n}o_j^{t,n-1}. ∂bin∂L=t=1∑Tϵit,n,∂wijn∂L=t=1∑Tϵit,nojt,n−1.

我们来想一下 u i t , n u_i^{t,n} uit,n 是怎样把自己的影响传到最终的损失函数 L L L 的。下面先按最一般的情况来讨论,也就是假设这个神经元既不是输出层,也不是最后时刻,即 t < T , n < N t<T,n<N t<T,n<N。四个case只是这个一般情形在边界处删掉或替换某些路径。

**第一类是空间路径。**膜电位先产生当前spike:

u i t , n ⟶ o i t , n . u_i^{t,n}\longrightarrow o_i^{t,n}. uit,n⟶oit,n.

这个spike会作为下一层的输入,也就是下一层所有神经元的输入电流的一部分,再进入它们的膜电位:

u i t , n ⟶ o i t , n ⟶ x j t , n + 1 ⟶ u j t , n + 1 , j = 1 , ... , l ( n + 1 ) . u_i^{t,n}\longrightarrow o_i^{t,n}\longrightarrow x_j^{t,n+1}\longrightarrow u_j^{t,n+1}, \qquad j=1,\ldots,l(n+1). uit,n⟶oit,n⟶xjt,n+1⟶ujt,n+1,j=1,...,l(n+1).

**第二类是时间路径。**由膜电位更新式可以看出,当前膜电位还会参与生成下一时刻的膜电位。不过这里要稍微拆开看,因为 u i t , n u_i^{t,n} uit,n 对 u i t + 1 , n u_i^{t+1,n} uit+1,n 的影响其实有两种方式。

第一种方式是直接作为旧膜电位保留到下一时刻。也就是说,在乘积

u i t , n f ( o i t , n ) u_i^{t,n}f(o_i^{t,n}) uit,nf(oit,n)

里,如果先把 f ( o i t , n ) f(o_i^{t,n}) f(oit,n) 看作一个已经确定的系数,那么 u i t , n u_i^{t,n} uit,n 会直接影响 u i t + 1 , n u_i^{t+1,n} uit+1,n:

u i t , n ⟶ u i t + 1 , n . u_i^{t,n}\longrightarrow u_i^{t+1,n}. uit,n⟶uit+1,n.

第二种方式是先影响当前spike,再通过forget gate(即 f f f)影响下一时刻膜电位。因为

o i t , n = g ( u i t , n ) , o_i^{t,n}=g(u_i^{t,n}), oit,n=g(uit,n),

所以 u i t , n u_i^{t,n} uit,n 还会先改变 o i t , n o_i^{t,n} oit,n,再通过 f ( o i t , n ) f(o_i^{t,n}) f(oit,n) 改变旧膜电位保留到下一时刻的比例:

u i t , n → o i t , n → f ( o i t , n ) → u i t + 1 , n . u_i^{t,n}\to o_i^{t,n}\to f(o_i^{t,n})\to u_i^{t+1,n}. uit,n→oit,n→f(oit,n)→uit+1,n.

这几类路径就是反传时的来源。按照空间路径和时间路径两类来写, u i t , n u_i^{t,n} uit,n 对损失的影响可以先写成

ϵ i t , n = ∂ L ∂ u i t , n = ∑ j ∂ L ∂ u j t , n + 1 ∂ u j t , n + 1 ∂ o i t , n ∂ o i t , n ∂ u i t , n ⏟ 空间路径 + ∂ L ∂ u i t + 1 , n ∂ u i t + 1 , n ∂ u i t , n ⏟ 时间路径 . \epsilon_i^{t,n}= \frac{\partial L}{\partial u_i^{t,n}}= \underbrace{ \sum_j \frac{\partial L}{\partial u_j^{t,n+1}} \frac{\partial u_j^{t,n+1}}{\partial o_i^{t,n}} \frac{\partial o_i^{t,n}}{\partial u_i^{t,n}} }{\text{空间路径}}+ \underbrace{ \frac{\partial L}{\partial u_i^{t+1,n}} \frac{\partial u_i^{t+1,n}}{\partial u_i^{t,n}} }{\text{时间路径}}. ϵit,n=∂uit,n∂L=空间路径 j∑∂ujt,n+1∂L∂oit,n∂ujt,n+1∂uit,n∂oit,n+时间路径 ∂uit+1,n∂L∂uit,n∂uit+1,n.

这个式子和前面的直觉是对应的:第一项表示当前膜电位先发放spike,再影响下一层;第二项表示当前膜电位影响下一时刻膜电位。接下来分别计算这两项。

先看空间路径。由

o i t , n = g ( u i t , n ) o_i^{t,n}=g(u_i^{t,n}) oit,n=g(uit,n)

可得

∂ o i t , n ∂ u i t , n = g ′ ( u i t , n ) . \frac{\partial o_i^{t,n}}{\partial u_i^{t,n}}=g'(u_i^{t,n}). ∂uit,n∂oit,n=g′(uit,n).

由于 g g g 是阶跃函数,训练时实际用替代梯度。记

h i t , n : = g ′ ( u i t , n ) . h_i^{t,n}:=g'(u_i^{t,n}). hit,n:=g′(uit,n).

再由

x j t , n + 1 = ∑ k w j k n + 1 o k t , n x_j^{t,n+1}=\sum_k w_{jk}^{n+1}o_k^{t,n} xjt,n+1=k∑wjkn+1okt,n

可得

∂ u j t , n + 1 ∂ o i t , n = ∂ x j t , n + 1 ∂ o i t , n = w j i n + 1 . \frac{\partial u_j^{t,n+1}}{\partial o_i^{t,n}}= \frac{\partial x_j^{t,n+1}}{\partial o_i^{t,n}}=w_{ji}^{n+1}. ∂oit,n∂ujt,n+1=∂oit,n∂xjt,n+1=wjin+1.

所以空间路径传回来的梯度是

∑ j ϵ j t , n + 1 w j i n + 1 h i t , n . \sum_j \epsilon_j^{t,n+1}w_{ji}^{n+1}h_i^{t,n}. j∑ϵjt,n+1wjin+1hit,n.

再看时间路径。上面的时间路径项需要计算的是

∂ L ∂ u i t + 1 , n ∂ u i t + 1 , n ∂ u i t , n . \frac{\partial L}{\partial u_i^{t+1,n}} \frac{\partial u_i^{t+1,n}}{\partial u_i^{t,n}}. ∂uit+1,n∂L∂uit,n∂uit+1,n.

其中关键是局部导数 ∂ u i t + 1 , n ∂ u i t , n \frac{\partial u_i^{t+1,n}}{\partial u_i^{t,n}} ∂uit,n∂uit+1,n。由膜电位更新式可知,和 u i t , n u_i^{t,n} uit,n 有关的部分是

u i t , n f ( o i t , n ) . u_i^{t,n}f(o_i^{t,n}). uit,nf(oit,n).

因此

∂ u i t + 1 , n ∂ u i t , n = ∂ ∂ u i t , n u i t , n f ( o i t , n ) . \frac{\partial u_i^{t+1,n}}{\partial u_i^{t,n}}= \frac{\partial}{\partial u_i^{t,n}} \leftu_i\^{t,n}f(o_i\^{t,n})\\right. ∂uit,n∂uit+1,n=∂uit,n∂uit,nf(oit,n).

对这个乘积使用乘积法则:

∂ ∂ u i t , n u i t , n f ( o i t , n ) = f ( o i t , n ) + u i t , n ∂ f ( o i t , n ) ∂ u i t , n . \frac{\partial}{\partial u_i^{t,n}} \leftu_i\^{t,n}f(o_i\^{t,n})\\right=f(o_i^{t,n})+u_i^{t,n} \frac{\partial f(o_i^{t,n})}{\partial u_i^{t,n}}. ∂uit,n∂uit,nf(oit,n)=f(oit,n)+uit,n∂uit,n∂f(oit,n).

第二项继续用链式法则。因为 o i t , n = g ( u i t , n ) o_i^{t,n}=g(u_i^{t,n}) oit,n=g(uit,n),所以

∂ f ( o i t , n ) ∂ u i t , n = f ′ ( o i t , n ) ∂ o i t , n ∂ u i t , n = f ′ ( o i t , n ) h i t , n . \frac{\partial f(o_i^{t,n})}{\partial u_i^{t,n}}=f'(o_i^{t,n}) \frac{\partial o_i^{t,n}}{\partial u_i^{t,n}}=f'(o_i^{t,n})h_i^{t,n}. ∂uit,n∂f(oit,n)=f′(oit,n)∂uit,n∂oit,n=f′(oit,n)hit,n.

于是

∂ u i t + 1 , n ∂ u i t , n = f ( o i t , n ) + u i t , n f ′ ( o i t , n ) h i t , n . \frac{\partial u_i^{t+1,n}}{\partial u_i^{t,n}}=f(o_i^{t,n})+u_i^{t,n}f'(o_i^{t,n})h_i^{t,n}. ∂uit,n∂uit+1,n=f(oit,n)+uit,nf′(oit,n)hit,n.

这里第一项 f ( o i t , n ) f(o_i^{t,n}) f(oit,n) 来自直接时间路径

u i t , n → u i t + 1 , n , u_i^{t,n}\to u_i^{t+1,n}, uit,n→uit+1,n,

第二项 u i t , n f ′ ( o i t , n ) h i t , n u_i^{t,n}f'(o_i^{t,n})h_i^{t,n} uit,nf′(oit,n)hit,n 来自间接时间路径

u i t , n → o i t , n → f ( o i t , n ) → u i t + 1 , n . u_i^{t,n}\to o_i^{t,n}\to f(o_i^{t,n})\to u_i^{t+1,n}. uit,n→oit,n→f(oit,n)→uit+1,n.

所以时间路径传回来的梯度是

ϵ i t + 1 , n f ( o i t , n ) + u i t , n f ′ ( o i t , n ) h i t , n . \epsilon_i^{t+1,n} \leftf(o_i\^{t,n})+u_i\^{t,n}f'(o_i\^{t,n})h_i\^{t,n} \\right. ϵit+1,nf(oit,n)+uit,nf′(oit,n)hit,n.

把空间路径和时间路径相加,就得到一般位置 t < T , n < N t<T,n<N t<T,n<N 的膜电位梯度:

ϵ i t , n = ∑ j ϵ j t , n + 1 w j i n + 1 h i t , n + ϵ i t + 1 , n f ( o i t , n ) + u i t , n f ′ ( o i t , n ) h i t , n . \epsilon_i^{t,n}= \sum_j \epsilon_j^{t,n+1}w_{ji}^{n+1}h_i^{t,n}+ \epsilon_i^{t+1,n} \leftf(o_i\^{t,n})+u_i\^{t,n}f'(o_i\^{t,n})h_i\^{t,n} \\right. ϵit,n=j∑ϵjt,n+1wjin+1hit,n+ϵit+1,nf(oit,n)+uit,nf′(oit,n)hit,n.

上面这个式子可以整理成论文中显式使用 ∂ L ∂ o i t , n \frac{\partial L}{\partial o_i^{t,n}} ∂oit,n∂L 的形式。先把括号打开:

ϵ i t , n = ∑ j ϵ j t , n + 1 w j i n + 1 h i t , n + ϵ i t + 1 , n f ( o i t , n ) + ϵ i t + 1 , n u i t , n f ′ ( o i t , n ) h i t , n . \epsilon_i^{t,n}= \sum_j\epsilon_j^{t,n+1}w_{ji}^{n+1}h_i^{t,n}+ \epsilon_i^{t+1,n}f(o_i^{t,n})+ \epsilon_i^{t+1,n}u_i^{t,n}f'(o_i^{t,n})h_i^{t,n}. ϵit,n=j∑ϵjt,n+1wjin+1hit,n+ϵit+1,nf(oit,n)+ϵit+1,nuit,nf′(oit,n)hit,n.

其中第一项和第三项都乘着

h i t , n = ∂ o i t , n ∂ u i t , n , h_i^{t,n}=\frac{\partial o_i^{t,n}}{\partial u_i^{t,n}}, hit,n=∂uit,n∂oit,n,

也就是说,这两项都先经过当前spike节点 o i t , n o_i^{t,n} oit,n,再回到膜电位 u i t , n u_i^{t,n} uit,n。把这两个经过spike的项合在一起:

ϵ i t , n = ( ∑ j ϵ j t , n + 1 w j i n + 1 + ϵ i t + 1 , n u i t , n f ′ ( o i t , n ) ) h i t , n + ϵ i t + 1 , n f ( o i t , n ) . \epsilon_i^{t,n}= \left( \sum_j\epsilon_j^{t,n+1}w_{ji}^{n+1}+ \epsilon_i^{t+1,n}u_i^{t,n}f'(o_i^{t,n}) \right)h_i^{t,n}+ \epsilon_i^{t+1,n}f(o_i^{t,n}). ϵit,n=(j∑ϵjt,n+1wjin+1+ϵit+1,nuit,nf′(oit,n))hit,n+ϵit+1,nf(oit,n).

括号里的量就是流到当前spike节点 o i t , n o_i^{t,n} oit,n 的总梯度,也就是

δ i t , n : = ∂ L ∂ o i t , n = ∑ j ϵ j t , n + 1 w j i n + 1 + ϵ i t + 1 , n u i t , n f ′ ( o i t , n ) . \delta_i^{t,n}:= \frac{\partial L}{\partial o_i^{t,n}}= \sum_j\epsilon_j^{t,n+1}w_{ji}^{n+1}+ \epsilon_i^{t+1,n}u_i^{t,n}f'(o_i^{t,n}). δit,n:=∂oit,n∂L=j∑ϵjt,n+1wjin+1+ϵit+1,nuit,nf′(oit,n).

它有两部分来源:第一部分是通过发放spike影响下一层膜电位的空间路径;第二部分是通过 f ( o i t , n ) f(o_i^{t,n}) f(oit,n) 影响下一时刻膜电位的gate时间路径。因此上面的式子就可以写成论文中常见的形式:

ϵ i t , n = δ i t , n h i t , n + ϵ i t + 1 , n f ( o i t , n ) . \epsilon_i^{t,n}= \delta_i^{t,n}h_i^{t,n}+ \epsilon_i^{t+1,n}f(o_i^{t,n}). ϵit,n=δit,nhit,n+ϵit+1,nf(oit,n).

两种写法是同一个式子。前一种更直接对应"空间路径 + 时间路径",后一种更方便和论文中的四个case对应。

输出层的直接损失项来自原文损失函数:

L = 1 2 S ∑ s = 1 S ∥ y s − 1 T ∑ t = 1 T o s t , N ∥ 2 2 . L=\frac{1}{2S}\sum_{s=1}^{S} \left\|y_s-\frac{1}{T}\sum_{t=1}^{T}o_s^{t,N} \right\|_2^2. L=2S1s=1∑S ys−T1t=1∑Tost,N 22.

省略样本下标后,输出层任意时刻的直接梯度为

r i : = ∂ L ∂ o i t , N ∣ loss = − 1 T S ( y i − 1 T ∑ k = 1 T o i k , N ) . r_i:= \left. \frac{\partial L}{\partial o_i^{t,N}} \right|{\text{loss}}=-\frac{1}{TS} \left(y_i-\frac{1}{T}\sum{k=1}^{T}o_i^{k,N} \right). ri:=∂oit,N∂L loss=−TS1(yi−T1k=1∑Toik,N).

下面把一般公式在边界处删除/替换某些项,就得到四个case。为了避免直接跳到结果,先把一般位置的两条递推式拿出来(即公式(36)、(37)):

δ i t , n = ∑ j ϵ j t , n + 1 w j i n + 1 ⏟ 空间项 + ϵ i t + 1 , n u i t , n f ′ ( o i t , n ) ⏟ gate时间项 , \delta_i^{t,n}= \underbrace{\sum_j\epsilon_j^{t,n+1}w_{ji}^{n+1}}{\text{空间项}}+ \underbrace{\epsilon_i^{t+1,n}u_i^{t,n}f'(o_i^{t,n})}{\text{gate时间项}}, δit,n=空间项 j∑ϵjt,n+1wjin+1+gate时间项 ϵit+1,nuit,nf′(oit,n),

ϵ i t , n = δ i t , n h i t , n ⏟ 经过spike回到膜电位 + ϵ i t + 1 , n f ( o i t , n ) ⏟ 直接时间项 . \epsilon_i^{t,n}= \underbrace{\delta_i^{t,n}h_i^{t,n}}{\text{经过spike回到膜电位}}+ \underbrace{\epsilon_i^{t+1,n}f(o_i^{t,n})}{\text{直接时间项}}. ϵit,n=经过spike回到膜电位 δit,nhit,n+直接时间项 ϵit+1,nf(oit,n).

这里要注意两种边界情况:如果 n = N n=N n=N,说明没有下一层,那么 δ \delta δ 公式里的第一项不再是"发放spike作为下一层输入"的空间项,而是被"输出spike直接进入损失函数"的梯度 r i r_i ri 取代;如果 t = T t=T t=T,说明没有下一时刻,所以上面的时间项都为0。

Case 1: t = T , n = N t=T,\ n=N t=T, n=N。

先看 ϵ \epsilon ϵ 递推式。把 t = T , n = N t=T,n=N t=T,n=N 代入:

ϵ i T , N = δ i T , N h i T , N + ϵ i T + 1 , N f ( o i T , N ) . \epsilon_i^{T,N}= \delta_i^{T,N}h_i^{T,N}+ \epsilon_i^{T+1,N}f(o_i^{T,N}). ϵiT,N=δiT,NhiT,N+ϵiT+1,Nf(oiT,N).

因为 T T T 是最后时刻,没有 T + 1 T+1 T+1,所以

ϵ i T + 1 , N = 0. \epsilon_i^{T+1,N}=0. ϵiT+1,N=0.

于是

ϵ i T , N = δ i T , N h i T , N . \epsilon_i^{T,N}= \delta_i^{T,N}h_i^{T,N}. ϵiT,N=δiT,NhiT,N.

现在只需要求 δ i T , N \delta_i^{T,N} δiT,N。看 δ \delta δ 递推式。由于 n = N n=N n=N,原本通向下一层的空间项被输出损失项 r i r_i ri 取代;又由于 t = T t=T t=T,gate时间项中的 ϵ i T + 1 , N = 0 \epsilon_i^{T+1,N}=0 ϵiT+1,N=0。所以

δ i T , N = r i + ϵ i T + 1 , N u i T , N f ′ ( o i T , N ) = r i . \delta_i^{T,N}=r_i+\epsilon_i^{T+1,N}u_i^{T,N}f'(o_i^{T,N})=r_i. δiT,N=ri+ϵiT+1,NuiT,Nf′(oiT,N)=ri.

代回 ϵ \epsilon ϵ 式,得到

ϵ i T , N = r i h i T , N . \epsilon_i^{T,N}=r_i h_i^{T,N}. ϵiT,N=rihiT,N.

所以

δ i T , N = r i , ϵ i T , N = r i h i T , N . \boxed{ \delta_i^{T,N}=r_i, \qquad \epsilon_i^{T,N}=r_i h_i^{T,N}. } δiT,N=ri,ϵiT,N=rihiT,N.

Case 2: t = T , n < N t=T,\ n<N t=T, n<N。

先看 ϵ \epsilon ϵ 递推式。把 t = T t=T t=T 代入:

ϵ i T , n = δ i T , n h i T , n + ϵ i T + 1 , n f ( o i T , n ) . \epsilon_i^{T,n}= \delta_i^{T,n}h_i^{T,n}+ \epsilon_i^{T+1,n}f(o_i^{T,n}). ϵiT,n=δiT,nhiT,n+ϵiT+1,nf(oiT,n).

因为 T T T 是最后时刻,所以 ϵ i T + 1 , n = 0 \epsilon_i^{T+1,n}=0 ϵiT+1,n=0,于是

ϵ i T , n = δ i T , n h i T , n . \epsilon_i^{T,n}= \delta_i^{T,n}h_i^{T,n}. ϵiT,n=δiT,nhiT,n.

再求 δ i T , n \delta_i^{T,n} δiT,n。把 t = T t=T t=T 代入 δ \delta δ 递推式:

δ i T , n = ∑ j ϵ j T , n + 1 w j i n + 1 + ϵ i T + 1 , n u i T , n f ′ ( o i T , n ) . \delta_i^{T,n}= \sum_j\epsilon_j^{T,n+1}w_{ji}^{n+1}+ \epsilon_i^{T+1,n}u_i^{T,n}f'(o_i^{T,n}). δiT,n=j∑ϵjT,n+1wjin+1+ϵiT+1,nuiT,nf′(oiT,n).

同样因为 ϵ i T + 1 , n = 0 \epsilon_i^{T+1,n}=0 ϵiT+1,n=0,gate时间项消失,因此

δ i T , n = ∑ j ϵ j T , n + 1 w j i n + 1 . \delta_i^{T,n}= \sum_j\epsilon_j^{T,n+1}w_{ji}^{n+1}. δiT,n=j∑ϵjT,n+1wjin+1.

代回 ϵ \epsilon ϵ 式,得到

ϵ i T , n = ( ∑ j ϵ j T , n + 1 w j i n + 1 ) h i T , n . \epsilon_i^{T,n}= \left( \sum_j\epsilon_j^{T,n+1}w_{ji}^{n+1} \right)h_i^{T,n}. ϵiT,n=(j∑ϵjT,n+1wjin+1)hiT,n.

也可以保留 δ \delta δ 记号写成

δ i T , n = ∑ j ϵ j T , n + 1 w j i n + 1 , ϵ i T , n = δ i T , n h i T , n . \boxed{ \delta_i^{T,n}= \sum_j\epsilon_j^{T,n+1}w_{ji}^{n+1}, \qquad \epsilon_i^{T,n}= \delta_i^{T,n}h_i^{T,n}. } δiT,n=j∑ϵjT,n+1wjin+1,ϵiT,n=δiT,nhiT,n.

论文原式(13)中,下划线处的 g ′ g' g′ 下标不对:

要和论文原式对比,直接从上面的正确式出发即可。正确式中

δ i T , n = ∑ j ϵ j T , n + 1 w j i n + 1 . \delta_i^{T,n}= \sum_j\epsilon_j^{T,n+1}w_{ji}^{n+1}. δiT,n=j∑ϵjT,n+1wjin+1.

如果想把它写成论文里的 δ g ′ \delta g' δg′ 形式,本质上只是把其中的 ϵ j T , n + 1 \epsilon_j^{T,n+1} ϵjT,n+1 展开。因为在同一个 case 下已经有

ϵ i T , n = δ i T , n h i T , n , \epsilon_i^{T,n}= \delta_i^{T,n}h_i^{T,n}, ϵiT,n=δiT,nhiT,n,

把上标从 ( T , n ) (T,n) (T,n) 移到下一层第 j j j 个神经元,就得到

ϵ j T , n + 1 = δ j T , n + 1 h j T , n + 1 = δ j T , n + 1 g ′ ( u j T , n + 1 ) . \epsilon_j^{T,n+1}= \delta_j^{T,n+1}h_j^{T,n+1}= \delta_j^{T,n+1}g'(u_j^{T,n+1}). ϵjT,n+1=δjT,n+1hjT,n+1=δjT,n+1g′(ujT,n+1).

代回 δ i T , n \delta_i^{T,n} δiT,n 的正确式:

δ i T , n = ∑ j δ j T , n + 1 g ′ ( u j T , n + 1 ) w j i n + 1 . \delta_i^{T,n}= \sum_j \delta_j^{T,n+1}g'(u_j^{T,n+1})w_{ji}^{n+1}. δiT,n=j∑δjT,n+1g′(ujT,n+1)wjin+1.

所以如果论文要写成 δ g ′ \delta g' δg′ 的形式, g ′ g' g′ 必须属于下一层第 j j j 个神经元,也就是 g ′ ( u j T , n + 1 ) g'(u_j^{T,n+1}) g′(ujT,n+1)。论文原式把它写成当前层当前神经元的 g ′ ( u i T , n ) g'(u_i^{T,n}) g′(uiT,n),下标就错了。

论文原式(14)中,下划线处引入了不存在的未来时刻:

在 t = T t=T t=T 时,边界条件是

ϵ i T + 1 , n = 0. \epsilon_i^{T+1,n}=0. ϵiT+1,n=0.

所以正确的膜电位梯度只有

ϵ i T , n = δ i T , n h i T , n . \epsilon_i^{T,n}= \delta_i^{T,n}h_i^{T,n}. ϵiT,n=δiT,nhiT,n.

Case 3: t < T , n = N t<T,\ n=N t<T, n=N。

先看 ϵ \epsilon ϵ 递推式。把 n = N n=N n=N 代入:

ϵ i t , N = δ i t , N h i t , N + ϵ i t + 1 , N f ( o i t , N ) . \epsilon_i^{t,N}= \delta_i^{t,N}h_i^{t,N}+ \epsilon_i^{t+1,N}f(o_i^{t,N}). ϵit,N=δit,Nhit,N+ϵit+1,Nf(oit,N).

这里 t < T t<T t<T,所以直接时间项保留。接下来要求 δ i t , N \delta_i^{t,N} δit,N。看 δ \delta δ 递推式:由于 n = N n=N n=N,空间项被输出损失项 r i r_i ri 取代;由于 t < T t<T t<T,gate时间项保留,所以

δ i t , N = r i + ϵ i t + 1 , N u i t , N f ′ ( o i t , N ) . \delta_i^{t,N}=r_i+\epsilon_i^{t+1,N}u_i^{t,N}f'(o_i^{t,N}). δit,N=ri+ϵit+1,Nuit,Nf′(oit,N).

代回 ϵ \epsilon ϵ 式,得到

ϵ i t , N = r i + ϵ i t + 1 , N u i t , N f ′ ( o i t , N ) h i t , N + ϵ i t + 1 , N f ( o i t , N ) . \epsilon_i^{t,N}= \leftr_i+\\epsilon_i\^{t+1,N}u_i\^{t,N}f'(o_i\^{t,N}) \\righth_i^{t,N}+ \epsilon_i^{t+1,N}f(o_i^{t,N}). ϵit,N=ri+ϵit+1,Nuit,Nf′(oit,N)hit,N+ϵit+1,Nf(oit,N).

也可以保留 δ \delta δ 记号写成

δ i t , N = r i + ϵ i t + 1 , N u i t , N f ′ ( o i t , N ) , \boxed{ \delta_i^{t,N}=r_i+\epsilon_i^{t+1,N}u_i^{t,N}f'(o_i^{t,N}), } δit,N=ri+ϵit+1,Nuit,Nf′(oit,N),

ϵ i t , N = δ i t , N h i t , N + ϵ i t + 1 , N f ( o i t , N ) . \boxed{ \epsilon_i^{t,N}= \delta_i^{t,N}h_i^{t,N} + \epsilon_i^{t+1,N}f(o_i^{t,N}). } ϵit,N=δit,Nhit,N+ϵit+1,Nf(oit,N).

Case 4: t < T , n < N t<T,\ n<N t<T, n<N。

这其实就是我们前面推出来的一般情况,也就是既不是最后时刻、也不是输出层的情况。先看 ϵ \epsilon ϵ 递推式:

ϵ i t , n = δ i t , n h i t , n + ϵ i t + 1 , n f ( o i t , n ) . \epsilon_i^{t,n}= \delta_i^{t,n}h_i^{t,n}+ \epsilon_i^{t+1,n}f(o_i^{t,n}). ϵit,n=δit,nhit,n+ϵit+1,nf(oit,n).

这里 t < T t<T t<T,所以直接时间项保留。再看 δ \delta δ 递推式。因为 n < N n<N n<N,空间项仍然是下一层传回来的空间项;因为 t < T t<T t<T,gate时间项也保留。因此

δ i t , n = ∑ j ϵ j t , n + 1 w j i n + 1 + ϵ i t + 1 , n u i t , n f ′ ( o i t , n ) . \delta_i^{t,n}= \sum_j\epsilon_j^{t,n+1}w_{ji}^{n+1}+ \epsilon_i^{t+1,n}u_i^{t,n}f'(o_i^{t,n}). δit,n=j∑ϵjt,n+1wjin+1+ϵit+1,nuit,nf′(oit,n).

代回 ϵ \epsilon ϵ 式,得到

ϵ i t , n = ∑ j ϵ j t , n + 1 w j i n + 1 + ϵ i t + 1 , n u i t , n f ′ ( o i t , n ) h i t , n + ϵ i t + 1 , n f ( o i t , n ) . \epsilon_i^{t,n}= \left \\sum_j\\epsilon_j\^{t,n+1}w_{ji}\^{n+1}+ \\epsilon_i\^{t+1,n}u_i\^{t,n}f'(o_i\^{t,n}) \\righth_i^{t,n}+ \epsilon_i^{t+1,n}f(o_i^{t,n}). ϵit,n=j∑ϵjt,n+1wjin+1+ϵit+1,nuit,nf′(oit,n)hit,n+ϵit+1,nf(oit,n).

也可以保留 δ \delta δ 记号写成

δ i t , n = ∑ j ϵ j t , n + 1 w j i n + 1 + ϵ i t + 1 , n u i t , n f ′ ( o i t , n ) , \boxed{ \delta_i^{t,n}= \sum_j\epsilon_j^{t,n+1}w_{ji}^{n+1}+ \epsilon_i^{t+1,n}u_i^{t,n}f'(o_i^{t,n}), } δit,n=j∑ϵjt,n+1wjin+1+ϵit+1,nuit,nf′(oit,n),

ϵ i t , n = δ i t , n h i t , n + ϵ i t + 1 , n f ( o i t , n ) . \boxed{ \epsilon_i^{t,n}= \delta_i^{t,n}h_i^{t,n}+ \epsilon_i^{t+1,n}f(o_i^{t,n}). } ϵit,n=δit,nhit,n+ϵit+1,nf(oit,n).

论文原式(19)中,两个下划线处的 g ′ g' g′ 都没有落在正确的节点上:

这里的问题不只是 g ′ g' g′ 下标写错,更根本的原因是:在 Case 4 中 t < T t<T t<T,所以一般不能像 Case 2 那样把某个 ϵ \epsilon ϵ 直接展开成 δ g ′ \delta g' δg′。因为此时膜电位梯度 ϵ \epsilon ϵ 还包含继续沿时间方向传播的直接时间项。

先看 δ i t , n \delta_i^{t,n} δit,n 的空间项。正确式里空间项是 ∑ j ϵ j t , n + 1 w j i n + 1 \sum_j\epsilon_j^{t,n+1}w_{ji}^{n+1} ∑jϵjt,n+1wjin+1。其中 ϵ j t , n + 1 \epsilon_j^{t,n+1} ϵjt,n+1 是下一层第 j j j 个神经元在当前时刻的完整膜电位梯度。由于这里仍然是一般时刻 t < T t<T t<T,它不仅仅是空间项 δ j t , n + 1 g ′ ( u j t , n + 1 ) \delta_j^{t,n+1}g'(u_j^{t,n+1}) δjt,n+1g′(ujt,n+1),而是还包含下一层内部继续沿时间传播的直接时间项。也就是说,将空间项中的 ϵ j t , n + 1 \epsilon_j^{t,n+1} ϵjt,n+1 正确展开应当是:

ϵ j t , n + 1 = δ j t , n + 1 h j t , n + 1 + ϵ j t + 1 , n + 1 f ( o j t , n + 1 ) . \epsilon_j^{t,n+1}= \delta_j^{t,n+1}h_j^{t,n+1}+ \epsilon_j^{t+1,n+1}f(o_j^{t,n+1}). ϵjt,n+1=δjt,n+1hjt,n+1+ϵjt+1,n+1f(ojt,n+1).

因此,原论文把完整的 ϵ j t , n + 1 \epsilon_j^{t,n+1} ϵjt,n+1 直接写成 δ j t , n + 1 g ′ ( ⋅ ) \delta_j^{t,n+1}g'(\cdot) δjt,n+1g′(⋅) 这一类形式会漏掉时间项;而且即便只看其中经过 spike 的空间项那部分, g ′ g' g′ 也应该落在下一层第 j j j 个神经元上,即 g ′ ( u j t , n + 1 ) g'(u_j^{t,n+1}) g′(ujt,n+1),不是当前层当前神经元的 g ′ ( u i t , n ) g'(u_i^{t,n}) g′(uit,n)。

再看 δ i t , n \delta_i^{t,n} δit,n 的时间相关项。正确的时间项是 ϵ i t + 1 , n u i t , n f ′ ( o i t , n ) \epsilon_i^{t+1,n}u_i^{t,n}f'(o_i^{t,n}) ϵit+1,nuit,nf′(oit,n)。

这里的 ϵ i t + 1 , n \epsilon_i^{t+1,n} ϵit+1,n 是下一时刻同一个神经元的完整膜电位梯度,它同样不能一般性地替换成 δ i t + 1 , n g ′ ( u i t + 1 , n ) \delta_i^{t+1,n}g'(u_i^{t+1,n}) δit+1,ng′(uit+1,n),因为 ϵ i t + 1 , n \epsilon_i^{t+1,n} ϵit+1,n 还包含继续传播到 u i t + 2 , n u_i^{t+2,n} uit+2,n 的时间项。所以更清晰、也更不容易出错的写法是时间项也不展开替换,直接停在膜电位梯度 ϵ i t + 1 , n u i t , n f ′ ( o i t , n ) \epsilon_i^{t+1,n}u_i^{t,n}f'(o_i^{t,n}) ϵit+1,nuit,nf′(oit,n)。

此外,论文原式(21)中,下划线处的时间下标也不对:

直接比对式(69),这里应为

f ( o i t , n ) , f(o_i^{t,n}), f(oit,n),

而不是 f ( o i t + 1 , n ) f(o_i^{t+1,n}) f(oit+1,n)。

最后回到参数。由前面的参数梯度公式可以得到:

∂ L ∂ b i n = ∑ t = 1 T ϵ i t , n , \boxed{ \frac{\partial L}{\partial b_i^n}= \sum_{t=1}^{T}\epsilon_i^{t,n}, } ∂bin∂L=t=1∑Tϵit,n,

∂ L ∂ w i j n = ∑ t = 1 T ϵ i t , n o j t , n − 1 . \boxed{ \frac{\partial L}{\partial w_{ij}^{n}}= \sum_{t=1}^{T}\epsilon_i^{t,n}o_j^{t,n-1}. } ∂wijn∂L=t=1∑Tϵit,nojt,n−1.

矩阵形式为

∂ L ∂ W n = ∑ t = 1 T ϵ t , n ( o t , n − 1 ) T . \boxed{ \frac{\partial L}{\partial W^n}= \sum_{t=1}^{T} \epsilon^{t,n}\left(o^{t,n-1}\right)^T. } ∂Wn∂L=t=1∑Tϵt,n(ot,n−1)T.

相关论文版本:arXiv https://arxiv.org/abs/1706.02609;Frontiers https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2018.00331/full