CLIP 的核心训练代码与对比损失的解释:中英双语

CLIP 的核心训练代码与对比损失的解释

CLIP(Contrastive Language-Image Pretraining)通过对比学习训练一个图像和文本嵌入模型,使得相关的图像和文本在嵌入空间中更加接近,而不相关的样本之间的距离更远。其核心损失函数使用了交叉熵损失(F.cross_entropy)来实现对比目标。


CLIP 的训练代码框架

以下是一个简化的 CLIP 训练核心代码:

python 复制代码
import torch
import torch.nn.functional as F

# 图像和文本的随机嵌入 (作为示例)
batch_size = 4  # 假设 batch size 为 4
embedding_dim = 512  # 假设嵌入向量的维度为 512

image_embeddings = torch.randn(batch_size, embedding_dim).cuda()  # 图像嵌入
text_embeddings = torch.randn(batch_size, embedding_dim).cuda()   # 文本嵌入

# 正则化嵌入向量 (归一化到单位球面)
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

# 计算相似性矩阵 (logits)
logits_per_image = torch.matmul(image_embeddings, text_embeddings.T)  # 图像对文本的相似度
logits_per_text = torch.matmul(text_embeddings, image_embeddings.T)  # 文本对图像的相似度

# 标签 (正样本是对角线,负样本是非对角线)
labels = torch.arange(batch_size).cuda()  # [0, 1, 2, 3]

# 交叉熵损失
image_loss = F.cross_entropy(logits_per_image, labels)
text_loss = F.cross_entropy(logits_per_text, labels)
loss = (image_loss + text_loss) / 2  # 对称损失

# 反向传播和优化
loss.backward()

交叉熵损失在 CLIP 中的工作原理
  1. 相似性矩阵(Logits)

    • logits_per_image 是一个 ( batch_size × batch_size \text{batch\_size} \times \text{batch\_size} batch_size×batch_size ) 的矩阵。
    • 例如,假设 batch size 为 4:
      logits_per_image = [ 1.2 0.3 − 0.8 0.5 0.4 1.5 0.1 − 0.2 0.0 − 0.3 2.0 0.6 − 0.5 0.7 0.8 1.3 ] \text{logits\_per\_image} = \begin{bmatrix} 1.2 & 0.3 & -0.8 & 0.5 \\ 0.4 & 1.5 & 0.1 & -0.2 \\ 0.0 & -0.3 & 2.0 & 0.6 \\ -0.5 & 0.7 & 0.8 & 1.3 \end{bmatrix} logits_per_image= 1.20.40.0−0.50.31.5−0.30.7−0.80.12.00.80.5−0.20.61.3
    • 对角线上的值是正样本的相似度,其他值是负样本的相似度。
  2. 交叉熵损失的计算

    • 对于每一行(例如第 ( i i i ) 行),交叉熵损失希望第 ( i i i ) 列的值(正样本)最大化,而其他列(负样本)最小化。
    • 计算公式如下:
      CrossEntropyLoss = − 1 N ∑ i = 1 N log ⁡ exp ⁡ ( logits [ i , i ] ) ∑ j = 1 N exp ⁡ ( logits [ i , j ] ) \text{CrossEntropyLoss} = -\frac{1}{N} \sum_{i=1}^N \log \frac{\exp(\text{logits}[i, i])}{\sum_{j=1}^N \exp(\text{logits}[i, j])} CrossEntropyLoss=−N1i=1∑Nlog∑j=1Nexp(logits[i,j])exp(logits[i,i])
    • 该公式中的分母(归一化项)将正样本和负样本的相似度联合建模,形成一种竞争关系。
  3. 正负样本的距离调节

    • 拉近正样本距离:通过最大化正样本(对角线值)在 softmax 分布中的概率。
    • 拉远负样本距离:通过对其他负样本的值施加抑制,使它们的 softmax 概率接近 0。

梯度更新的过程

假设有一个简化的例子:

  • ( logits [ i , i ] = 2.0 \text{logits}[i, i] = 2.0 logits[i,i]=2.0 )(正样本相似度)。
  • 其他负样本 ( logits [ i , j ] = [ 0.5 , − 1.0 , 0.2 ] \text{logits}[i, j] = [0.5, -1.0, 0.2] logits[i,j]=[0.5,−1.0,0.2] )(负样本相似度)。

计算 softmax 分布:
P ( positive ) = exp ⁡ ( 2.0 ) exp ⁡ ( 2.0 ) + exp ⁡ ( 0.5 ) + exp ⁡ ( − 1.0 ) + exp ⁡ ( 0.2 ) ≈ 0.71 P(\text{positive}) = \frac{\exp(2.0)}{\exp(2.0) + \exp(0.5) + \exp(-1.0) + \exp(0.2)} \approx 0.71 P(positive)=exp(2.0)+exp(0.5)+exp(−1.0)+exp(0.2)exp(2.0)≈0.71
P ( negative , j ) = exp ⁡ ( logits [ i , j ] ) denominator P(\text{negative}, j) = \frac{\exp(\text{logits}[i, j])}{\text{denominator}} P(negative,j)=denominatorexp(logits[i,j])

交叉熵损失:
Loss = − log ⁡ ( P ( positive ) ) ≈ − log ⁡ ( 0.71 ) = 0.34 \text{Loss} = -\log(P(\text{positive})) \approx -\log(0.71) = 0.34 Loss=−log(P(positive))≈−log(0.71)=0.34

梯度计算

  • 对于正样本,梯度是正的(推高相似度)。
  • 对于负样本,梯度是负的(降低相似度)。

CLIP 中的对比学习特点
  1. 对称性 :同时优化 logits_per_imagelogits_per_text
  2. 归一化嵌入:嵌入向量被归一化,确保相似度是单位球面上的点积。
  3. 负样本挖掘:批次内的每个样本都充当其他样本的负样本,增强对比学习的效率。

其他大模型中的应用
  • SimCLR:类似 CLIP,通过对比损失学习图像特征。
  • BYOL:虽然不显式使用负样本,但隐式地优化了正样本的相似性。
  • DINO:在图像自监督学习中使用对比损失,增强视图不变性。

总结

CLIP 的损失函数通过交叉熵优化,使正样本更接近,负样本更远离。这种对比学习的思想是大模型学习多模态特征的核心,广泛应用于图像-文本建模、视频检索等任务中。

Core Training Code of CLIP and Explanation of F.cross_entropy Loss

CLIP (Contrastive Language-Image Pretraining) uses contrastive learning to align image and text embeddings in a shared space. The loss function, implemented using F.cross_entropy, plays a crucial role in bringing related samples closer while pushing unrelated ones apart.


CLIP Training Code

Below is a simplified implementation of CLIP's core training loop:

python 复制代码
import torch
import torch.nn.functional as F

# Simulated embeddings for images and texts
batch_size = 4  # Example batch size
embedding_dim = 512  # Dimensionality of embeddings

# Random image and text embeddings
image_embeddings = torch.randn(batch_size, embedding_dim).cuda()
text_embeddings = torch.randn(batch_size, embedding_dim).cuda()

# Normalize embeddings (to unit length)
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

# Compute similarity matrices (logits)
logits_per_image = torch.matmul(image_embeddings, text_embeddings.T)  # Image-to-text similarity
logits_per_text = torch.matmul(text_embeddings, image_embeddings.T)  # Text-to-image similarity

# Labels (positive samples are diagonal elements)
labels = torch.arange(batch_size).cuda()  # [0, 1, 2, 3]

# Cross-entropy loss
image_loss = F.cross_entropy(logits_per_image, labels)
text_loss = F.cross_entropy(logits_per_text, labels)
loss = (image_loss + text_loss) / 2  # Symmetric loss

# Backpropagation and optimization
loss.backward()

How Does F.cross_entropy Work in CLIP?

1. Similarity Matrix (Logits):
  • logits_per_image is a ( batch_size × batch_size \text{batch\_size} \times \text{batch\_size} batch_size×batch_size ) matrix, where each entry represents the similarity score (dot product) between an image and a text embedding.
  • For example, if the batch size is 4:
    logits_per_image = [ 1.2 0.3 − 0.8 0.5 0.4 1.5 0.1 − 0.2 0.0 − 0.3 2.0 0.6 − 0.5 0.7 0.8 1.3 ] \text{logits\_per\_image} = \begin{bmatrix} 1.2 & 0.3 & -0.8 & 0.5 \\ 0.4 & 1.5 & 0.1 & -0.2 \\ 0.0 & -0.3 & 2.0 & 0.6 \\ -0.5 & 0.7 & 0.8 & 1.3 \end{bmatrix} logits_per_image= 1.20.40.0−0.50.31.5−0.30.7−0.80.12.00.80.5−0.20.61.3
    • Diagonal elements represent positive pairs (e.g., matching image-text pairs).
    • Non-diagonal elements represent negative pairs (e.g., mismatched pairs).
2. Cross-Entropy Loss Calculation:
  • The loss function is designed to maximize the similarity of positive pairs (diagonal elements) and minimize that of negative pairs (off-diagonal elements).
  • Formula for cross-entropy loss:
    Loss = − 1 N ∑ i = 1 N log ⁡ exp ⁡ ( logits [ i , i ] ) ∑ j = 1 N exp ⁡ ( logits [ i , j ] ) \text{Loss} = -\frac{1}{N} \sum_{i=1}^N \log \frac{\exp(\text{logits}[i, i])}{\sum_{j=1}^N \exp(\text{logits}[i, j])} Loss=−N1i=1∑Nlog∑j=1Nexp(logits[i,j])exp(logits[i,i])
    • Numerator: Exponentiated similarity score of the positive pair.
    • Denominator: Sum of exponentiated scores for all pairs (positive + negative).
3. Effects on Positive and Negative Samples:
  • Pull Positive Pairs Closer: The loss encourages diagonal elements (positive pairs) to dominate their respective rows by increasing their softmax probability.
  • Push Negative Pairs Further Apart: By competing for softmax probability, negative pairs are implicitly suppressed, making their similarity scores smaller.

Example Walkthrough

Assume a simplified example:

  • Batch size = 3.
  • Similarity matrix (logits):
    logits_per_image = [ 2.0 0.5 − 1.0 0.3 1.8 0.2 − 0.5 0.4 1.5 ] \text{logits\_per\_image} = \begin{bmatrix} 2.0 & 0.5 & -1.0 \\ 0.3 & 1.8 & 0.2 \\ -0.5 & 0.4 & 1.5 \end{bmatrix} logits_per_image= 2.00.3−0.50.51.80.4−1.00.21.5
Step 1: Compute Softmax Probabilities

For the first row:
P ( positive ) = exp ⁡ ( 2.0 ) exp ⁡ ( 2.0 ) + exp ⁡ ( 0.5 ) + exp ⁡ ( − 1.0 ) ≈ 0.71 P(\text{positive}) = \frac{\exp(2.0)}{\exp(2.0) + \exp(0.5) + \exp(-1.0)} \approx 0.71 P(positive)=exp(2.0)+exp(0.5)+exp(−1.0)exp(2.0)≈0.71
P ( negative , j = 2 ) = exp ⁡ ( 0.5 ) denominator ≈ 0.23 , P ( negative , j = 3 ) ≈ 0.06 P(\text{negative}, j=2) = \frac{\exp(0.5)}{\text{denominator}} \approx 0.23, \quad P(\text{negative}, j=3) \approx 0.06 P(negative,j=2)=denominatorexp(0.5)≈0.23,P(negative,j=3)≈0.06

Step 2: Compute Cross-Entropy Loss

The loss for the first row:
Loss = − log ⁡ ( P ( positive ) ) ≈ − log ⁡ ( 0.71 ) = 0.34 \text{Loss} = -\log(P(\text{positive})) \approx -\log(0.71) = 0.34 Loss=−log(P(positive))≈−log(0.71)=0.34

Step 3: Backpropagation

Gradients will:

  • Increase positive pair similarity (push 2.0 higher).
  • Decrease negative pair similarity (push 0.5 and -1.0 lower).

Key Features of CLIP's Contrastive Loss

  1. Symmetric Loss:

    • Both logits_per_image and logits_per_text are optimized.
    • Ensures embeddings for images and texts are equally aligned.
  2. Batch-Wise Negative Sampling:

    • Every other sample in the batch acts as a negative sample, increasing the diversity of training signals.
  3. Normalized Embeddings:

    • By normalizing embeddings to lie on a unit sphere, CLIP ensures consistent similarity scores and prevents unbounded gradients.

How Contrastive Loss Widens Negative Distances

The use of cross-entropy loss in CLIP inherently ensures that:

  • Positive pairs are maximized in similarity.
  • Negative pairs are minimized through softmax competition, which redistributes probability away from negatives toward positives.

Broader Applications of Contrastive Loss

  • SimCLR: Self-supervised contrastive learning for images.
  • BYOL: Implicit contrastive learning without explicit negatives.
  • DINO: Contrastive loss for self-supervised vision transformers.

Conclusion

CLIP's contrastive loss leverages softmax and cross-entropy to align multimodal embeddings effectively. This loss function drives positive samples closer and negative samples apart by optimizing for softmax probabilities, enabling powerful vision-language models for retrieval, captioning, and more.

后记

2024年12月13日21点47分于上海,在GPT4o大模型辅助下完成。

相关推荐
小熊科研路(同名GZH)12 分钟前
【电力负荷预测实例】采用新英格兰2024年最新电力负荷数据的BPNN神经网络电力负荷预测模型
人工智能·神经网络·机器学习
安全方案22 分钟前
免费下载 | 2024算网融合技术与产业白皮书
人工智能
星夜Zn34 分钟前
斯坦福大学发布最新AI形势报告(2024)第七章:Policy and Governance
论文阅读·人工智能·形势报告
珠穆拉玛峰44 分钟前
数字货币金融研究,深度学习虚拟币价格预测 数据集 市值top20 (2014年—2024年)
深度学习·金融·区块链
m0_748256561 小时前
Web 端语音对话 AI 示例:使用 Whisper 和 llama.cpp 构建语音聊天机器人
前端·人工智能·whisper
霍格沃兹测试开发学社测试人社区1 小时前
meta llama 大模型一个基础语言模型的集合
软件测试·人工智能·测试开发
阿正的梦工坊1 小时前
PyTorch 中detach 和no_grad的应用:以 Llama 3 冻结参数为例
人工智能·pytorch·llama
weixin_429248901 小时前
机器学习经典算法
人工智能·算法·机器学习
合方圆~小文1 小时前
工业现场的视频图像采集设备
java·c语言·人工智能·数码相机·物联网·信号处理
code04号1 小时前
机器学习环境
人工智能·机器学习