作为STAR的笔记,方便理解与复习。
STAR来自论文"STAR: STABILITY-INDUCING WEIGHT PERTURBATION
FOR CONTINUAL LEARNING",其链接为:点此跳转
另外,本文总结的公式编号和文中对应公式的编号是一致的,但是为了方便总结,更改了部分公式的叙述顺序,所以可能出现第 i i i 个出现的公式编号为 j j j 的情况。
一.背景与原理
持续学习的目标是在训练结束时,最小化所有已见数据流上的分类误差,因此,我们可以将持续学习的目标表示为:
min θ E ( x , y ) ∼ X 0 : τ [ 1 [ f θ ( x ) ≠ y ] ] (1) \min {\theta} \mathbb{E}{(x, y)\sim \mathcal {X}{0: \tau }}\left[ 1{[f_{\theta }(x)\neq y]}\right] \tag{1} θminE(x,y)∼X0:τ[1[fθ(x)=y]](1)其中:
- X t \mathcal{X}_t Xt 是时间 t t t 时的训练数据分布。
- X 0 : τ \mathcal {X}{0: \tau } X0:τ 表示分布集合 { X i } i ≤ τ \{\mathcal {X}i \}{i \leq \tau} {Xi}i≤τ 的均匀混合,即从 X 0 : τ \mathcal {X}{0: \tau } X0:τ 中抽取的任何样本,都等可能地来自于任何一个 X i \mathcal{X}_{i} Xi,它是时间 t t t 及之前所有观察到的数据的分布。
- 1 [ c o n d i t i o n ] 1_{[condition]} 1[condition] 表示指示函数,当条件成立时取值为 1,否则取值为 0,式(1)中 1 [ f θ ( x ) ≠ y ] 1_{[f_{\theta }(x)\neq y]} 1[fθ(x)=y] 表示预测值 f θ ( x ) f_{\theta }(x) fθ(x) 与标签 y y y 不一致时取 1,一致时取 0。
更正式地,我们可以通过训练过程中两个时间戳 t t t 和 t + s t+s t+s 之间的误差差异来量化遗忘:
F o r g e t t i n g ( t , s ) : = E ( x , y ) ∼ X 0 : t [ 1 [ f θ t + s ( x ) ≠ y ] − 1 [ f θ t ( x ) ≠ y ] ] (2) Forgetting(t, s) := \mathbb{E}{(x,y)\sim\mathcal{X}{0:t}}\left[\mathbf{1}{[f{\theta_{t+s}}(x)\neq y]} - \mathbf{1}{[f{\theta_{t}}(x)\neq y]}\right] \tag{2} Forgetting(t,s):=E(x,y)∼X0:t[1[fθt+s(x)=y]−1[fθt(x)=y]](2)该式子可以分为四种情况:
- 未来参数预测情况和当前参数都正确,即 f θ t + s ( x ) = y f_{\theta_{t+s}}(x) = y fθt+s(x)=y 且 f θ t ( x ) = y f_{\theta_{t}}(x)=y fθt(x)=y 时,该样本产生的值为 0。
- 未来参数预测情况和当前参数都错误,即 f θ t + s ( x ) ≠ y f_{\theta_{t+s}}(x) \neq y fθt+s(x)=y 且 f θ t ( x ) ≠ y f_{\theta_{t}}(x)\neq y fθt(x)=y 时,该样本产生的值为 0。
- 未来参数预测情况比当前参数要好,即 f θ t + s ( x ) = y f_{\theta_{t+s}}(x) = y fθt+s(x)=y 且 f θ t ( x ) ≠ y f_{\theta_{t}}(x) \neq y fθt(x)=y 时,该样本产生的值为 -1。
- 未来参数预测情况比当前参数要差,即 f θ t + s ( x ) ≠ y f_{\theta_{t+s}}(x) \neq y fθt+s(x)=y 且 f θ t ( x ) = y f_{\theta_{t}}(x) = y fθt(x)=y 时,该样本产生的值为 1。
这也就代表 F o r g e t t i n g ( t , s ) Forgetting(t, s) Forgetting(t,s) 越大,模型遗忘情况就越严重。
排除公式(2)中两种值为 0 的情况,原式子可以简化为:
F o r g e t t i n g ( t , s ) = E ( x , y ) ∼ X 0 : t [ 1 [ f θ t ( x ) = y ∧ f θ t + s ( x ) ≠ y ] ⏟ Forgotten samples − 1 [ f θ t ( x ) ≠ y ∧ f θ t + s ( x ) = y ] ⏟ Newly learned samples ] (4) Forgetting(t,s)=\mathbb{E}{(x,y)\sim\mathcal{X}{0:t}}\left[ \underbrace{\mathbf{1}{[f{\theta_{t}}(x)=y\wedge f_{\theta_{t+s}}(x)\neq y]}}{\text{Forgotten samples}} - \underbrace{\mathbf{1}{[f_{\theta_{t}}(x)\neq y\wedge f_{\theta_{t+s}}(x)=y]}}_{\text{Newly learned samples}} \right]\tag{4} Forgetting(t,s)=E(x,y)∼X0:t Forgotten samples 1[fθt(x)=y∧fθt+s(x)=y]−Newly learned samples 1[fθt(x)=y∧fθt+s(x)=y] (4)在这个式子中,前一项代表"遗忘样本",也就是未来模型将之前正确分类的样本进行了错误分类,是遗忘的表现。后一项代表"新习得样本",也就是未来模型将之前错误分类的样本进行了正确分类,是新学习的知识对先前知识的促进表现。
为了减少前者,即"遗忘样本"带来的误差,作者提出了以下损失函数:
E ( x , y ) ∼ X 0 : t [ 1 [ f θ t ( x ) = y ] ⋅ K L ( q θ t ( x ) ∣ q θ t + s ( x ) ) ] (5) \mathbb{E}{(x,y)\sim\mathcal{X}{0:t}}\left[ \mathbf{1}{[f{\theta_t}(x)=y]} \cdot \mathcal{KL}\big(q_{\theta_t}(x)|q_{\theta_{t+s}}(x)\big) \right]\tag{5} E(x,y)∼X0:t[1[fθt(x)=y]⋅KL(qθt(x)∣qθt+s(x))](5)其中:
- q θ ( x ) q_{\theta}(x) qθ(x) 表示由参数 θ \theta θ参数化的模型对输入 x x x 的输出概率分布。
- K L ( q θ t ( x ) ∣ q θ t + s ( x ) ) \mathcal{KL}\big(q_{\theta_t}(x)|q_{\theta_{t+s}}(x)\big) KL(qθt(x)∣qθt+s(x)) 是KL散度,衡量了模型在时刻 t t t 和未来时刻 t + s t+s t+s 对输入 x x x 的输出概率分布 q θ t ( x ) q_{\theta_t}(x) qθt(x) 和 q θ t + s ( x ) q_{\theta_{t+s}}(x) qθt+s(x) 之间的差异,它的值越小,代表两个分布的差异越小,值为 0 时,代表两者分布相同。
该损失函数本质上是通过拉近两个不同时刻的分布来减小遗忘,但是这里仍然有两个问题需要解决:
- X 0 : t \mathcal{X}_{0:t} X0:t 的数据在在线学习中无法全部存储,无法直接表示。
- θ t + s \theta_{t+s} θt+s 表示未来参数,无法提前预知。
针对第一个问题,论文中使用缓冲区 M t M_t Mt 作为 X 0 : t \mathcal{X}{0:t} X0:t 的替代,此时,公式(5)可以优化为:
L F G ( θ t , θ t + s ) = ∑ ( x , y ) ∈ M t ∗ K L ( q θ t ( x ) ∣ q θ t + s ( x ) ) (6) \mathcal{L}{FG}(\theta_t,\theta_{t+s})=\sum_{(x,y)\in\mathcal{M}^{*}t}\mathcal{KL}\big(q{\theta_t}(x)|q_{\theta_{t+s}}(x)\big)\tag{6} LFG(θt,θt+s)=(x,y)∈Mt∗∑KL(qθt(x)∣qθt+s(x))(6)其中, M t ∗ = { ( x , y ) ∈ M t ∣ f θ t ( x ) = y } \mathcal{M}^{*}_t=\{(x,y)\in\mathcal{M}t|f{\theta_t}(x)=y\} Mt∗={(x,y)∈Mt∣fθt(x)=y}。
针对第二个问题,论文中提议使用局部最坏情况扰动 θ t + δ \theta_t + \delta θt+δ(其中 δ ∈ R P \delta \in \mathbb{R}^P δ∈RP)来近似 θ t + s \theta_{t+s} θt+s,以此确保无论未来参数更新的方向如何,都能减少遗忘,优化后的公式,也即最终STAR的损失函数如下:
L S T A R ( θ t ) : = max δ L F G ( θ t , θ t + δ ) (7) \mathcal{L}{STAR}(\theta_t):=\max{\delta}\mathcal{L}_{FG}(\theta_t,\theta_t+\delta)\tag{7} LSTAR(θt):=δmaxLFG(θt,θt+δ)(7)其中, ∥ δ ∥ 2 ≤ d \|\delta\|_2 \leq d ∥δ∥2≤d , d d d 是一个控制扰动领域大小的任意正数,它是STAR方法的超参数。
大多数基于重放的持续学习方法的训练损失形式如下:
L C L ( θ t ) = ℓ ( f θ t ( x t ) , y t ) + ℓ M t ( f θ t ( x M t ) , y M t ) (3) \mathcal{L}{CL}(\theta_t) = \ell(f{\theta_t}(x_t), y_t) + \ell_{\mathcal{M}t}(f{\theta_t}(x_{\mathcal{M}t}), y{\mathcal{M}_t})\tag{3} LCL(θt)=ℓ(fθt(xt),yt)+ℓMt(fθt(xMt),yMt)(3)其中:
- ℓ \ell ℓ 对应新任务的损失函数。
- x t x_t xt 和 y t y_t yt 是时刻 t t t 从当前任务数据分布 X t \mathcal{X}_t Xt 中采样的输入数据和对应的标签。
- ℓ M t \ell_{\mathcal{M}_t} ℓMt 是对应于记忆缓冲区的损失函数。
- M t \mathcal{M}_t Mt 是时刻 t t t 的记忆缓冲区,存储了之前观察到的少量样本。
- x M t x_{\mathcal{M}t} xMt 和 y M t y{\mathcal{M}_t} yMt 是从缓冲区 M t \mathcal{M}_t Mt 中采样的输入数据和对应的标签。
将STAR损失函数与 L C L \mathcal{L}{CL} LCL 组合可以得到:
L f i n a l = L C L ( θ t ) + λ L S T A R ( θ t ) (8) \mathcal{L}{final} = \mathcal{L}{CL}(\theta_t) + \lambda \mathcal{L}{STAR}(\theta_t)\tag{8} Lfinal=LCL(θt)+λLSTAR(θt)(8)其中 λ \lambda λ 是控制每个目标重要性的超参数。
从这里可以看出, L S T A R \mathcal{L}_{STAR} LSTAR 可以作为一个即插即用的组件,运用于基于回放的任意方法。
二. STAR的实现方式
从 L S T A R \mathcal{L}_{STAR} LSTAR 的定义来看,由于要对 δ \delta δ 进行最大化,需要面临两个挑战:
- 计算权重扰动 δ \delta δ
- 在损失依赖于 δ \delta δ 的情况下,计算相对于原始参数 θ t \theta_t θt 的梯度。
直接优化并不现实,所以作者通过借鉴现有工作,提出了相关的解决方法:
1.优化扰动 δ \delta δ
作者提议使用梯度上升来计算 δ \delta δ,虽然由于 q θ q_\theta qθ 通常是非凸的,计算精确的局部最大值具有挑战性,但是在实践中,仅执行一次最大化梯度步骤就足以提升性能,也就是说后续提到的梯度上升,均只需执行一次。
同样借鉴现有的工作,论文中并没有使用单一常数 r r r 作为范数约束,而是逐层归一化模型的梯度,这是因为:
- 每一层权重的数值分布不同,采用单一常数 r r r 可能导致部分层权重相对变化过大(比如数值增大到原来的200%)或过小(比如数值减小到原来的99.99%),逐层归一化模型可以让每一层的相对变化较为固定(比如均为10%),从而避免扰动后的参数不稳定或者效果不足的问题。
- 权重具有尺度不变性(例如,在某一层乘以 10,在下一层除以 10,网络的输出保持不变),说明网络的参数完全有可能在训练时某一层参数扩大一定倍数,另一层缩小一定倍数,这进一步印证了单一常数 r r r 在这种情况上的不可靠。
下面是归一化梯度的步骤:
①初始化 δ \delta δ
由于 δ = 0 \delta=0 δ=0 是KL项的全局最小值(此时梯度为0),所以采用将 δ \delta δ 初始化为 δ 0 \delta_0 δ0 ------ 带有与层范数成比例的小噪声,具体如下:
δ 0 ( l ) ∼ N ( 0 , ϵ ∥ θ ( l ) ∥ I ) (9) \delta_{0}^{(l)} \sim \mathcal{N}(0, \epsilon \|\theta^{(l)}\| I)\tag{9} δ0(l)∼N(0,ϵ∥θ(l)∥I)(9)其中:
- l l l 是第 l l l 层的索引
- ϵ \epsilon ϵ 是一个小的正数,它可以控制 δ 0 \delta_0 δ0 的强度,让使其小但非零。
- I I I 是维度适当的单位矩阵
这样有两个好处:
- δ = 0 \delta=0 δ=0 是 KL 的全局最小点,其梯度为 0,无法开始最大化; δ 0 ≠ 0 \delta_0 \neq 0 δ0=0 能保证 KL 项具有非零梯度,从而允许进行梯度上升以寻找最坏扰动 δ \delta δ。
- δ 0 \delta_0 δ0 的噪声按层权重规模进行缩放 ( ∝ ∥ θ ( l ) ∥ ) (\propto \|\theta^{(l)}\|) (∝∥θ(l)∥),保证不同层的 δ \delta δ 具有合理且一致的尺度,使最大化过程可稳定进行。
②计算梯度
令 g g g 表示公式(7)中内层 KL 项的梯度:
g : = ∇ θ + δ 0 L F G ( θ , θ + δ 0 ) (10) g := \nabla_{\theta+\delta_0}\mathcal{L}_{FG}(\theta, \theta+\delta_0)\tag{10} g:=∇θ+δ0LFG(θ,θ+δ0)(10)至于这里求的是 θ + δ 0 \theta+\delta_0 θ+δ0 的梯度而不是 θ \theta θ 的梯度的原因也很简单,我们现在找的是让公式(7)最大的扰动 δ 0 \delta_0 δ0 而不是 θ \theta θ,相当于此时的 θ \theta θ 被视为常数,我们只需要求 δ 0 \delta_0 δ0 的梯度即可,而在 θ \theta θ 被视为常数的情况下, δ 0 \delta_0 δ0 的梯度与 θ + δ 0 \theta+\delta_0 θ+δ0 的梯度完全等价,所以这里选择了实现更加方便、可读性更好的后者。
③更新 δ \delta δ
随后,计算每一层 l l l 的 δ \delta δ:
δ ( l ) : = δ 0 ( l ) + γ ∥ θ ( l ) ∥ 2 ∥ g ( l ) ∥ 2 g ( l ) (11) \delta^{(l)} := \delta_0^{(l)} + \gamma \frac{\|\theta^{(l)}\|_2}{\|g^{(l)}\|_2} g^{(l)}\tag{11} δ(l):=δ0(l)+γ∥g(l)∥2∥θ(l)∥2g(l)(11)其中 1 ≤ l ≤ L 1 \leq l \leq L 1≤l≤L, δ ( l ) \delta^{(l)} δ(l) 是对第 l l l 层参数 θ ( l ) \theta^{(l)} θ(l) 的扰动, 0 ≤ γ 0 \leq \gamma 0≤γ 是控制扰动比例的超参数(即 δ ( l ) \delta^{(l)} δ(l) 被归一化,使得 ∥ δ ( l ) ∥ 2 ∥ θ ( l ) ∥ 2 ≈ γ \frac{\|\delta^{(l)}\|_2}{\|\theta^{(l)}\|_2} \approx \gamma ∥θ(l)∥2∥δ(l)∥2≈γ)。
对于上面 ∥ δ ( l ) ∥ 2 ∥ θ ( l ) ∥ 2 ≈ γ \frac{\|\delta^{(l)}\|_2}{\|\theta^{(l)}\|_2} \approx \gamma ∥θ(l)∥2∥δ(l)∥2≈γ 的解释:
- 忽略很小的初始化噪声 δ 0 ( l ) \delta_0^{(l)} δ0(l) ,主要部分是:
δ ( l ) ≈ γ ∥ θ ( l ) ∥ 2 ∥ g ( l ) ∥ 2 g ( l ) \delta^{(l)} \approx \gamma \frac{\|\theta^{(l)}\|_2}{\|g^{(l)}\|_2} g^{(l)} δ(l)≈γ∥g(l)∥2∥θ(l)∥2g(l) - 计算范数:
∥ δ ( l ) ∥ ≈ γ ∥ θ ( l ) ∥ ∥ g ( l ) ∥ ∥ g ( l ) ∥ = γ ∥ θ ( l ) ∥ \|\delta^{(l)}\| \approx \gamma \frac{\|\theta^{(l)}\|}{\|g^{(l)}\|}\|g^{(l)}\| = \gamma \|\theta^{(l)}\| ∥δ(l)∥≈γ∥g(l)∥∥θ(l)∥∥g(l)∥=γ∥θ(l)∥ - 移项得到:
∥ δ ( l ) ∥ 2 ∥ θ ( l ) ∥ 2 ≈ γ \frac{\|\delta^{(l)}\|_2}{\|\theta^{(l)}\|_2} \approx \gamma ∥θ(l)∥2∥δ(l)∥2≈γ
2.更新参数 θ \theta θ
为了使用公式 (11) 计算的 δ \delta δ,通过 L S T A R L_{STAR} LSTAR 执行 SGD(随机梯度下降),我们需要计算:
∇ θ L F G ( θ , θ + δ ) \nabla_\theta \mathcal{L}{FG}(\theta, \theta + \delta) ∇θLFG(θ,θ+δ)但是直接求该梯度需要计算海森矩阵或海森向量积,计算代价昂贵,所以借鉴已有工作的同类方法中的常用类似:
∇ θ L F G ( θ , θ + δ ) ≈ ∇ θ + δ L F G ( θ , θ + δ ) \nabla\theta \mathcal{L}{FG}(\theta, \theta + \delta) \approx \nabla{\theta+\delta} \mathcal{L}_{FG}(\theta, \theta + \delta) ∇θLFG(θ,θ+δ)≈∇θ+δLFG(θ,θ+δ)以加速计算。
这里略过该近似可行性的详细证明,而为什么左边优化麻烦,右边优化简单,是因为 δ \delta δ 可以视为 θ \theta θ 的函数,当对 θ + δ \theta+\delta θ+δ 求梯度时,可以把 θ \theta θ 当成常数,但是对 θ \theta θ 求梯度时,还要额外考虑 θ + δ \theta+\delta θ+δ 的问题,而 δ \delta δ 又是通过最大化求到的,本身就包含一个梯度,所以再求梯度会产生一个额外的海森矩阵。
三. 总结
1、核心思想
目标 :通过 减少遗忘 来提高模型在序列学习中的表现,尤其是在面对不断变化的数据分布时,避免模型对已学任务的遗忘。
手段 :通过 对抗性思维 ,设计最坏的扰动 δ \delta δ,并利用该扰动反向优化参数,使模型趋向稳定的参数空间,从而减少未来参数更新对模型稳定性和性能的影响。
2、关键公式(三行核心)
-
遗忘代理 (衡量遗忘程度):
L F G = ∑ 正确样本 KL ( q θ t ( x ) ∣ q θ t + s ( x ) ) \mathcal{L}{FG} = \sum{\text{正确样本}} \text{KL}(q_{\theta_t}(x) | q_{\theta_{t+s}}(x)) LFG=正确样本∑KL(qθt(x)∣qθt+s(x))这个公式表示 KL散度 ,用于度量当前模型参数 θ t \theta_t θt 和未来参数 θ t + s \theta_{t+s} θt+s 之间的预测差异,作为 遗忘的代理。 -
STAR损失 (推动模型稳定性):
L S T A R = max ∣ δ ∣ ≤ d L F G ( θ , θ + δ ) \mathcal{L}{STAR} = \max{|\delta|\leq d} \mathcal{L}_{FG}(\theta, \theta+\delta) LSTAR=∣δ∣≤dmaxLFG(θ,θ+δ)该公式通过计算 最大扰动 (即最坏扰动) δ \delta δ,确保在扰动范围内,模型输出的稳定性最大化,避免对已学习任务的遗忘。 -
总目标 (联合目标优化):
L = L C L + λ L S T A R \mathcal{L} = \mathcal{L}{CL} + \lambda \mathcal{L}{STAR} L=LCL+λLSTAR这是模型训练中的总目标函数,其中 L C L \mathcal{L}{CL} LCL 表示传统的持续学习损失(比如回放或正则化损失),而 L S T A R \mathcal{L}{STAR} LSTAR 则是通过STAR促进稳定性的目标, λ \lambda λ 为超参数,控制两者的权重。
3、实现技巧
① 计算扰动 δ \delta δ
-
方法 :采用 梯度上升法 (Gradient Ascent)一步来计算扰动 δ \delta δ,通过反向传播计算梯度来寻找最坏的扰动。
-
归一化 :为了确保扰动的相对大小一致,每层的扰动 δ ( l ) \delta^{(l)} δ(l) 会根据该层的梯度和权重进行归一化:
δ ( l ) ≈ γ ∣ θ ( l ) ∣ ∣ g ( l ) ∣ g ( l ) \delta^{(l)} \approx \gamma \frac{|\theta^{(l)}|}{|g^{(l)}|} g^{(l)} δ(l)≈γ∣g(l)∣∣θ(l)∣g(l)这里 γ \gamma γ 控制每层扰动的比例, ∣ θ ( l ) ∣ ) 和 ( ∣ g ( l ) ∣ |\theta^{(l)}|) 和 (|g^{(l)}| ∣θ(l)∣)和(∣g(l)∣ 分别是当前层权重和梯度的范数,确保扰动不会因为某些层的权重或梯度过大而变得不稳定。
② 更新参数 θ \theta θ
-
关键近似 :通过 梯度计算的近似 :
∇ θ L F G ≈ ∇ θ + δ L F G \nabla_\theta \mathcal{L}{FG} \approx \nabla{\theta+\delta} \mathcal{L}_{FG} ∇θLFG≈∇θ+δLFG这个近似方法简化了对扰动后参数的计算。它假设通过扰动后参数 θ + δ \theta + \delta θ+δ 计算的梯度可以近似代替原始参数 θ \theta θ 计算的梯度,从而避免了计算海森矩阵(即二阶导数)。 -
实际操作 :
在实际更新时,我们利用 计算扰动 δ \delta δ 时的梯度 g g g ,将其作为梯度来更新参数 θ \theta θ。
θ \theta θ 的更新公式为:
θ t + 1 = θ t − α ∇ θ + δ L F G \theta_{t+1} = \theta_t - \alpha \nabla_{\theta+\delta} \mathcal{L}_{FG} θt+1=θt−α∇θ+δLFG这里, α \alpha α 是学习率,而扰动 δ \delta δ 影响着更新的方向,逼迫模型进入一个稳定区域,从而减小遗忘的风险。
总的来说,STAR 通过对 参数施加最大扰动 ,使得模型 进入平坦区域 ,从而 减少遗忘 。这种扰动不仅仅依赖于传统的正则化方法,还通过 最坏扰动 的引入,使得模型的稳定性更强,尤其在面对新任务时,可以避免对旧任务的遗忘。