《Learning to Reweight Examples for Robust Deep Learning》笔记

[1] 用 meta-learning 学样本权重,可用于 class imbalance、noisy label 场景。之前对其 (7) 式中 ϵ i , t = 0 \epsilon_{i,t}=0 ϵi,t=0(对应 Algorithm 1 第 5 句、代码 ex_wts_a = tf.zeros([bsize_a], dtype=tf.float32) )不理解:如果 ϵ \epsilon ϵ 已知是 0,那 (4) 式的加权 loss 不是恒为零吗?(5) 式不是优化了个吉而 θ ^ t + 1 ( ϵ ) ≡ θ t \hat\theta_{t+1}(\epsilon) \equiv \theta_t θ^t+1(ϵ)≡θt ?有人在 issue 提了这个问题^[2]^,但其人想通了没解释就关了 issue。

看到 [3] 代码中对 ϵ \epsilon ϵ 设了 requires_grad=True 才反应过来:用编程的话说, ϵ \epsilon ϵ 不应理解成常量,而是变量; 用数学的话说,(5) 的求梯度( ∇ \nabla ∇)是算子,而不是函数,即 (5) 只是在借梯度下降建立 θ ^ t + 1 \hat\theta_{t+1} θ^t+1 与 ϵ \epsilon ϵ 之间的函数(或用 TensorFlow 的话说,只是在建图 ),即 θ ^ t + 1 ( ϵ ) \hat\theta_{t+1}(\epsilon) θ^t+1(ϵ),而不是基于常量 θ t \theta_t θt、 ϵ = 0 \epsilon=0 ϵ=0 算了一步 SGD 得到一个常量 θ ^ t + 1 \hat\theta_{t+1} θ^t+1。

一个符号细节:无 hat 的 θ t + 1 \theta_{t+1} θt+1 指由 (3) 用无 perturbation 的 loss 经 SGD 从 θ t \theta_t θt 优化一步所得; θ ^ t + 1 \hat\theta_{t+1} θ^t+1 则是用 (4) perturbed loss。文中 (6)、(7) 有错用作 θ t + 1 \theta_{t+1} θt+1 的嫌疑。

所以大思路是用 clean validation set 构造一条关于 ϵ \epsilon ϵ 的 loss J ( ϵ ) J(\epsilon) J(ϵ),然后用优化器求它,即 ϵ t ∗ = arg ⁡ min ⁡ ϵ J ( ϵ ) \epsilon_t^*=\arg\min_\epsilon J(\epsilon) ϵt∗=argminϵJ(ϵ)。由 (4) - (6) 有: J ( ϵ ) = 1 M ∑ j = 1 M f j v ( θ ^ t + 1 ( ϵ ) ) ( 6 ) = 1 M ∑ j = 1 M f j v ( θ t − α [ ∇ θ ∑ i = 1 n f i , ϵ ( θ ) ] ∣ θ = θ t ⏟ g 1 ( ϵ ; θ t ) ) ( 5 ) = 1 M ∑ j = 1 M f j v ( θ t − α [ ∇ θ ∑ i = 1 n ϵ i f i ( θ ) ] ∣ θ = θ t ) ( 4 ) = g 2 ( ϵ ; θ t ) \begin{aligned} J(\epsilon) &= \frac{1}{M}\sum_{j=1}^M f_j^v \left(\hat\theta_{t+1}(\epsilon) \right) & (6) \\ &= \frac{1}{M}\sum_{j=1}^M f_j^v \left(\theta_t - \alpha \underbrace{\left[ \nabla_{\theta} \sum_{i=1}^n f_{i,\epsilon}(\theta) \right] \bigg|{\theta=\theta_t}}{g_1(\epsilon; \theta_t)} \right) & (5) \\ &= \frac{1}{M}\sum_{j=1}^M f_j^v \left(\theta_t - \alpha \left[ \nabla_{\theta} \sum_{i=1}^n \epsilon_i f_i(\theta) \right] \bigg|{\theta=\theta_t} \right) & (4) \\ &= g_2(\epsilon; \theta_t) \end{aligned} J(ϵ)=M1j=1∑Mfjv(θ^t+1(ϵ))=M1j=1∑Mfjv θt−αg1(ϵ;θt) [∇θi=1∑nfi,ϵ(θ)] θ=θt =M1j=1∑Mfjv(θt−α[∇θi=1∑nϵifi(θ)] θ=θt)=g2(ϵ;θt)(6)(5)(4) 要注意的就是 (5) 那求导式,本质是个函数,而不是常量,其中 ϵ \epsilon ϵ 是自由的, θ \theta θ 由于被 ∣ θ = θ t |{\theta=\theta_t} ∣θ=θt 指定了,所以看成常量,所以记为 g 1 ( ϵ ; θ t ) g_1(\epsilon;\theta_t) g1(ϵ;θt),于是整个 J ( ϵ ) J(\epsilon) J(ϵ) 也可以看成一个 g 2 ( ϵ ; θ t ) g_2(\epsilon; \theta_t) g2(ϵ;θt)。

按 (6) 求 ϵ t ∗ \epsilon_t^* ϵt∗ 的思路就是:

  1. 随机初始化 ϵ t ( 0 ) \epsilon_t^{(0)} ϵt(0);
  2. ϵ t ( s + 1 ) ← ϵ t ( s ) − η ∇ ϵ J ( ϵ ) ∣ ϵ = ϵ t ( s ) \epsilon^{(s+1)}t \leftarrow \epsilon^{(s)}t - \eta \nabla{\epsilon} J(\epsilon) \big|{\epsilon=\epsilon^{(s)}_t} ϵt(s+1)←ϵt(s)−η∇ϵJ(ϵ) ϵ=ϵt(s),即 (7) 右边。可能由于 J ( ϵ ) J(\epsilon) J(ϵ) 形式上是带梯度的表达式, § \S § 3.3 就称此为「unroll the gradient graph」,而求 ϵ t ( s + 1 ) \epsilon^{(s+1)}_t ϵt(s+1) 的这一步就称为「backward-on-backward」吧。

而文章的 online approximation 就是:

  • ϵ t ( 0 ) = 0 \epsilon^{(0)}_t=0 ϵt(0)=0
  • ϵ t ∗ ≈ ϵ t ( 1 ) \epsilon^*_t \approx \epsilon^{(1)}_t ϵt∗≈ϵt(1)

初始化为 0 可能不是最好的初始化方法,但不影响后续迭代优化,可参考 LoRA^[7]^,它也用到全零初始化。

References

  1. (ICML'18) Learning to Reweight Examples for Robust Deep Learning - paper, code
  2. gradients of noisy loss w.r.t parameter \theta #2
  3. (PyTorch 复现 1)TinfoilHat0/Learning-to-Reweight-Examples-for-Robust-Deep-Learning-with-PyTorch-Higher
  4. (PyTorch 复现 2)danieltan07/learning-to-reweight-examples
  5. facebookresearch/higher
  6. Stateful vs stateless
  7. (ICLR'22) LoRA: Low-Rank Adaptation of Large Language Models - paper, code
相关推荐
前行居士1 个月前
元学习与机器学习
人工智能·神经网络·学习·机器学习·元学习
YiPeng_Deng3 个月前
【Deep Learning】Meta-Learning:训练训练神经网络的神经网络
人工智能·深度学习·神经网络·元学习·deep learning
Better Bench3 个月前
【博士每天一篇文献-综述】A survey on few-shot class-incremental learning
元学习·小样本学习·持续学习·灾难性遗忘·增量学习·过拟合·少量样本增量学习
zh-jp6 个月前
Delving into Sample Loss Curve to Embrace Noisy and Imbalanced Data
机器学习·元学习·long-tailed data·noisy labels
LabVIEW开发6 个月前
LabVIEW齿轮箱噪声监测系统
labview·labview开发·噪声·齿轮箱
uncle_ll6 个月前
机器学习——元学习
人工智能·学习·机器学习·meta·元学习
ErizJ7 个月前
论文阅读笔记 | MetaIQA: Deep Meta-learning for No-Reference Image Quality Assessment
论文阅读·笔记·深度学习·元学习·图像质量评价·iqa
前端扛把子9 个月前
微信小程序实时噪声分贝
微信小程序·噪声·分贝测试
zh-jp1 年前
论文精读:用于少样本目标检测的元调整损失函数和数据增强(Meta-tuning Loss Functions and Data Augmentation for Few-shot Object Detection)
强化学习·元学习·文献·few-shot·有监督学习