损失函数的目标是让图像 x i x_i xi;与正确文本描述 t i t_i ti的相似度最大化,同时与所有其他不相关文本 t j t_j tj 的相似度最小化,公式为:
L = − log exp ( sim ( x i , t i ) / τ ) ∑ j = 1 N exp ( sim ( x i , t j ) / τ ) \mathcal{L} = - \log\frac{\exp(\text{sim}(x_i,t_i)/\tau)}{\sum_{j = 1}^{N}\exp(\text{sim}(x_i,t_j)/\tau)} L=−log∑j=1Nexp(sim(xi,tj)/τ)exp(sim(xi,ti)/τ)
( x i ) (x_i) (xi):第 ( i ) (i) (i)个图像样本。
( t i ) (t_i) (ti):第 ( i ) (i) (i)个图像样本的正确文本描述。
( t j ) (t_j) (tj):其他文本描述(包括 ( t i ) (t_i) (ti)和其他与 ( x i ) (x_i) (xi)不匹配的文本描述)。
sim ( x i , t j ) \text{sim}(x_i,t_j) sim(xi,tj):图像 ( x i ) (x_i) (xi)和文本 ( t j ) (t_j) (tj)或者 ( t i ) (t_i) (ti)的相似度,一般使用余弦相似度来计算。
( τ ) (\tau) (τ):温度参数,用于控制相似度分布的平滑程度。
sim ( x i , t j ) \text{sim}(x_i, t_j) sim(xi,tj) 可以使用余弦相似度:
sim ( v i , t j ) = v i ⋅ t j ∥ v i ∥ ∥ t j ∥ \text{sim}(v_i,t_j)=\frac{v_i\cdot t_j}{\|v_i\|\|t_j\|} sim(vi,tj)=∥vi∥∥tj∥vi⋅tj
其中 ( v i v_i vi ) 是图像 ( x i x_i xi ) 的嵌入向量,( t j t_j tj ) 是文本 ( t j t_j tj ) 的嵌入向量。这样计算得到一个 相似度矩阵,矩阵中的每个元素表示批次中任意一对图像和文本的相似度。
图像损失部分 :对于每一个图像 ( x i x_i xi ),该部分的损失最大化它与正确文本 ( t i t_i ti ) 的相似度,同时最小化它与其他错误文本 ( t j t_j tj ) 的相似度。这一部分确保了图像能够找到正确的文本,也就是说图像编码器能够将图像嵌入到一个空间中,使得匹配的文本描述与它更接近。
L image = − 1 N ∑ i = 1 N log exp ( sim ( v i , t i ) / τ ) ∑ j = 1 N exp ( sim ( v i , t j ) / τ ) \mathcal{L}{\text{image}} = - \frac{1}{N}\sum{i = 1}^{N}\log\frac{\exp(\text{sim}(v_i,t_i)/\tau)}{\sum_{j = 1}^{N}\exp(\text{sim}(v_i,t_j)/\tau)} Limage=−N1i=1∑Nlog∑j=1Nexp(sim(vi,tj)/τ)exp(sim(vi,ti)/τ)
文本编码器损失函数
作用于文本检索图像:给定一个文本描述,可以找到与之最匹配的图像。
文本损失部分 :对于每一个文本 ( t i t_i ti ),该部分的损失最大化它与正确图像 ( x i x_i xi ) 的相似度,同时最小化它与其他错误图像 ( x j x_j xj ) 的相似度。这一部分确保了文本能够找到正确的图像 ,也就是说文本编码器能够将文本嵌入到一个空间中,使得匹配的图像与它更接近。
L text = − 1 N ∑ i = 1 N log exp ( sim ( v i , t i ) / τ ) ∑ j = 1 N exp ( sim ( v j , t i ) / τ ) \mathcal{L}{\text{text}} = - \frac{1}{N}\sum{i = 1}^{N}\log\frac{\exp(\text{sim}(v_i,t_i)/\tau)}{\sum_{j = 1}^{N}\exp(\text{sim}(v_j,t_i)/\tau)} Ltext=−N1i=1∑Nlog∑j=1Nexp(sim(vj,ti)/τ)exp(sim(vi,ti)/τ)
总损失函数
最大化图像和其正确文本描述之间的相似度,同时最小化图像和其他不匹配文本描述之间的相似度。
L CLIP = 1 2 ( L image + L text ) \mathcal{L}{\text{CLIP}}=\frac{1}{2}(\mathcal{L}{\text{image}}+\mathcal{L}_{\text{text}}) LCLIP=21(Limage+Ltext)
( L image \mathcal{L}_{\text{image}} Limage ):文本编码器损失函数
( L image \mathcal{L}_{\text{image}} Limage ):图像编码器损失函数