后合成指数移动平均(post-hoc synthesized EMA)

总览

指数移动平均 EMA(Exponential Moving Average)是一种网络训练技巧。通过取一段时间的权重的平均值,能提高模型的泛化能力。

一言蔽之:模型训练完后还能获得 "免费" 的提升。

后合成指数移动平均(post-hoc synthesized EMA)是在论文 Analyzing and Improving the Training Dynamics of Diffusion Models 中提出的 EMA 改进方法。主要改进在于,比起原方法是在训练中完成平均,post-hoc EMA 则会先保存一系列关键节点的 EMA 权重,然后在训练结束后找一个最佳的的超参进行权重平均,进一步提升模型能力。

<math xmlns="http://www.w3.org/1998/Math/MathML"> θ ^ β \hat{\theta}_\beta </math>θ^β 的更新

使用 EMA 技巧训练模型时,会维护一份模型权重 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 的副本 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ ^ β \hat{\theta}\beta </math>θ^β。传统方法下,每训练一个 step,会用以下公式更新 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ ^ β \hat{\theta}\beta </math>θ^β:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ ^ β ( t ) = β θ ^ β ( t − 1 ) + ( 1 − β ) θ ( t ) \hat{\theta}\beta(t)=\beta\hat{\theta}\beta (t-1)+(1-\beta)\theta(t) </math>θ^β(t)=βθ^β(t−1)+(1−β)θ(t)

<math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 代表当前训练 step。 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β 是常量超参,通常非常接近 1。

Post-hoc EMA 没有使用这样的指数衰减策略,而是修改为了幂函数:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ ^ γ ( t ) = ∫ 0 t τ γ θ ( τ ) d τ ∫ 0 t τ γ d τ = γ + 1 t γ + 1 ∫ 0 t τ γ θ ( τ ) d τ \hat{\theta}_\gamma(t)=\frac{\int^t_0 \tau^\gamma\theta(\tau)\mathrm{d}\tau}{\int^t_0 \tau^\gamma\mathrm{d}\tau}=\frac{\gamma+1}{t^{\gamma+1}}\int^t_0\tau^\gamma\theta(\tau)\mathrm{d}\tau </math>θ^γ(t)=∫0tτγdτ∫0tτγθ(τ)dτ=tγ+1γ+1∫0tτγθ(τ)dτ

<math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 是控制尖锐程度的超参。通常 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ t = 0 \theta_{t=0} </math>θt=0 不采取随机初始化而是直接赋值为 0。

可见, <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 为 0 时,公式的含义就是直接取一段时间权重的均值。 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 越大, <math xmlns="http://www.w3.org/1998/Math/MathML"> θ ^ γ ( t ) \hat{\theta}_\gamma(t) </math>θ^γ(t) 就越受临近 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 的 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ ( t ) \theta(t) </math>θ(t) 的影响。

实际计算时使用递归形式,节省内存:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ ^ β ( t ) = β γ ( t ) θ ^ γ ( t − 1 ) + ( 1 − β γ ( t ) ) θ ( t ) w h e r e β γ ( t ) = ( 1 − 1 / t ) γ + 1 \begin{aligned} &\hat{\theta}\beta(t)=\beta\gamma(t)\hat{\theta}\gamma(t-1)+(1-\beta\gamma(t))\theta(t)\\ \mathrm{where\ }& \beta_\gamma(t)=(1-1/t)^{\gamma+1} \end{aligned} </math>where θ^β(t)=βγ(t)θ^γ(t−1)+(1−βγ(t))θ(t)βγ(t)=(1−1/t)γ+1

与本节的第一个公式很像,唯一区别是 <math xmlns="http://www.w3.org/1998/Math/MathML"> β γ ( t ) \beta_\gamma(t) </math>βγ(t) 会随着 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 增加而减小,而不是一个恒定值。

<math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 可能不是很值观不方便配置,于是论文提出了 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 超参,含义是 "相邻峰值的宽度相对于整个训练时长的占比"。实际训练时,可以通过 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 确定 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> σ rel = ( γ + 1 ) 1 / 2 ( γ + 2 ) − 1 ( γ + 3 ) − 1 / 2 \sigma_{\text{rel}}=(\gamma+1)^{1/2}(\gamma+2)^{-1}(\gamma+3)^{-1/2} </math>σrel=(γ+1)1/2(γ+2)−1(γ+3)−1/2

即可以用 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 定义以训练时长为标尺的 EMA 长度,而不是像原始 EMA 公式中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β 那样会随着 step 数量变化而剧烈影响训练效果。这是论文对原始 EMA 的另一个改进贡献。

顺带一提,论文在实际实验中发现, <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 超参控制下的最佳 EMA 长度其实仍会随着 step 数量增加而缓慢地变长。

从快照重建 EMA

Post-hoc EMA 的核心是滞后平均。如何合理地从权重快照重建出各种超参的 EMA 结果尤为重要。换句话说,需要有一种方法,能在训练完成后处理出任意 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ 值(或者是等效的 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 值)对应的平滑权重。

用什么策略创建快照?可以选取两个 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 值,隔一段 step 步数保存一次快照。

如何合并快照?论文在公式推导后给出了一个代码示例,用于获得快照的权重:

python 复制代码
def p_dot_p(t_a, gamma_a, t_b, gamma_b):
    t_ratio = t_a / t_b
    t_exp = torch.where(t_a < t_b , gamma_b , -gamma_a)
    t_max = torch.maximum(t_a , t_b)
    num = (gamma_a + 1) * (gamma_b + 1) * t_ratio ** t_exp
    den = (gamma_a + gamma_b + 1) * t_max
    return num / den

def solve_weights(t_i, gamma_i, t_r, gamma_r):
    rv = lambda x: x.double().reshape(-1, 1)
    cv = lambda x: x.double().reshape(1, -1)
    A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i))
    b = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r))
    return np.linalg.solve(A, b)

我自己实测下来重建效果可以说是意外的优秀。即使总共只使用 40 个快照( <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 取 0.05 和 0.2),大致 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 在 0.04 到 0.22 范围内重建出的 EMA 权重几近完美。

该怎么写代码?

有一个开箱即用的 EMA 包,通过 pip install ema-pytorch 即可使用。

python 复制代码
import torch
from ema_pytorch import PostHocEMA

# your neural network as a pytorch module

model = ...

emas = PostHocEMA(
    model,
    sigma_rels = (0.05, 0.3),
    update_every = 1,   # 每调用 1 次 emas.update() 就更新一次 ema 权重
    checkpoint_every_num_steps = 50,
    checkpoint_folder = './post-hoc-ema-checkpoints'  # 保存快照的路径
)

model.train()

for epoch in range(300):
    for input, target in loader:
        ...
        optimizer.zero_grad()
        loss_fn(model(input), target).backward()
        optimizer.step()
    
        emas.update()  # 在此调用 emas.update()

# 重建
synthesized_ema = emas.synthesize_ema_model(sigma_rel = 0.15)
# 直接调用 synthesized_ema 进行推理
synthesized_ema_output = synthesized_ema(data)

有些细节要注意,

  • PostHocEMA 不会主动清空 checkpoint_folder 里的内容。目录里原有的 pt 文件会对程序造成干扰
  • 使用 emas.model 获取原模型,使用 emas.ema_model 获取权重重建后的模型
  • 建议保存模型时保存 PostHocEMA。这样能保留 ema step 等信息

实验

为了探究重建效果,使用以下代码进行实验。

<math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 取 0.05 和 0.2,总步数 1000,两种 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 下各取 20 个快照。重建 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel 为 0.15 时的 ema 权重。

python 复制代码
import torch
from ema_pytorch import PostHocEMA
import plotly.graph_objects as go

net = torch.nn.Linear(1, 1000, bias=False)

emas = PostHocEMA(
    net,
    sigma_rels=(
        0.05,
        0.2,
    ),
    update_every=1,
    checkpoint_every_num_steps=50,
    checkpoint_folder=r'Z:\post-hoc-ema-checkpoints',
)

net.train()

for i in range(1000):
    with torch.no_grad():
        net.weight.zero_()
        channel_index = i % 1000
        net.weight[channel_index, 0] = 1.

    emas.update()

synthesized_ema = emas.synthesize_ema_model(sigma_rel=0.15)
ema_weights = synthesized_ema.ema_model.weight.detach().numpy().flatten()

fig = go.Figure(data=go.Scatter(x=list(range(1000)), y=ema_weights, mode='lines+markers'))
fig.update_layout(
    title='Synthesized EMA Model Weights',
    xaxis_title='Channel Index',
    yaxis_title='Weight Value',
    template='plotly_white',
)
fig.show()

这个曲线相当够用了。

思考

EMA 能在训练完成后提升模型效果。post-hoc EMA 更进一步,允许在训练完成后找到最合适 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ rel \sigma_{\text{rel}} </math>σrel,进一步提升模型能力。可以说是相当值得一用的 trick。

不过显而易见的,post-hoc EMA 要求维护不止一个模型的参数,对显存要求陡然上升。即使单独让 EMA 权重放在 CPU 侧,设备通信过程会大大拖慢训练进程。即使可以设置 update_every 参数减少通信次数,但势必会影响重建效果。另一个缺点就是快照占用磁盘空间。40 可不是小数目,更别说论文中提及的至少 160 个快照。

论文提出 post-hoc EMA 是针对扩散模型的,而扩散模型众所周知大块头一个......普通人很难很好地用上 post-hoc EMA 训练大模型吧。不过训练小玩具时,好歹有了个似乎很有提分希望的可选技巧。

参考来源

相关推荐
AI_NEW_COME34 分钟前
知识库管理系统可扩展性深度测评
人工智能
海棠AI实验室1 小时前
AI的进阶之路:从机器学习到深度学习的演变(一)
人工智能·深度学习·机器学习
hunteritself1 小时前
AI Weekly『12月16-22日』:OpenAI公布o3,谷歌发布首个推理模型,GitHub Copilot免费版上线!
人工智能·gpt·chatgpt·github·openai·copilot
IT古董2 小时前
【机器学习】机器学习的基本分类-强化学习-策略梯度(Policy Gradient,PG)
人工智能·机器学习·分类
centurysee2 小时前
【最佳实践】Anthropic:Agentic系统实践案例
人工智能
mahuifa2 小时前
混合开发环境---使用编程AI辅助开发Qt
人工智能·vscode·qt·qtcreator·编程ai
四口鲸鱼爱吃盐2 小时前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类
蓝天星空2 小时前
Python调用open ai接口
人工智能·python
睡觉狂魔er2 小时前
自动驾驶控制与规划——Project 3: LQR车辆横向控制
人工智能·机器学习·自动驾驶
scan7242 小时前
LILAC采样算法
人工智能·算法·机器学习