半监督对比学习 (Semi-Supervised SimCLR) 实现

项目概述

本项目实现了一个基于对比学习的半监督学习框架,主要基于 SimCLR 模型架构。这个实现旨在解决标签数据有限场景下的机器学习问题,通过利用大量未标记数据来提高模型性能。

该项目的核心思想是:

"半监督学习结合了有标签和无标签数据的优势,而对比学习则是一种自监督学习方法,通过学习样本间的相似性来提取有意义的特征表示。"

核心功能

  1. 对比学习预训练:利用大量未标记数据进行自监督预训练,学习数据的内在结构
  2. 监督微调:使用少量标记数据对预训练模型进行微调
  3. 半监督学习:结合标记和未标记数据进行模型训练
  4. 数据增强:实现多种图像增强策略,提高模型的泛化能力
  5. 模型评估:支持不同评估指标,如准确率、损失曲线可视化等

支持的架构

  • 对比模型:实现了 SimCLR 风格的对比学习模型
  • 监督基线模型:提供标准监督学习作为性能比较基准
  • 编码器:支持多种基础模型作为编码器,如 ResNet 系列

技术栈

  • TensorFlow/Keras:深度学习框架
  • NumPy:数据处理
  • Matplotlib:结果可视化
  • 其他必要的机器学习和数据科学库

实现原理

对比学习的核心思想

根据 main.py 中的注释,对比学习的核心在于:

"通过学习样本间的相似性来提取有意义的特征表示。"

该实现基于以下关键原理:

  1. 数据增强对:对同一图像应用不同的数据增强策略,生成正样本对
  2. 对比损失:最大化正样本对之间的相似性,最小化负样本对之间的相似性
  3. 投影头:将编码器输出映射到对比学习空间的非线性变换层

模型架构

项目实现了两种主要模型:

1. 对比模型 (ContrastiveModel)
python 复制代码
class ContrastiveModel(keras.Model):
    def __init__(self, encoder, projection_units):
        # 编码器用于提取特征表示
        # 投影头将特征映射到对比学习空间
        # ...

该模型包含以下核心组件:

  • 编码器:从输入图像中提取特征表示
  • 投影头:将编码器输出转换为适合对比学习的向量
  • 温度参数:控制对比损失的锐度
2. 监督微调模型

在预训练后,通过添加分类层进行监督微调:

python 复制代码
# 创建用于微调的监督模型
finetuning_model = keras.Sequential([
    pretrained_encoder,
    keras.layers.Dense(num_classes, activation="softmax")
])

关键实现细节

数据增强

根据 main.py,项目实现了强大的数据增强策略:

python 复制代码
def get_augmenter(image_size):
    """创建图像增强流水线"""
    # 包含随机裁剪、翻转、颜色抖动等多种增强操作
    # ...

主要增强操作包括:

  • 随机裁剪和调整大小
  • 水平翻转
  • 颜色抖动(亮度、对比度、饱和度、色调)
  • 随机灰度转换
  • 高斯模糊
对比损失函数

实现了基于温度缩放的对比损失:

python 复制代码
def compute_contrastive_loss(features, temperature=0.1):
    # 计算批次内所有样本对之间的相似性
    # 应用温度缩放
    # 计算 NT-Xent 损失
    # ...
训练流程

项目实现了两阶段训练流程:

  1. 对比预训练:使用所有数据(标记和未标记)进行自监督学习
  2. 监督微调:仅使用标记数据对预训练编码器进行微调

使用方法、参数配置和运行示例

安装依赖

运行本项目需要安装以下依赖:

bash 复制代码
pip install tensorflow numpy matplotlib

超参数配置详解

根据 main.py 中的定义,以下是主要超参数及其说明:

python 复制代码
# 超参数设置
batch_size = 256        # 批次大小,影响训练速度和稳定性
image_size = 224        # 输入图像尺寸
projection_units = 128  # 投影头输出维度
num_epochs = 100        # 训练轮数
learning_rate = 0.001   # 学习率
temperature = 0.1       # 对比损失的温度参数,控制相似度分布的锐度

其他重要参数:

python 复制代码
# 数据集参数
data_dir = "./data"    # 数据集路径
num_classes = 10       # 分类任务的类别数
labeled_ratio = 0.1    # 标记数据的比例

# 模型参数
base_model = "resnet50"  # 基础编码器模型

运行示例

1. 完整训练流程
bash 复制代码
python main.py

默认情况下,脚本将执行以下步骤:

  1. 加载并准备数据集(包括标记和未标记数据)
  2. 创建数据增强流水线
  3. 初始化对比模型并进行预训练
  4. 使用标记数据对预训练模型进行微调
  5. 评估模型性能并生成可视化结果
2. 仅运行对比预训练

可以修改 main.py 中的相关代码,只执行预训练阶段:

python 复制代码
# 预训练对比模型
contrastive_history = contrastive_model.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=val_dataset
)

# 保存预训练模型
contrastive_model.encoder.save("pretrained_encoder")
3. 使用自定义数据集

要使用自定义数据集,需要修改数据加载部分:

python 复制代码
# 自定义数据集加载
def load_custom_dataset(data_dir):
    # 实现自定义数据加载逻辑
    # 返回标记和未标记的数据集
    pass

train_data, val_data, test_data = load_custom_dataset("./your_data")

结果评估

训练完成后,模型会自动在测试集上进行评估,并生成以下指标:

复制代码
# 测试集上的性能评估
accuracy = supervised_model.evaluate(test_dataset)
print(f"测试准确率: {accuracy:.4f}")

训练过程中的损失和准确率曲线会保存为图像文件,便于分析模型收敛情况。

项目结构和工作流程

详细项目结构

根据项目文件组织,项目包含以下关键组件:

复制代码
semisupervised_simclr/
├── main.py                 # 主实现文件,包含完整的训练和评估流程
│   ├── 文件头部说明和注释      # 项目概述和基本概念介绍
│   ├── 环境设置和跨平台兼容    # 资源限制设置和平台检测
│   ├── 超参数配置            # 所有可调参数的集中管理
│   ├── 数据处理模块          # 数据集加载和预处理功能
│   ├── 数据增强实现          # 自定义增强变换和流水线
│   ├── 模型架构定义          # 编码器、对比模型和监督模型
│   ├── 训练循环实现          # train_step 和 test_step 方法
│   ├── 主训练逻辑            # 预训练和微调流程
│   └── 结果评估和可视化       # 性能指标计算和图表生成
└── semisupervised_simclr.py  # 辅助模块,可能包含特定功能实现

main.py 主要组件详解

  1. 数据处理组件

    • 实现了数据集的加载、预处理和批次生成
    • 支持标记和未标记数据的混合使用
    • 包含数据分割和验证集构建功能
  2. 数据增强模块

    • 基于 main.py 中的 get_augmenter() 函数实现
    • 包含自定义的 RandomColorAffine 增强层
    • 支持多种增强操作的组合和参数调整
  3. 模型架构

    • get_encoder() 函数用于创建特征提取器
    • ContrastiveModel 类实现对比学习逻辑
    • 监督模型构建和微调功能
  4. 训练和优化

    • 自定义训练循环
    • 对比损失计算
    • 学习率调度和优化器配置

详细工作流程

项目的完整工作流程包括以下主要阶段:

1. 初始化阶段
python 复制代码
# 环境和资源配置
import necessary modules
set random seeds
configure platform-specific settings

# 超参数初始化
batch_size = 256
image_size = 224
# 其他参数设置...
2. 数据处理阶段
python 复制代码
# 加载数据集
train_data, val_data, test_data = load_dataset(data_dir)

# 创建数据增强器
augmenter = get_augmenter(image_size)

# 构建训练和验证数据集
train_dataset = create_contrastive_dataset(train_data, augmenter, batch_size)
val_dataset = create_contrastive_dataset(val_data, augmenter, batch_size)
3. 模型构建阶段
python 复制代码
# 创建编码器
encoder = get_encoder(image_size, base_model=base_model)

# 创建对比模型
contrastive_model = ContrastiveModel(
    encoder=encoder,
    projection_units=projection_units,
    temperature=temperature
)

# 编译模型
contrastive_model.compile(optimizer=keras.optimizers.Adam(learning_rate))
4. 对比预训练阶段
python 复制代码
# 执行预训练
contrastive_history = contrastive_model.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=val_dataset
)
5. 监督微调阶段
python 复制代码
# 提取预训练编码器
pretrained_encoder = contrastive_model.encoder

# 创建监督微调模型
finetuning_model = keras.Sequential([
    pretrained_encoder,
    keras.layers.Dense(num_classes, activation="softmax")
])

# 编译和训练微调模型
finetuning_model.compile(...)
finetuning_history = finetuning_model.fit(...)
6. 评估和可视化阶段
python 复制代码
# 评估模型性能
test_accuracy = finetuning_model.evaluate(test_dataset)

# 绘制训练曲线
plot_training_curves(contrastive_history, finetuning_history)

数据流和转换过程

下面是项目中数据的完整流程:

复制代码
┌─────────────┐     ┌─────────────┐     ┌─────────────────────┐
│  原始图像数据 │ ──> │  数据增强     │ ──> │  生成对比学习样本对  │
└─────────────┘     └─────────────┘     └─────────────────────┘
                                                │
                                                ▼
┌─────────────┐     ┌─────────────┐     ┌─────────────────────┐
│  评估结果    │ <── │  监督微调     │ <── │  对比预训练         │
└─────────────┘     └─────────────┘     └─────────────────────┘
        │                 │                   │
        ▼                 ▼                   ▼
┌─────────────┐     ┌─────────────┐     ┌─────────────────────┐
│  可视化结果   │     │  分类头      │     │  编码器 + 投影头    │
└─────────────┘     └─────────────┘     └─────────────────────┘

核心代码流程

根据 main.py 的实现,以下是项目的核心代码流程:

  1. 定义数据增强和预处理

    • 实现 RandomColorAffineget_augmenter()
  2. 创建模型架构

    • 使用 get_encoder() 创建基础编码器
    • 实现 ContrastiveModel
  3. 训练对比模型

    • 实现自定义 train_step()test_step()
    • 执行对比预训练
  4. 微调监督模型

    • 重用预训练编码器
    • 添加分类层并训练
  5. 评估和可视化

    • 计算性能指标
    • 生成训练曲线和结果报告

进一步改进和相关工作

根据 main.py 中的注释,该实现可以进一步改进:

"架构改进:

  • 实现 MoCo 或 BYOL 等其他对比学习方法

  • 添加动量编码器提高性能

  • 探索不同的投影头结构"
    "超参数调优笔记:

  • 温度参数对对比学习效果影响显著

  • 数据增强策略对性能有重要影响

  • 学习率调度对于稳定训练很重要"

相关工作包括 BYOL 和 SimSiam,它们在无负样本的情况下实现了对比学习。

总结

本项目实现了一个完整的半监督对比学习框架,通过结合标记和未标记数据,在数据有限的场景下提高模型性能。该实现基于 SimCLR 架构,包含了从数据准备到模型评估的完整工作流程,可作为半监督学习研究和应用的基础框架。

完整代码

python 复制代码
# -*- coding: utf-8 -*-
# 半监督对比学习框架实现(Semi-Supervised SimCLR)
# 该文件实现了基于SimCLR架构的半监督对比学习模型,用于图像分类任务
# 主要功能:
# 1. 使用未标记数据进行对比预训练,学习图像的有效表示
# 2. 使用少量标记数据进行线性探针评估和监督微调
# 3. 与纯监督基线模型进行性能比较
# 4. 可视化数据增强效果和训练曲线

import os
# 设置Keras后端为TensorFlow
os.environ["KERAS_BACKEND"] = "tensorflow"
# 调整系统资源限制,以支持更大的文件打开数量
import resource
low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))

# 导入必要的库
import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import keras
from keras import ops
from keras import layers
# 数据集配置参数
unlabeled_dataset_size = 100000  # 未标记数据集大小
labeled_dataset_size = 5000      # 标记数据集大小
image_channels = 3              # 图像通道数(RGB)

# 训练配置参数
num_epochs = 20                 # 训练轮数
batch_size = 525                # 批量大小
width = 128                     # 模型中间层特征维度

# 对比学习配置参数
temperature = 0.1               # 温度参数,控制对比损失的平滑程度

# 图像增强配置
# 对比学习使用较强的数据增强,促进模型学习鲁棒特征
contrastive_augmentation = {
    "min_area": 0.25,    # 随机裁剪保留的最小区域比例
    "brightness": 0.6,   # 亮度调整范围
    "jitter": 0.2        # 颜色抖动强度
}

# 分类任务使用较弱的数据增强,保持图像语义信息
classification_augmentation = {
    "min_area": 0.75,    # 随机裁剪保留的最小区域比例
    "brightness": 0.3,   # 亮度调整范围
    "jitter": 0.1        # 颜色抖动强度
}
# 准备训练和测试数据集
# 功能:加载STL-10数据集,分离未标记数据、标记数据和测试数据,并进行预处理
# 返回:
#     train_dataset: 包含未标记数据和标记数据的组合数据集
#     labeled_train_dataset: 仅标记训练数据的数据集
#     test_dataset: 测试数据集
def prepare_dataset():
    # 计算每轮训练的步数
    steps_per_epoch = (unlabeled_dataset_size + labeled_dataset_size) // batch_size
    
    # 计算未标记数据和标记数据各自的批量大小
    unlabeled_batch_size = unlabeled_dataset_size // steps_per_epoch
    labeled_batch_size = labeled_dataset_size // steps_per_epoch
    
    print(
        f"batch size is {unlabeled_batch_size} (unlabeled) + {labeled_batch_size} (labeled)"
    )
    
    # 加载未标记数据集 (STL-10的'unlabelled'分割)
    unlabeled_train_dataset = (
        tfds.load("stl10", split="unlabelled", as_supervised=True, shuffle_files=False)
        .shuffle(buffer_size=10 * unlabeled_batch_size)  # 打乱数据,增加随机性
        .batch(unlabeled_batch_size)  # 批量处理
    )
    
    # 加载标记训练数据集 (STL-10的'train'分割)
    labeled_train_dataset = (
        tfds.load("stl10", split="train", as_supervised=True, shuffle_files=False)
        .shuffle(buffer_size=10 * labeled_batch_size)  # 打乱数据
        .batch(labeled_batch_size)  # 批量处理
    )
    
    # 加载测试数据集 (STL-10的'test'分割)
    test_dataset = (
        tfds.load("stl10", split="test", as_supervised=True)
        .batch(batch_size)  # 批量处理
        .prefetch(buffer_size=tf.data.AUTOTUNE)  # 预加载数据以提高性能
    )
    
    # 组合未标记数据和标记数据,创建用于对比学习的训练数据集
    train_dataset = tf.data.Dataset.zip(
        (unlabeled_train_dataset, labeled_train_dataset)
    ).prefetch(buffer_size=tf.data.AUTOTUNE)  # 预加载提高性能
    
    return train_dataset, labeled_train_dataset, test_dataset

# 主程序入口
# SimCLR(对比学习)半监督学习实现的整体流程:
# 1. 准备数据集(大量未标记数据 + 少量标记数据)
# 2. 可视化数据增强效果
# 3. 训练三种模型进行对比:
#    a. 基线模型:纯监督学习(仅使用少量标记数据)
#    b. 预训练模型:自监督对比学习(使用大量未标记数据)+ 线性探针评估
#    c. 微调模型:在预训练模型基础上进行监督微调(使用少量标记数据)
# 4. 绘制训练曲线,比较三种方法的性能差异

# 准备数据集
# 生成三种数据集:
#   1. train_dataset: 包含未标记和标记数据,用于对比学习预训练
#   2. labeled_train_dataset: 仅包含标记数据,用于监督学习基线和微调
#   3. test_dataset: 测试集,用于评估模型性能
train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()

# 自定义颜色变换层,用于实现图像亮度和颜色抖动增强
class RandomColorAffine(layers.Layer):
    def __init__(self, brightness=0, jitter=0, **kwargs):
        # 初始化颜色变换层
        # 参数:
        #   brightness: 亮度调整范围,0表示不调整,值越大调整范围越大
        #   jitter: 颜色抖动强度,0表示不抖动,值越大抖动越明显
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1337)  # 设置随机种子生成器
        self.brightness = brightness  # 亮度调整参数
        self.jitter = jitter  # 颜色抖动参数
    
    def get_config(self):
        # 获取层配置,用于模型保存和加载
        config = super().get_config()
        config.update({"brightness": self.brightness, "jitter": self.jitter})
        return config
    
    def call(self, images, training=True):
        # 执行颜色变换操作
        # 仅在训练模式下应用变换
        if training:
            batch_size = ops.shape(images)[0]  # 获取批量大小
            
            # 生成亮度缩放因子,形状为[batch_size, 1, 1, 1]
            brightness_scales = 1 + keras.random.uniform(
                (batch_size, 1, 1, 1),
                minval=-self.brightness,
                maxval=self.brightness,
                seed=self.seed_generator,
            )
            
            # 生成颜色抖动矩阵,形状为[batch_size, 1, 3, 3]
            jitter_matrices = keras.random.uniform(
                (batch_size, 1, 3, 3),
                minval=-self.jitter,
                maxval=self.jitter,
                seed=self.seed_generator,
            )
            
            # 构造颜色变换矩阵:单位矩阵 * 亮度缩放 + 抖动矩阵
            color_transforms = (
                ops.tile(ops.expand_dims(ops.eye(3), axis=0), (batch_size, 1, 1, 1))
                * brightness_scales
                + jitter_matrices
            )
            
            # 应用颜色变换并裁剪到[0, 1]范围
            images = ops.clip(ops.matmul(images, color_transforms), 0, 1)
        return images
# 创建数据增强器
# 参数:
#   min_area: 随机裁剪保留的最小区域比例
#   brightness: 亮度调整范围
#   jitter: 颜色抖动强度
# 返回:
#   数据增强序列模型
def get_augmenter(min_area, brightness, jitter):
    # 根据最小区域比例计算缩放因子
    zoom_factor = 1.0 - math.sqrt(min_area)
    
    # 构建数据增强管道,按顺序应用以下增强操作:
    # 1. 像素值归一化 (0-1范围)
    # 2. 水平随机翻转
    # 3. 随机平移
    # 4. 随机缩放
    # 5. 颜色和亮度变换
    return keras.Sequential(
        [
            layers.Rescaling(1 / 255),  # 像素值归一化
            layers.RandomFlip("horizontal"),  # 水平随机翻转
            layers.RandomTranslation(zoom_factor / 2, zoom_factor / 2),  # 随机平移
            layers.RandomZoom((-zoom_factor, 0.0), (-zoom_factor, 0.0)),  # 随机缩放
            RandomColorAffine(brightness, jitter),  # 颜色和亮度变换
        ]
    )
# 可视化数据增强效果函数
# 参数:
#   num_images: 要可视化的图像数量
# 功能:展示不同数据增强方法对同一图像的变换效果
#       帮助理解和验证数据增强策略的有效性
# 增强方法对比:
#   1. 原始图像 - 无变换
#   2. 弱增强 - 用于分类任务的轻量级变换
#   3. 强增强 - 用于对比学习的更显著变换(生成两个不同的增强视图)
def visualize_augmentations(num_images):
    # 从训练数据集中获取样例图像
    # 注意:这里使用的是训练数据集的未标记部分的第一个批次中的图像
    images = next(iter(train_dataset))[0][0][:num_images]
    
    # 为每张图像生成多种增强版本:
    # 1. 原始图像 - 未进行任何变换的基准图像
    # 2. 弱增强版本 - 用于分类任务的数据增强,保留较多语义信息
    # 3. 强增强版本1 - 用于对比学习的第一个增强视图
    # 4. 强增强版本2 - 用于对比学习的第二个增强视图
    # 对比学习需要为同一图像生成两个不同的增强视图作为正样本对
    augmented_images = zip(
        images,
        get_augmenter(**classification_augmentation)(images),  # 弱增强(用于分类任务)
        get_augmenter(**contrastive_augmentation)(images),    # 强增强1(用于对比学习)
        get_augmenter(**contrastive_augmentation)(images),    # 强增强2(用于对比学习)
    )
    
    # 行标题
    row_titles = [
        "Original:",           # 原始图像
        "Weakly augmented:",   # 弱增强(分类用)
        "Strongly augmented:", # 强增强1(对比学习用)
        "Strongly augmented:", # 强增强2(对比学习用)
    ]
    
    # 创建可视化图表
    # 图表布局:4行(num_images)列,每行展示一种增强类型,每列展示一个图像样本的不同增强结果
    # 图表尺寸根据图像数量动态调整,dpi=100保证图像清晰显示
    plt.figure(figsize=(num_images * 2.2, 4 * 2.2), dpi=100)
    
    # 绘制图像网格
    # 外层循环遍历每个图像样本(column)
    # 内层循环遍历该样本的四种增强版本(row)
    for column, image_row in enumerate(augmented_images):
        for row, image in enumerate(image_row):
            # 设置子图位置:4行(num_images)列,当前子图索引为row * num_images + column + 1
            plt.subplot(4, num_images, row * num_images + column + 1)
            # 显示增强后的图像
            plt.imshow(image)
            # 只为第一列添加标题,避免重复标注
            if column == 0:
                plt.title(row_titles[row], loc="left")  # 标题左对齐,提高可读性
            plt.axis("off")  # 关闭坐标轴,突出图像内容
    
    plt.tight_layout()  # 调整布局,自动调整子图参数,避免标签重叠
    # 这种网格布局使不同增强策略的效果对比更加直观
    # 行方向比较同一图像的不同增强版本,列方向比较不同图像的增强一致性
# 可视化数据增强效果
# 展示原始图像、弱增强和强增强的视觉差异
# 弱增强用于分类任务,强增强用于对比学习
visualize_augmentations(num_images=8)
# 创建图像编码器(特征提取器)
# 返回:
#   卷积神经网络编码器,用于从图像中提取特征表示
def get_encoder():
    # 构建卷积神经网络编码器,包含以下层:
    # 1. 4个卷积层,每个卷积层使用3x3核,步长为2(进行下采样)
    # 2. 扁平化层,将卷积特征转换为一维向量
    # 3. 全连接层,输出固定维度的特征向量
    return keras.Sequential(
        [
            # 第一个卷积层,输出通道数为width,使用ReLU激活函数
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            # 第二个卷积层
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            # 第三个卷积层
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            # 第四个卷积层
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            # 扁平化层
            layers.Flatten(),
            # 输出特征向量的全连接层
            layers.Dense(width, activation="relu"),
        ],
        name="encoder",  # 模型名称
    )
# 创建基线模型(纯监督学习)
# 作为性能基准,直接在标记数据上训练
# 用于验证对比学习预训练的有效性
baseline_model = keras.Sequential(
    [
        get_augmenter(**classification_augmentation),  # 使用弱数据增强
        get_encoder(),  # 使用与对比学习相同的编码器
        layers.Dense(10),  # 分类头(输出10个类别的logits)
    ],
    name="baseline_model",  # 模型名称
)
# 编译基线模型
baseline_model.compile(
    optimizer=keras.optimizers.Adam(),  # Adam优化器
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),  # 交叉熵损失
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],  # 评估指标
)
baseline_history = baseline_model.fit(
    labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
)
# 打印基线模型的最佳验证准确率
# 基线模型是纯监督学习方法,直接使用标记数据训练
print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(baseline_history.history["val_acc"]) * 100
    )
)
# 对比学习模型类
# 实现SimCLR算法的核心逻辑,用于无监督特征学习
class ContrastiveModel(keras.Model):
    # 初始化对比学习模型
    # 核心组件包括:数据增强器、特征编码器、投影头和线性评估器
    def __init__(self):
        super().__init__()
        # 温度参数:控制对比损失中相似度的缩放,值越小,分布越陡峭
        self.temperature = temperature
        # 强数据增强器:用于对比学习任务,生成多样化视图
        self.contrastive_augmenter = get_augmenter(**contrastive_augmentation)
        # 弱数据增强器:用于分类任务,保留更多语义信息
        self.classification_augmenter = get_augmenter(**classification_augmentation)
        # 特征编码器:将图像转换为高维特征向量
        self.encoder = get_encoder()
        # 投影头:将编码器输出映射到低维潜在空间,促进对比学习
        self.projection_head = keras.Sequential(
            [
                keras.Input(shape=(width,)),
                layers.Dense(width, activation="relu"),
                layers.Dense(width),
            ],
            name="projection_head",
        )
        # 线性探针:用于评估学习到的特征表示质量
        self.linear_probe = keras.Sequential(
            [layers.Input(shape=(width,)), layers.Dense(10)],
            name="linear_probe",
        )
        # 打印各组件架构信息
        self.encoder.summary()
        self.projection_head.summary()
        self.linear_probe.summary()
    
    # 模型前向传播方法
    # 参数:
    #   inputs: 输入图像批次,形状为(batch_size, height, width, channels)
    #   training: 是否为训练模式
    # 返回:
    #   训练模式:返回特征和投影
    #   推理模式:返回编码器输出的特征
    def call(self, inputs, training=None):
        # 前向传播核心流程:图像 -> 特征提取 -> 投影
        # 获取输入图像的特征表示
        features = self.encoder(inputs, training=training)
        
        # 获取投影特征
        # 在对比学习中,投影特征用于计算样本间的相似性
        projections = self.projection_head(features, training=training)
        
        # 根据训练模式返回不同结果
        if training:
            return features, projections
        else:
            return features  # 推理时仅返回编码器特征
    def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
        super().compile(**kwargs)
        self.contrastive_optimizer = contrastive_optimizer
        self.probe_optimizer = probe_optimizer
        self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.contrastive_loss_tracker = keras.metrics.Mean(name="c_loss")
        self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy(
            name="c_acc"
        )
        self.probe_loss_tracker = keras.metrics.Mean(name="p_loss")
        self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy(name="p_acc")
    @property
    def metrics(self):
        return [
            self.contrastive_loss_tracker,
            self.contrastive_accuracy,
            self.probe_loss_tracker,
            self.probe_accuracy,
        ]
    # 计算对比损失的核心方法
    # 参数:
    #   projections_1: 第一个增强视图的投影特征
    #   projections_2: 第二个增强视图的投影特征
    # 返回:
    #   计算得到的对比损失值
    def contrastive_loss(self, projections_1, projections_2):
        # 对投影特征进行L2归一化
        # 确保特征在单位超球面上,使点积等价于余弦相似度
        projections_1 = ops.normalize(projections_1, axis=1)
        projections_2 = ops.normalize(projections_2, axis=1)
        
        # 计算两个视图投影特征之间的相似性矩阵
        # 使用温度参数缩放相似度,温度越低,分布越陡峭
        similarities = (
            ops.matmul(projections_1, ops.transpose(projections_2)) / self.temperature
        )
        
        # 获取批次大小
        batch_size = ops.shape(projections_1)[0]
        
        # 创建标签:每个样本应该与另一个视图中的相同索引样本匹配
        contrastive_labels = ops.arange(batch_size)
        
        # 更新对比准确率指标
        # 分别计算从视图1到视图2和从视图2到视图1的准确率
        self.contrastive_accuracy.update_state(contrastive_labels, similarities)
        self.contrastive_accuracy.update_state(
            contrastive_labels, ops.transpose(similarities)
        )
        
        # 计算两个方向的交叉熵损失
        # loss_1_2: 从视图1到视图2的对比损失
        loss_1_2 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, similarities, from_logits=True
        )
        # loss_2_1: 从视图2到视图1的对比损失
        loss_2_1 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, ops.transpose(similarities), from_logits=True
        )
        
        # 返回两个方向损失的平均值
        # 这种双向计算确保模型能够从两个角度学习样本之间的关系
        return (loss_1_2 + loss_2_1) / 2
    # 模型训练步骤
    # 参数:
    #   data: 训练数据,包含未标记数据和标记数据的组合
    # 返回:
    #   字典形式的训练指标,包括对比损失、对比准确率、探针损失和探针准确率
    def train_step(self, data):
        # 处理输入数据结构,获取未标记数据和标记数据
        (unlabeled_images, _), (labeled_images, labels) = data
        
        # 对比学习训练部分
        # 1. 合并未标记数据和标记数据,用于对比学习
        images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
        
        # 2. 为每张图像生成两个不同的增强视图
        # 这是对比学习的核心:同一图像的不同增强版本应该具有相似的表示
        augmented_images_1 = self.contrastive_augmenter(images, training=True)
        augmented_images_2 = self.contrastive_augmenter(images, training=True)
        
        # 3. 在第一个梯度带上计算对比学习损失并更新编码器和投影头
        with tf.GradientTape() as tape:
            # 通过编码器提取两个增强视图的特征表示
            features_1 = self.encoder(augmented_images_1, training=True)
            features_2 = self.encoder(augmented_images_2, training=True)
            
            # 通过投影头将特征映射到潜在空间
            projections_1 = self.projection_head(features_1, training=True)
            projections_2 = self.projection_head(features_2, training=True)
            
            # 计算对比损失,衡量正负样本对的相似性
            contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        
        # 4. 计算并应用梯度,更新编码器和投影头的参数
        gradients = tape.gradient(
            contrastive_loss,
            self.encoder.trainable_weights + self.projection_head.trainable_weights,
        )
        self.contrastive_optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights + self.projection_head.trainable_weights,
            )
        )
        
        # 5. 更新对比损失跟踪指标
        self.contrastive_loss_tracker.update_state(contrastive_loss)
        
        # 线性探针(分类任务)训练部分
        # 1. 对标记图像应用弱增强,保持语义信息
        preprocessed_images = self.classification_augmenter(
            labeled_images, training=True
        )
        
        # 2. 在第二个梯度带上计算分类损失并更新线性探针
        with tf.GradientTape() as tape:
            # 使用编码器提取特征(注意:编码器在此阶段不更新参数)
            features = self.encoder(preprocessed_images, training=False)
            
            # 通过线性探针进行分类预测
            class_logits = self.linear_probe(features, training=True)
            
            # 计算分类损失
            probe_loss = self.probe_loss(labels, class_logits)
        
        # 3. 计算并应用梯度,只更新线性探针的参数
        gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
        self.probe_optimizer.apply_gradients(
            zip(gradients, self.linear_probe.trainable_weights)
        )
        
        # 4. 更新线性探针的损失和准确率指标
        self.probe_loss_tracker.update_state(probe_loss)
        self.probe_accuracy.update_state(labels, class_logits)
        
        # 5. 返回所有训练指标
        return {m.name: m.result() for m in self.metrics}
    # 模型测试/评估步骤
    # 参数:
    #   data: 测试数据,包含图像和对应的标签
    # 返回:
    #   字典形式的评估指标,包括探针损失和探针准确率
    def test_step(self, data):
        # 提取测试数据中的图像和标签
        labeled_images, labels = data
        
        # 对测试图像应用弱增强
        # 注意:即使在测试阶段,也应用相同的数据预处理以保持一致性
        preprocessed_images = self.classification_augmenter(
            labeled_images, training=False
        )
        
        # 使用编码器提取特征
        # 训练=False表示不更新编码器参数,仅用于特征提取
        features = self.encoder(preprocessed_images, training=False)
        
        # 通过线性探针进行分类预测
        class_logits = self.linear_probe(features, training=False)
        
        # 计算分类损失
        probe_loss = self.probe_loss(labels, class_logits)
        
        # 更新评估指标
        self.probe_loss_tracker.update_state(probe_loss)
        self.probe_accuracy.update_state(labels, class_logits)
        
        # 返回后两个评估指标(仅探针相关指标)
        return {m.name: m.result() for m in self.metrics[2:]}
# 初始化预训练模型
# 创建ContrastiveModel实例,用于无监督对比学习预训练
pretraining_model = ContrastiveModel()

# 编译预训练模型
# 为不同部分设置不同的优化器:
# 1. contrastive_optimizer: 用于更新编码器和投影头的参数
# 2. probe_optimizer: 仅用于更新线性探针的参数
pretraining_model.compile(
    contrastive_optimizer=keras.optimizers.Adam(learning_rate=0.001),  # 对比学习部分优化器
    probe_optimizer=keras.optimizers.Adam(learning_rate=0.001),       # 线性探针部分优化器
)
# 开始预训练模型
# 训练过程同时进行对比学习(无监督)和线性探针评估(有监督)
pretraining_history = pretraining_model.fit(
    train_dataset,                  # 训练数据(包含未标记和标记数据)
    epochs=num_epochs,              # 预训练轮数
    validation_data=test_dataset    # 验证数据,用于监控性能
)
# 打印对比预训练模型的最佳验证准确率
# 预训练模型使用未标记数据进行对比学习,同时在线性探针上评估性能
print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(pretraining_history.history["val_p_acc"]) * 100
    )
)
# 创建微调模型
# 利用预训练好的编码器进行监督微调,添加新的分类层
finetuning_model = keras.Sequential(
    [
        get_augmenter(**classification_augmentation),  # 复用分类任务的弱数据增强
        pretraining_model.encoder,                     # 使用预训练好的编码器提取特征
        layers.Dense(10),                              # 新的分类层(输出10个类别的logits)
    ],
    name="finetuning_model",
)

# 编译微调模型
# 设置优化器、损失函数和评估指标
finetuning_model.compile(
    optimizer=keras.optimizers.Adam(),                 # Adam优化器
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),  # 交叉熵损失(from_logits=True表示模型输出logits)
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],  # 评估指标:稀疏分类准确率
)

# 开始监督微调
# 使用少量标记数据对模型进行微调,验证集用于监控性能
finetuning_history = finetuning_model.fit(
    labeled_train_dataset,      # 仅使用标记训练数据
    epochs=num_epochs,          # 微调轮数
    validation_data=test_dataset  # 测试集作为验证数据
)

# 打印微调后的最佳验证准确率
# 输出微调过程中达到的最高验证集准确率(百分比形式)
# 打印微调模型的最佳验证准确率
# 微调模型使用预训练好的编码器,在少量标记数据上进一步优化
# 通常表现优于纯监督基线模型和仅预训练模型
print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(finetuning_history.history["val_acc"]) * 100
    )
)
# 绘制训练曲线函数
# 参数:
#   pretraining_history: 自监督预训练模型的训练历史
#   finetuning_history: 监督微调模型的训练历史
#   baseline_history: 纯监督基线模型的训练历史
# 功能:可视化和比较三种不同模型训练方法的验证性能
#       1. 纯监督基线模型
#       2. 自监督预训练模型
#       3. 自监督预训练+监督微调模型
# 图表内容:
#   1. 准确率曲线图:展示不同方法的验证准确率变化
#   2. 损失曲线图:展示不同方法的验证损失变化
# 绘制训练曲线函数
# 参数:
#   pretraining_history: 自监督预训练模型的训练历史
#   finetuning_history: 监督微调模型的训练历史
#   baseline_history: 纯监督基线模型的训练历史
# 功能:可视化和比较三种不同模型训练方法的验证性能
#       1. 纯监督基线模型
#       2. 自监督预训练模型
#       3. 自监督预训练+监督微调模型
# 图表内容:
#   1. 准确率曲线图:展示不同方法的验证准确率变化
#   2. 损失曲线图:展示不同方法的验证损失变化
def plot_training_curves(pretraining_history, finetuning_history, baseline_history):
    # 循环绘制准确率和损失曲线
    # 首先绘制准确率曲线,然后绘制损失曲线
    for metric_key, metric_name in zip(["acc", "loss"], ["accuracy", "loss"]):
        # 创建图表,设置尺寸和分辨率
        plt.figure(figsize=(8, 5), dpi=100)
        # 绘制纯监督基线模型的验证曲线
        # 作为性能基准,直接使用标记数据训练
        plt.plot(
            baseline_history.history[f"val_{metric_key}"],
            label="supervised baseline",
        )
        # 绘制自监督预训练模型的验证曲线
        # 使用大量未标记数据进行对比学习预训练
        plt.plot(
            pretraining_history.history[f"val_p_{metric_key}"],
            label="self-supervised pretraining",
        )
        # 绘制监督微调模型的验证曲线
        # 在预训练基础上,使用少量标记数据进行微调
        plt.plot(
            finetuning_history.history[f"val_{metric_key}"],
            label="supervised finetuning",
        )
        # 添加图例
        plt.legend()
        # 设置图表标题
        plt.title(f"Classification {metric_name} during training")
        # 设置x轴标签
        plt.xlabel("epochs")
        # 设置y轴标签
        plt.ylabel(f"validation {metric_name}")

# 调用训练曲线绘制函数
# 可视化三种方法的性能对比
# 预期结果:
#   1. 监督微调模型 > 自监督预训练模型 > 纯监督基线模型(准确率)
#   2. 监督微调模型 < 自监督预训练模型 < 纯监督基线模型(损失值)
# 这验证了半监督学习方法在标记数据有限情况下的有效性
plot_training_curves(pretraining_history, finetuning_history, baseline_history)
相关推荐
952362 小时前
数据结构-堆
java·数据结构·学习·算法
im_AMBER3 小时前
Leetcode 57
笔记·学习·算法·leetcode
im_AMBER3 小时前
Leetcode 58 | 附:滑动窗口题单
笔记·学习·算法·leetcode
伯明翰java3 小时前
Redis学习笔记-List列表(2)
redis·笔记·学习
云帆小二3 小时前
从开发语言出发如何选择学习考试系统
开发语言·学习
Elias不吃糖4 小时前
总结我的小项目里现在用到的Redis
c++·redis·学习
BullSmall4 小时前
《道德经》第六十三章
学习
AA陈超4 小时前
使用UnrealEngine引擎,实现鼠标点击移动
c++·笔记·学习·ue5·虚幻引擎
BullSmall5 小时前
《道德经》第六十二章
学习