【Domain Generalization(2)】领域泛化在文生图领域的工作之——PromptStyler(ICCV23)

系列文章目录


文章目录


研究背景

  • 深度神经网络的分布偏移问题 :深度神经网络通常在训练和测试数据独立同分布的假设下进行训练,但在实际应用中,训练和测试数据之间可能存在显著的分布偏移,这严重影响了模型的性能,成为其在现实应用中的主要障碍之一。
  • 领域适应(DA)的研究 :为提高模型对分布偏移的鲁棒性,领域适应(DA)旨在利用训练中可用的目标域数据使神经网络适应目标域,但目标域在常见训练场景中往往难以获取,限制了DA的应用。
  • 领域泛化(DG)的研究 :领域泛化(DG)旨在提高模型对任意未见域的泛化能力,常见做法是利用多个源域学习域不变特征,但存在难以确定理想源域以及收集和标注多源域数据成本高且有时不可行的问题。
  • 联合视觉语言空间的特性 :大规模预训练模型已观察到多种领域,其联合视觉语言空间中,文本特征可有效表示相关图像特征,且尽管存在模态差异,但存在跨模态迁移性现象,即可以使用文本特征训练分类器并使用图像特征进行推理,这为解决无源域数据的领域泛化问题提供了思路。

研究动机

  • 利用大规模模型模拟分布偏移 :作者思考能否在不使用任何源域数据的情况下,通过模拟大规模模型潜在空间中的各种分布偏移来进一步提高模型的泛化能力,若可行将使DG更具实用性,但该方法具有挑战性,因为无法获取源域和目标域的实际数据,仅知道目标任务定义(如类名)。
  • 基于视觉语言模型解决无源DG问题 :作者认为大规模视觉语言模型有助于解决无源域泛化 这一具有挑战性的问题。
    • 通过在联合视觉语言空间中利用文本特征表示图像特征的特性 以及跨模态迁移性 ,可通过提示(prompt)模拟分布偏移,从而提出了PromptStyler方法。
    • 该方法通过学习可学习的风格词向量 来合成多样化风格,以模拟超球联合视觉语言空间(hyperspherical joint vision-language space) 中的分布偏移,同时考虑风格多样性和内容一致性,最后使用合成的特征训练分类器来实现领域泛化。

🔥提出的方法-PromptStyler

为解决无源域泛化问题,本文提出了PromptStyler方法,该方法主要通过在联合视觉语言空间中合成多样化风格来模拟分布偏移,从而提高模型的泛化能力,具体如下:

1. 风格词向量学习
  1. 在超球联合视觉语言空间(如CLIP潜在空间)中,通过随机初始化风格词向量 s i s_i si,并使用风格多样性损失 L s t y l e \mathcal{L}{style} Lstyle和内容一致性损失 L c o n t e n t \mathcal{L}{content} Lcontent来优化,以学习到多样化且不扭曲内容信息的风格词向量。

    • 风格多样性损失 :为最大化风格多样性,使学习到的风格特征相互正交,其计算方式为
      L s t y l e = 1 i − 1 ∑ j = 1 i − 1 ∣ T ( P i s t y l e ) ∥ T ( P i s t y l e ) ∥ 2 ⋅ T ( P j s t y l e ) ∥ T ( P j s t y l e ) ∥ 2 ∣ \mathcal{L}{style}=\frac{1}{i - 1}\sum{j = 1}^{i - 1}\left|\frac{T(\mathcal{P}{i}^{style})}{\|T(\mathcal{P}{i}^{style})\|{2}}\cdot\frac{T(\mathcal{P}{j}^{style})}{\|T(\mathcal{P}{j}^{style})\|{2}}\right| Lstyle=i−11∑j=1i−1 ∥T(Pistyle)∥2T(Pistyle)⋅∥T(Pjstyle)∥2T(Pjstyle)

      其中 T ( ⋅ ) T(\cdot) T(⋅)是预训练文本编码器, P i s t y l e \mathcal{P}_{i}^{style} Pistyle是风格提示。

    • 内容一致性损失 :为防止风格扭曲内容信息,使风格 - 内容特征与相应内容特征具有最高余弦相似度,计算方式为
      L c o n t e n t = − 1 N ∑ m = 1 N log ⁡ ( exp ⁡ ( z i m m ) ∑ n = 1 N exp ⁡ ( z i m n ) ) \mathcal{L}{content}=-\frac{1}{N}\sum{m = 1}^{N}\log\left(\frac{\exp(z_{imm})}{\sum_{n = 1}^{N}\exp(z_{imn})}\right) Lcontent=−N1∑m=1Nlog(∑n=1Nexp(zimn)exp(zimm))

      其中 z i m n z_{imn} zimn是风格 - 内容特征与内容特征之间的余弦相似度得分,具体为
      z i m n = T ( P i s t y l e ∘ P m c o n t e n t ) ∥ T ( P i s t y l e ∘ P m c o n t e n t ) ∥ 2 ⋅ T ( P n c o n t e n t ) ∥ T ( P n c o n t e n t ) ∥ 2 z_{imn}=\frac{T(\mathcal{P}{i}^{style} \circ \mathcal{P}{m}^{content})}{\|T(\mathcal{P}{i}^{style} \circ \mathcal{P}{m}^{content})\|{2}} \cdot \frac{T(\mathcal{P}{n}^{content})}{\|T(\mathcal{P}{n}^{content})\|{2}} zimn=∥T(Pistyle∘Pmcontent)∥2T(Pistyle∘Pmcontent)⋅∥T(Pncontent)∥2T(Pncontent)

    • 总提示损失 :将风格多样性损失和内容一致性损失相加得到总提示损失 L p r o m p t = L s t y l e + L c o n t e n t \mathcal{L}{prompt}=\mathcal{L}{style}+\mathcal{L}_{content} Lprompt=Lstyle+Lcontent,用于顺序学习 K K K个风格词向量。

2. 使用多样化风格训练线性分类器
  • 学习到 K K K个风格词向量后,利用预训练文本编码器 T ( ⋅ ) T(\cdot) T(⋅)和预定义的 N N N个类名合成 K N KN KN个风格 - 内容特征,用于训练线性分类器。
  • 分类器采用 ArcFace 损失 L c l a s s \mathcal{L}_{class} Lclass进行训练,ArcFace 损失是一种角度 Softmax 损失,通过添加类间 角度边际惩罚( additive angular margin penalty) 来计算分类器输入特征与权重之间的余弦相似度,使不同类别的特征分得更开,从而实现更具判别性的预测。
3. 使用训练好的分类器进行推理。
  • 在推理时,预训练图像编码器 I ( ⋅ ) I(\cdot) I(⋅)从输入图像中提取图像特征,经 ℓ 2 \ell_2 ℓ2归一化后映射到联合视觉语言空间,然后输入到训练好的分类器中产生类别分数。
  • 此时文本编码器 T ( ⋅ ) T(\cdot) T(⋅)不参与推理过程,仅图像编码器 I ( ⋅ ) I(\cdot) I(⋅)被使用。
相关推荐
掐死你滴温柔3 小时前
SQLALchemy如何将SQL语句编译为特定数据库方言
数据结构·数据库·python·sql
西猫雷婶3 小时前
python学opencv|读取图像(二十三)使用cv2.putText()绘制文字
开发语言·python·opencv
三掌柜6664 小时前
2025三掌柜赠书活动第一期:动手学深度学习(PyTorch版)
人工智能·pytorch·深度学习
唯创知音5 小时前
基于W2605C语音识别合成芯片的智能语音交互闹钟方案-AI对话享受智能生活
人工智能·单片机·物联网·生活·智能家居·语音识别
说私域5 小时前
数字化供应链创新解决方案在零售行业的应用研究——以开源AI智能名片S2B2C商城小程序为例
人工智能·开源·零售
yvestine6 小时前
数据挖掘——支持向量机分类器
人工智能·算法·机器学习·支持向量机·分类·数据挖掘·svm
阿正的梦工坊6 小时前
PyTorch到C++再到 CUDA 的调用链(C++ ATen 层) :以torch._amp_update_scale_调用为例
c++·人工智能·pytorch
三万棵雪松6 小时前
5.系统学习-PyTorch与多层感知机
人工智能·pytorch·学习
AIGC大时代6 小时前
不只是工具:ChatGPT写作在学术中的创新思维与深度思考
人工智能·chatgpt·prompt·aigc·ai写作
陈序缘7 小时前
PyTorch快速入门
人工智能·pytorch·python·深度学习·算法·机器学习