[ICLR 2023] LPT: Long-tailed Prompt Tuning for Image Classification

Contents

  • Introduction
  • [Preliminary Study](#Preliminary Study)
    • [Performance Investigation of VPT (Visual Prompt Tuning)](#Performance Investigation of VPT (Visual Prompt Tuning))
    • [Analysis of Prompt Tuning](#Analysis of Prompt Tuning)
  • [Long-tailed Prompt Tuning (LPT)](#Long-tailed Prompt Tuning (LPT))
    • [Phase 1: Shared Prompt Tuning](#Phase 1: Shared Prompt Tuning)
    • [Phase 2: Group Prompts Tuning](#Phase 2: Group Prompts Tuning)
    • [Loss Function](#Loss Function)
  • Experiments
    • [Comparison with State-of-The-Art Methods](#Comparison with State-of-The-Art Methods)
    • [Robustness with Domain Shift](#Robustness with Domain Shift)
    • [Ablation Study](#Ablation Study)
  • References

Introduction

  • 作者提出 Long-tailed Prompt Tuning (LPT) ,通过 prompt learning 来解决长尾问题,包括 (1) 使用 shared prompt 学习 general features 并将预训练模型 adapt 到 target domain;(2) 使用 group-specific prompts 学习 group-specific features 来提高模型的 fine-grained discriminative ability

Preliminary Study

Performance Investigation of VPT (Visual Prompt Tuning)

  • 作者首先通过对比 VPT (Visual Prompt Tuning) 和 linear probing 在 Places-LT 数据集上的精度来说明 prompt tuning 对长尾数据集是有效的 (VPT 的输入为 input tokens 加上 learnable prompts (tokens),同时和 linear probing 一样在预训练模型最后加上 linear classifier)
  • 从下表中可以看出:a) prompt tuning 可以持续提高模型的 LTR 性能 ;b) prompt tuning 对长尾分布具有鲁棒性,能更好地学习尾部类别 。同时也可以注意到,简单的 prompt tuning 并不能直接让模型在长尾数据集上达到 SOTA

Analysis of Prompt Tuning

  • 作者接下来分析了为什么 prompt tuning 适合长尾识别 (但仍然没有从原理上分析为什么)
  • 由下图的 LDA 可视化可以看出 (use the pretrained ViT-B and the ViT-B fine-tuned by VPT on Places-LT to extract features of ImageNet val set and Places-LT val set),prompt tuning 可以很好地将下游任务数据分布 (Places-LT) 和预训练数据分布 (ImageNet) 对齐,可以更好地让预训练模型 adapt 到长尾任务的 target domain (from domain adaptation perspective)
  • 作者计算了 ViT-B 和 VPT 输出特征的平均类内距离、平均类间距离以及两者之商 γ \gamma γ,可以看到,VPT 的平均类内距离和 γ \gamma γ 都更小,KNN 分类准确率更高,说明 VPT 输出的特征更具有区分度

Long-tailed Prompt Tuning (LPT)

Phase 1: Shared Prompt Tuning

  • 类似于 VPT-Deep,给 ViT 的 L L L 层都各自加上额外的 prompts,因此 phase 1 需要优化 shared prompt u = [ u 1 , . . . , u L ] \mathbf u=[\mathbf u_1,...,\mathbf u_L] u=[u1,...,uL] 和 cosine classifier f f f ,其中 shared prompt 用于学习所有类别的共同特征,并带来了上节讨论的 prompt tuning 的各种好处,包括 domain adaptation 和输出更具区分度的特征
  • 每层里的前向过程
    其中, c \mathbf c c 为 [CLS], z \mathbf z z 为 token embed. 新添加的 prompts 不需要计算对应的自注意力输出,只需要作为 key 和 value 与 token embed 做交互即可
  • 损失函数

Phase 2: Group Prompts Tuning

  • 作者在 phase 2 加入了 m m m 组 group-specific prompts R = { ( k 1 , r 1 ) , . . . , ( k m , r m ) } \mathcal R=\{(\mathbf k_1,\mathbf r^1),...,(\mathbf k_m,\mathbf r^m)\} R={(k1,r1),...,(km,rm)} 用于学习 group-specific knowledge 从而增强模型的 fine-grained discriminative ability,其中 k i \mathbf k_i ki 为 i i i-th group 的 key, r i \mathbf r^i ri 为 i i i-th group 的 prompts,包含 L − K L-K L−K 个 prompt 序列 (只在后 L − K L-K L−K 层使用 group-specific prompts).
  • Phase 2 包含两个步骤:(1) 冻住 shared prompts,经过 L L L 层推理得到 c L \mathbf c_L cL 作为 query q \mathbf q q 与 m m m 个 keys 计算余弦相似度,选出相似度最高的 k k k 个 groups
    然后对选出的 k k k 个 groups 的 prompts 进行 prompt ensembling
    (2) 重新使用步骤 (1) 在前向传播中得到的 ( c K , z K ) (\mathbf c_K,\mathbf z_K) (cK,zK),在后 L − K L-K L−K 层重新进行前向传播,每层的输入包括 [CLS] embed c \mathbf c c、patch embed z \mathbf z z、shared prompt u \mathbf u u 和 group-specific prompt r \mathbf r r,每层里的前向过程为
  • 损失函数
    其中, β \beta β 为 scale factor,第二项损失函数被用于增大 q \mathbf q q 和其匹配的 k k k 个 groups 的 keys 之间的余弦相似度,这是由于 Phase 1 生成的特征已经比较 compact 并且在 Phase 2 是不变的,因此该损失项可以使得 keys 靠近特征空间中的不同聚类中心,使得不同 groups 对应不同的 group-specific feature
  • Dual Sampling . class-balanced sampling 和 instance-balanced sampling 分别容易使得模型对尾部和头部类别过拟合,作者采用 Dual Sampling,从 instance-balanced sampler 和 class-balanced sampler 分别采样一个 mini-batch { I } ins \{\mathbf I\}{\text{ins}} {I}ins 和 { I } bal \{\mathbf I\}{\text{bal}} {I}bal. { I } bal \{\mathbf I\}{\text{bal}} {I}bal 的损失函数对应 β = 1 \beta=1 β=1 时的 L P 2 \mathcal L{\mathbf P_2} LP2, { I } ins \{\mathbf I\}{\text{ins}} {I}ins 的损失函数对应 β = η ( E − e ) / E \beta=\eta(E-e)/E β=η(E−e)/E 时的 L P 2 \mathcal L{\mathbf P_2} LP2,其中 η = 0.5 \eta=0.5 η=0.5 为 initialized weight, E E E 为总的训练 epoch 数, e e e 为当前 epoch 数

Loss Function

  • phase 1/2 中使用的 L cls \mathcal L_{\text{cls}} Lcls 采用 asymmetric GCL loss L A-GCL \mathcal L_{\text{A-GCL}} LA-GCL.
  • 首先根据 GCL 对 logits s ^ \hat {\mathbf s} s^ 进行加上 bias 和 rescale
    其中, α \alpha α 为 scaling factor, ϵ \epsilon ϵ 为从高斯分布中采样的随机变量 ( ∥ ϵ ∥ \|\epsilon\| ∥ϵ∥ 为取绝对值), n i n_i ni 为训练集中类别 i i i 的样本数, n m a x n_{max} nmax 为训练集中的最大类别样本数. 对应的 per-class probability 为
  • 然后根据 ASL 进行 Asymmetric Focusing
    L A − G C L = − y j ( 1 − p j ) λ + log ⁡ ( p j ) − ∑ 1 ≤ i ≤ C , i ≠ j y i ( p i ) λ − log ⁡ ( p i ) \mathcal{L}{\mathrm{A}-\mathrm{GCL}}=-\mathbf y{\mathrm j}\left(1-\mathbf{p}{\mathrm{j}}\right)^{\lambda{+}} \log \left(\mathbf{p}{\mathrm{j}}\right)-\sum{1 \leq \mathrm{i} \leq \mathrm{C}, \mathrm{i} \neq \mathrm{j}}\mathbf y_{\mathrm i}\left(\mathbf{p}{\mathrm{i}}\right)^{\lambda{-}} \log \left(\mathbf{p}{\mathrm{i}}\right) LA−GCL=−yj(1−pj)λ+log(pj)−1≤i≤C,i=j∑yi(pi)λ−log(pi)其中, j j j 为输入样本的标签类别, λ + = 0 , λ − = 4 λ+=0,λ_−=4 λ+=0,λ−=4 为 focusing parameter, y \mathbf y y 为 label smoothing 后的类别标签向量,即 y j = 0.9 + 0.1 / C , y i = 0.1 / C \mathbf y_{\mathrm j}=0.9+0.1/C,\mathbf y_{\mathrm i}=0.1/C yj=0.9+0.1/C,yi=0.1/C (疑问 :ASL 本来是 BCE 上用的,但这里是 CE + label smoothing 之后再加上 ASL 的动态加权, ( 1 − p j ) λ + \left(1-\mathbf{p}{\mathrm{j}}\right)^{\lambda{+}} (1−pj)λ+ 的意义和 ASL 一样,都是筛选出难样本,但感觉 ( p i ) λ − \left(\mathbf{p}{\mathrm{i}}\right)^{\lambda{-}} (pi)λ− 的意义已经和 ASL 完全不同了,可以等进一步理解 label smoothing 为什么有用之后再来看)

Experiments

  • Model. ViT-B/16 with ImageNet-21k pretrained model.
  • Shared Prompt. default length of prompt as 10.
  • Group-specific Prompts . shared layer number K = 6 K = 6 K=6 and the size of prompt size m = 20 m = 20 m=20; for each prompt in the set, the prompt length is also set as 10 (Note that setting K = 6 K = 6 K=6 may lead to 1.5x inference cost compared to VPT). prompt ensemble number k = 2 k = 2 k=2.

Comparison with State-of-The-Art Methods

  • Comparison on Places-LT .
  • Comparison on CIFAR100-LT .
  • Comparison on iNaturalist 2018 .

Robustness with Domain Shift

Ablation Study

  • Different Model Size and Pretrained Models .

  • Effect of Each Phase .

  • Decoupled Training . during joint training, the shared prompt is still updated simultaneously, thus the query function is sub-optimal during training, resulting in worse matching results.

  • Query Function and Group Size m m m .
    when we further increase the size to 40, the final accuracy declines to 49.87%. A possible reason is that, some classes in the dataset may share some similar group-specific feature or knowledge

  • Effect of K K K . K K K 过大会导致无法学得有效的 group-specific knowledge,过小会导致 Phase 2 匹配 groups 时无法充分利用 Phase 1 得到的 adapted feature representation

  • Effect of Ensemble Number k k k .

  • Effect of Asymmetric GCL Loss .

  • Statistic of Prompt Matching .

References

相关推荐
此星光明3 个月前
GEE数据集——汉森全球森林变化数据集Hansen Global Forest Change v1.11 (2000-2023)
云计算·数据集·gee·2023·森林·全球·损失
岁月标记3 个月前
ACSAC 2023
2023·acsac
岁月标记4 个月前
CCS 2023
2023·ccs
岁月标记4 个月前
S&P 2023
2023·sp
ladymorgana4 个月前
JetBrains全家桶激活,分享 CLion 2024 激活的方案
jetbrains·2023·clion·2024·激活方法·3.2·3.4
ladymorgana4 个月前
2024 PhpStorm激活,分享几个PhpStorm激活的方案
phpstorm·2023·2024·激活方法·3.2·3.4
ladymorgana4 个月前
2024 RubyMine 激活,分享几个RubyMine 激活的方案
2023·rubymine·2024·激活方法·3.2·3.4
webmote6 个月前
2023年总结我所经历的技术大变革
chatgpt·webrtc·2023·总结·技术变革
TechMerger6 个月前
毕业 10 年,也成了 Android 10 年老开发|紧张充实的 2023
android·kotlin·鸿蒙·2023·总结
深竹清风6 个月前
2023我的编程之旅-地质人的山和水
2023·总结