
项目概述
本项目实现了一个基于对比学习的半监督学习框架,主要基于 SimCLR 模型架构。这个实现旨在解决标签数据有限场景下的机器学习问题,通过利用大量未标记数据来提高模型性能。
该项目的核心思想是:
"半监督学习结合了有标签和无标签数据的优势,而对比学习则是一种自监督学习方法,通过学习样本间的相似性来提取有意义的特征表示。"
核心功能
- 对比学习预训练:利用大量未标记数据进行自监督预训练,学习数据的内在结构
- 监督微调:使用少量标记数据对预训练模型进行微调
- 半监督学习:结合标记和未标记数据进行模型训练
- 数据增强:实现多种图像增强策略,提高模型的泛化能力
- 模型评估:支持不同评估指标,如准确率、损失曲线可视化等
支持的架构
- 对比模型:实现了 SimCLR 风格的对比学习模型
- 监督基线模型:提供标准监督学习作为性能比较基准
- 编码器:支持多种基础模型作为编码器,如 ResNet 系列
技术栈
- TensorFlow/Keras:深度学习框架
- NumPy:数据处理
- Matplotlib:结果可视化
- 其他必要的机器学习和数据科学库
实现原理
对比学习的核心思想
根据 main.py 中的注释,对比学习的核心在于:
"通过学习样本间的相似性来提取有意义的特征表示。"
该实现基于以下关键原理:
- 数据增强对:对同一图像应用不同的数据增强策略,生成正样本对
- 对比损失:最大化正样本对之间的相似性,最小化负样本对之间的相似性
- 投影头:将编码器输出映射到对比学习空间的非线性变换层
模型架构
项目实现了两种主要模型:
1. 对比模型 (ContrastiveModel)
python
class ContrastiveModel(keras.Model):
def __init__(self, encoder, projection_units):
# 编码器用于提取特征表示
# 投影头将特征映射到对比学习空间
# ...
该模型包含以下核心组件:
- 编码器:从输入图像中提取特征表示
- 投影头:将编码器输出转换为适合对比学习的向量
- 温度参数:控制对比损失的锐度
2. 监督微调模型
在预训练后,通过添加分类层进行监督微调:
python
# 创建用于微调的监督模型
finetuning_model = keras.Sequential([
pretrained_encoder,
keras.layers.Dense(num_classes, activation="softmax")
])
关键实现细节
数据增强
根据 main.py,项目实现了强大的数据增强策略:
python
def get_augmenter(image_size):
"""创建图像增强流水线"""
# 包含随机裁剪、翻转、颜色抖动等多种增强操作
# ...
主要增强操作包括:
- 随机裁剪和调整大小
- 水平翻转
- 颜色抖动(亮度、对比度、饱和度、色调)
- 随机灰度转换
- 高斯模糊
对比损失函数
实现了基于温度缩放的对比损失:
python
def compute_contrastive_loss(features, temperature=0.1):
# 计算批次内所有样本对之间的相似性
# 应用温度缩放
# 计算 NT-Xent 损失
# ...
训练流程
项目实现了两阶段训练流程:
- 对比预训练:使用所有数据(标记和未标记)进行自监督学习
- 监督微调:仅使用标记数据对预训练编码器进行微调
使用方法、参数配置和运行示例
安装依赖
运行本项目需要安装以下依赖:
bash
pip install tensorflow numpy matplotlib
超参数配置详解
根据 main.py 中的定义,以下是主要超参数及其说明:
python
# 超参数设置
batch_size = 256 # 批次大小,影响训练速度和稳定性
image_size = 224 # 输入图像尺寸
projection_units = 128 # 投影头输出维度
num_epochs = 100 # 训练轮数
learning_rate = 0.001 # 学习率
temperature = 0.1 # 对比损失的温度参数,控制相似度分布的锐度
其他重要参数:
python
# 数据集参数
data_dir = "./data" # 数据集路径
num_classes = 10 # 分类任务的类别数
labeled_ratio = 0.1 # 标记数据的比例
# 模型参数
base_model = "resnet50" # 基础编码器模型
运行示例
1. 完整训练流程
bash
python main.py
默认情况下,脚本将执行以下步骤:
- 加载并准备数据集(包括标记和未标记数据)
- 创建数据增强流水线
- 初始化对比模型并进行预训练
- 使用标记数据对预训练模型进行微调
- 评估模型性能并生成可视化结果
2. 仅运行对比预训练
可以修改 main.py 中的相关代码,只执行预训练阶段:
python
# 预训练对比模型
contrastive_history = contrastive_model.fit(
train_dataset,
epochs=num_epochs,
validation_data=val_dataset
)
# 保存预训练模型
contrastive_model.encoder.save("pretrained_encoder")
3. 使用自定义数据集
要使用自定义数据集,需要修改数据加载部分:
python
# 自定义数据集加载
def load_custom_dataset(data_dir):
# 实现自定义数据加载逻辑
# 返回标记和未标记的数据集
pass
train_data, val_data, test_data = load_custom_dataset("./your_data")
结果评估
训练完成后,模型会自动在测试集上进行评估,并生成以下指标:
# 测试集上的性能评估
accuracy = supervised_model.evaluate(test_dataset)
print(f"测试准确率: {accuracy:.4f}")
训练过程中的损失和准确率曲线会保存为图像文件,便于分析模型收敛情况。
项目结构和工作流程
详细项目结构
根据项目文件组织,项目包含以下关键组件:
semisupervised_simclr/
├── main.py # 主实现文件,包含完整的训练和评估流程
│ ├── 文件头部说明和注释 # 项目概述和基本概念介绍
│ ├── 环境设置和跨平台兼容 # 资源限制设置和平台检测
│ ├── 超参数配置 # 所有可调参数的集中管理
│ ├── 数据处理模块 # 数据集加载和预处理功能
│ ├── 数据增强实现 # 自定义增强变换和流水线
│ ├── 模型架构定义 # 编码器、对比模型和监督模型
│ ├── 训练循环实现 # train_step 和 test_step 方法
│ ├── 主训练逻辑 # 预训练和微调流程
│ └── 结果评估和可视化 # 性能指标计算和图表生成
└── semisupervised_simclr.py # 辅助模块,可能包含特定功能实现
main.py 主要组件详解
-
数据处理组件:
- 实现了数据集的加载、预处理和批次生成
- 支持标记和未标记数据的混合使用
- 包含数据分割和验证集构建功能
-
数据增强模块:
- 基于 main.py 中的
get_augmenter()函数实现 - 包含自定义的
RandomColorAffine增强层 - 支持多种增强操作的组合和参数调整
- 基于 main.py 中的
-
模型架构:
get_encoder()函数用于创建特征提取器ContrastiveModel类实现对比学习逻辑- 监督模型构建和微调功能
-
训练和优化:
- 自定义训练循环
- 对比损失计算
- 学习率调度和优化器配置
详细工作流程
项目的完整工作流程包括以下主要阶段:
1. 初始化阶段
python
# 环境和资源配置
import necessary modules
set random seeds
configure platform-specific settings
# 超参数初始化
batch_size = 256
image_size = 224
# 其他参数设置...
2. 数据处理阶段
python
# 加载数据集
train_data, val_data, test_data = load_dataset(data_dir)
# 创建数据增强器
augmenter = get_augmenter(image_size)
# 构建训练和验证数据集
train_dataset = create_contrastive_dataset(train_data, augmenter, batch_size)
val_dataset = create_contrastive_dataset(val_data, augmenter, batch_size)
3. 模型构建阶段
python
# 创建编码器
encoder = get_encoder(image_size, base_model=base_model)
# 创建对比模型
contrastive_model = ContrastiveModel(
encoder=encoder,
projection_units=projection_units,
temperature=temperature
)
# 编译模型
contrastive_model.compile(optimizer=keras.optimizers.Adam(learning_rate))
4. 对比预训练阶段
python
# 执行预训练
contrastive_history = contrastive_model.fit(
train_dataset,
epochs=num_epochs,
validation_data=val_dataset
)
5. 监督微调阶段
python
# 提取预训练编码器
pretrained_encoder = contrastive_model.encoder
# 创建监督微调模型
finetuning_model = keras.Sequential([
pretrained_encoder,
keras.layers.Dense(num_classes, activation="softmax")
])
# 编译和训练微调模型
finetuning_model.compile(...)
finetuning_history = finetuning_model.fit(...)
6. 评估和可视化阶段
python
# 评估模型性能
test_accuracy = finetuning_model.evaluate(test_dataset)
# 绘制训练曲线
plot_training_curves(contrastive_history, finetuning_history)
数据流和转换过程
下面是项目中数据的完整流程:
┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐
│ 原始图像数据 │ ──> │ 数据增强 │ ──> │ 生成对比学习样本对 │
└─────────────┘ └─────────────┘ └─────────────────────┘
│
▼
┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐
│ 评估结果 │ <── │ 监督微调 │ <── │ 对比预训练 │
└─────────────┘ └─────────────┘ └─────────────────────┘
│ │ │
▼ ▼ ▼
┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐
│ 可视化结果 │ │ 分类头 │ │ 编码器 + 投影头 │
└─────────────┘ └─────────────┘ └─────────────────────┘
核心代码流程
根据 main.py 的实现,以下是项目的核心代码流程:
-
定义数据增强和预处理:
- 实现
RandomColorAffine和get_augmenter()
- 实现
-
创建模型架构:
- 使用
get_encoder()创建基础编码器 - 实现
ContrastiveModel类
- 使用
-
训练对比模型:
- 实现自定义
train_step()和test_step() - 执行对比预训练
- 实现自定义
-
微调监督模型:
- 重用预训练编码器
- 添加分类层并训练
-
评估和可视化:
- 计算性能指标
- 生成训练曲线和结果报告
进一步改进和相关工作
根据 main.py 中的注释,该实现可以进一步改进:
"架构改进:
实现 MoCo 或 BYOL 等其他对比学习方法
添加动量编码器提高性能
探索不同的投影头结构"
"超参数调优笔记:温度参数对对比学习效果影响显著
数据增强策略对性能有重要影响
学习率调度对于稳定训练很重要"
相关工作包括 BYOL 和 SimSiam,它们在无负样本的情况下实现了对比学习。
总结
本项目实现了一个完整的半监督对比学习框架,通过结合标记和未标记数据,在数据有限的场景下提高模型性能。该实现基于 SimCLR 架构,包含了从数据准备到模型评估的完整工作流程,可作为半监督学习研究和应用的基础框架。
完整代码
python
# -*- coding: utf-8 -*-
# 半监督对比学习框架实现(Semi-Supervised SimCLR)
# 该文件实现了基于SimCLR架构的半监督对比学习模型,用于图像分类任务
# 主要功能:
# 1. 使用未标记数据进行对比预训练,学习图像的有效表示
# 2. 使用少量标记数据进行线性探针评估和监督微调
# 3. 与纯监督基线模型进行性能比较
# 4. 可视化数据增强效果和训练曲线
import os
# 设置Keras后端为TensorFlow
os.environ["KERAS_BACKEND"] = "tensorflow"
# 调整系统资源限制,以支持更大的文件打开数量
import resource
low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
# 导入必要的库
import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import keras
from keras import ops
from keras import layers
# 数据集配置参数
unlabeled_dataset_size = 100000 # 未标记数据集大小
labeled_dataset_size = 5000 # 标记数据集大小
image_channels = 3 # 图像通道数(RGB)
# 训练配置参数
num_epochs = 20 # 训练轮数
batch_size = 525 # 批量大小
width = 128 # 模型中间层特征维度
# 对比学习配置参数
temperature = 0.1 # 温度参数,控制对比损失的平滑程度
# 图像增强配置
# 对比学习使用较强的数据增强,促进模型学习鲁棒特征
contrastive_augmentation = {
"min_area": 0.25, # 随机裁剪保留的最小区域比例
"brightness": 0.6, # 亮度调整范围
"jitter": 0.2 # 颜色抖动强度
}
# 分类任务使用较弱的数据增强,保持图像语义信息
classification_augmentation = {
"min_area": 0.75, # 随机裁剪保留的最小区域比例
"brightness": 0.3, # 亮度调整范围
"jitter": 0.1 # 颜色抖动强度
}
# 准备训练和测试数据集
# 功能:加载STL-10数据集,分离未标记数据、标记数据和测试数据,并进行预处理
# 返回:
# train_dataset: 包含未标记数据和标记数据的组合数据集
# labeled_train_dataset: 仅标记训练数据的数据集
# test_dataset: 测试数据集
def prepare_dataset():
# 计算每轮训练的步数
steps_per_epoch = (unlabeled_dataset_size + labeled_dataset_size) // batch_size
# 计算未标记数据和标记数据各自的批量大小
unlabeled_batch_size = unlabeled_dataset_size // steps_per_epoch
labeled_batch_size = labeled_dataset_size // steps_per_epoch
print(
f"batch size is {unlabeled_batch_size} (unlabeled) + {labeled_batch_size} (labeled)"
)
# 加载未标记数据集 (STL-10的'unlabelled'分割)
unlabeled_train_dataset = (
tfds.load("stl10", split="unlabelled", as_supervised=True, shuffle_files=False)
.shuffle(buffer_size=10 * unlabeled_batch_size) # 打乱数据,增加随机性
.batch(unlabeled_batch_size) # 批量处理
)
# 加载标记训练数据集 (STL-10的'train'分割)
labeled_train_dataset = (
tfds.load("stl10", split="train", as_supervised=True, shuffle_files=False)
.shuffle(buffer_size=10 * labeled_batch_size) # 打乱数据
.batch(labeled_batch_size) # 批量处理
)
# 加载测试数据集 (STL-10的'test'分割)
test_dataset = (
tfds.load("stl10", split="test", as_supervised=True)
.batch(batch_size) # 批量处理
.prefetch(buffer_size=tf.data.AUTOTUNE) # 预加载数据以提高性能
)
# 组合未标记数据和标记数据,创建用于对比学习的训练数据集
train_dataset = tf.data.Dataset.zip(
(unlabeled_train_dataset, labeled_train_dataset)
).prefetch(buffer_size=tf.data.AUTOTUNE) # 预加载提高性能
return train_dataset, labeled_train_dataset, test_dataset
# 主程序入口
# SimCLR(对比学习)半监督学习实现的整体流程:
# 1. 准备数据集(大量未标记数据 + 少量标记数据)
# 2. 可视化数据增强效果
# 3. 训练三种模型进行对比:
# a. 基线模型:纯监督学习(仅使用少量标记数据)
# b. 预训练模型:自监督对比学习(使用大量未标记数据)+ 线性探针评估
# c. 微调模型:在预训练模型基础上进行监督微调(使用少量标记数据)
# 4. 绘制训练曲线,比较三种方法的性能差异
# 准备数据集
# 生成三种数据集:
# 1. train_dataset: 包含未标记和标记数据,用于对比学习预训练
# 2. labeled_train_dataset: 仅包含标记数据,用于监督学习基线和微调
# 3. test_dataset: 测试集,用于评估模型性能
train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()
# 自定义颜色变换层,用于实现图像亮度和颜色抖动增强
class RandomColorAffine(layers.Layer):
def __init__(self, brightness=0, jitter=0, **kwargs):
# 初始化颜色变换层
# 参数:
# brightness: 亮度调整范围,0表示不调整,值越大调整范围越大
# jitter: 颜色抖动强度,0表示不抖动,值越大抖动越明显
super().__init__(**kwargs)
self.seed_generator = keras.random.SeedGenerator(1337) # 设置随机种子生成器
self.brightness = brightness # 亮度调整参数
self.jitter = jitter # 颜色抖动参数
def get_config(self):
# 获取层配置,用于模型保存和加载
config = super().get_config()
config.update({"brightness": self.brightness, "jitter": self.jitter})
return config
def call(self, images, training=True):
# 执行颜色变换操作
# 仅在训练模式下应用变换
if training:
batch_size = ops.shape(images)[0] # 获取批量大小
# 生成亮度缩放因子,形状为[batch_size, 1, 1, 1]
brightness_scales = 1 + keras.random.uniform(
(batch_size, 1, 1, 1),
minval=-self.brightness,
maxval=self.brightness,
seed=self.seed_generator,
)
# 生成颜色抖动矩阵,形状为[batch_size, 1, 3, 3]
jitter_matrices = keras.random.uniform(
(batch_size, 1, 3, 3),
minval=-self.jitter,
maxval=self.jitter,
seed=self.seed_generator,
)
# 构造颜色变换矩阵:单位矩阵 * 亮度缩放 + 抖动矩阵
color_transforms = (
ops.tile(ops.expand_dims(ops.eye(3), axis=0), (batch_size, 1, 1, 1))
* brightness_scales
+ jitter_matrices
)
# 应用颜色变换并裁剪到[0, 1]范围
images = ops.clip(ops.matmul(images, color_transforms), 0, 1)
return images
# 创建数据增强器
# 参数:
# min_area: 随机裁剪保留的最小区域比例
# brightness: 亮度调整范围
# jitter: 颜色抖动强度
# 返回:
# 数据增强序列模型
def get_augmenter(min_area, brightness, jitter):
# 根据最小区域比例计算缩放因子
zoom_factor = 1.0 - math.sqrt(min_area)
# 构建数据增强管道,按顺序应用以下增强操作:
# 1. 像素值归一化 (0-1范围)
# 2. 水平随机翻转
# 3. 随机平移
# 4. 随机缩放
# 5. 颜色和亮度变换
return keras.Sequential(
[
layers.Rescaling(1 / 255), # 像素值归一化
layers.RandomFlip("horizontal"), # 水平随机翻转
layers.RandomTranslation(zoom_factor / 2, zoom_factor / 2), # 随机平移
layers.RandomZoom((-zoom_factor, 0.0), (-zoom_factor, 0.0)), # 随机缩放
RandomColorAffine(brightness, jitter), # 颜色和亮度变换
]
)
# 可视化数据增强效果函数
# 参数:
# num_images: 要可视化的图像数量
# 功能:展示不同数据增强方法对同一图像的变换效果
# 帮助理解和验证数据增强策略的有效性
# 增强方法对比:
# 1. 原始图像 - 无变换
# 2. 弱增强 - 用于分类任务的轻量级变换
# 3. 强增强 - 用于对比学习的更显著变换(生成两个不同的增强视图)
def visualize_augmentations(num_images):
# 从训练数据集中获取样例图像
# 注意:这里使用的是训练数据集的未标记部分的第一个批次中的图像
images = next(iter(train_dataset))[0][0][:num_images]
# 为每张图像生成多种增强版本:
# 1. 原始图像 - 未进行任何变换的基准图像
# 2. 弱增强版本 - 用于分类任务的数据增强,保留较多语义信息
# 3. 强增强版本1 - 用于对比学习的第一个增强视图
# 4. 强增强版本2 - 用于对比学习的第二个增强视图
# 对比学习需要为同一图像生成两个不同的增强视图作为正样本对
augmented_images = zip(
images,
get_augmenter(**classification_augmentation)(images), # 弱增强(用于分类任务)
get_augmenter(**contrastive_augmentation)(images), # 强增强1(用于对比学习)
get_augmenter(**contrastive_augmentation)(images), # 强增强2(用于对比学习)
)
# 行标题
row_titles = [
"Original:", # 原始图像
"Weakly augmented:", # 弱增强(分类用)
"Strongly augmented:", # 强增强1(对比学习用)
"Strongly augmented:", # 强增强2(对比学习用)
]
# 创建可视化图表
# 图表布局:4行(num_images)列,每行展示一种增强类型,每列展示一个图像样本的不同增强结果
# 图表尺寸根据图像数量动态调整,dpi=100保证图像清晰显示
plt.figure(figsize=(num_images * 2.2, 4 * 2.2), dpi=100)
# 绘制图像网格
# 外层循环遍历每个图像样本(column)
# 内层循环遍历该样本的四种增强版本(row)
for column, image_row in enumerate(augmented_images):
for row, image in enumerate(image_row):
# 设置子图位置:4行(num_images)列,当前子图索引为row * num_images + column + 1
plt.subplot(4, num_images, row * num_images + column + 1)
# 显示增强后的图像
plt.imshow(image)
# 只为第一列添加标题,避免重复标注
if column == 0:
plt.title(row_titles[row], loc="left") # 标题左对齐,提高可读性
plt.axis("off") # 关闭坐标轴,突出图像内容
plt.tight_layout() # 调整布局,自动调整子图参数,避免标签重叠
# 这种网格布局使不同增强策略的效果对比更加直观
# 行方向比较同一图像的不同增强版本,列方向比较不同图像的增强一致性
# 可视化数据增强效果
# 展示原始图像、弱增强和强增强的视觉差异
# 弱增强用于分类任务,强增强用于对比学习
visualize_augmentations(num_images=8)
# 创建图像编码器(特征提取器)
# 返回:
# 卷积神经网络编码器,用于从图像中提取特征表示
def get_encoder():
# 构建卷积神经网络编码器,包含以下层:
# 1. 4个卷积层,每个卷积层使用3x3核,步长为2(进行下采样)
# 2. 扁平化层,将卷积特征转换为一维向量
# 3. 全连接层,输出固定维度的特征向量
return keras.Sequential(
[
# 第一个卷积层,输出通道数为width,使用ReLU激活函数
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
# 第二个卷积层
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
# 第三个卷积层
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
# 第四个卷积层
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
# 扁平化层
layers.Flatten(),
# 输出特征向量的全连接层
layers.Dense(width, activation="relu"),
],
name="encoder", # 模型名称
)
# 创建基线模型(纯监督学习)
# 作为性能基准,直接在标记数据上训练
# 用于验证对比学习预训练的有效性
baseline_model = keras.Sequential(
[
get_augmenter(**classification_augmentation), # 使用弱数据增强
get_encoder(), # 使用与对比学习相同的编码器
layers.Dense(10), # 分类头(输出10个类别的logits)
],
name="baseline_model", # 模型名称
)
# 编译基线模型
baseline_model.compile(
optimizer=keras.optimizers.Adam(), # Adam优化器
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 交叉熵损失
metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")], # 评估指标
)
baseline_history = baseline_model.fit(
labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
)
# 打印基线模型的最佳验证准确率
# 基线模型是纯监督学习方法,直接使用标记数据训练
print(
"Maximal validation accuracy: {:.2f}%".format(
max(baseline_history.history["val_acc"]) * 100
)
)
# 对比学习模型类
# 实现SimCLR算法的核心逻辑,用于无监督特征学习
class ContrastiveModel(keras.Model):
# 初始化对比学习模型
# 核心组件包括:数据增强器、特征编码器、投影头和线性评估器
def __init__(self):
super().__init__()
# 温度参数:控制对比损失中相似度的缩放,值越小,分布越陡峭
self.temperature = temperature
# 强数据增强器:用于对比学习任务,生成多样化视图
self.contrastive_augmenter = get_augmenter(**contrastive_augmentation)
# 弱数据增强器:用于分类任务,保留更多语义信息
self.classification_augmenter = get_augmenter(**classification_augmentation)
# 特征编码器:将图像转换为高维特征向量
self.encoder = get_encoder()
# 投影头:将编码器输出映射到低维潜在空间,促进对比学习
self.projection_head = keras.Sequential(
[
keras.Input(shape=(width,)),
layers.Dense(width, activation="relu"),
layers.Dense(width),
],
name="projection_head",
)
# 线性探针:用于评估学习到的特征表示质量
self.linear_probe = keras.Sequential(
[layers.Input(shape=(width,)), layers.Dense(10)],
name="linear_probe",
)
# 打印各组件架构信息
self.encoder.summary()
self.projection_head.summary()
self.linear_probe.summary()
# 模型前向传播方法
# 参数:
# inputs: 输入图像批次,形状为(batch_size, height, width, channels)
# training: 是否为训练模式
# 返回:
# 训练模式:返回特征和投影
# 推理模式:返回编码器输出的特征
def call(self, inputs, training=None):
# 前向传播核心流程:图像 -> 特征提取 -> 投影
# 获取输入图像的特征表示
features = self.encoder(inputs, training=training)
# 获取投影特征
# 在对比学习中,投影特征用于计算样本间的相似性
projections = self.projection_head(features, training=training)
# 根据训练模式返回不同结果
if training:
return features, projections
else:
return features # 推理时仅返回编码器特征
def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
super().compile(**kwargs)
self.contrastive_optimizer = contrastive_optimizer
self.probe_optimizer = probe_optimizer
self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
self.contrastive_loss_tracker = keras.metrics.Mean(name="c_loss")
self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy(
name="c_acc"
)
self.probe_loss_tracker = keras.metrics.Mean(name="p_loss")
self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy(name="p_acc")
@property
def metrics(self):
return [
self.contrastive_loss_tracker,
self.contrastive_accuracy,
self.probe_loss_tracker,
self.probe_accuracy,
]
# 计算对比损失的核心方法
# 参数:
# projections_1: 第一个增强视图的投影特征
# projections_2: 第二个增强视图的投影特征
# 返回:
# 计算得到的对比损失值
def contrastive_loss(self, projections_1, projections_2):
# 对投影特征进行L2归一化
# 确保特征在单位超球面上,使点积等价于余弦相似度
projections_1 = ops.normalize(projections_1, axis=1)
projections_2 = ops.normalize(projections_2, axis=1)
# 计算两个视图投影特征之间的相似性矩阵
# 使用温度参数缩放相似度,温度越低,分布越陡峭
similarities = (
ops.matmul(projections_1, ops.transpose(projections_2)) / self.temperature
)
# 获取批次大小
batch_size = ops.shape(projections_1)[0]
# 创建标签:每个样本应该与另一个视图中的相同索引样本匹配
contrastive_labels = ops.arange(batch_size)
# 更新对比准确率指标
# 分别计算从视图1到视图2和从视图2到视图1的准确率
self.contrastive_accuracy.update_state(contrastive_labels, similarities)
self.contrastive_accuracy.update_state(
contrastive_labels, ops.transpose(similarities)
)
# 计算两个方向的交叉熵损失
# loss_1_2: 从视图1到视图2的对比损失
loss_1_2 = keras.losses.sparse_categorical_crossentropy(
contrastive_labels, similarities, from_logits=True
)
# loss_2_1: 从视图2到视图1的对比损失
loss_2_1 = keras.losses.sparse_categorical_crossentropy(
contrastive_labels, ops.transpose(similarities), from_logits=True
)
# 返回两个方向损失的平均值
# 这种双向计算确保模型能够从两个角度学习样本之间的关系
return (loss_1_2 + loss_2_1) / 2
# 模型训练步骤
# 参数:
# data: 训练数据,包含未标记数据和标记数据的组合
# 返回:
# 字典形式的训练指标,包括对比损失、对比准确率、探针损失和探针准确率
def train_step(self, data):
# 处理输入数据结构,获取未标记数据和标记数据
(unlabeled_images, _), (labeled_images, labels) = data
# 对比学习训练部分
# 1. 合并未标记数据和标记数据,用于对比学习
images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
# 2. 为每张图像生成两个不同的增强视图
# 这是对比学习的核心:同一图像的不同增强版本应该具有相似的表示
augmented_images_1 = self.contrastive_augmenter(images, training=True)
augmented_images_2 = self.contrastive_augmenter(images, training=True)
# 3. 在第一个梯度带上计算对比学习损失并更新编码器和投影头
with tf.GradientTape() as tape:
# 通过编码器提取两个增强视图的特征表示
features_1 = self.encoder(augmented_images_1, training=True)
features_2 = self.encoder(augmented_images_2, training=True)
# 通过投影头将特征映射到潜在空间
projections_1 = self.projection_head(features_1, training=True)
projections_2 = self.projection_head(features_2, training=True)
# 计算对比损失,衡量正负样本对的相似性
contrastive_loss = self.contrastive_loss(projections_1, projections_2)
# 4. 计算并应用梯度,更新编码器和投影头的参数
gradients = tape.gradient(
contrastive_loss,
self.encoder.trainable_weights + self.projection_head.trainable_weights,
)
self.contrastive_optimizer.apply_gradients(
zip(
gradients,
self.encoder.trainable_weights + self.projection_head.trainable_weights,
)
)
# 5. 更新对比损失跟踪指标
self.contrastive_loss_tracker.update_state(contrastive_loss)
# 线性探针(分类任务)训练部分
# 1. 对标记图像应用弱增强,保持语义信息
preprocessed_images = self.classification_augmenter(
labeled_images, training=True
)
# 2. 在第二个梯度带上计算分类损失并更新线性探针
with tf.GradientTape() as tape:
# 使用编码器提取特征(注意:编码器在此阶段不更新参数)
features = self.encoder(preprocessed_images, training=False)
# 通过线性探针进行分类预测
class_logits = self.linear_probe(features, training=True)
# 计算分类损失
probe_loss = self.probe_loss(labels, class_logits)
# 3. 计算并应用梯度,只更新线性探针的参数
gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
self.probe_optimizer.apply_gradients(
zip(gradients, self.linear_probe.trainable_weights)
)
# 4. 更新线性探针的损失和准确率指标
self.probe_loss_tracker.update_state(probe_loss)
self.probe_accuracy.update_state(labels, class_logits)
# 5. 返回所有训练指标
return {m.name: m.result() for m in self.metrics}
# 模型测试/评估步骤
# 参数:
# data: 测试数据,包含图像和对应的标签
# 返回:
# 字典形式的评估指标,包括探针损失和探针准确率
def test_step(self, data):
# 提取测试数据中的图像和标签
labeled_images, labels = data
# 对测试图像应用弱增强
# 注意:即使在测试阶段,也应用相同的数据预处理以保持一致性
preprocessed_images = self.classification_augmenter(
labeled_images, training=False
)
# 使用编码器提取特征
# 训练=False表示不更新编码器参数,仅用于特征提取
features = self.encoder(preprocessed_images, training=False)
# 通过线性探针进行分类预测
class_logits = self.linear_probe(features, training=False)
# 计算分类损失
probe_loss = self.probe_loss(labels, class_logits)
# 更新评估指标
self.probe_loss_tracker.update_state(probe_loss)
self.probe_accuracy.update_state(labels, class_logits)
# 返回后两个评估指标(仅探针相关指标)
return {m.name: m.result() for m in self.metrics[2:]}
# 初始化预训练模型
# 创建ContrastiveModel实例,用于无监督对比学习预训练
pretraining_model = ContrastiveModel()
# 编译预训练模型
# 为不同部分设置不同的优化器:
# 1. contrastive_optimizer: 用于更新编码器和投影头的参数
# 2. probe_optimizer: 仅用于更新线性探针的参数
pretraining_model.compile(
contrastive_optimizer=keras.optimizers.Adam(learning_rate=0.001), # 对比学习部分优化器
probe_optimizer=keras.optimizers.Adam(learning_rate=0.001), # 线性探针部分优化器
)
# 开始预训练模型
# 训练过程同时进行对比学习(无监督)和线性探针评估(有监督)
pretraining_history = pretraining_model.fit(
train_dataset, # 训练数据(包含未标记和标记数据)
epochs=num_epochs, # 预训练轮数
validation_data=test_dataset # 验证数据,用于监控性能
)
# 打印对比预训练模型的最佳验证准确率
# 预训练模型使用未标记数据进行对比学习,同时在线性探针上评估性能
print(
"Maximal validation accuracy: {:.2f}%".format(
max(pretraining_history.history["val_p_acc"]) * 100
)
)
# 创建微调模型
# 利用预训练好的编码器进行监督微调,添加新的分类层
finetuning_model = keras.Sequential(
[
get_augmenter(**classification_augmentation), # 复用分类任务的弱数据增强
pretraining_model.encoder, # 使用预训练好的编码器提取特征
layers.Dense(10), # 新的分类层(输出10个类别的logits)
],
name="finetuning_model",
)
# 编译微调模型
# 设置优化器、损失函数和评估指标
finetuning_model.compile(
optimizer=keras.optimizers.Adam(), # Adam优化器
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 交叉熵损失(from_logits=True表示模型输出logits)
metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")], # 评估指标:稀疏分类准确率
)
# 开始监督微调
# 使用少量标记数据对模型进行微调,验证集用于监控性能
finetuning_history = finetuning_model.fit(
labeled_train_dataset, # 仅使用标记训练数据
epochs=num_epochs, # 微调轮数
validation_data=test_dataset # 测试集作为验证数据
)
# 打印微调后的最佳验证准确率
# 输出微调过程中达到的最高验证集准确率(百分比形式)
# 打印微调模型的最佳验证准确率
# 微调模型使用预训练好的编码器,在少量标记数据上进一步优化
# 通常表现优于纯监督基线模型和仅预训练模型
print(
"Maximal validation accuracy: {:.2f}%".format(
max(finetuning_history.history["val_acc"]) * 100
)
)
# 绘制训练曲线函数
# 参数:
# pretraining_history: 自监督预训练模型的训练历史
# finetuning_history: 监督微调模型的训练历史
# baseline_history: 纯监督基线模型的训练历史
# 功能:可视化和比较三种不同模型训练方法的验证性能
# 1. 纯监督基线模型
# 2. 自监督预训练模型
# 3. 自监督预训练+监督微调模型
# 图表内容:
# 1. 准确率曲线图:展示不同方法的验证准确率变化
# 2. 损失曲线图:展示不同方法的验证损失变化
# 绘制训练曲线函数
# 参数:
# pretraining_history: 自监督预训练模型的训练历史
# finetuning_history: 监督微调模型的训练历史
# baseline_history: 纯监督基线模型的训练历史
# 功能:可视化和比较三种不同模型训练方法的验证性能
# 1. 纯监督基线模型
# 2. 自监督预训练模型
# 3. 自监督预训练+监督微调模型
# 图表内容:
# 1. 准确率曲线图:展示不同方法的验证准确率变化
# 2. 损失曲线图:展示不同方法的验证损失变化
def plot_training_curves(pretraining_history, finetuning_history, baseline_history):
# 循环绘制准确率和损失曲线
# 首先绘制准确率曲线,然后绘制损失曲线
for metric_key, metric_name in zip(["acc", "loss"], ["accuracy", "loss"]):
# 创建图表,设置尺寸和分辨率
plt.figure(figsize=(8, 5), dpi=100)
# 绘制纯监督基线模型的验证曲线
# 作为性能基准,直接使用标记数据训练
plt.plot(
baseline_history.history[f"val_{metric_key}"],
label="supervised baseline",
)
# 绘制自监督预训练模型的验证曲线
# 使用大量未标记数据进行对比学习预训练
plt.plot(
pretraining_history.history[f"val_p_{metric_key}"],
label="self-supervised pretraining",
)
# 绘制监督微调模型的验证曲线
# 在预训练基础上,使用少量标记数据进行微调
plt.plot(
finetuning_history.history[f"val_{metric_key}"],
label="supervised finetuning",
)
# 添加图例
plt.legend()
# 设置图表标题
plt.title(f"Classification {metric_name} during training")
# 设置x轴标签
plt.xlabel("epochs")
# 设置y轴标签
plt.ylabel(f"validation {metric_name}")
# 调用训练曲线绘制函数
# 可视化三种方法的性能对比
# 预期结果:
# 1. 监督微调模型 > 自监督预训练模型 > 纯监督基线模型(准确率)
# 2. 监督微调模型 < 自监督预训练模型 < 纯监督基线模型(损失值)
# 这验证了半监督学习方法在标记数据有限情况下的有效性
plot_training_curves(pretraining_history, finetuning_history, baseline_history)