CLIP 的核心训练代码与对比损失的解释:中英双语

CLIP 的核心训练代码与对比损失的解释

CLIP(Contrastive Language-Image Pretraining)通过对比学习训练一个图像和文本嵌入模型,使得相关的图像和文本在嵌入空间中更加接近,而不相关的样本之间的距离更远。其核心损失函数使用了交叉熵损失(F.cross_entropy)来实现对比目标。


CLIP 的训练代码框架

以下是一个简化的 CLIP 训练核心代码:

python 复制代码
import torch
import torch.nn.functional as F

# 图像和文本的随机嵌入 (作为示例)
batch_size = 4  # 假设 batch size 为 4
embedding_dim = 512  # 假设嵌入向量的维度为 512

image_embeddings = torch.randn(batch_size, embedding_dim).cuda()  # 图像嵌入
text_embeddings = torch.randn(batch_size, embedding_dim).cuda()   # 文本嵌入

# 正则化嵌入向量 (归一化到单位球面)
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

# 计算相似性矩阵 (logits)
logits_per_image = torch.matmul(image_embeddings, text_embeddings.T)  # 图像对文本的相似度
logits_per_text = torch.matmul(text_embeddings, image_embeddings.T)  # 文本对图像的相似度

# 标签 (正样本是对角线,负样本是非对角线)
labels = torch.arange(batch_size).cuda()  # [0, 1, 2, 3]

# 交叉熵损失
image_loss = F.cross_entropy(logits_per_image, labels)
text_loss = F.cross_entropy(logits_per_text, labels)
loss = (image_loss + text_loss) / 2  # 对称损失

# 反向传播和优化
loss.backward()

交叉熵损失在 CLIP 中的工作原理
  1. 相似性矩阵(Logits)

    • logits_per_image 是一个 ( batch_size × batch_size \text{batch\_size} \times \text{batch\_size} batch_size×batch_size ) 的矩阵。
    • 例如,假设 batch size 为 4:
      logits_per_image = [ 1.2 0.3 − 0.8 0.5 0.4 1.5 0.1 − 0.2 0.0 − 0.3 2.0 0.6 − 0.5 0.7 0.8 1.3 ] \text{logits\_per\_image} = \begin{bmatrix} 1.2 & 0.3 & -0.8 & 0.5 \\ 0.4 & 1.5 & 0.1 & -0.2 \\ 0.0 & -0.3 & 2.0 & 0.6 \\ -0.5 & 0.7 & 0.8 & 1.3 \end{bmatrix} logits_per_image= 1.20.40.0−0.50.31.5−0.30.7−0.80.12.00.80.5−0.20.61.3
    • 对角线上的值是正样本的相似度,其他值是负样本的相似度。
  2. 交叉熵损失的计算

    • 对于每一行(例如第 ( i i i ) 行),交叉熵损失希望第 ( i i i ) 列的值(正样本)最大化,而其他列(负样本)最小化。
    • 计算公式如下:
      CrossEntropyLoss = − 1 N ∑ i = 1 N log ⁡ exp ⁡ ( logits [ i , i ] ) ∑ j = 1 N exp ⁡ ( logits [ i , j ] ) \text{CrossEntropyLoss} = -\frac{1}{N} \sum_{i=1}^N \log \frac{\exp(\text{logits}[i, i])}{\sum_{j=1}^N \exp(\text{logits}[i, j])} CrossEntropyLoss=−N1i=1∑Nlog∑j=1Nexp(logits[i,j])exp(logits[i,i])
    • 该公式中的分母(归一化项)将正样本和负样本的相似度联合建模,形成一种竞争关系。
  3. 正负样本的距离调节

    • 拉近正样本距离:通过最大化正样本(对角线值)在 softmax 分布中的概率。
    • 拉远负样本距离:通过对其他负样本的值施加抑制,使它们的 softmax 概率接近 0。

梯度更新的过程

假设有一个简化的例子:

  • ( logits [ i , i ] = 2.0 \text{logits}[i, i] = 2.0 logits[i,i]=2.0 )(正样本相似度)。
  • 其他负样本 ( logits [ i , j ] = [ 0.5 , − 1.0 , 0.2 ] \text{logits}[i, j] = [0.5, -1.0, 0.2] logits[i,j]=[0.5,−1.0,0.2] )(负样本相似度)。

计算 softmax 分布:
P ( positive ) = exp ⁡ ( 2.0 ) exp ⁡ ( 2.0 ) + exp ⁡ ( 0.5 ) + exp ⁡ ( − 1.0 ) + exp ⁡ ( 0.2 ) ≈ 0.71 P(\text{positive}) = \frac{\exp(2.0)}{\exp(2.0) + \exp(0.5) + \exp(-1.0) + \exp(0.2)} \approx 0.71 P(positive)=exp(2.0)+exp(0.5)+exp(−1.0)+exp(0.2)exp(2.0)≈0.71
P ( negative , j ) = exp ⁡ ( logits [ i , j ] ) denominator P(\text{negative}, j) = \frac{\exp(\text{logits}[i, j])}{\text{denominator}} P(negative,j)=denominatorexp(logits[i,j])

交叉熵损失:
Loss = − log ⁡ ( P ( positive ) ) ≈ − log ⁡ ( 0.71 ) = 0.34 \text{Loss} = -\log(P(\text{positive})) \approx -\log(0.71) = 0.34 Loss=−log(P(positive))≈−log(0.71)=0.34

梯度计算

  • 对于正样本,梯度是正的(推高相似度)。
  • 对于负样本,梯度是负的(降低相似度)。

CLIP 中的对比学习特点
  1. 对称性 :同时优化 logits_per_imagelogits_per_text
  2. 归一化嵌入:嵌入向量被归一化,确保相似度是单位球面上的点积。
  3. 负样本挖掘:批次内的每个样本都充当其他样本的负样本,增强对比学习的效率。

其他大模型中的应用
  • SimCLR:类似 CLIP,通过对比损失学习图像特征。
  • BYOL:虽然不显式使用负样本,但隐式地优化了正样本的相似性。
  • DINO:在图像自监督学习中使用对比损失,增强视图不变性。

总结

CLIP 的损失函数通过交叉熵优化,使正样本更接近,负样本更远离。这种对比学习的思想是大模型学习多模态特征的核心,广泛应用于图像-文本建模、视频检索等任务中。

Core Training Code of CLIP and Explanation of F.cross_entropy Loss

CLIP (Contrastive Language-Image Pretraining) uses contrastive learning to align image and text embeddings in a shared space. The loss function, implemented using F.cross_entropy, plays a crucial role in bringing related samples closer while pushing unrelated ones apart.


CLIP Training Code

Below is a simplified implementation of CLIP's core training loop:

python 复制代码
import torch
import torch.nn.functional as F

# Simulated embeddings for images and texts
batch_size = 4  # Example batch size
embedding_dim = 512  # Dimensionality of embeddings

# Random image and text embeddings
image_embeddings = torch.randn(batch_size, embedding_dim).cuda()
text_embeddings = torch.randn(batch_size, embedding_dim).cuda()

# Normalize embeddings (to unit length)
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

# Compute similarity matrices (logits)
logits_per_image = torch.matmul(image_embeddings, text_embeddings.T)  # Image-to-text similarity
logits_per_text = torch.matmul(text_embeddings, image_embeddings.T)  # Text-to-image similarity

# Labels (positive samples are diagonal elements)
labels = torch.arange(batch_size).cuda()  # [0, 1, 2, 3]

# Cross-entropy loss
image_loss = F.cross_entropy(logits_per_image, labels)
text_loss = F.cross_entropy(logits_per_text, labels)
loss = (image_loss + text_loss) / 2  # Symmetric loss

# Backpropagation and optimization
loss.backward()

How Does F.cross_entropy Work in CLIP?

1. Similarity Matrix (Logits):
  • logits_per_image is a ( batch_size × batch_size \text{batch\_size} \times \text{batch\_size} batch_size×batch_size ) matrix, where each entry represents the similarity score (dot product) between an image and a text embedding.
  • For example, if the batch size is 4:
    logits_per_image = [ 1.2 0.3 − 0.8 0.5 0.4 1.5 0.1 − 0.2 0.0 − 0.3 2.0 0.6 − 0.5 0.7 0.8 1.3 ] \text{logits\_per\_image} = \begin{bmatrix} 1.2 & 0.3 & -0.8 & 0.5 \\ 0.4 & 1.5 & 0.1 & -0.2 \\ 0.0 & -0.3 & 2.0 & 0.6 \\ -0.5 & 0.7 & 0.8 & 1.3 \end{bmatrix} logits_per_image= 1.20.40.0−0.50.31.5−0.30.7−0.80.12.00.80.5−0.20.61.3
    • Diagonal elements represent positive pairs (e.g., matching image-text pairs).
    • Non-diagonal elements represent negative pairs (e.g., mismatched pairs).
2. Cross-Entropy Loss Calculation:
  • The loss function is designed to maximize the similarity of positive pairs (diagonal elements) and minimize that of negative pairs (off-diagonal elements).
  • Formula for cross-entropy loss:
    Loss = − 1 N ∑ i = 1 N log ⁡ exp ⁡ ( logits [ i , i ] ) ∑ j = 1 N exp ⁡ ( logits [ i , j ] ) \text{Loss} = -\frac{1}{N} \sum_{i=1}^N \log \frac{\exp(\text{logits}[i, i])}{\sum_{j=1}^N \exp(\text{logits}[i, j])} Loss=−N1i=1∑Nlog∑j=1Nexp(logits[i,j])exp(logits[i,i])
    • Numerator: Exponentiated similarity score of the positive pair.
    • Denominator: Sum of exponentiated scores for all pairs (positive + negative).
3. Effects on Positive and Negative Samples:
  • Pull Positive Pairs Closer: The loss encourages diagonal elements (positive pairs) to dominate their respective rows by increasing their softmax probability.
  • Push Negative Pairs Further Apart: By competing for softmax probability, negative pairs are implicitly suppressed, making their similarity scores smaller.

Example Walkthrough

Assume a simplified example:

  • Batch size = 3.
  • Similarity matrix (logits):
    logits_per_image = [ 2.0 0.5 − 1.0 0.3 1.8 0.2 − 0.5 0.4 1.5 ] \text{logits\_per\_image} = \begin{bmatrix} 2.0 & 0.5 & -1.0 \\ 0.3 & 1.8 & 0.2 \\ -0.5 & 0.4 & 1.5 \end{bmatrix} logits_per_image= 2.00.3−0.50.51.80.4−1.00.21.5
Step 1: Compute Softmax Probabilities

For the first row:
P ( positive ) = exp ⁡ ( 2.0 ) exp ⁡ ( 2.0 ) + exp ⁡ ( 0.5 ) + exp ⁡ ( − 1.0 ) ≈ 0.71 P(\text{positive}) = \frac{\exp(2.0)}{\exp(2.0) + \exp(0.5) + \exp(-1.0)} \approx 0.71 P(positive)=exp(2.0)+exp(0.5)+exp(−1.0)exp(2.0)≈0.71
P ( negative , j = 2 ) = exp ⁡ ( 0.5 ) denominator ≈ 0.23 , P ( negative , j = 3 ) ≈ 0.06 P(\text{negative}, j=2) = \frac{\exp(0.5)}{\text{denominator}} \approx 0.23, \quad P(\text{negative}, j=3) \approx 0.06 P(negative,j=2)=denominatorexp(0.5)≈0.23,P(negative,j=3)≈0.06

Step 2: Compute Cross-Entropy Loss

The loss for the first row:
Loss = − log ⁡ ( P ( positive ) ) ≈ − log ⁡ ( 0.71 ) = 0.34 \text{Loss} = -\log(P(\text{positive})) \approx -\log(0.71) = 0.34 Loss=−log(P(positive))≈−log(0.71)=0.34

Step 3: Backpropagation

Gradients will:

  • Increase positive pair similarity (push 2.0 higher).
  • Decrease negative pair similarity (push 0.5 and -1.0 lower).

Key Features of CLIP's Contrastive Loss

  1. Symmetric Loss:

    • Both logits_per_image and logits_per_text are optimized.
    • Ensures embeddings for images and texts are equally aligned.
  2. Batch-Wise Negative Sampling:

    • Every other sample in the batch acts as a negative sample, increasing the diversity of training signals.
  3. Normalized Embeddings:

    • By normalizing embeddings to lie on a unit sphere, CLIP ensures consistent similarity scores and prevents unbounded gradients.

How Contrastive Loss Widens Negative Distances

The use of cross-entropy loss in CLIP inherently ensures that:

  • Positive pairs are maximized in similarity.
  • Negative pairs are minimized through softmax competition, which redistributes probability away from negatives toward positives.

Broader Applications of Contrastive Loss

  • SimCLR: Self-supervised contrastive learning for images.
  • BYOL: Implicit contrastive learning without explicit negatives.
  • DINO: Contrastive loss for self-supervised vision transformers.

Conclusion

CLIP's contrastive loss leverages softmax and cross-entropy to align multimodal embeddings effectively. This loss function drives positive samples closer and negative samples apart by optimizing for softmax probabilities, enabling powerful vision-language models for retrieval, captioning, and more.

后记

2024年12月13日21点47分于上海,在GPT4o大模型辅助下完成。

相关推荐
guoji778814 分钟前
安全与对齐的深层博弈:Gemini 3.1 Pro 安全护栏与对抗测试深度拆解
人工智能·安全
实在智能RPA22 分钟前
实在 Agent 和通用大模型有什么不一样?深度拆解 AI Agent 的感知、决策与执行逻辑
人工智能·ai
独隅27 分钟前
PyTorch 模型部署的 Docker 配置与性能调优深入指南
人工智能·pytorch·docker
lihuayong34 分钟前
OpenClaw 系统提示词
人工智能·prompt·提示词·openclaw
黑客说1 小时前
AI驱动剧情,解锁无限可能——AI游戏发展解析
人工智能·游戏
踩着两条虫1 小时前
AI驱动的Vue3应用开发平台深入探究(十):物料系统之内置组件库
android·前端·vue.js·人工智能·低代码·系统架构·rxjava
小仙女的小稀罕1 小时前
听不清重要会议录音急疯?这款常见AI工具听脑AI精准转译
开发语言·人工智能·python
reesn1 小时前
qwen3.5 0.8B纠正任务实践
人工智能·语言模型
实在智能RPA1 小时前
实在Agent 制造业落地案例:探寻工业大模型从实验室走向车间的实战路径
人工智能·ai