半监督对比学习 (Semi-Supervised SimCLR) 实现

项目概述

本项目实现了一个基于对比学习的半监督学习框架,主要基于 SimCLR 模型架构。这个实现旨在解决标签数据有限场景下的机器学习问题,通过利用大量未标记数据来提高模型性能。

该项目的核心思想是:

"半监督学习结合了有标签和无标签数据的优势,而对比学习则是一种自监督学习方法,通过学习样本间的相似性来提取有意义的特征表示。"

核心功能

  1. 对比学习预训练:利用大量未标记数据进行自监督预训练,学习数据的内在结构
  2. 监督微调:使用少量标记数据对预训练模型进行微调
  3. 半监督学习:结合标记和未标记数据进行模型训练
  4. 数据增强:实现多种图像增强策略,提高模型的泛化能力
  5. 模型评估:支持不同评估指标,如准确率、损失曲线可视化等

支持的架构

  • 对比模型:实现了 SimCLR 风格的对比学习模型
  • 监督基线模型:提供标准监督学习作为性能比较基准
  • 编码器:支持多种基础模型作为编码器,如 ResNet 系列

技术栈

  • TensorFlow/Keras:深度学习框架
  • NumPy:数据处理
  • Matplotlib:结果可视化
  • 其他必要的机器学习和数据科学库

实现原理

对比学习的核心思想

根据 main.py 中的注释,对比学习的核心在于:

"通过学习样本间的相似性来提取有意义的特征表示。"

该实现基于以下关键原理:

  1. 数据增强对:对同一图像应用不同的数据增强策略,生成正样本对
  2. 对比损失:最大化正样本对之间的相似性,最小化负样本对之间的相似性
  3. 投影头:将编码器输出映射到对比学习空间的非线性变换层

模型架构

项目实现了两种主要模型:

1. 对比模型 (ContrastiveModel)
python 复制代码
class ContrastiveModel(keras.Model):
    def __init__(self, encoder, projection_units):
        # 编码器用于提取特征表示
        # 投影头将特征映射到对比学习空间
        # ...

该模型包含以下核心组件:

  • 编码器:从输入图像中提取特征表示
  • 投影头:将编码器输出转换为适合对比学习的向量
  • 温度参数:控制对比损失的锐度
2. 监督微调模型

在预训练后,通过添加分类层进行监督微调:

python 复制代码
# 创建用于微调的监督模型
finetuning_model = keras.Sequential([
    pretrained_encoder,
    keras.layers.Dense(num_classes, activation="softmax")
])

关键实现细节

数据增强

根据 main.py,项目实现了强大的数据增强策略:

python 复制代码
def get_augmenter(image_size):
    """创建图像增强流水线"""
    # 包含随机裁剪、翻转、颜色抖动等多种增强操作
    # ...

主要增强操作包括:

  • 随机裁剪和调整大小
  • 水平翻转
  • 颜色抖动(亮度、对比度、饱和度、色调)
  • 随机灰度转换
  • 高斯模糊
对比损失函数

实现了基于温度缩放的对比损失:

python 复制代码
def compute_contrastive_loss(features, temperature=0.1):
    # 计算批次内所有样本对之间的相似性
    # 应用温度缩放
    # 计算 NT-Xent 损失
    # ...
训练流程

项目实现了两阶段训练流程:

  1. 对比预训练:使用所有数据(标记和未标记)进行自监督学习
  2. 监督微调:仅使用标记数据对预训练编码器进行微调

使用方法、参数配置和运行示例

安装依赖

运行本项目需要安装以下依赖:

bash 复制代码
pip install tensorflow numpy matplotlib

超参数配置详解

根据 main.py 中的定义,以下是主要超参数及其说明:

python 复制代码
# 超参数设置
batch_size = 256        # 批次大小,影响训练速度和稳定性
image_size = 224        # 输入图像尺寸
projection_units = 128  # 投影头输出维度
num_epochs = 100        # 训练轮数
learning_rate = 0.001   # 学习率
temperature = 0.1       # 对比损失的温度参数,控制相似度分布的锐度

其他重要参数:

python 复制代码
# 数据集参数
data_dir = "./data"    # 数据集路径
num_classes = 10       # 分类任务的类别数
labeled_ratio = 0.1    # 标记数据的比例

# 模型参数
base_model = "resnet50"  # 基础编码器模型

运行示例

1. 完整训练流程
bash 复制代码
python main.py

默认情况下,脚本将执行以下步骤:

  1. 加载并准备数据集(包括标记和未标记数据)
  2. 创建数据增强流水线
  3. 初始化对比模型并进行预训练
  4. 使用标记数据对预训练模型进行微调
  5. 评估模型性能并生成可视化结果
2. 仅运行对比预训练

可以修改 main.py 中的相关代码,只执行预训练阶段:

python 复制代码
# 预训练对比模型
contrastive_history = contrastive_model.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=val_dataset
)

# 保存预训练模型
contrastive_model.encoder.save("pretrained_encoder")
3. 使用自定义数据集

要使用自定义数据集,需要修改数据加载部分:

python 复制代码
# 自定义数据集加载
def load_custom_dataset(data_dir):
    # 实现自定义数据加载逻辑
    # 返回标记和未标记的数据集
    pass

train_data, val_data, test_data = load_custom_dataset("./your_data")

结果评估

训练完成后,模型会自动在测试集上进行评估,并生成以下指标:

复制代码
# 测试集上的性能评估
accuracy = supervised_model.evaluate(test_dataset)
print(f"测试准确率: {accuracy:.4f}")

训练过程中的损失和准确率曲线会保存为图像文件,便于分析模型收敛情况。

项目结构和工作流程

详细项目结构

根据项目文件组织,项目包含以下关键组件:

复制代码
semisupervised_simclr/
├── main.py                 # 主实现文件,包含完整的训练和评估流程
│   ├── 文件头部说明和注释      # 项目概述和基本概念介绍
│   ├── 环境设置和跨平台兼容    # 资源限制设置和平台检测
│   ├── 超参数配置            # 所有可调参数的集中管理
│   ├── 数据处理模块          # 数据集加载和预处理功能
│   ├── 数据增强实现          # 自定义增强变换和流水线
│   ├── 模型架构定义          # 编码器、对比模型和监督模型
│   ├── 训练循环实现          # train_step 和 test_step 方法
│   ├── 主训练逻辑            # 预训练和微调流程
│   └── 结果评估和可视化       # 性能指标计算和图表生成
└── semisupervised_simclr.py  # 辅助模块,可能包含特定功能实现

main.py 主要组件详解

  1. 数据处理组件

    • 实现了数据集的加载、预处理和批次生成
    • 支持标记和未标记数据的混合使用
    • 包含数据分割和验证集构建功能
  2. 数据增强模块

    • 基于 main.py 中的 get_augmenter() 函数实现
    • 包含自定义的 RandomColorAffine 增强层
    • 支持多种增强操作的组合和参数调整
  3. 模型架构

    • get_encoder() 函数用于创建特征提取器
    • ContrastiveModel 类实现对比学习逻辑
    • 监督模型构建和微调功能
  4. 训练和优化

    • 自定义训练循环
    • 对比损失计算
    • 学习率调度和优化器配置

详细工作流程

项目的完整工作流程包括以下主要阶段:

1. 初始化阶段
python 复制代码
# 环境和资源配置
import necessary modules
set random seeds
configure platform-specific settings

# 超参数初始化
batch_size = 256
image_size = 224
# 其他参数设置...
2. 数据处理阶段
python 复制代码
# 加载数据集
train_data, val_data, test_data = load_dataset(data_dir)

# 创建数据增强器
augmenter = get_augmenter(image_size)

# 构建训练和验证数据集
train_dataset = create_contrastive_dataset(train_data, augmenter, batch_size)
val_dataset = create_contrastive_dataset(val_data, augmenter, batch_size)
3. 模型构建阶段
python 复制代码
# 创建编码器
encoder = get_encoder(image_size, base_model=base_model)

# 创建对比模型
contrastive_model = ContrastiveModel(
    encoder=encoder,
    projection_units=projection_units,
    temperature=temperature
)

# 编译模型
contrastive_model.compile(optimizer=keras.optimizers.Adam(learning_rate))
4. 对比预训练阶段
python 复制代码
# 执行预训练
contrastive_history = contrastive_model.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=val_dataset
)
5. 监督微调阶段
python 复制代码
# 提取预训练编码器
pretrained_encoder = contrastive_model.encoder

# 创建监督微调模型
finetuning_model = keras.Sequential([
    pretrained_encoder,
    keras.layers.Dense(num_classes, activation="softmax")
])

# 编译和训练微调模型
finetuning_model.compile(...)
finetuning_history = finetuning_model.fit(...)
6. 评估和可视化阶段
python 复制代码
# 评估模型性能
test_accuracy = finetuning_model.evaluate(test_dataset)

# 绘制训练曲线
plot_training_curves(contrastive_history, finetuning_history)

数据流和转换过程

下面是项目中数据的完整流程:

复制代码
┌─────────────┐     ┌─────────────┐     ┌─────────────────────┐
│  原始图像数据 │ ──> │  数据增强     │ ──> │  生成对比学习样本对  │
└─────────────┘     └─────────────┘     └─────────────────────┘
                                                │
                                                ▼
┌─────────────┐     ┌─────────────┐     ┌─────────────────────┐
│  评估结果    │ <── │  监督微调     │ <── │  对比预训练         │
└─────────────┘     └─────────────┘     └─────────────────────┘
        │                 │                   │
        ▼                 ▼                   ▼
┌─────────────┐     ┌─────────────┐     ┌─────────────────────┐
│  可视化结果   │     │  分类头      │     │  编码器 + 投影头    │
└─────────────┘     └─────────────┘     └─────────────────────┘

核心代码流程

根据 main.py 的实现,以下是项目的核心代码流程:

  1. 定义数据增强和预处理

    • 实现 RandomColorAffineget_augmenter()
  2. 创建模型架构

    • 使用 get_encoder() 创建基础编码器
    • 实现 ContrastiveModel
  3. 训练对比模型

    • 实现自定义 train_step()test_step()
    • 执行对比预训练
  4. 微调监督模型

    • 重用预训练编码器
    • 添加分类层并训练
  5. 评估和可视化

    • 计算性能指标
    • 生成训练曲线和结果报告

进一步改进和相关工作

根据 main.py 中的注释,该实现可以进一步改进:

"架构改进:

  • 实现 MoCo 或 BYOL 等其他对比学习方法

  • 添加动量编码器提高性能

  • 探索不同的投影头结构"
    "超参数调优笔记:

  • 温度参数对对比学习效果影响显著

  • 数据增强策略对性能有重要影响

  • 学习率调度对于稳定训练很重要"

相关工作包括 BYOL 和 SimSiam,它们在无负样本的情况下实现了对比学习。

总结

本项目实现了一个完整的半监督对比学习框架,通过结合标记和未标记数据,在数据有限的场景下提高模型性能。该实现基于 SimCLR 架构,包含了从数据准备到模型评估的完整工作流程,可作为半监督学习研究和应用的基础框架。

完整代码

python 复制代码
# -*- coding: utf-8 -*-
# 半监督对比学习框架实现(Semi-Supervised SimCLR)
# 该文件实现了基于SimCLR架构的半监督对比学习模型,用于图像分类任务
# 主要功能:
# 1. 使用未标记数据进行对比预训练,学习图像的有效表示
# 2. 使用少量标记数据进行线性探针评估和监督微调
# 3. 与纯监督基线模型进行性能比较
# 4. 可视化数据增强效果和训练曲线

import os
# 设置Keras后端为TensorFlow
os.environ["KERAS_BACKEND"] = "tensorflow"
# 调整系统资源限制,以支持更大的文件打开数量
import resource
low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))

# 导入必要的库
import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import keras
from keras import ops
from keras import layers
# 数据集配置参数
unlabeled_dataset_size = 100000  # 未标记数据集大小
labeled_dataset_size = 5000      # 标记数据集大小
image_channels = 3              # 图像通道数(RGB)

# 训练配置参数
num_epochs = 20                 # 训练轮数
batch_size = 525                # 批量大小
width = 128                     # 模型中间层特征维度

# 对比学习配置参数
temperature = 0.1               # 温度参数,控制对比损失的平滑程度

# 图像增强配置
# 对比学习使用较强的数据增强,促进模型学习鲁棒特征
contrastive_augmentation = {
    "min_area": 0.25,    # 随机裁剪保留的最小区域比例
    "brightness": 0.6,   # 亮度调整范围
    "jitter": 0.2        # 颜色抖动强度
}

# 分类任务使用较弱的数据增强,保持图像语义信息
classification_augmentation = {
    "min_area": 0.75,    # 随机裁剪保留的最小区域比例
    "brightness": 0.3,   # 亮度调整范围
    "jitter": 0.1        # 颜色抖动强度
}
# 准备训练和测试数据集
# 功能:加载STL-10数据集,分离未标记数据、标记数据和测试数据,并进行预处理
# 返回:
#     train_dataset: 包含未标记数据和标记数据的组合数据集
#     labeled_train_dataset: 仅标记训练数据的数据集
#     test_dataset: 测试数据集
def prepare_dataset():
    # 计算每轮训练的步数
    steps_per_epoch = (unlabeled_dataset_size + labeled_dataset_size) // batch_size
    
    # 计算未标记数据和标记数据各自的批量大小
    unlabeled_batch_size = unlabeled_dataset_size // steps_per_epoch
    labeled_batch_size = labeled_dataset_size // steps_per_epoch
    
    print(
        f"batch size is {unlabeled_batch_size} (unlabeled) + {labeled_batch_size} (labeled)"
    )
    
    # 加载未标记数据集 (STL-10的'unlabelled'分割)
    unlabeled_train_dataset = (
        tfds.load("stl10", split="unlabelled", as_supervised=True, shuffle_files=False)
        .shuffle(buffer_size=10 * unlabeled_batch_size)  # 打乱数据,增加随机性
        .batch(unlabeled_batch_size)  # 批量处理
    )
    
    # 加载标记训练数据集 (STL-10的'train'分割)
    labeled_train_dataset = (
        tfds.load("stl10", split="train", as_supervised=True, shuffle_files=False)
        .shuffle(buffer_size=10 * labeled_batch_size)  # 打乱数据
        .batch(labeled_batch_size)  # 批量处理
    )
    
    # 加载测试数据集 (STL-10的'test'分割)
    test_dataset = (
        tfds.load("stl10", split="test", as_supervised=True)
        .batch(batch_size)  # 批量处理
        .prefetch(buffer_size=tf.data.AUTOTUNE)  # 预加载数据以提高性能
    )
    
    # 组合未标记数据和标记数据,创建用于对比学习的训练数据集
    train_dataset = tf.data.Dataset.zip(
        (unlabeled_train_dataset, labeled_train_dataset)
    ).prefetch(buffer_size=tf.data.AUTOTUNE)  # 预加载提高性能
    
    return train_dataset, labeled_train_dataset, test_dataset

# 主程序入口
# SimCLR(对比学习)半监督学习实现的整体流程:
# 1. 准备数据集(大量未标记数据 + 少量标记数据)
# 2. 可视化数据增强效果
# 3. 训练三种模型进行对比:
#    a. 基线模型:纯监督学习(仅使用少量标记数据)
#    b. 预训练模型:自监督对比学习(使用大量未标记数据)+ 线性探针评估
#    c. 微调模型:在预训练模型基础上进行监督微调(使用少量标记数据)
# 4. 绘制训练曲线,比较三种方法的性能差异

# 准备数据集
# 生成三种数据集:
#   1. train_dataset: 包含未标记和标记数据,用于对比学习预训练
#   2. labeled_train_dataset: 仅包含标记数据,用于监督学习基线和微调
#   3. test_dataset: 测试集,用于评估模型性能
train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()

# 自定义颜色变换层,用于实现图像亮度和颜色抖动增强
class RandomColorAffine(layers.Layer):
    def __init__(self, brightness=0, jitter=0, **kwargs):
        # 初始化颜色变换层
        # 参数:
        #   brightness: 亮度调整范围,0表示不调整,值越大调整范围越大
        #   jitter: 颜色抖动强度,0表示不抖动,值越大抖动越明显
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1337)  # 设置随机种子生成器
        self.brightness = brightness  # 亮度调整参数
        self.jitter = jitter  # 颜色抖动参数
    
    def get_config(self):
        # 获取层配置,用于模型保存和加载
        config = super().get_config()
        config.update({"brightness": self.brightness, "jitter": self.jitter})
        return config
    
    def call(self, images, training=True):
        # 执行颜色变换操作
        # 仅在训练模式下应用变换
        if training:
            batch_size = ops.shape(images)[0]  # 获取批量大小
            
            # 生成亮度缩放因子,形状为[batch_size, 1, 1, 1]
            brightness_scales = 1 + keras.random.uniform(
                (batch_size, 1, 1, 1),
                minval=-self.brightness,
                maxval=self.brightness,
                seed=self.seed_generator,
            )
            
            # 生成颜色抖动矩阵,形状为[batch_size, 1, 3, 3]
            jitter_matrices = keras.random.uniform(
                (batch_size, 1, 3, 3),
                minval=-self.jitter,
                maxval=self.jitter,
                seed=self.seed_generator,
            )
            
            # 构造颜色变换矩阵:单位矩阵 * 亮度缩放 + 抖动矩阵
            color_transforms = (
                ops.tile(ops.expand_dims(ops.eye(3), axis=0), (batch_size, 1, 1, 1))
                * brightness_scales
                + jitter_matrices
            )
            
            # 应用颜色变换并裁剪到[0, 1]范围
            images = ops.clip(ops.matmul(images, color_transforms), 0, 1)
        return images
# 创建数据增强器
# 参数:
#   min_area: 随机裁剪保留的最小区域比例
#   brightness: 亮度调整范围
#   jitter: 颜色抖动强度
# 返回:
#   数据增强序列模型
def get_augmenter(min_area, brightness, jitter):
    # 根据最小区域比例计算缩放因子
    zoom_factor = 1.0 - math.sqrt(min_area)
    
    # 构建数据增强管道,按顺序应用以下增强操作:
    # 1. 像素值归一化 (0-1范围)
    # 2. 水平随机翻转
    # 3. 随机平移
    # 4. 随机缩放
    # 5. 颜色和亮度变换
    return keras.Sequential(
        [
            layers.Rescaling(1 / 255),  # 像素值归一化
            layers.RandomFlip("horizontal"),  # 水平随机翻转
            layers.RandomTranslation(zoom_factor / 2, zoom_factor / 2),  # 随机平移
            layers.RandomZoom((-zoom_factor, 0.0), (-zoom_factor, 0.0)),  # 随机缩放
            RandomColorAffine(brightness, jitter),  # 颜色和亮度变换
        ]
    )
# 可视化数据增强效果函数
# 参数:
#   num_images: 要可视化的图像数量
# 功能:展示不同数据增强方法对同一图像的变换效果
#       帮助理解和验证数据增强策略的有效性
# 增强方法对比:
#   1. 原始图像 - 无变换
#   2. 弱增强 - 用于分类任务的轻量级变换
#   3. 强增强 - 用于对比学习的更显著变换(生成两个不同的增强视图)
def visualize_augmentations(num_images):
    # 从训练数据集中获取样例图像
    # 注意:这里使用的是训练数据集的未标记部分的第一个批次中的图像
    images = next(iter(train_dataset))[0][0][:num_images]
    
    # 为每张图像生成多种增强版本:
    # 1. 原始图像 - 未进行任何变换的基准图像
    # 2. 弱增强版本 - 用于分类任务的数据增强,保留较多语义信息
    # 3. 强增强版本1 - 用于对比学习的第一个增强视图
    # 4. 强增强版本2 - 用于对比学习的第二个增强视图
    # 对比学习需要为同一图像生成两个不同的增强视图作为正样本对
    augmented_images = zip(
        images,
        get_augmenter(**classification_augmentation)(images),  # 弱增强(用于分类任务)
        get_augmenter(**contrastive_augmentation)(images),    # 强增强1(用于对比学习)
        get_augmenter(**contrastive_augmentation)(images),    # 强增强2(用于对比学习)
    )
    
    # 行标题
    row_titles = [
        "Original:",           # 原始图像
        "Weakly augmented:",   # 弱增强(分类用)
        "Strongly augmented:", # 强增强1(对比学习用)
        "Strongly augmented:", # 强增强2(对比学习用)
    ]
    
    # 创建可视化图表
    # 图表布局:4行(num_images)列,每行展示一种增强类型,每列展示一个图像样本的不同增强结果
    # 图表尺寸根据图像数量动态调整,dpi=100保证图像清晰显示
    plt.figure(figsize=(num_images * 2.2, 4 * 2.2), dpi=100)
    
    # 绘制图像网格
    # 外层循环遍历每个图像样本(column)
    # 内层循环遍历该样本的四种增强版本(row)
    for column, image_row in enumerate(augmented_images):
        for row, image in enumerate(image_row):
            # 设置子图位置:4行(num_images)列,当前子图索引为row * num_images + column + 1
            plt.subplot(4, num_images, row * num_images + column + 1)
            # 显示增强后的图像
            plt.imshow(image)
            # 只为第一列添加标题,避免重复标注
            if column == 0:
                plt.title(row_titles[row], loc="left")  # 标题左对齐,提高可读性
            plt.axis("off")  # 关闭坐标轴,突出图像内容
    
    plt.tight_layout()  # 调整布局,自动调整子图参数,避免标签重叠
    # 这种网格布局使不同增强策略的效果对比更加直观
    # 行方向比较同一图像的不同增强版本,列方向比较不同图像的增强一致性
# 可视化数据增强效果
# 展示原始图像、弱增强和强增强的视觉差异
# 弱增强用于分类任务,强增强用于对比学习
visualize_augmentations(num_images=8)
# 创建图像编码器(特征提取器)
# 返回:
#   卷积神经网络编码器,用于从图像中提取特征表示
def get_encoder():
    # 构建卷积神经网络编码器,包含以下层:
    # 1. 4个卷积层,每个卷积层使用3x3核,步长为2(进行下采样)
    # 2. 扁平化层,将卷积特征转换为一维向量
    # 3. 全连接层,输出固定维度的特征向量
    return keras.Sequential(
        [
            # 第一个卷积层,输出通道数为width,使用ReLU激活函数
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            # 第二个卷积层
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            # 第三个卷积层
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            # 第四个卷积层
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            # 扁平化层
            layers.Flatten(),
            # 输出特征向量的全连接层
            layers.Dense(width, activation="relu"),
        ],
        name="encoder",  # 模型名称
    )
# 创建基线模型(纯监督学习)
# 作为性能基准,直接在标记数据上训练
# 用于验证对比学习预训练的有效性
baseline_model = keras.Sequential(
    [
        get_augmenter(**classification_augmentation),  # 使用弱数据增强
        get_encoder(),  # 使用与对比学习相同的编码器
        layers.Dense(10),  # 分类头(输出10个类别的logits)
    ],
    name="baseline_model",  # 模型名称
)
# 编译基线模型
baseline_model.compile(
    optimizer=keras.optimizers.Adam(),  # Adam优化器
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),  # 交叉熵损失
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],  # 评估指标
)
baseline_history = baseline_model.fit(
    labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
)
# 打印基线模型的最佳验证准确率
# 基线模型是纯监督学习方法,直接使用标记数据训练
print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(baseline_history.history["val_acc"]) * 100
    )
)
# 对比学习模型类
# 实现SimCLR算法的核心逻辑,用于无监督特征学习
class ContrastiveModel(keras.Model):
    # 初始化对比学习模型
    # 核心组件包括:数据增强器、特征编码器、投影头和线性评估器
    def __init__(self):
        super().__init__()
        # 温度参数:控制对比损失中相似度的缩放,值越小,分布越陡峭
        self.temperature = temperature
        # 强数据增强器:用于对比学习任务,生成多样化视图
        self.contrastive_augmenter = get_augmenter(**contrastive_augmentation)
        # 弱数据增强器:用于分类任务,保留更多语义信息
        self.classification_augmenter = get_augmenter(**classification_augmentation)
        # 特征编码器:将图像转换为高维特征向量
        self.encoder = get_encoder()
        # 投影头:将编码器输出映射到低维潜在空间,促进对比学习
        self.projection_head = keras.Sequential(
            [
                keras.Input(shape=(width,)),
                layers.Dense(width, activation="relu"),
                layers.Dense(width),
            ],
            name="projection_head",
        )
        # 线性探针:用于评估学习到的特征表示质量
        self.linear_probe = keras.Sequential(
            [layers.Input(shape=(width,)), layers.Dense(10)],
            name="linear_probe",
        )
        # 打印各组件架构信息
        self.encoder.summary()
        self.projection_head.summary()
        self.linear_probe.summary()
    
    # 模型前向传播方法
    # 参数:
    #   inputs: 输入图像批次,形状为(batch_size, height, width, channels)
    #   training: 是否为训练模式
    # 返回:
    #   训练模式:返回特征和投影
    #   推理模式:返回编码器输出的特征
    def call(self, inputs, training=None):
        # 前向传播核心流程:图像 -> 特征提取 -> 投影
        # 获取输入图像的特征表示
        features = self.encoder(inputs, training=training)
        
        # 获取投影特征
        # 在对比学习中,投影特征用于计算样本间的相似性
        projections = self.projection_head(features, training=training)
        
        # 根据训练模式返回不同结果
        if training:
            return features, projections
        else:
            return features  # 推理时仅返回编码器特征
    def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
        super().compile(**kwargs)
        self.contrastive_optimizer = contrastive_optimizer
        self.probe_optimizer = probe_optimizer
        self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.contrastive_loss_tracker = keras.metrics.Mean(name="c_loss")
        self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy(
            name="c_acc"
        )
        self.probe_loss_tracker = keras.metrics.Mean(name="p_loss")
        self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy(name="p_acc")
    @property
    def metrics(self):
        return [
            self.contrastive_loss_tracker,
            self.contrastive_accuracy,
            self.probe_loss_tracker,
            self.probe_accuracy,
        ]
    # 计算对比损失的核心方法
    # 参数:
    #   projections_1: 第一个增强视图的投影特征
    #   projections_2: 第二个增强视图的投影特征
    # 返回:
    #   计算得到的对比损失值
    def contrastive_loss(self, projections_1, projections_2):
        # 对投影特征进行L2归一化
        # 确保特征在单位超球面上,使点积等价于余弦相似度
        projections_1 = ops.normalize(projections_1, axis=1)
        projections_2 = ops.normalize(projections_2, axis=1)
        
        # 计算两个视图投影特征之间的相似性矩阵
        # 使用温度参数缩放相似度,温度越低,分布越陡峭
        similarities = (
            ops.matmul(projections_1, ops.transpose(projections_2)) / self.temperature
        )
        
        # 获取批次大小
        batch_size = ops.shape(projections_1)[0]
        
        # 创建标签:每个样本应该与另一个视图中的相同索引样本匹配
        contrastive_labels = ops.arange(batch_size)
        
        # 更新对比准确率指标
        # 分别计算从视图1到视图2和从视图2到视图1的准确率
        self.contrastive_accuracy.update_state(contrastive_labels, similarities)
        self.contrastive_accuracy.update_state(
            contrastive_labels, ops.transpose(similarities)
        )
        
        # 计算两个方向的交叉熵损失
        # loss_1_2: 从视图1到视图2的对比损失
        loss_1_2 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, similarities, from_logits=True
        )
        # loss_2_1: 从视图2到视图1的对比损失
        loss_2_1 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, ops.transpose(similarities), from_logits=True
        )
        
        # 返回两个方向损失的平均值
        # 这种双向计算确保模型能够从两个角度学习样本之间的关系
        return (loss_1_2 + loss_2_1) / 2
    # 模型训练步骤
    # 参数:
    #   data: 训练数据,包含未标记数据和标记数据的组合
    # 返回:
    #   字典形式的训练指标,包括对比损失、对比准确率、探针损失和探针准确率
    def train_step(self, data):
        # 处理输入数据结构,获取未标记数据和标记数据
        (unlabeled_images, _), (labeled_images, labels) = data
        
        # 对比学习训练部分
        # 1. 合并未标记数据和标记数据,用于对比学习
        images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
        
        # 2. 为每张图像生成两个不同的增强视图
        # 这是对比学习的核心:同一图像的不同增强版本应该具有相似的表示
        augmented_images_1 = self.contrastive_augmenter(images, training=True)
        augmented_images_2 = self.contrastive_augmenter(images, training=True)
        
        # 3. 在第一个梯度带上计算对比学习损失并更新编码器和投影头
        with tf.GradientTape() as tape:
            # 通过编码器提取两个增强视图的特征表示
            features_1 = self.encoder(augmented_images_1, training=True)
            features_2 = self.encoder(augmented_images_2, training=True)
            
            # 通过投影头将特征映射到潜在空间
            projections_1 = self.projection_head(features_1, training=True)
            projections_2 = self.projection_head(features_2, training=True)
            
            # 计算对比损失,衡量正负样本对的相似性
            contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        
        # 4. 计算并应用梯度,更新编码器和投影头的参数
        gradients = tape.gradient(
            contrastive_loss,
            self.encoder.trainable_weights + self.projection_head.trainable_weights,
        )
        self.contrastive_optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights + self.projection_head.trainable_weights,
            )
        )
        
        # 5. 更新对比损失跟踪指标
        self.contrastive_loss_tracker.update_state(contrastive_loss)
        
        # 线性探针(分类任务)训练部分
        # 1. 对标记图像应用弱增强,保持语义信息
        preprocessed_images = self.classification_augmenter(
            labeled_images, training=True
        )
        
        # 2. 在第二个梯度带上计算分类损失并更新线性探针
        with tf.GradientTape() as tape:
            # 使用编码器提取特征(注意:编码器在此阶段不更新参数)
            features = self.encoder(preprocessed_images, training=False)
            
            # 通过线性探针进行分类预测
            class_logits = self.linear_probe(features, training=True)
            
            # 计算分类损失
            probe_loss = self.probe_loss(labels, class_logits)
        
        # 3. 计算并应用梯度,只更新线性探针的参数
        gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
        self.probe_optimizer.apply_gradients(
            zip(gradients, self.linear_probe.trainable_weights)
        )
        
        # 4. 更新线性探针的损失和准确率指标
        self.probe_loss_tracker.update_state(probe_loss)
        self.probe_accuracy.update_state(labels, class_logits)
        
        # 5. 返回所有训练指标
        return {m.name: m.result() for m in self.metrics}
    # 模型测试/评估步骤
    # 参数:
    #   data: 测试数据,包含图像和对应的标签
    # 返回:
    #   字典形式的评估指标,包括探针损失和探针准确率
    def test_step(self, data):
        # 提取测试数据中的图像和标签
        labeled_images, labels = data
        
        # 对测试图像应用弱增强
        # 注意:即使在测试阶段,也应用相同的数据预处理以保持一致性
        preprocessed_images = self.classification_augmenter(
            labeled_images, training=False
        )
        
        # 使用编码器提取特征
        # 训练=False表示不更新编码器参数,仅用于特征提取
        features = self.encoder(preprocessed_images, training=False)
        
        # 通过线性探针进行分类预测
        class_logits = self.linear_probe(features, training=False)
        
        # 计算分类损失
        probe_loss = self.probe_loss(labels, class_logits)
        
        # 更新评估指标
        self.probe_loss_tracker.update_state(probe_loss)
        self.probe_accuracy.update_state(labels, class_logits)
        
        # 返回后两个评估指标(仅探针相关指标)
        return {m.name: m.result() for m in self.metrics[2:]}
# 初始化预训练模型
# 创建ContrastiveModel实例,用于无监督对比学习预训练
pretraining_model = ContrastiveModel()

# 编译预训练模型
# 为不同部分设置不同的优化器:
# 1. contrastive_optimizer: 用于更新编码器和投影头的参数
# 2. probe_optimizer: 仅用于更新线性探针的参数
pretraining_model.compile(
    contrastive_optimizer=keras.optimizers.Adam(learning_rate=0.001),  # 对比学习部分优化器
    probe_optimizer=keras.optimizers.Adam(learning_rate=0.001),       # 线性探针部分优化器
)
# 开始预训练模型
# 训练过程同时进行对比学习(无监督)和线性探针评估(有监督)
pretraining_history = pretraining_model.fit(
    train_dataset,                  # 训练数据(包含未标记和标记数据)
    epochs=num_epochs,              # 预训练轮数
    validation_data=test_dataset    # 验证数据,用于监控性能
)
# 打印对比预训练模型的最佳验证准确率
# 预训练模型使用未标记数据进行对比学习,同时在线性探针上评估性能
print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(pretraining_history.history["val_p_acc"]) * 100
    )
)
# 创建微调模型
# 利用预训练好的编码器进行监督微调,添加新的分类层
finetuning_model = keras.Sequential(
    [
        get_augmenter(**classification_augmentation),  # 复用分类任务的弱数据增强
        pretraining_model.encoder,                     # 使用预训练好的编码器提取特征
        layers.Dense(10),                              # 新的分类层(输出10个类别的logits)
    ],
    name="finetuning_model",
)

# 编译微调模型
# 设置优化器、损失函数和评估指标
finetuning_model.compile(
    optimizer=keras.optimizers.Adam(),                 # Adam优化器
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),  # 交叉熵损失(from_logits=True表示模型输出logits)
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],  # 评估指标:稀疏分类准确率
)

# 开始监督微调
# 使用少量标记数据对模型进行微调,验证集用于监控性能
finetuning_history = finetuning_model.fit(
    labeled_train_dataset,      # 仅使用标记训练数据
    epochs=num_epochs,          # 微调轮数
    validation_data=test_dataset  # 测试集作为验证数据
)

# 打印微调后的最佳验证准确率
# 输出微调过程中达到的最高验证集准确率(百分比形式)
# 打印微调模型的最佳验证准确率
# 微调模型使用预训练好的编码器,在少量标记数据上进一步优化
# 通常表现优于纯监督基线模型和仅预训练模型
print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(finetuning_history.history["val_acc"]) * 100
    )
)
# 绘制训练曲线函数
# 参数:
#   pretraining_history: 自监督预训练模型的训练历史
#   finetuning_history: 监督微调模型的训练历史
#   baseline_history: 纯监督基线模型的训练历史
# 功能:可视化和比较三种不同模型训练方法的验证性能
#       1. 纯监督基线模型
#       2. 自监督预训练模型
#       3. 自监督预训练+监督微调模型
# 图表内容:
#   1. 准确率曲线图:展示不同方法的验证准确率变化
#   2. 损失曲线图:展示不同方法的验证损失变化
# 绘制训练曲线函数
# 参数:
#   pretraining_history: 自监督预训练模型的训练历史
#   finetuning_history: 监督微调模型的训练历史
#   baseline_history: 纯监督基线模型的训练历史
# 功能:可视化和比较三种不同模型训练方法的验证性能
#       1. 纯监督基线模型
#       2. 自监督预训练模型
#       3. 自监督预训练+监督微调模型
# 图表内容:
#   1. 准确率曲线图:展示不同方法的验证准确率变化
#   2. 损失曲线图:展示不同方法的验证损失变化
def plot_training_curves(pretraining_history, finetuning_history, baseline_history):
    # 循环绘制准确率和损失曲线
    # 首先绘制准确率曲线,然后绘制损失曲线
    for metric_key, metric_name in zip(["acc", "loss"], ["accuracy", "loss"]):
        # 创建图表,设置尺寸和分辨率
        plt.figure(figsize=(8, 5), dpi=100)
        # 绘制纯监督基线模型的验证曲线
        # 作为性能基准,直接使用标记数据训练
        plt.plot(
            baseline_history.history[f"val_{metric_key}"],
            label="supervised baseline",
        )
        # 绘制自监督预训练模型的验证曲线
        # 使用大量未标记数据进行对比学习预训练
        plt.plot(
            pretraining_history.history[f"val_p_{metric_key}"],
            label="self-supervised pretraining",
        )
        # 绘制监督微调模型的验证曲线
        # 在预训练基础上,使用少量标记数据进行微调
        plt.plot(
            finetuning_history.history[f"val_{metric_key}"],
            label="supervised finetuning",
        )
        # 添加图例
        plt.legend()
        # 设置图表标题
        plt.title(f"Classification {metric_name} during training")
        # 设置x轴标签
        plt.xlabel("epochs")
        # 设置y轴标签
        plt.ylabel(f"validation {metric_name}")

# 调用训练曲线绘制函数
# 可视化三种方法的性能对比
# 预期结果:
#   1. 监督微调模型 > 自监督预训练模型 > 纯监督基线模型(准确率)
#   2. 监督微调模型 < 自监督预训练模型 < 纯监督基线模型(损失值)
# 这验证了半监督学习方法在标记数据有限情况下的有效性
plot_training_curves(pretraining_history, finetuning_history, baseline_history)
相关推荐
西岸行者2 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习
悠哉悠哉愿意2 天前
【单片机学习笔记】串口、超声波、NE555的同时使用
笔记·单片机·学习
别催小唐敲代码2 天前
嵌入式学习路线
学习
毛小茛2 天前
计算机系统概论——校验码
学习
babe小鑫2 天前
大专经济信息管理专业学习数据分析的必要性
学习·数据挖掘·数据分析
winfreedoms2 天前
ROS2知识大白话
笔记·学习·ros2
在这habit之下2 天前
Linux Virtual Server(LVS)学习总结
linux·学习·lvs
我想我不够好。2 天前
2026.2.25监控学习
学习
im_AMBER2 天前
Leetcode 127 删除有序数组中的重复项 | 删除有序数组中的重复项 II
数据结构·学习·算法·leetcode
CodeJourney_J2 天前
从“Hello World“ 开始 C++
c语言·c++·学习