总览
指数移动平均 EMA(Exponential Moving Average)是一种网络训练技巧。通过取一段时间的权重的平均值,能提高模型的泛化能力。
一言蔽之:模型训练完后还能获得 "免费" 的提升。
后合成指数移动平均(post-hoc synthesized EMA)是在论文 Analyzing and Improving the Training Dynamics of Diffusion Models 中提出的 EMA 改进方法。主要改进在于,比起原方法是在训练中完成平均,post-hoc EMA 则会先保存一系列关键节点的 EMA 权重,然后在训练结束后找一个最佳的的超参进行权重平均,进一步提升模型能力。
<math xmlns="http://www.w3.org/1998/Math/MathML"> θ ^ β \hat{\theta}_\beta </math>θ^β 的更新
使用 EMA 技巧训练模型时,会维护一份模型权重 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 的副本 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ ^ β \hat{\theta}\beta </math>θ^β。传统方法下,每训练一个 step,会用以下公式更新 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ ^ β \hat{\theta}\beta </math>θ^β:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ ^ β ( t ) = β θ ^ β ( t − 1 ) + ( 1 − β ) θ ( t ) \hat{\theta}\beta(t)=\beta\hat{\theta}\beta (t-1)+(1-\beta)\theta(t) </math>θ^β(t)=βθ^β(t−1)+(1−β)θ(t)
<math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 代表当前训练 step。 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β 是常量超参,通常非常接近 1。
Post-hoc EMA 没有使用这样的指数衰减策略,而是修改为了幂函数:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ ^ γ ( t ) = ∫ 0 t τ γ θ ( τ ) d τ ∫ 0 t τ γ d τ = γ + 1 t γ + 1 ∫ 0 t τ γ θ ( τ ) d τ \hat{\theta}_\gamma(t)=\frac{\int^t_0 \tau^\gamma\theta(\tau)\mathrm{d}\tau}{\int^t_0 \tau^\gamma\mathrm{d}\tau}=\frac{\gamma+1}{t^{\gamma+1}}\int^t_0\tau^\gamma\theta(\tau)\mathrm{d}\tau </math>θ^γ(t)=∫0tτγdτ∫0tτγθ(τ)dτ=tγ+1γ+1∫0tτγθ(τ)dτ
<math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 是控制尖锐程度的超参。通常 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ t = 0 \theta_{t=0} </math>θt=0 不采取随机初始化而是直接赋值为 0。
可见, <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 为 0 时,公式的含义就是直接取一段时间权重的均值。 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 越大, <math xmlns="http://www.w3.org/1998/Math/MathML"> θ ^ γ ( t ) \hat{\theta}_\gamma(t) </math>θ^γ(t) 就越受临近 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ ( t ) \theta(t) </math>θ(t) 的影响。
实际计算时使用递归形式,节省内存:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ ^ β ( t ) = β γ ( t ) θ ^ γ ( t − 1 ) + ( 1 − β γ ( t ) ) θ ( t ) w h e r e β γ ( t ) = ( 1 − 1 / t ) γ + 1 \begin{aligned} &\hat{\theta}\beta(t)=\beta\gamma(t)\hat{\theta}\gamma(t-1)+(1-\beta\gamma(t))\theta(t)\\ \mathrm{where\ }& \beta_\gamma(t)=(1-1/t)^{\gamma+1} \end{aligned} </math>where θ^β(t)=βγ(t)θ^γ(t−1)+(1−βγ(t))θ(t)βγ(t)=(1−1/t)γ+1
与本节的第一个公式很像,唯一区别是 <math xmlns="http://www.w3.org/1998/Math/MathML"> β γ ( t ) \beta_\gamma(t) </math>βγ(t) 会随着 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 增加而减小,而不是一个恒定值。
<math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 可能不是很值观不方便配置,于是论文提出了 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 超参,含义是 "相邻峰值的宽度相对于整个训练时长的占比"。实际训练时,可以通过 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 确定 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> σ rel = ( γ + 1 ) 1 / 2 ( γ + 2 ) − 1 ( γ + 3 ) − 1 / 2 \sigma_{\text{rel}}=(\gamma+1)^{1/2}(\gamma+2)^{-1}(\gamma+3)^{-1/2} </math>σrel=(γ+1)1/2(γ+2)−1(γ+3)−1/2
即可以用 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 定义以训练时长为标尺的 EMA 长度,而不是像原始 EMA 公式中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β 那样会随着 step 数量变化而剧烈影响训练效果。这是论文对原始 EMA 的另一个改进贡献。
顺带一提,论文在实际实验中发现, <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 超参控制下的最佳 EMA 长度其实仍会随着 step 数量增加而缓慢地变长。
从快照重建 EMA
Post-hoc EMA 的核心是滞后平均。如何合理地从权重快照重建出各种超参的 EMA 结果尤为重要。换句话说,需要有一种方法,能在训练完成后处理出任意 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 值(或者是等效的 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 值)对应的平滑权重。
用什么策略创建快照?可以选取两个 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 值,隔一段 step 步数保存一次快照。
如何合并快照?论文在公式推导后给出了一个代码示例,用于获得快照的权重:
python
def p_dot_p(t_a, gamma_a, t_b, gamma_b):
t_ratio = t_a / t_b
t_exp = torch.where(t_a < t_b , gamma_b , -gamma_a)
t_max = torch.maximum(t_a , t_b)
num = (gamma_a + 1) * (gamma_b + 1) * t_ratio ** t_exp
den = (gamma_a + gamma_b + 1) * t_max
return num / den
def solve_weights(t_i, gamma_i, t_r, gamma_r):
rv = lambda x: x.double().reshape(-1, 1)
cv = lambda x: x.double().reshape(1, -1)
A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i))
b = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r))
return np.linalg.solve(A, b)
我自己实测下来重建效果可以说是意外的优秀。即使总共只使用 40 个快照( <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 取 0.05 和 0.2),大致 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 在 0.04 到 0.22 范围内重建出的 EMA 权重几近完美。
该怎么写代码?
有一个开箱即用的 EMA 包,通过 pip install ema-pytorch
即可使用。
python
import torch
from ema_pytorch import PostHocEMA
# your neural network as a pytorch module
model = ...
emas = PostHocEMA(
model,
sigma_rels = (0.05, 0.3),
update_every = 1, # 每调用 1 次 emas.update() 就更新一次 ema 权重
checkpoint_every_num_steps = 50,
checkpoint_folder = './post-hoc-ema-checkpoints' # 保存快照的路径
)
model.train()
for epoch in range(300):
for input, target in loader:
...
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
emas.update() # 在此调用 emas.update()
# 重建
synthesized_ema = emas.synthesize_ema_model(sigma_rel = 0.15)
# 直接调用 synthesized_ema 进行推理
synthesized_ema_output = synthesized_ema(data)
有些细节要注意,
PostHocEMA
不会主动清空checkpoint_folder
里的内容。目录里原有的 pt 文件会对程序造成干扰- 使用
emas.model
获取原模型,使用emas.ema_model
获取权重重建后的模型 - 建议保存模型时保存
PostHocEMA
。这样能保留 ema step 等信息
实验
为了探究重建效果,使用以下代码进行实验。
<math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 取 0.05 和 0.2,总步数 1000,两种 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 下各取 20 个快照。重建 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 为 0.15 时的 ema 权重。
python
import torch
from ema_pytorch import PostHocEMA
import plotly.graph_objects as go
net = torch.nn.Linear(1, 1000, bias=False)
emas = PostHocEMA(
net,
sigma_rels=(
0.05,
0.2,
),
update_every=1,
checkpoint_every_num_steps=50,
checkpoint_folder=r'Z:\post-hoc-ema-checkpoints',
)
net.train()
for i in range(1000):
with torch.no_grad():
net.weight.zero_()
channel_index = i % 1000
net.weight[channel_index, 0] = 1.
emas.update()
synthesized_ema = emas.synthesize_ema_model(sigma_rel=0.15)
ema_weights = synthesized_ema.ema_model.weight.detach().numpy().flatten()
fig = go.Figure(data=go.Scatter(x=list(range(1000)), y=ema_weights, mode='lines+markers'))
fig.update_layout(
title='Synthesized EMA Model Weights',
xaxis_title='Channel Index',
yaxis_title='Weight Value',
template='plotly_white',
)
fig.show()
这个曲线相当够用了。
思考
EMA 能在训练完成后提升模型效果。post-hoc EMA 更进一步,允许在训练完成后找到最合适 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel,进一步提升模型能力。可以说是相当值得一用的 trick。
不过显而易见的,post-hoc EMA 要求维护不止一个模型的参数,对显存要求陡然上升。即使单独让 EMA 权重放在 CPU 侧,设备通信过程会大大拖慢训练进程。即使可以设置 update_every
参数减少通信次数,但势必会影响重建效果。另一个缺点就是快照占用磁盘空间。40 可不是小数目,更别说论文中提及的至少 160 个快照。
论文提出 post-hoc EMA 是针对扩散模型的,而扩散模型众所周知大块头一个......普通人很难很好地用上 post-hoc EMA 训练大模型吧。不过训练小玩具时,好歹有了个似乎很有提分希望的可选技巧。
参考来源
- github.com/lucidrains/...
- arxiv.org/abs/2312.02...
- Miika Aittala,"Rethinking How to Train Diffusion Models",developer.nvidia.com/blog/rethin...