【提示学习论文五】Conditional Prompt Learning for Vision-Language Models论文原理及复现工作

Conditional Prompt Learning for Vision-Language Models 视觉语言模型的条件提示学习

文章介绍

  • 这篇文章于2022年发表在CVPR(Conference on Computer Vision and Pattern Recognition),作者是kaiyang.zhou, jingkang001, ccloy, ziwei.liu。
  • 研究发现CoOp的问题:泛化性差,CoOp在训练时对于已知类别(base classes)过拟合学习的上下文向量不能推广到同一数据集中的未知类
  • 作者提出Conditional Context Optimization(CoCoOp)。CoCoOp在CoOp基础上引入一个轻量级的神经网络为每张图像生成 input-conditional tokens(vectors),这些tokens会加到原本CoOp的learnable vectors上,从而可以学习到更泛化的prompt。

问题背景

  • CoOp是一种有效利用数据的方法,只需少量标记图像数据即可训练上下文向量,以提高模型性能。
  • 然而,CoOp存在一个问题,其学到的上下文信息无法推广到同一数据集中更广泛的未知类别,CoOp在训练中过于专注于特定类别 ,导致模型无法很好地泛化到其他类别上。
  • 作者认为,通过实例条件化上下文,可以更好地泛化,因为这使得模型不再专注于特定一组类别,而是关注于每个输入实例及整个任务
  • 为了解决这个问题,提出了CoCoOp方法。

设计

  • 简单实现方法: 构建 M M M个神经网络来生成 M M M个上下文标记,但这会增加计算资源的需求。
  • 参数效率设计: 作者提出了更高效的设计方案,该方案在M个上下文向量的基础上进一步学习一个轻量级的神经网络(Meta-Net) 。这个Meta-Net用于为每个输入图像生成一个条件化的标记并将其与上下文向量结合

模型结构

  • CoOp
  • CoCoOp:由两个可学习的组件组成,一组上下文向量和一个轻量级神经网络(Meta-Net),为每个图像生成一个输入条件token
  • 输入图像编码器生成的图像 x \mathbf{x} x 特征 ,通过 Meta-Net 生成相应的条件标记 t y ( x ) \mathbf{t}_y (\mathbf{x}) ty(x)
  • 计算输入图像 x \mathbf{x} x 与每个类别提示 t i ( x ) \mathbf{t}_i (\mathbf{x}) ti(x)之间的相似度
  • 对于每个类别 i i i ,将相似度值作为指数项应用于指数函数,同时用温度参数 τ \tau τ 进行缩放,将相似度映射为概率得分
  • 将所有类别的指数项相加并归一化,得到每个类别的归一化概率分布
  • 最终的预测概率表示为给定输入图像 x \mathbf{x} x下属于每个类别的可能性。

实现细节

p ( y ∣ x ) = exp ⁡ ( sim ⁡ ( x , g ( t y ( x ) ) ) / τ ) ∑ i = 1 K exp ⁡ ( sim ⁡ ( x , g ( t i ( x ) ) / τ ) p(y | \mathbf{x}) = \frac{\exp (\operatorname{sim} (\mathbf{x}, g(\mathbf{t}y (\mathbf{x}))) / \tau )}{\sum{i=1}^K \exp (\operatorname{sim} (\mathbf{x}, g(\mathbf{t}_i (\mathbf{x})) / \tau )} p(y∣x)=∑i=1Kexp(sim(x,g(ti(x))/τ)exp(sim(x,g(ty(x)))/τ)

  • 计算预测概率的公式,涉及了上下文标记和模型的预测函数。

  • 评估模型对给定输入图像的类别预测概率。

  • 训练过程中,更新了上下文向量 v m {v_m} vm 和 Meta-Net 的参数 θ θ θ 。

  • Meta-Net 结构: Meta-Net采用了一个两层的瓶颈结构,隐藏层将输入维度降低了16倍。

参数

  • p ( y ∣ x ) p(y | \mathbf{x}) p(y∣x):表示在给定输入图像 x \mathbf{x} x 的情况下,模型预测为类别 y y y 的概率。
  • t y ( x ) \mathbf{t}_y (\mathbf{x}) ty(x):表示输入图像 x \mathbf{x} x 对应类别 y y y 的提示(即条件化的标记),包括了关于这个图像的特定信息。
  • sim ⁡ ( x , g ( t i ( x ) ) ) \operatorname{sim} (\mathbf{x}, g(\mathbf{t}_i (\mathbf{x}))) sim(x,g(ti(x))):表示图像 x \mathbf{x} x 与类别 i i i的提示 t i ( x ) \mathbf{t}_i (\mathbf{x}) ti(x)之间的相似度。这个相似度函数可以是任何测量图像与提示之间相似程度的函数。
  • K K K:表示类别的总数。
  • τ \tau τ:表示温度参数,用于调整预测分布的平滑度。
相关推荐
摇滚侠2 小时前
如何选择 nodejs 版本,nodejs 版本号详解
学习
醇氧2 小时前
【学习】IP地址:数字世界的“门牌号”怎么读?
网络协议·学习·tcp/ip
talen_hx2963 小时前
《零基础入门Spark》学习笔记 Day 11
笔记·学习·spark
ZhiqianXia4 小时前
gem5 模拟器学习笔记(1):核心术语整理
笔记·学习
GHL2842710904 小时前
MCP学习
学习·ai
凌波粒5 小时前
D2L学习笔记:安装、张量与数据处理
笔记·python·学习·pandas
chools5 小时前
Java后端拥抱AI开发之个人学习路线 - - Spring AI【第一期】
java·人工智能·学习·spring·ai
忙什么果6 小时前
transformer学习笔记2
笔记·学习·transformer
ZhiqianXia6 小时前
Gem5 学习笔记(2) : Gem5 建模要点与基本思路
笔记·学习