DynaPrompt(ICLR 2025)论文总结
论文链接:https://arxiv.org/pdf/2501.16404
代码地址:https://github.com/zzzx1224/DynaPrompt
一、研究背景与核心问题
当前视觉 - 语言模型(如 CLIP)在下游任务适应中面临分布偏移问题 ,现有测试时提示调优 (Test-Time Prompt Tuning, TPT)方法存在显著缺陷:TPT(Shu et al., 2022)为每个测试样本独立调优提示,忽略了测试样本间的关联性及测试数据分布信息;而在线测试时提示调优(Online TPT)虽尝试利用历史测试样本信息,却因误差累积导致 "提示崩溃"------ 即提示逐渐积累噪声,最终预测准确率骤降(Online TPT 在 ImageNet-A 上最终准确率仅 6.96%)。为解决这一矛盾,论文提出DynaPrompt(Dynamic Test-Time Prompt Tuning) ,通过动态提示管理机制,在利用测试样本关联性的同时缓解误差累积。
论文基于预训练 CLIP 模型(默认采用 ViT-Base-16 架构),全程冻结 CLIP 的图像编码器( F θ I ( ⋅ ) F_{\theta_I}(\cdot) FθI(⋅))与文本编码器参数,仅在测试阶段对 "可学习提示" 的嵌入向量进行优化,无需针对下游任务额外训练。
二、DynaPrompt 核心方法设计
DynaPrompt 的核心是在线提示缓冲区(Prompt Buffer) ,通过 "动态提示选择""动态提示追加""提示优化与缓冲区更新" 三个模块,实现对每个测试样本的自适应提示调优,具体流程与技术细节如下:
1. 预处理:测试样本增强
对每个输入测试样本 x n x_n xn,采用 AugMix 策略(Hendrycks et al., 2020)生成 63 个增强样本 X n X_n Xn( X n = { 增强样本 } X_n = \{增强样本\} Xn={增强样本}),与原始样本共同构成含 64 个样本的集合。该集合用于后续计算预测熵、概率差及提示优化,目的是通过数据增强捕捉样本多视角特征,提升提示调优的鲁棒性。
2. 动态提示选择:筛选相关提示 Dynamic Prompt Selection
提示缓冲区初始化时为空( V 0 = ∅ V_0 = \emptyset V0=∅),后续存储历史优化后的提示嵌入向量(记为 v i ∈ V n v_i \in V_n vi∈Vn, M n M_n Mn为缓冲区第 n n n步的提示数量,最大容量 M = 10 M=10 M=10)。对每个测试样本 x n x_n xn,从缓冲区 V n V_n Vn中筛选与 x n x_n xn相关的提示子集 S n S_n Sn,筛选依赖两个核心指标:
(1)预测熵(Entropy):衡量提示置信度
预测熵量化提示对测试样本预测的不确定性,熵值越低,提示对样本的预测越自信,蕴含的相关分布信息越多。计算公式为:
D e n t ( x n , v i ) = − ∑ c = 1 C p ( y ^ = c ∣ X n , v i ) log p ( y ^ = c ∣ X n , v i ) \mathcal{D}{ent}(x_n, v_i) = -\sum{c=1}^{C} p(\hat{y}=c \mid X_n, v_i) \log p(\hat{y}=c \mid X_n, v_i) Dent(xn,vi)=−c=1∑Cp(y^=c∣Xn,vi)logp(y^=c∣Xn,vi)
-
变量解释: C C C为下游任务类别数; p ( y ^ = c ∣ X n , v i ) p(\hat{y}=c \mid X_n, v_i) p(y^=c∣Xn,vi)是提示 v i v_i vi对增强样本集 X n X_n Xn的 "平均预测概率"------ 先对 X n X_n Xn中每个样本 x ∈ X n x \in X_n x∈Xn,通过 CLIP 计算类别 c c c的概率( p ( y ^ = c ∣ x , v i ) = exp ( cos ( F θ I ( x ) , F θ T ( t c v ) ) / T ) ∑ c ′ = 1 C exp ( cos ( F θ I ( x ) , F θ T ( t c ′ v ) ) / T ) p(\hat{y}=c \mid x, v_i) = \frac{\exp(\cos(F_{\theta_I}(x), F_{\theta_T}(t_c^v))/T)}{\sum_{c'=1}^C \exp(\cos(F_{\theta_I}(x), F_{\theta_T}(t_c'^v))/T)} p(y^=c∣x,vi)=∑c′=1Cexp(cos(FθI(x),FθT(tc′v))/T)exp(cos(FθI(x),FθT(tcv))/T),其中 t c v = [ v i ] [ c l a s s c ] t_c^v = [v_i][class\ c] tcv=[vi][class c]为提示 v i v_i vi与类别 c c c构成的文本输入, F θ T ( ⋅ ) F_{\theta_T}(\cdot) FθT(⋅)为 CLIP 文本编码器, T T T为 CLIP 预训练的温度参数),再对所有 x ∈ X n x \in X_n x∈Xn的概率取平均,得到 p ( y ^ = c ∣ X n , v i ) p(\hat{y}=c \mid X_n, v_i) p(y^=c∣Xn,vi);
-
筛选规则:以初始提示 v 0 v_0 v0(手动构建如 "a photo of a" 的嵌入,或复用 CoOp/MaPLe 的预训练提示嵌入)的熵值 D e n t ( x n , v 0 ) \mathcal{D}{ent}(x_n, v_0) Dent(xn,v0)为阈值,筛选出熵值≤该阈值的提示,构成子集 E n = { v i ∈ V n ∣ D e n t ( x n , v i ) ≤ D e n t ( x n , v 0 ) } \mathcal{E}n = \{v_i \in V_n \mid \mathcal{D}{ent}(x_n, v_i) \leq \mathcal{D}{ent}(x_n, v_0)\} En={vi∈Vn∣Dent(xn,vi)≤Dent(xn,v0)}。
(2)概率差(Probability Difference):避免过度自信
概率差量化提示对样本结构变化的敏感性,差值越高,提示越能区分原始样本与增强样本,不易因 "过度自信" 导致提示崩溃。计算公式为:
D p r o ( x n , v i ) = p ( y ^ = c ∗ ∣ x n , v i ) − p ( y ^ = c ∗ ∣ X n , v i ) \mathcal{D}_{pro}(x_n, v_i) = p(\hat{y}=c^* \mid x_n, v_i) - p(\hat{y}=c^* \mid X_n, v_i) Dpro(xn,vi)=p(y^=c∗∣xn,vi)−p(y^=c∗∣Xn,vi)
-
变量解释: c ∗ = arg max c p ( y ^ = c ∣ x n , v i ) c^* = \arg\max_c p(\hat{y}=c \mid x_n, v_i) c∗=argmaxcp(y^=c∣xn,vi)是提示 v i v_i vi对原始样本 x n x_n xn的 "伪标签"(预测概率最高的类别); p ( y ^ = c ∗ ∣ x n , v i ) p(\hat{y}=c^* \mid x_n, v_i) p(y^=c∗∣xn,vi)是提示 v i v_i vi对原始样本 x n x_n xn的类别 c ∗ c^* c∗概率, p ( y ^ = c ∗ ∣ X n , v i ) p(\hat{y}=c^* \mid X_n, v_i) p(y^=c∗∣Xn,vi)是对增强样本集 X n X_n Xn的类别 c ∗ c^* c∗平均概率;
-
筛选规则:同样以 v 0 v_0 v0的概率差 D p r o ( x n , v 0 ) \mathcal{D}{pro}(x_n, v_0) Dpro(xn,v0)为阈值,筛选出差值≥该阈值的提示,构成子集 R n = { v i ∈ V n ∣ D p r o ( x n , v i ) ≥ D p r o ( x n , v 0 ) } \mathcal{R}n = \{v_i \in V_n \mid \mathcal{D}{pro}(x_n, v_i) \geq \mathcal{D}{pro}(x_n, v_0)\} Rn={vi∈Vn∣Dpro(xn,vi)≥Dpro(xn,v0)}。
(3)最终筛选子集
取 E n \mathcal{E}_n En与 R n \mathcal{R}_n Rn的交集,得到最终相关提示子集 S n = E n ∩ R n S_n = \mathcal{E}_n \cap \mathcal{R}_n Sn=En∩Rn。该子集同时满足 "高置信度" 与 "高敏感性",既利用历史分布信息,又规避提示崩溃风险。
3. 动态提示追加:处理无相关提示场景 Dynamic Prompt Appending
若缓冲区 V n V_n Vn中无符合条件的提示( S n = ∅ S_n = \emptyset Sn=∅),则自动将初始提示 v 0 v_0 v0追加到 S n S_n Sn中(即 S n = { v 0 } S_n = \{v_0\} Sn={v0}),避免因无可用提示导致优化方向冲突。这一步是缓解误差累积的关键:当现有缓冲区提示均无关或崩溃时,引入全新初始提示,切断历史误差传递链。
4. 提示优化:熵最小化目标
对筛选后的提示子集 S n S_n Sn,以 "最小化预测熵" 为目标进行梯度更新,优化其嵌入向量。损失函数为:
L e n t ( S n ; x n ) = − ∑ c = 1 C p ( y ^ = c ∣ X n , S n ) log p ( y ^ = c ∣ X n , S n ) \mathcal{L}{ent}(S_n; x_n) = -\sum{c=1}^{C} p(\hat{y}=c \mid X_n, S_n) \log p(\hat{y}=c \mid X_n, S_n) Lent(Sn;xn)=−c=1∑Cp(y^=c∣Xn,Sn)logp(y^=c∣Xn,Sn)
-
变量解释: p ( y ^ = c ∣ X n , S n ) p(\hat{y}=c \mid X_n, S_n) p(y^=c∣Xn,Sn)是子集 S n S_n Sn中所有提示对 X n X_n Xn的 "平均预测概率"------ 先计算每个 v i ∈ S n v_i \in S_n vi∈Sn的 p ( y ^ = c ∣ X n , v i ) p(\hat{y}=c \mid X_n, v_i) p(y^=c∣Xn,vi),再对所有 v i v_i vi取平均;
-
优化操作:采用 AdamW 优化器,学习率根据任务场景设定(领域泛化场景 α = 0.005 \alpha=0.005 α=0.005,跨数据集场景 α = 0.003 \alpha=0.003 α=0.003),通过梯度下降更新提示嵌入: S ~ n ← S n − α ∇ L e n t ( x n , S n ) \tilde{S}n \leftarrow S_n - \alpha \nabla \mathcal{L}{ent}(x_n, S_n) S~n←Sn−α∇Lent(xn,Sn),其中 S ~ n \tilde{S}_n S~n为优化后的提示子集。
5. 测试预测与缓冲区更新 Prompt Buffer Updating
(1)预测
用优化后的提示子集 S ~ n \tilde{S}_n S~n对原始样本 x n x_n xn预测,计算每个 v i ∈ S ~ n v_i \in \tilde{S}_n vi∈S~n的 p ( y ^ = c ∣ x n , v i ) p(\hat{y}=c \mid x_n, v_i) p(y^=c∣xn,vi),取平均后概率最大的类别作为最终预测结果:
y ^ = arg max c 1 ∣ S ~ n ∣ ∑ v i ∈ S ~ n p ( y ^ = c ∣ x n , v i ) \hat{y} = \arg\max_c \frac{1}{|\tilde{S}n|} \sum{v_i \in \tilde{S}_n} p(\hat{y}=c \mid x_n, v_i) y^=argcmax∣S~n∣1vi∈S~n∑p(y^=c∣xn,vi)
(2)缓冲区更新
根据缓冲区容量与是否追加新提示,更新缓冲区 V n V_n Vn为 V n + 1 V_{n+1} Vn+1:
-
若 S n S_n Sn是追加的 v 0 v_0 v0(即 E n ∩ R n = ∅ \mathcal{E}_n \cap \mathcal{R}_n = \emptyset En∩Rn=∅)且缓冲区已满( M n = M M_n = M Mn=M):将 S ~ n \tilde{S}n S~n追加到缓冲区顶部,同时删除缓冲区底部 "最不活跃提示"(长期未被筛选的提示),即 V n + 1 = V n + S ~ n − { v i n a c t i v e } V{n+1} = V_n + \tilde{S}n - \{v{inactive}\} Vn+1=Vn+S~n−{vinactive};
-
其他情况:直接将 S ~ n \tilde{S}n S~n追加到缓冲区顶部,并移除原 S n S_n Sn在缓冲区中的旧版本(避免重复存储),即 V n + 1 = V n + S ~ n − S n V{n+1} = V_n + \tilde{S}_n - S_n Vn+1=Vn+S~n−Sn。
三、实验验证与关键结果
1. 实验设置
-
数据集:14 个基准数据集,涵盖两类场景 ------ 领域泛化(ImageNet 及 4 个变体:ImageNet-V2/S/A/R)、跨数据集(Caltech101、OxfordPets 等 10 个图像分类数据集);
-
对比方法:包括 CLIP、Prompt Learning 方法(CoOp、MaPLe 等)、测试时提示调优方法(TPT、DiffTPT、C-TPT 等)。
2. 核心性能结果
-
领域泛化场景:DynaPrompt 在 ImageNet-A 上准确率达 56.17%,优于 TPT(54.77%)、DiffTPT(55.68%),且与 CoOp 结合后准确率提升至 60.55%;平均准确率(OoD Mean)达 61.81%,高于所有对比方法;
-
跨数据集场景:DynaPrompt 在 8 个数据集上优于 TPT,与 MaPLe 结合后平均准确率达 67.29%,为所有方法最优;
-
消融实验:① 去掉预测熵筛选,平均准确率降至 58.69%;② 去掉概率差筛选,降至 59.23%;③ 去掉动态提示追加,平均准确率骤降至 32.63%,证明三个模块均不可或缺。④ prompt buffer的影响,buffer增大,平均准确率提升存在上界,但额外需要的时间单调增长。⑤ 对样本顺序的敏感度,样本量越大越稳定,越小波动更大。⑥ 初始prompt的影响。⑦ 文本提示的长度的影响。⑧ 不同backbone的影响。
3. 误差累积分析
通过 Oracle 方法(仅用正确预测的提示更新)验证误差累积的影响:Oracle 在 ImageNet-A 上准确率达 59.38%,远高于 Online TPT(6.96%),而 DynaPrompt 通过动态机制接近 Oracle 性能(56.17%),证明其有效缓解了误差累积。
四、结论
DynaPrompt 通过 "动态提示缓冲区 + 选择 / 追加策略",解决了测试时提示调优中 "样本关联性利用" 与 "误差累积" 的矛盾。其核心优势在于:无需下游任务训练,仅在测试阶段优化提示嵌入;通过双指标筛选与动态追加,确保提示的相关性与鲁棒性;可无缝结合现有 Prompt Learning 方法提升性能。实验表明,DynaPrompt 在 14 个数据集上均验证了有效性,为视觉 - 语言模型的测试时自适应提供了高效解决方案。
(注:文档部分内容可能存在理解偏差,欢迎指正)