结合噪声对比估计(Noise Contrastive Estimation,NCE)的思想,通过互信息(Mutual Information)最小化来优化大规模分类任务,2019年,DeepMind的研究人员提出InfoNCE损失。
相较于NCE损失,InfoNCE损失有如下区别:
- InfoNCE的噪声采样自原数据的分布,而NCE的噪声采样自设定的某一分布
- InfoNCE在实操层面通过多分类实现,而NCE在实操层面通过二分类实现
- InfoNCE实现了数据表征和类别表征的互信息最小化,而NCE未实现
问题建模
已知 t t t时间段上下文 c c c,预测 t + k t+k t+k时间段的数据为 x t + k x_{t+k} xt+k。该问题常见的应用场景有:
(1)文本、音频、图像生成任务,已知 t t t时间段之前的所有文本、音频、图像,生成 t + k t+k t+k时间段的文本、音频、图像
(2)图文对比学习任务,已知文本,预测其对应的图像是哪一个
(3)分类任务,已知聚类信息,预测某数据的类别
在case(2)和case(3)中, t t t时间段上下文 c c c分别表示文本和聚类信息,表明InfoNCE具备较强的泛化能力,能够应用到各种场景中。
NCE的具体做法为利用模型拟合 P ( x t + k ∣ c ) P(x_{t+k}|c) P(xt+k∣c),InfoNCE认为该拟合方式粒度不够细,例如借助高维隐层向量 c c c来复原图像 x t + k x_{t+k} xt+k任务,隐层向量维度可能太低了,不足以复原图像。
InfoNCE将上下文 c c c和数据 x t + k x_{t+k} xt+k同时建模成维度一致的向量表示,通过互信息最大化,缩小两者之间的互信息,使得借助上下文 c c c来复原数据 x t + k x_{t+k} xt+k成为可能。
互信息
互信息的公式表示为:
I ( x t + k , c t ) = ∑ x t + k , c t p ( x t + k , c t ) l o g p ( x t + k ∣ c t ) p ( x t + k ) (1) I(x_{t+k},c_t)=\sum_{x_{t+k},c_t}p(x_{t+k},c_t)log\frac{p(x_{t+k}|c_{t})}{p(x_{t+k})}\tag{1} I(xt+k,ct)=xt+k,ct∑p(xt+k,ct)logp(xt+k)p(xt+k∣ct)(1)
互信息表示熵的差值,有:
I ( x t + k , c t ) = H ( x t + k ) − H ( x t + k ∣ c t ) (2) I(x_{t+k},c_t)=H(x_{t+k})-H(x_{t+k}|c_t)\tag{2} I(xt+k,ct)=H(xt+k)−H(xt+k∣ct)(2)
其中, H ( ⋅ ) H(\cdot) H(⋅)表示熵。
熵表示信息的不确定性或者信息量。互信息等于 c t c_t ct变量引入后, x t + k x_{t+k} xt+k的熵的变化。变化越少,说明 x t + k x_{t+k} xt+k和 c t c_t ct越接近。也就是说,互信息可以代表两个变量的相关性。
优化目标
公式(1)的互信息可以写成:
I ( x t + k , c t ) = ∑ x t + k , c t p ( x t + k , c t ) l o g p ( x t + k ∣ c t ) p ( x t + k ) = E x t + k , c t ( l o g p ( x t + k ∣ c t ) p ( x t + k ) ) (3) \begin{equation}\begin{aligned} I(x_{t+k},c_t)&=\sum_{x_{t+k},c_t}p(x_{t+k},c_t)log\frac{p(x_{t+k}|c_{t})}{p(x_{t+k})}\\ &=E_{x_{t+k},c_t}\left({log\frac{p(x_{t+k}|c_{t})}{p(x_{t+k})}}\right) \end{aligned} \end{equation} \tag{3} I(xt+k,ct)=xt+k,ct∑p(xt+k,ct)logp(xt+k)p(xt+k∣ct)=Ext+k,ct(logp(xt+k)p(xt+k∣ct))(3)
从互信息的角度来看,InfoNCE损失希望模型参数 f θ ( x t + k , c t ) f_{\theta}(x_{t+k},c_t) fθ(xt+k,ct)拟合到 p ( x t + k ∣ c t ) p ( x t + k ) \frac{p(x_{t+k}|c_{t})}{p(x_{t+k})} p(xt+k)p(xt+k∣ct),也就是 f θ ( x t + k , c t ) ∝ p ( x t + k ∣ c t ) p ( x t + k ) f_{\theta}(x_{t+k},c_t)\propto\frac{p(x_{t+k}|c_{t})}{p(x_{t+k})} fθ(xt+k,ct)∝p(xt+k)p(xt+k∣ct)。
观察 p ( x t + k ∣ c t ) p ( x t + k ) \frac{p(x_{t+k}|c_{t})}{p(x_{t+k})} p(xt+k)p(xt+k∣ct),类似于NCE损失,InfoNCE将分子 p ( x t + k ∣ c t ) p(x_{t+k}|c_{t}) p(xt+k∣ct)作为正样本,分母 p ( x t + k ) p(x_{t+k}) p(xt+k)作为负样本。
- NCE损失通过二分类,将不同的负样本 p ( x t + k ) p(x_{t+k}) p(xt+k)合并成一个类,区分正样本类别和负样本类别。
- InfoNCE损失通过多分类,每一个负样本 p ( x t + k ) p(x_{t+k}) p(xt+k)独立成各自的类别,若有 N − 1 N-1 N−1个负样本, 1 1 1个正样本,通过 N N N多分类任务进行区分。
于是,训练数据有 N N N个样本 V = { v 1 , v 2 , . . . , v N } V=\{v_1,v_2,...,v_N\} V={v1,v2,...,vN},其中 1 1 1个采样自正样本 p ( x t + k ∣ c t ) p(x_{t+k}|c_{t}) p(xt+k∣ct), N − 1 N-1 N−1个采样自负样本 p ( x t + k ) p(x_{t+k}) p(xt+k), 每个样本具有各自的类别。那么 v i v_i vi采样自正样本,以及 v j ≠ i v_{j\neq i} vj=i采样自负样本的概率,即 v i v_i vi的数据分布为:
p v i = p ( x t + k ∣ c t ) ∏ l ≠ t + k p ( x l ) ∑ j p ( x j ∣ c t ) ∏ l ≠ j p ( x l ) = p ( x t + k ∣ c t ) p ( x t + k ) ∑ j p ( x j ∣ x t ) p ( x j ) (4) \begin{equation}\begin{aligned} p_{v_i}=\frac{p(x_{t+k}|c_t)\prod_{l\neq{t+k}}p(x_l)}{\sum_jp(x_j|c_t)\prod_{l\neq j}p(x_l)}=\frac{\frac{p(x_{t+k}|c_t)}{p(x_{t+k})}}{\sum_j\frac{p(x_j|x_t)}{p(x_j)}} \end{aligned}\end{equation} \tag{4} pvi=∑jp(xj∣ct)∏l=jp(xl)p(xt+k∣ct)∏l=t+kp(xl)=∑jp(xj)p(xj∣xt)p(xt+k)p(xt+k∣ct)(4)
观察公式(4),由于 f θ ( x t + k , c t ) ∝ p ( x t + k ∣ c t ) p ( x t + k ) f_{\theta}(x_{t+k},c_t)\propto\frac{p(x_{t+k}|c_{t})}{p(x_{t+k})} fθ(xt+k,ct)∝p(xt+k)p(xt+k∣ct),公式(4)可以写成: f θ ( x t + k , c t ) ∑ j f θ ( x j , c t ) (5) \begin{equation}\begin{aligned} \frac{f_{\theta}(x_{t+k},c_t)}{\sum_jf_{\theta}(x_j,c_t)} \end{aligned}\end{equation} \tag{5} ∑jfθ(xj,ct)fθ(xt+k,ct)(5)
损失函数直接等于对数似然函数均值的负数形式,有:
L = − E V l o g p v h = − E V l o g [ f θ ( x t + k , c t ) ∑ j f θ ( x j , c t ) ] (6) \begin{equation}\begin{aligned} L&=-\mathbb{E}Vlogp{v_h}\\&=-\mathbb{E}Vlog\left[\frac{f{\theta}(x_{t+k},c_t)}{\sum_jf_{\theta}(x_j,c_t)}\right] \end{aligned}\end{equation} \tag{6} L=−EVlogpvh=−EVlog[∑jfθ(xj,ct)fθ(xt+k,ct)](6)
在具体实现过程中, f θ ( x t + k , c t ) f_{\theta}(x_{t+k},c_t) fθ(xt+k,ct)一般表示 x t + k x_{t+k} xt+k和 c t c_t ct的向量表征之间的余弦相似度。
InfoNCE与互信息的关系
InfoNCE利用 f ( x t + k , c t ) f(x_{t+k},c_t) f(xt+k,ct)来拟合 p ( x t + k ∣ c t ) p ( x t + k ) \frac{p(x_{t+k}|c_t)}{p(x_{t+k})} p(xt+k)p(xt+k∣ct),如果拟合成功,则有最优损失形式:
L = − E V l o g [ f θ ( x t + k , c t ) ∑ j f θ ( x j , c t ) ] = − E V l o g [ p ( x t + k ∣ c t ) p ( x t + k ) p ( x t + k ∣ c t ) p ( x t + k ) + ∑ v j ∈ V n e g p ( x j ∣ c t ) p ( x j ) ] = E V l o g [ 1 + p ( x t + k ) p ( x t + k ∣ c t ) ∑ v j ∈ V n e g p ( x j ∣ c t ) p ( x j ) ] ≈ E V l o g [ 1 + p ( x t + k ) p ( x t + k ∣ c t ) ( N − 1 ) E v j p ( x j ∣ c t ) p ( x j ) ] = E V l o g [ 1 + p ( x t + k ) p ( x t + k ∣ c t ) ( N − 1 ) ] = E V l o g [ p ( x t + k ∣ c t ) p ( x t + k ∣ c t ) + p ( x t + k ) p ( x t + k ∣ c t ) ( N − 1 ) ] ≥ E V l o g [ p ( x t + k ) p ( x t + k ∣ c t ) + p ( x t + k ) p ( x t + k ∣ c t ) ( N − 1 ) ] = E V l o g [ p ( x t + k ) p ( x t + k ∣ c t ) N ] = E V l o g [ p ( x t + k ) p ( x t + k ∣ c t ) ] + l o g N = − I ( x t + k , c t ) + l o g ( N ) (7) \begin{equation}\begin{aligned} L&=-\mathbb{E}Vlog\left[\frac{f{\theta}(x_{t+k},c_t)}{\sum_jf_{\theta}(x_j,c_t)}\right]\\ &=-\mathbb{E}Vlog\left[\frac{\frac{p(x{t+k}|c_t)}{p(x_{t+k})}}{\frac{p(x_{t+k}|c_t)}{p(x_{t+k})}+\sum_{v_j \in V_{neg}}\frac{p(x_{j}|c_t)}{p(x_{j})}}\right]\\ &=\mathbb{E}Vlog\left[1+\frac{p(x{t+k})}{p(x_{t+k}|c_t)}\sum_{v_j \in V_{neg} }\frac{p(x_{j}|c_t)}{p(x_{j})}\right]\\ &\approx\mathbb{E}Vlog\left[1+\frac{p(x{t+k})}{p(x_{t+k}|c_t)}(N-1)\mathbb{E}{v_j}\frac{p(x{j}|c_t)}{p(x_{j})}\right]\\ &=\mathbb{E}Vlog\left[1+\frac{p(x{t+k})}{p(x_{t+k}|c_t)}(N-1)\right]\\ &=\mathbb{E}Vlog\left[\frac{p(x{t+k}|c_t)}{p(x_{t+k}|c_t)}+\frac{p(x_{t+k})}{p(x_{t+k}|c_t)}(N-1)\right]\\ &\geq\mathbb{E}Vlog\left[\frac{p(x{t+k})}{p(x_{t+k}|c_t)}+\frac{p(x_{t+k})}{p(x_{t+k}|c_t)}(N-1)\right]\\ &=\mathbb{E}Vlog\left[\frac{p(x{t+k})}{p(x_{t+k}|c_t)}N\right]\\ &=\mathbb{E}Vlog\left[\frac{p(x{t+k})}{p(x_{t+k}|c_t)}\right]+logN\\ &=-I(x_{t+k},c_t)+log(N) \end{aligned}\end{equation} \tag{7} L=−EVlog[∑jfθ(xj,ct)fθ(xt+k,ct)]=−EVlog p(xt+k)p(xt+k∣ct)+∑vj∈Vnegp(xj)p(xj∣ct)p(xt+k)p(xt+k∣ct) =EVlog 1+p(xt+k∣ct)p(xt+k)vj∈Vneg∑p(xj)p(xj∣ct) ≈EVlog[1+p(xt+k∣ct)p(xt+k)(N−1)Evjp(xj)p(xj∣ct)]=EVlog[1+p(xt+k∣ct)p(xt+k)(N−1)]=EVlog[p(xt+k∣ct)p(xt+k∣ct)+p(xt+k∣ct)p(xt+k)(N−1)]≥EVlog[p(xt+k∣ct)p(xt+k)+p(xt+k∣ct)p(xt+k)(N−1)]=EVlog[p(xt+k∣ct)p(xt+k)N]=EVlog[p(xt+k∣ct)p(xt+k)]+logN=−I(xt+k,ct)+log(N)(7)
代码实现
基于最经典的CLIP模型,来理解InfoNCE的代码实现。
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_i, d_e] - learned proj of text to embed
# t - learned temperature parameter
# extract feature representations of each modality
I_f= image_encoder(I) #[n, d_i]
T_f= text_encoder(T) #[n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_i), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t) / 2
CLIP模型分别计算了文本侧和图像侧的InfoNCE损失,具体而言,包含如下步骤:
- 明确CLIP做的是图文对比学习任务:假设给定 M M M个图文对 { ( I i , T i ) } \{(I_i,T_i)\} {(Ii,Ti)}( I I I表示图像, T T T表示文本),我们需要拉近 I i I_i Ii和 T i T_i Ti间的距离,拉远 I i I_i Ii和 T j ≠ i T_{j\neq i} Tj=i距离,拉远 T i T_i Ti和 I j ≠ i I_{j\neq i} Ij=i的距离。最直接的方式是将每一个图文对归成一个类,通过 M M M类别的多分类任务来实现。但图文对 M M M的数量会非常大,可能几百万,可能几亿,压根无法训练。
- 为此,CLIP通过InfoNCE损失来替代 M M M类别的多分类任务。在一个batch内,组成训练数据集合 V V V,共 n n n个,如公式(4)所示。
- 首先提取图像表征 I _ f I\_f I_f和文本表征 T _ f T\_f T_f。
- 对于图像表征 I _ f i I\_f_i I_fi来说
- 文本表征 T _ f i T\f_i T_fi可以理解为InfoNCE中的正样本 p ( x t + k ∣ c t ) p(x{t+k}|c_{t}) p(xt+k∣ct),有 1 1 1个正样本
- T _ f j ≠ i T\f{j \neq i} T_fj=i可以理解为InfoNCE中的负样本 p ( x t + k ) p(x_{t+k}) p(xt+k),共有 n − 1 n-1 n−1个负样本
- 计算图像表征 I _ f i I\_f_i I_fi到所有文本表征 { T _ f 1 , T _ f 2 , . . . , T _ f n } \{T\_f_1,T\_f_2,...,T\_f_n\} {T_f1,T_f2,...,T_fn}的相似度距离,CLIP中采用temperature对相似度距离进行平滑及锐化
- 调用torch的cross_entroy_loss方法,计算损失
- 有 n n n个图像表征 { I _ f 1 , I _ f 2 , . . . , I _ f n } \{I\_f_1,I\_f_2,...,I\_f_n\} {I_f1,I_f2,...,I_fn},每一个图像表征计算一个损失,平均后得到 l o s s i loss_i lossi,平均的操作已集成到cross_entropy_loss方法中
- 对于文本表征 T _ f i T\_f_i T_fi来说,上述流程类似,得到 l o s s t loss_t losst
- 两个损失相加,得到最终损失