对比学习(Contrastive Learning)

一直想要学习一下对比学习,这里正好有机会学习整理一下,方便以后查阅。

相关博客链接:监督(Supervised)、半监督(Semi-Supervised)、无监督(Unsupervised)、自监督(Self-Supervised)学习

对比学习(Contrastive Learning)

  • [Contrastive Learning](#Contrastive Learning)
    • [1. 什么是对比学习](#1. 什么是对比学习)
    • [2. 工作流程](#2. 工作流程)
    • [3. 损失函数](#3. 损失函数)
    • [4. 发展历程](#4. 发展历程)
      • [4.1 早期工作(2000s-2010s)](#4.1 早期工作(2000s-2010s))
      • [4.2 计算机视觉中的突破(2018-2020)](#4.2 计算机视觉中的突破(2018-2020))
      • [4.3 无负样本对比学习(2020-2021)](#4.3 无负样本对比学习(2020-2021))
      • [4.4 多模态与跨领域扩展(2021-至今)](#4.4 多模态与跨领域扩展(2021-至今))
      • [4.5 理论与优化(2022-至今)](#4.5 理论与优化(2022-至今))
    • [5. 主要方法对比](#5. 主要方法对比)
    • [6. 应用](#6. 应用)
    • [7. 优势 and 缺点](#7. 优势 and 缺点)
      • [7.1 优势](#7.1 优势)
      • [7.2 缺点](#7.2 缺点)
    • [8. 未来趋势](#8. 未来趋势)
  • 对比学习在有监督分类任务上的应用(InfoNCE+Cross-Entropy)
  • 参考资料

Contrastive Learning

1. 什么是对比学习

对比学习(Contrastive Learning)是一种自监督学习方法。


对比学习在深度学习中的位置

通过对比正样本(相似数据)和负样本(不相似数据),学习数据的高质量表示。它在无需标注数据的情况下,利用数据的内在结构(如图像的空间一致性、文本的语义关联),生成监督信号,广泛应用于计算机视觉、自然语言处理和多模态任务。


对比学习原理示意


有监督对比学习原理示意

2. 工作流程

对比学习的工作流程通常包括以下步骤:

  1. 数据增强:对输入数据(如图像)应用随机增强(如随机裁剪、翻转、颜色抖动),生成正样本对。例如,同一图像的两个增强视图被视为正样本,其他图像的视图为负样本。
  2. 特征提取:使用神经网络(如卷积神经网络CNN或Transformer)将输入数据映射到特征空间,生成特征向量。
  3. 对比损失优化:计算正样本和负样本的相似度,优化InfoNCE损失或其他对比损失(如Cosine Loss)。
  4. 下游任务微调:使用预训练的特征表示,微调模型以适配特定任务。

3. 损失函数

损失函数是对比学习的核心,定义了如何衡量正样本和负样本的相似性。以下是主要损失函数的技术细节:

  • 对比损失(Contrastive Loss)

    最早用于深度度量学习,目标是使同类样本的嵌入距离最小化,不同类样本的嵌入距离最大化。公式为:
    L = 1 2 N ∑ i = 1 N ( y i d i 2 + ( 1 − y i ) max ⁡ ( m a r g i n − d i , 0 ) 2 ) L = \frac{1}{2N} \sum_{i=1}^{N} \left( y_i d_i^2 + (1 - y_i) \max(margin - d_i, 0)^2 \right) L=2N1i=1∑N(yidi2+(1−yi)max(margin−di,0)2)

    其中, N N N是样本对数, y i = 0 y_i=0 yi=0表示相似, y i = 1 y_i=1 yi=1表示不相似, d i d_i di是嵌入之间的距离, m a r g i n margin margin是超参数,定义不相似样本的最小距离。

    来源:Learning a Similarity Metric Discriminatively, with Application to Face Verification

  • Triplet损失(Triplet Loss)

    由FaceNet提出,使用一个锚点(anchor)、一个正样本(positive)和一个负样本(negative),确保锚点与正样本的距离小于锚点与负样本的距离,公式为:
    L = ∑ x ∈ X max ⁡ ( 0 , ∥ f ( x ) − f ( x + ) ∥ 2 2 − ∥ f ( x ) − f ( x − ) ∥ 2 2 + ϵ ) L = \sum_{\mathbf{x} \in \mathcal{X}} \max(0, \|f(\mathbf{x}) - f(\mathbf{x}^+)\|^2_2 - \|f(\mathbf{x}) - f(\mathbf{x}^-)\|^2_2 + \epsilon) L=x∈X∑max(0,∥f(x)−f(x+)∥22−∥f(x)−f(x−)∥22+ϵ)

    其中, ϵ \epsilon ϵ是边距参数,选择具有挑战性的负样本至关重要。

    来源:FaceNet: A Unified Embedding for Face Recognition and Clustering

  • N-pair损失(N-pair Loss)
    推广Triplet损失,包含一个正样本和 N − 1 N-1 N−1个负样本,公式为:
    L = log ⁡ ( 1 + ∑ i = 1 N − 1 exp ⁡ ( f ( x ) ⊤ f ( x i − ) − f ( x ) ⊤ f ( x + ) ) ) L = \log(1 + \sum_{i=1}^{N-1} \exp(f(\mathbf{x})^\top f(\mathbf{x}^-_i) - f(\mathbf{x})^\top f(\mathbf{x}^+))) L=log(1+i=1∑N−1exp(f(x)⊤f(xi−)−f(x)⊤f(x+)))
    等价于使用一个负样本的softmax损失,提高效率。
    来源:Improved Deep Metric Learning with Multi-class N-pair Loss Objective
  • InfoNCE损失(NT-Xent)
    用于SimCLR,最大化不同增强视图之间的协议,公式为:
    L = − 1 N ∑ i = 1 N log ⁡ exp ⁡ ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N 1 [ k ≠ i ] exp ⁡ ( s i m ( z i , z k ) / τ ) L = - \frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(sim(z_i, z_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq i]} \exp(sim(z_i, z_k) / \tau)} L=−N1i=1∑Nlog∑k=12N1[k=i]exp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)
    其中, s i m ( u , v ) = u ⊤ v ∣ ∣ u ∣ ∣ ∣ ∣ v ∣ ∣ sim(u, v) = \frac{u^\top v}{||u|| ||v||} sim(u,v)=∣∣u∣∣∣∣v∣∣u⊤v是余弦相似度, τ \tau τ是温度参数,批次中的其他样本作为负样本。
    来源:A Simple Framework for Contrastive Learning of Visual Representations

其他损失函数包括:

  • Lifted Structured Loss:利用批次中的所有成对边,提高效率,挖掘困难负样本。
  • Noise Contrastive Estimation (NCE):通过逻辑回归区分目标和噪声,扩展到多个负样本。
  • Soft-Nearest Neighbors Loss:扩展到多个正样本,使用温度(\tau)调整特征集中度。

以下是关键损失函数和技术参数的对比:

损失函数 公式 关键参数 应用场景
对比损失 L = 1 2 N ∑ i = 1 N ( y i d i 2 + ( 1 − y i ) max ⁡ ( m a r g i n − d i , 0 ) 2 ) L = \frac{1}{2N} \sum_{i=1}^{N} \left( y_i d_i^2 + (1 - y_i) \max(margin - d_i, 0)^2 \right) L=2N1∑i=1N(yidi2+(1−yi)max(margin−di,0)2) m a r g i n margin margin 早期度量学习
Triplet损失 L = ∑ x ∈ X max ⁡ ( 0 , ∣ f ( x ) − f ( x + ) ∣ 2 2 − ∣ f ( x ) − f ( x − ) ∣ 2 2 + ϵ ) L = \sum_{\mathbf{x} \in \mathcal{X}} \max(0, |f(\mathbf{x}) - f(\mathbf{x}^+)|^2_2 - |f(\mathbf{x}) - f(\mathbf{x}^-)|^2_2 + \epsilon) L=∑x∈Xmax(0,∣f(x)−f(x+)∣22−∣f(x)−f(x−)∣22+ϵ) ϵ \epsilon ϵ 人脸识别
N-pair损失 L = log ⁡ ( 1 + ∑ i = 1 N − 1 exp ⁡ ( f ( x ) ⊤ f ( x i − ) − f ( x ) ⊤ f ( x + ) ) ) L = \log(1 + \sum_{i=1}^{N-1} \exp(f(\mathbf{x})^\top f(\mathbf{x}^-_i) - f(\mathbf{x})^\top f(\mathbf{x}^+))) L=log(1+∑i=1N−1exp(f(x)⊤f(xi−)−f(x)⊤f(x+))) N − 1 N-1 N−1负样本 高效多负样本对比
InfoNCE(NT-Xent) L = − 1 N ∑ i = 1 N log ⁡ exp ⁡ ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N 1 [ k ≠ i ] exp ⁡ ( s i m ( z i , z k ) / τ ) L = - \frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(sim(z_i, z_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq i]} \exp(sim(z_i, z_k) / \tau)} L=−N1∑i=1Nlog∑k=12N1[k=i]exp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ) τ \tau τ温度参数 SimCLR, 大批量训练

4. 发展历程

对比学习的发展经历了从早期理论到现代高效方法的演变,以下按时间线详细阐述:

4.1 早期工作(2000s-2010s)

  • 理论基础:对比学习的根源可以追溯到噪声对比估计(Noise-Contrastive Estimation, NCE)和负采样(Negative Sampling),最初用于语言模型(如Word2Vec)。这些方法通过对比目标词与上下文词(正样本)和其他随机词(负样本),学习词向量。
  • 代表工作
    • Word2Vec (2013):Mikolov等提出的Word2Vec通过负采样学习词嵌入,奠定了对比学习在NLP中的基础。
    • Deep InfoMax (2018):Hjelm等提出通过最大化全局和局部特征的互信息,学习图像表示,扩展对比学习到视觉领域。

4.2 计算机视觉中的突破(2018-2020)

  • 背景:监督学习(如ImageNet预训练)依赖昂贵的标注数据,促使研究者探索自监督方法。对比学习在视觉领域爆发,受益于深度网络(如ResNet、ViT)和数据增强技术。
  • 代表工作
    • CPC (Contrastive Predictive Coding, 2018):Oord等提出,通过预测序列中未来部分的表示,学习图像和语音的特征,引入时间序列对比任务。
    • MoCo (Momentum Contrast, 2019):He等提出动量编码器(Momentum Encoder)和队列机制,存储大量负样本,解决大批量训练的内存限制,在ImageNet上取得显著成果。
    • SimCLR (Simple Framework for Contrastive Learning, 2020):Chen等简化对比学习框架,强调强数据增强(如随机裁剪、颜色变换)和大批量训练,证明数据增强是关键因素,在ImageNet上达到接近监督学习的性能。
    • MoCo v2 (2020):Chen等结合SimCLR的改进(如投影头、增强策略),进一步提升MoCo性能。

4.3 无负样本对比学习(2020-2021)

  • 背景:传统对比学习依赖大量负样本,计算和存储成本高。新方法探索避免负样本,聚焦正样本之间的关系,降低计算复杂性。
  • 代表工作
    • BYOL (Bootstrap Your Own Latent, 2020):Grill等提出不使用负样本,通过两个网络(在线网络和目标网络)自举学习,目标网络通过动量更新,稳定训练过程。
    • SimSiam (2021):Chen等进一步简化BYOL,去除动量编码器,直接优化正样本对的相似性,证明对比学习无需负样本也能学习有效表示。
    • SwAV (Swapping Assignments between Views, 2020):Caron等结合聚类和对比学习,通过在线聚类分配伪标签,增强表示一致性。

4.4 多模态与跨领域扩展(2021-至今)

  • 背景:对比学习扩展到多模态(如图像-文本、视频-音频)和非视觉领域(如NLP、图数据)。大规模数据集(如WebImageText、LAION)和Transformer架构推动了性能提升。
  • 代表工作
    • CLIP (Contrastive Language-Image Pretraining, 2021):Radford等使用图像-文本对进行对比学习,学习跨模态表示,在零样本分类、图像检索等任务中表现出色。
    • ALIGN (2021):Jia等类似CLIP,强调大规模噪声数据的高效利用。
    • Video Contrastive Learning:方法如VideoMoCo、CVRL扩展对比学习到视频领域,利用时序信息。

4.5 理论与优化(2022-至今)

  • 理论进展:研究者分析对比学习的收敛性、表示崩塌问题(Representation Collapse)以及正负样本的作用,提出正则化方法(如Spectral Contrastive Loss)防止表示退化。
  • 优化方法:蒸馏技术(如DINO)结合对比学习和自监督蒸馏;数据高效方法(如iBOT)结合掩码建模和对比学习。
  • 大规模应用:对比学习成为视觉基础模型(如DINOv2、EVA-CLIP)的重要组成部分。

5. 主要方法对比

以下表格总结了主要对比学习方法的对比:

方法 年份 核心创新 正样本生成 负样本 损失函数 适用场景
CPC 2018 预测序列未来表示 时序上下文 InfoNCE 图像、语音
MoCo 2019 动量编码器+负样本队列 数据增强视图 InfoNCE 图像分类、检测
SimCLR 2020 强数据增强+大批量训练 数据增强视图 InfoNCE 图像分类、检测
BYOL 2020 无负样本,自举学习 数据增强视图 MSE(预测误差) 图像分类
SimSiam 2021 无负样本,简化BYOL 数据增强视图 Cosine Loss 图像分类
SwAV 2020 在线聚类+对比学习 数据增强视图+伪标签 Sinkhorn 图像分类、检测
CLIP 2021 图像-文本跨模态对比 图像-文本对 InfoNCE 多模态任务
DINO 2021 自监督蒸馏+对比学习 数据增强视图 Cross-Entropy 图像分类、分割

6. 应用

对比学习生成的表示广泛应用于:

  • 计算机视觉:图像分类(如ImageNet)、目标检测(如COCO)、语义分割(如ADE20K)、图像检索和生成。
  • 多模态任务:图像-文本检索(如CLIP)、视觉问答(VQA)、图像标注。
  • 视频处理:动作识别(如Kinetics)、视频分类和检索。
  • 自然语言处理:句嵌入(如Sentence-BERT)、文本分类和匹配。
  • 图数据:图节点分类、图表示学习(如GraphCL)。

7. 优势 and 缺点

7.1 优势

  • 无需标注:利用无标注数据,降低数据准备成本。
  • 通用性:学习的表示适用于多种下游任务。
  • 高效性:相较监督学习,预训练阶段更灵活,微调成本低。
  • 鲁棒性:数据增强和对比机制使表示对噪声和变换鲁棒。
  • 多模态支持:跨模态对比学习(如CLIP)扩展了应用场景。

7.2 缺点

  • 计算成本高:大批量训练(如SimCLR)或负样本队列(如MoCo)需要大量内存和算力。
  • 负样本依赖:传统方法需要大量负样本,可能引入噪声或不相关信息。
  • 表示崩塌:若正负样本设计不当,特征可能退化为常数(需正则化或无负样本方法)。
  • 数据增强依赖:性能高度依赖增强策略,需针对任务精心设计。
  • 领域迁移性:在特定领域(如医学影像)可能需要重新设计增强或正样本策略。

8. 未来趋势

  • 无负样本方法:进一步探索BYOL、SimSiam等无负样本框架,降低计算复杂性。
  • 多模态融合:结合图像、文本、音频、视频等多模态数据,构建统一表示。
  • 数据高效学习:针对小规模或特定领域数据优化对比学习。
  • 理论深入:分析对比学习的收敛性、表示质量和泛化能力。
  • 与生成模型结合:结合MAE、扩散模型等生成式自监督方法,提升表示的语义性。
  • 高效部署:通过蒸馏、量化等技术,降低模型规模和推理成本。

对比学习在有监督分类任务上的应用(InfoNCE+Cross-Entropy)

可以通过联合 InfoNCE 和交叉熵损失(Cross-Entropy Loss)来训练一个分类模型。这种方法结合了对比学习的优势(学习鲁棒的特征表示)和交叉熵损失的分类能力,能够提升模型的性能,尤其在有标签的数据集上。联合损失的训练方式通常是将 InfoNCE 损失(用于增强类内紧凑性和类间分离性)与交叉熵损失(用于直接优化分类性能)进行加权组合。

相关论文

以下是一些与联合 InfoNCE 和交叉熵损失相关的代表性论文,提供了理论依据和实验验证:

  1. "Supervised Contrastive Learning" (Khosla et al., 2020)

    • 链接 : arXiv:2004.11362
    • 贡献: 提出了 Supervised Contrastive Loss(SupCon),利用标签信息构造正负样本对,并展示了 SupCon 损失与交叉熵损失联合训练的优越性。论文中明确提到可以通过加权组合 SupCon 损失和交叉熵损失来优化分类模型。
    • 相关内容: 论文实验表明,SupCon 损失在预训练阶段学习到更鲁棒的特征表示,结合交叉熵损失微调后,分类性能显著提升。
  2. "Decoupled Contrastive Learning" (Yeh et al., 2021)

    • 链接 : arXiv:2110.06848
    • 贡献: 提出了一种改进的对比学习框架,探讨了如何在有监督场景中结合对比损失和分类损失。论文中提到联合训练可以缓解对比学习中的表示坍塌问题。
  3. "A Simple Framework for Contrastive Learning of Visual Representations" (Chen et al., 2020, SimCLR)

    • 链接 : arXiv:2002.05709
    • 贡献: 虽然主要聚焦无监督对比学习,但 SimCLR 的 InfoNCE 损失被广泛扩展到有监督场景。许多后续工作(如 SupCon)基于 InfoNCE 提出联合训练策略。
  4. "Combining Contrastive and Supervised Learning for Visual Representation Learning" (Various works)

    • 多篇后续论文探讨了如何在有监督场景中结合 InfoNCE 和交叉熵损失。例如,CLIP(Radford et al., 2021, arXiv:2103.00020)的跨模态对比学习也启发了有监督联合训练的研究。

这些论文提供了联合 InfoNCE 和交叉熵损失的理论支持,SupCon 论文是最直接相关的参考。

代码示例

以下是一个基于 PyTorch 的代码示例,展示如何在 CIFAR-10 数据集上使用联合 InfoNCE(Supervised Contrastive Loss)和交叉熵损失训练一个分类模型。代码包括数据增强、模型架构、联合损失函数和两阶段训练(联合训练 + 微调)。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np

# 超参数
batch_size = 256
embedding_dim = 128
epochs = 10
temperature = 0.1
learning_rate = 0.001
supcon_weight = 1.0  # SupCon 损失的权重
ce_weight = 1.0      # 交叉熵损失的权重
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. 数据增强
def get_transforms():
    transform = transforms.Compose([
        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    return transform

# 2. 模型架构(编码器 + 投影头 + 分类头)
class JointModel(nn.Module):
    def __init__(self, base_encoder, projection_dim=128, num_classes=10):
        super(JointModel, self).__init__()
        self.encoder = base_encoder
        # 投影头:用于对比学习
        self.projector = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )
        # 分类头:用于交叉熵损失
        self.classifier = nn.Linear(512, num_classes)
    
    def forward(self, x):
        # 提取特征
        h = self.encoder(x)
        # 投影到对比学习空间
        z = self.projector(h)
        # 分类输出
        logits = self.classifier(h)
        return h, z, logits

# 3. Supervised Contrastive Loss (InfoNCE-based)
def supcon_loss(features, labels, temperature):
    """
    features: [batch_size, projection_dim]
    labels: [batch_size]
    """
    batch_size = features.shape[0]
    # 归一化特征
    features = F.normalize(features, dim=1)
    # 计算相似度矩阵
    similarity_matrix = torch.matmul(features, features.T) / temperature
    
    # 创建标签掩码:同一类的样本为正样本
    labels = labels.contiguous().view(-1, 1)
    mask = torch.eq(labels, labels.T).float().to(device)
    
    # 去除自比较
    mask_self = torch.eye(batch_size, dtype=torch.bool).to(device)
    mask = mask.masked_fill(mask_self, 0)
    
    # 计算正样本和负样本的相似度
    exp_sim = torch.exp(similarity_matrix) * (1 - mask_self.float())
    pos_sum = torch.sum(exp_sim * mask, dim=1, keepdim=True)
    neg_sum = torch.sum(exp_sim * (1 - mask), dim=1, keepdim=True)
    
    # 避免除零
    pos_sum = torch.clamp(pos_sum, min=1e-9)
    neg_sum = torch.clamp(neg_sum, min=1e-9)
    
    # 计算损失
    loss = -torch.log(pos_sum / (pos_sum + neg_sum))
    loss = loss.mean()
    return loss

# 4. 联合训练函数
def train_joint(model, train_loader, optimizer, epochs):
    model.train()
    ce_criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        total_loss = 0
        total_supcon_loss = 0
        total_ce_loss = 0
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            
            # 前向传播
            h, z, logits = model(images)
            
            # 计算 SupCon 损失
            supcon_loss_val = supcon_loss(z, labels, temperature)
            # 计算交叉熵损失
            ce_loss_val = ce_criterion(logits, labels)
            # 联合损失
            loss = supcon_weight * supcon_loss_val + ce_weight * ce_loss_val
            
            # 反向传播
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_supcon_loss += supcon_loss_val.item()
            total_ce_loss += ce_loss_val.item()
        
        avg_loss = total_loss / len(train_loader)
        avg_supcon_loss = total_supcon_loss / len(train_loader)
        avg_ce_loss = total_ce_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}], Total Loss: {avg_loss:.4f}, "
              f"SupCon Loss: {avg_supcon_loss:.4f}, CE Loss: {avg_ce_loss:.4f}")

# 5. 测试函数
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            _, _, logits = model(images)
            _, predicted = logits.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100. * correct / total
    return accuracy

# 6. 主函数
def main():
    # 数据加载
    transform = get_transforms()
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    # 模型
    base_encoder = torchvision.models.resnet18(pretrained=False)
    base_encoder.fc = nn.Identity()  # 移除原始全连接层
    model = JointModel(base_encoder, projection_dim=embedding_dim, num_classes=10).to(device)
    
    # 优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # 训练
    print("Training with Joint SupCon and Cross-Entropy Loss...")
    train_joint(model, train_loader, optimizer, epochs)
    
    # 测试
    accuracy = test_model(model, test_loader)
    print(f"\nTest Accuracy: {accuracy:.2f}%")

if __name__ == "__main__":
    main()

代码说明

  1. 数据增强

    • 使用与 SimCLR 和 SupCon 类似的数据增强策略(如随机裁剪、颜色抖动等),增强模型对数据变换的鲁棒性。
    • 直接使用原始图像和标签,无需生成两组增强视图(与无监督 SimCLR 不同)。
  2. 模型架构

    • 使用 ResNet-18 作为基础编码器,移除原始全连接层。
    • 添加投影头(两层 MLP),用于计算 SupCon 损失。
    • 添加分类头(线性层),用于计算交叉熵损失。
    • 模型同时输出编码器特征(h)、投影特征(z)和分类 logits(logits)。
  3. 联合损失函数

    • SupCon Loss:基于 InfoNCE 的有监督对比损失,利用标签构造正样本对(同一类)和负样本对(不同类)。
    • Cross-Entropy Loss:标准分类损失,直接优化分类性能。
    • 联合损失:loss = supcon_weight * supcon_loss + ce_weight * ce_loss,其中 supcon_weightce_weight 控制两种损失的相对重要性(默认均为 1.0)。
  4. 训练流程

    • 在每个 batch 中,模型同时计算 SupCon 损失(基于投影特征 z)和交叉熵损失(基于分类 logits)。
    • 联合损失用于反向传播,优化整个模型(编码器、投影头和分类头)。
    • 每轮输出总损失、SupCon 损失和交叉熵损失,便于监控训练过程。
  5. 测试

    • 在训练完成后,使用测试集评估分类准确率,仅使用分类头的输出(logits)。

关键特点

  1. 联合训练

    • SupCon 损失增强类内紧凑性和类间分离性,生成鲁棒的特征表示。
    • 交叉熵损失直接优化分类边界,提高分类性能。
    • 联合训练结合了两者的优势,避免了传统两阶段训练(先对比学习,后微调)的复杂性。
  2. 灵活性

    • 通过调整 supcon_weightce_weight,可以平衡对比学习和分类任务的优先级。例如:
      • 增大 supcon_weight 强调特征学习。
      • 增大 ce_weight 强调分类性能。
  3. 性能提升

    • 联合损失通常比单独使用交叉熵损失生成更鲁棒的特征,尤其在数据量较少或类别不平衡时。

参考资料

相关推荐
冬奇Lab6 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab6 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP10 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年10 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼10 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS10 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区11 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈11 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang12 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx
shengjk113 小时前
NanoClaw 深度剖析:一个"AI 原生"架构的个人助手是如何运转的?
人工智能