CLIP 的核心训练代码与对比损失的解释
CLIP(Contrastive Language-Image Pretraining)通过对比学习训练一个图像和文本嵌入模型,使得相关的图像和文本在嵌入空间中更加接近,而不相关的样本之间的距离更远。其核心损失函数使用了交叉熵损失(F.cross_entropy
)来实现对比目标。
CLIP 的训练代码框架
以下是一个简化的 CLIP 训练核心代码:
python
import torch
import torch.nn.functional as F
# 图像和文本的随机嵌入 (作为示例)
batch_size = 4 # 假设 batch size 为 4
embedding_dim = 512 # 假设嵌入向量的维度为 512
image_embeddings = torch.randn(batch_size, embedding_dim).cuda() # 图像嵌入
text_embeddings = torch.randn(batch_size, embedding_dim).cuda() # 文本嵌入
# 正则化嵌入向量 (归一化到单位球面)
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
# 计算相似性矩阵 (logits)
logits_per_image = torch.matmul(image_embeddings, text_embeddings.T) # 图像对文本的相似度
logits_per_text = torch.matmul(text_embeddings, image_embeddings.T) # 文本对图像的相似度
# 标签 (正样本是对角线,负样本是非对角线)
labels = torch.arange(batch_size).cuda() # [0, 1, 2, 3]
# 交叉熵损失
image_loss = F.cross_entropy(logits_per_image, labels)
text_loss = F.cross_entropy(logits_per_text, labels)
loss = (image_loss + text_loss) / 2 # 对称损失
# 反向传播和优化
loss.backward()
交叉熵损失在 CLIP 中的工作原理
-
相似性矩阵(Logits):
logits_per_image
是一个 ( batch_size × batch_size \text{batch\_size} \times \text{batch\_size} batch_size×batch_size ) 的矩阵。- 例如,假设 batch size 为 4:
logits_per_image = [ 1.2 0.3 − 0.8 0.5 0.4 1.5 0.1 − 0.2 0.0 − 0.3 2.0 0.6 − 0.5 0.7 0.8 1.3 ] \text{logits\_per\_image} = \begin{bmatrix} 1.2 & 0.3 & -0.8 & 0.5 \\ 0.4 & 1.5 & 0.1 & -0.2 \\ 0.0 & -0.3 & 2.0 & 0.6 \\ -0.5 & 0.7 & 0.8 & 1.3 \end{bmatrix} logits_per_image= 1.20.40.0−0.50.31.5−0.30.7−0.80.12.00.80.5−0.20.61.3 - 对角线上的值是正样本的相似度,其他值是负样本的相似度。
-
交叉熵损失的计算:
- 对于每一行(例如第 ( i i i ) 行),交叉熵损失希望第 ( i i i ) 列的值(正样本)最大化,而其他列(负样本)最小化。
- 计算公式如下:
CrossEntropyLoss = − 1 N ∑ i = 1 N log exp ( logits [ i , i ] ) ∑ j = 1 N exp ( logits [ i , j ] ) \text{CrossEntropyLoss} = -\frac{1}{N} \sum_{i=1}^N \log \frac{\exp(\text{logits}[i, i])}{\sum_{j=1}^N \exp(\text{logits}[i, j])} CrossEntropyLoss=−N1i=1∑Nlog∑j=1Nexp(logits[i,j])exp(logits[i,i]) - 该公式中的分母(归一化项)将正样本和负样本的相似度联合建模,形成一种竞争关系。
-
正负样本的距离调节:
- 拉近正样本距离:通过最大化正样本(对角线值)在 softmax 分布中的概率。
- 拉远负样本距离:通过对其他负样本的值施加抑制,使它们的 softmax 概率接近 0。
梯度更新的过程
假设有一个简化的例子:
- ( logits [ i , i ] = 2.0 \text{logits}[i, i] = 2.0 logits[i,i]=2.0 )(正样本相似度)。
- 其他负样本 ( logits [ i , j ] = [ 0.5 , − 1.0 , 0.2 ] \text{logits}[i, j] = [0.5, -1.0, 0.2] logits[i,j]=[0.5,−1.0,0.2] )(负样本相似度)。
计算 softmax 分布:
P ( positive ) = exp ( 2.0 ) exp ( 2.0 ) + exp ( 0.5 ) + exp ( − 1.0 ) + exp ( 0.2 ) ≈ 0.71 P(\text{positive}) = \frac{\exp(2.0)}{\exp(2.0) + \exp(0.5) + \exp(-1.0) + \exp(0.2)} \approx 0.71 P(positive)=exp(2.0)+exp(0.5)+exp(−1.0)+exp(0.2)exp(2.0)≈0.71
P ( negative , j ) = exp ( logits [ i , j ] ) denominator P(\text{negative}, j) = \frac{\exp(\text{logits}[i, j])}{\text{denominator}} P(negative,j)=denominatorexp(logits[i,j])
交叉熵损失:
Loss = − log ( P ( positive ) ) ≈ − log ( 0.71 ) = 0.34 \text{Loss} = -\log(P(\text{positive})) \approx -\log(0.71) = 0.34 Loss=−log(P(positive))≈−log(0.71)=0.34
梯度计算:
- 对于正样本,梯度是正的(推高相似度)。
- 对于负样本,梯度是负的(降低相似度)。
CLIP 中的对比学习特点
- 对称性 :同时优化
logits_per_image
和logits_per_text
。 - 归一化嵌入:嵌入向量被归一化,确保相似度是单位球面上的点积。
- 负样本挖掘:批次内的每个样本都充当其他样本的负样本,增强对比学习的效率。
其他大模型中的应用
- SimCLR:类似 CLIP,通过对比损失学习图像特征。
- BYOL:虽然不显式使用负样本,但隐式地优化了正样本的相似性。
- DINO:在图像自监督学习中使用对比损失,增强视图不变性。
总结
CLIP 的损失函数通过交叉熵优化,使正样本更接近,负样本更远离。这种对比学习的思想是大模型学习多模态特征的核心,广泛应用于图像-文本建模、视频检索等任务中。
Core Training Code of CLIP and Explanation of F.cross_entropy
Loss
CLIP (Contrastive Language-Image Pretraining) uses contrastive learning to align image and text embeddings in a shared space. The loss function, implemented using F.cross_entropy
, plays a crucial role in bringing related samples closer while pushing unrelated ones apart.
CLIP Training Code
Below is a simplified implementation of CLIP's core training loop:
python
import torch
import torch.nn.functional as F
# Simulated embeddings for images and texts
batch_size = 4 # Example batch size
embedding_dim = 512 # Dimensionality of embeddings
# Random image and text embeddings
image_embeddings = torch.randn(batch_size, embedding_dim).cuda()
text_embeddings = torch.randn(batch_size, embedding_dim).cuda()
# Normalize embeddings (to unit length)
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
# Compute similarity matrices (logits)
logits_per_image = torch.matmul(image_embeddings, text_embeddings.T) # Image-to-text similarity
logits_per_text = torch.matmul(text_embeddings, image_embeddings.T) # Text-to-image similarity
# Labels (positive samples are diagonal elements)
labels = torch.arange(batch_size).cuda() # [0, 1, 2, 3]
# Cross-entropy loss
image_loss = F.cross_entropy(logits_per_image, labels)
text_loss = F.cross_entropy(logits_per_text, labels)
loss = (image_loss + text_loss) / 2 # Symmetric loss
# Backpropagation and optimization
loss.backward()
How Does F.cross_entropy
Work in CLIP?
1. Similarity Matrix (Logits):
logits_per_image
is a ( batch_size × batch_size \text{batch\_size} \times \text{batch\_size} batch_size×batch_size ) matrix, where each entry represents the similarity score (dot product) between an image and a text embedding.- For example, if the batch size is 4:
logits_per_image = [ 1.2 0.3 − 0.8 0.5 0.4 1.5 0.1 − 0.2 0.0 − 0.3 2.0 0.6 − 0.5 0.7 0.8 1.3 ] \text{logits\_per\_image} = \begin{bmatrix} 1.2 & 0.3 & -0.8 & 0.5 \\ 0.4 & 1.5 & 0.1 & -0.2 \\ 0.0 & -0.3 & 2.0 & 0.6 \\ -0.5 & 0.7 & 0.8 & 1.3 \end{bmatrix} logits_per_image= 1.20.40.0−0.50.31.5−0.30.7−0.80.12.00.80.5−0.20.61.3- Diagonal elements represent positive pairs (e.g., matching image-text pairs).
- Non-diagonal elements represent negative pairs (e.g., mismatched pairs).
2. Cross-Entropy Loss Calculation:
- The loss function is designed to maximize the similarity of positive pairs (diagonal elements) and minimize that of negative pairs (off-diagonal elements).
- Formula for cross-entropy loss:
Loss = − 1 N ∑ i = 1 N log exp ( logits [ i , i ] ) ∑ j = 1 N exp ( logits [ i , j ] ) \text{Loss} = -\frac{1}{N} \sum_{i=1}^N \log \frac{\exp(\text{logits}[i, i])}{\sum_{j=1}^N \exp(\text{logits}[i, j])} Loss=−N1i=1∑Nlog∑j=1Nexp(logits[i,j])exp(logits[i,i])- Numerator: Exponentiated similarity score of the positive pair.
- Denominator: Sum of exponentiated scores for all pairs (positive + negative).
3. Effects on Positive and Negative Samples:
- Pull Positive Pairs Closer: The loss encourages diagonal elements (positive pairs) to dominate their respective rows by increasing their softmax probability.
- Push Negative Pairs Further Apart: By competing for softmax probability, negative pairs are implicitly suppressed, making their similarity scores smaller.
Example Walkthrough
Assume a simplified example:
- Batch size = 3.
- Similarity matrix (logits):
logits_per_image = [ 2.0 0.5 − 1.0 0.3 1.8 0.2 − 0.5 0.4 1.5 ] \text{logits\_per\_image} = \begin{bmatrix} 2.0 & 0.5 & -1.0 \\ 0.3 & 1.8 & 0.2 \\ -0.5 & 0.4 & 1.5 \end{bmatrix} logits_per_image= 2.00.3−0.50.51.80.4−1.00.21.5
Step 1: Compute Softmax Probabilities
For the first row:
P ( positive ) = exp ( 2.0 ) exp ( 2.0 ) + exp ( 0.5 ) + exp ( − 1.0 ) ≈ 0.71 P(\text{positive}) = \frac{\exp(2.0)}{\exp(2.0) + \exp(0.5) + \exp(-1.0)} \approx 0.71 P(positive)=exp(2.0)+exp(0.5)+exp(−1.0)exp(2.0)≈0.71
P ( negative , j = 2 ) = exp ( 0.5 ) denominator ≈ 0.23 , P ( negative , j = 3 ) ≈ 0.06 P(\text{negative}, j=2) = \frac{\exp(0.5)}{\text{denominator}} \approx 0.23, \quad P(\text{negative}, j=3) \approx 0.06 P(negative,j=2)=denominatorexp(0.5)≈0.23,P(negative,j=3)≈0.06
Step 2: Compute Cross-Entropy Loss
The loss for the first row:
Loss = − log ( P ( positive ) ) ≈ − log ( 0.71 ) = 0.34 \text{Loss} = -\log(P(\text{positive})) \approx -\log(0.71) = 0.34 Loss=−log(P(positive))≈−log(0.71)=0.34
Step 3: Backpropagation
Gradients will:
- Increase positive pair similarity (push 2.0 higher).
- Decrease negative pair similarity (push 0.5 and -1.0 lower).
Key Features of CLIP's Contrastive Loss
-
Symmetric Loss:
- Both
logits_per_image
andlogits_per_text
are optimized. - Ensures embeddings for images and texts are equally aligned.
- Both
-
Batch-Wise Negative Sampling:
- Every other sample in the batch acts as a negative sample, increasing the diversity of training signals.
-
Normalized Embeddings:
- By normalizing embeddings to lie on a unit sphere, CLIP ensures consistent similarity scores and prevents unbounded gradients.
How Contrastive Loss Widens Negative Distances
The use of cross-entropy loss in CLIP inherently ensures that:
- Positive pairs are maximized in similarity.
- Negative pairs are minimized through softmax competition, which redistributes probability away from negatives toward positives.
Broader Applications of Contrastive Loss
- SimCLR: Self-supervised contrastive learning for images.
- BYOL: Implicit contrastive learning without explicit negatives.
- DINO: Contrastive loss for self-supervised vision transformers.
Conclusion
CLIP's contrastive loss leverages softmax and cross-entropy to align multimodal embeddings effectively. This loss function drives positive samples closer and negative samples apart by optimizing for softmax probabilities, enabling powerful vision-language models for retrieval, captioning, and more.
后记
2024年12月13日21点47分于上海,在GPT4o大模型辅助下完成。