论文阅读笔记:《Curriculum Coarse-to-Fine Selection for High-IPC Dataset Distillation》

论文阅读笔记:《Curriculum Coarse-to-Fine Selection for High-IPC Dataset Distillation》

CVPR25 github

一句话总结:

CCFS基于组合范式 (轨迹匹配+选择真实图像),通过"粗过滤+精选"课程式框架,动态补充合成集弱点,显著提升高IPC设定下的数据集蒸馏性能,是目前高IPC场景下的SOTA方法。


1.背景与动机

  • Dataset Distillation: 将一个大规模训练集压缩成一个小型合成数据集,使得在此合成集上训练的模型性能接近用原始全量数据训练的模型。
  • IPC(Image Per Class):每类合成图像数。低IPC场景下(每类几张图),已有方法表现不错;但IPC增大(要合成更多图像)时,性能往往退化,甚至不如简单随机抽样。
  • 核心问题:高IPC时合成集过于"平均",缺少稀有/复杂特征(hard samples),导致合成集覆盖不足;已有的混合蒸馏+真实样本方法(如SelMatch)是一次性静态选样,缺乏与合成集的动态互补。

2.核心贡献

  1. 不兼容性诊断:分析了"先选真实样本再蒸馏"范式下,静态真样本与动态蒸馏集互补不足的问题。
  2. CCFS方法 :提出一种课程式(Curriculum)"从粗到细"动态选真样本框架,将选样分为两阶段:
    • 粗过滤(Coarse):用当前合成集训练的filter模型识别"还没学会"的真实样本(即被错分的样本)
    • 精细选择(fine):在这些候选中,根据"难度分数"或直接用filter logits选出"最简单但尚未学会"的样本,逐步补充到合成集中。
  3. 实证效果:在CIFAR-10/100和Tiny-ImageNet的高IPC设置(压缩比5%~30%)下,CCFS刷新多项SOTA,部分场景下性能仅比全量训练低0.3%。

3.方法详解

整体流程:

  1. 初始化
    • 从任一基础蒸馏算法(如CDA)得到初始合成集 D d i s t i l l D_{distill} Ddistill
    • 令当前合成集 S 0 = D d i s t i l l S_0=D_{distill} S0=Ddistill
  2. 课程循环(共j阶)
    • 训练Filter:在 S j − 1 S_{j-1} Sj−1上蒸馏训练一个filter模型 ϕ j ϕ_j ϕj, 让它学会当前合成集的决策边界。
    • Coarse:用 ϕ j ϕ_j ϕj在原始训练集T上做推理,挑出被错分的样本集合 D m i s j D_{mis}^j Dmisj
    • Fine:对 D m i s j D_{mis}^j Dmisj内部进行排序,选出每类最"简单未学会"的前 k j k_j kj张,构成 D r e a l j D_{\mathrm{real}}^j Drealj
    • 更新: S j = S j − 1 ∪ D r e a l j S_j = S_{j-1} \cup D^j_{\mathrm{real}} Sj=Sj−1∪Drealj

解释一下为什么选"简单未学会"的样本

错分样本集合 D m i s s D_{miss} Dmiss反映了S中的局限性 。在这些局限性中,更简单的特征相对于更复杂的特征而言,对模型训练的益处更大,因为它们更容易被学习。预选计算的难度分数 能够从全局角度有效衡量样本特征的相对难度,指导下一步的精细选择。通过从误分类样本中选择最简单的特征,可以获得最优的 D r e a l D_{real} Dreal,同时避免引入可能阻碍S性能的过于复杂的特征

4.实验结果与贡献

  • 数据集:CIFAR-10/100,Tiny-ImageNet
  • 高IPC设置刷新SOTA
    • CIFAR-10/100 在 10% IPC 下,分别较最佳基线提升 ~6.1% / ~5.8%;
    • Tiny-ImageNet 20% IPC 下,仅比全量训练低 0.3%。
  • 跨架构泛化
    用 ResNet-18 生成合成集,训练 ResNet-50/101、DenseNet-121、RegNet 等网络,均优于 CDA、SelMatch 等方法
  • 详尽消融
    • 验证 coarse(错分 vs 自信分对)、fine(简单 vs 困难 vs 随机)策略组合;
    • 不同难度分数对比,Forgetting score 最好;
    • 课程轮数对性能与效率影响,3 轮是良好折中。

主体代码

python 复制代码
import os
import datetime
import time
import warnings
import numpy as np
import random
import torch
import torch.utils.data
import torchvision
import utils
from torch import nn
import torchvision.transforms as transforms
from imagenet_ipc import ImageFolderIPC
import torch.nn.functional as F
from tqdm import tqdm
import json
warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")



def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(description="CCFS on CIFAR-100", add_help=add_help)

    parser.add_argument("--data-path", default=None, type=str, help="path to CIFAR-100 data folder")
    parser.add_argument("--filter-model", default="resnet18", type=str, help="filter model name")
    parser.add_argument("--teacher-model", default="resnet18", type=str, help="teacher model name")
    parser.add_argument("--teacher-path", default=None, type=str, help="path to teacher model")
    parser.add_argument("--eval-model", default="resnet18", type=str, help="model for final evaluation")
    
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument("-b", "--batch-size", default=64, type=int, help="Batch size")
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="# training epochs for both the filter and the evaluation model")
    parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 16)")
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, metavar="W", help="weight decay (default: 1e-4)", dest="weight_decay")
    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument("--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)")
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("-T", "--temperature", default=20, type=float, help="temperature for distillation loss")
    parser.add_argument("--print-freq", default=1000, type=int, help="print frequency")
    # --- CCFS parameters ---
    # 目标每类最终的图像数量IPC
    parser.add_argument("--image-per-class", default=50, type=int, help="number of synthetic images per class")
    parser.add_argument("--distill-data-path", default=None, type=str, help="path to already distilled data")

    # distillation portion,决定合成(蒸馏)图像与真实选择的比例;
    # cpc=IPC*alpha 是每类的合成图像 (condensed per class)
    # spc=IPC-cpc是每类要选的真实图像数 (selected per class)
    parser.add_argument('--alpha', type=float, default=0.2, help='Distillation portion')

    # 分几个阶段做"课程式"选样(例如3轮)
    parser.add_argument('--curriculum-num', type=int, default=None, help='Number of curricula')
    # 粗阶段式选被filter预测错的(True)还是预测对的(false)
    parser.add_argument('--select-misclassified', action='store_true', help='Selection strategy in coarse stage')
    # 细阶段的选法 ,simple/hard/random(对应论文里"最简单未学会"/"最难"/"随机")
    parser.add_argument('--select-method', type=str, default='simple', choices=['random', 'hard', 'simple'], help='Selection strategy in fine stage')

    # 是否每类均衡选
    parser.add_argument('--balance', action='store_true', help='Whether to balance the amount of the synthetic data between classes')

    # 选择哪种方法评分
    parser.add_argument('--score', type=str, default='forgetting', choices=['logits', 'forgetting', 'cscore'], help='Difficulty score used in fine stage')

    # 如果不是logits而是预先算好的难度分(如forgetting score),用这个路径读入
    parser.add_argument('--score-path', type=str, default=None, help='Path to the difficulty score')
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--num-eval", default=1, type=int, help="number of evaluations")

    return parser

def load_data(args):
    '''
    数据集加载
    Returns:
        dataset: 蒸馏数据集
        image_og, labels_og: 全量原始训练样本(用于选样)
        dataset_test: 验证集(test)
        以及对应sampler
    '''
    # Data loading code
    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                     std=[0.2023, 0.1994, 0.2010])
    print("Loading distilled data")
    train_transform = transforms.Compose([
            transforms.RandomResizedCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

    # ImageFolderIPC自定义数据读取,可以从大的ipc蒸馏数据集中每类选择或随机选择cpc个图像
    # cpc=IPC * alpha,是每类的合成图像数
    dataset = ImageFolderIPC(root=args.distill_data_path, ipc=args.cpc, transform=train_transform)
        
    print("Loading validation data")
    val_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    # 加载验证集(test)
    dataset_test = torchvision.datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=val_transform)

    print("Loading original training data")
    # 加载原始训练集(用于做coarse selection / teacher correctness等)
    dataset_og = torchvision.datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=val_transform)

    # 构造原始训练数据:直接把全量CIFAR100d的所有图像展开到一个大tensor image_og 和对应标签labels_og。
    # 这在内存允许时可行,但规模增大时可以优化成分batch处理或lazy访问
    images_og = [torch.unsqueeze(dataset_og[i][0], dim=0) for i in range(len(dataset_og))]
    labels_og = [dataset_og[i][1] for i in range(len(dataset_og))]
    images_og = torch.cat(images_og, dim=0)
    labels_og = torch.tensor(labels_og, dtype=torch.long)

    print("Creating data loaders")
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, images_og, labels_og, dataset_test, train_sampler, test_sampler


def create_model(model_name, device, num_classes, path=None):
    # 根据名称构造backbone(TODO:默认不加载预训练权重)
    model = torchvision.models.get_model(model_name, weights=None, num_classes=num_classes)
    # 将下采样第一层conv和pooling修改为适配CIFAR风格
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)
    model.maxpool = nn.Identity()

    # 加载预训练权重 (TODO:是否加载预训练权重)
    if path is not None:
        checkpoint = torch.load(path, map_location="cpu")
        if "model" in checkpoint:
            checkpoint = checkpoint["model"]
        elif "state_dict" in checkpoint:
            checkpoint = checkpoint["state_dict"]
        if "module." in list(checkpoint.keys())[0]:
            checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}
        model.load_state_dict(checkpoint)
    model.to(device)
    return model

def curriculum_arrangement(spc, curriculum_num):
    '''
    课程分配安排
    将总共要选的每类真实图像数spc等分到curriculum_num轮:
    例如spc=7,curriculum_num=3会分成[3,2,2](前面多余的向前分)
    '''
    remainder = spc % curriculum_num
    arrangement = [spc // curriculum_num] * curriculum_num
    for i in range(remainder):
        arrangement[i] += 1

    return arrangement

def train_one_epoch(model, teacher_model, criterion, optimizer, data_loader, device, epoch, args):
    """
    在一个 epoch(遍历一遍 data_loader)里,用 KL 散度蒸馏(distillation)student(model)去学习 teacher_model 的"软标签"。
    具体做法是:把 teacher 和 student 的 logits 都除以温度 T 后做 log_softmax,然后用 KLDivLoss;
    最后乘上 T^2 做梯度缩放,确保温度对 loss 的影响保持一致。
    """
    # 切换student到train模式
    model.train()
    # 切换teacher到eval模式,只做前向不更新
    teacher_model.eval()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))

    header = f"Epoch: [{epoch}]"
    for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
        start_time = time.time()
        image, target = image.to(device), target.to(device)
        # 1)teacher 前向
        teacher_output = teacher_model(image)
        # 2)student前向
        output = model(image)

        # 把Logits除以温度系数,再做log_softmax
        teacher_output_log_softmax = F.log_softmax(teacher_output/args.temperature, dim=1)
        output_log_softmax = F.log_softmax(output/args.temperature, dim=1)

        # 用KL散度计算loss,乘上T^2以抵消温度缩放带来的梯度变换
        loss = criterion(output_log_softmax, teacher_output_log_softmax) * (args.temperature ** 2)

        # 标准的 backward 流程
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 计算 student 在原始 hard label(target)上的 top1/top5 准确率
        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]

        # 更新 metric_logger 里的各项指标
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))

def evaluate(model, criterion, data_loader, device, log_suffix=""):
    """
    在测试/验证集上跑一个完整的 forward,计算交叉熵 loss 和 top1/top5 准确率。
    不做梯度更新,只做推理。
    """
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples = 0  # 累计处理样本数

    with torch.inference_mode():
        for image, target in data_loader:
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            # 前向
            output = model(image)
            # 用硬标签算交叉熵
            loss = criterion(output, target)

            # 计算 top1/top5
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
            num_processed_samples += batch_size

    # 如果是分布式,需要把各卡的样本数累加
    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

    metric_logger.synchronize_between_processes()

    return metric_logger.acc1.global_avg

def curriculum_train(current_curriculum, dst_train, test_loader, model, teacher_model, args):
    """
    对当前的"合成数据+已选真实数据" dst_train进行一次完整的filter模型训练(蒸馏学习):
    - 根据数据规模动态调整batch_size
    - 构造 DataLoader / Criterion / Optimizer / LR Scheduler(含 warmup)
    - 训练 args.epochs 轮,后 20% 轮做验证并记录最佳 acc1
    返回:训练好的 model 和最佳 top-1 准确率 best_acc1
    """

    best_acc1 = 0

    # 1. 根据dst_train(合成+真实)大小粗略选batch size
    if len(dst_train) < 50 * args.num_classes:
        args.batch_size = 32
    elif 50 * args.num_classes <= len(dst_train) < 100 * args.num_classes:
        args.batch_size = 64
    else:
        args.batch_size = 128
    
    # 2. 用随机采样器包装训练集,保证每个epoch顺序打散
    train_sampler = torch.utils.data.RandomSampler(dst_train)
    train_loader = torch.utils.data.DataLoader(
        dst_train,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True,
    )

    # 3. 损失函数:硬标签用CrossEntropy,蒸馏软标签用KLDiv
    criterion = nn.CrossEntropyLoss()
    criterion_kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

    parameters = utils.set_weight_decay(model, args.weight_decay)
    
    # 构造优化器
    opt_name = args.opt.lower()
    if opt_name.startswith("sgd"):
        optimizer = torch.optim.SGD(
            parameters,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(
            parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
        )
    elif opt_name == "adamw":
        optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")

    # 6. 构造主学习率调度器:StepLR / CosineAnnealingLR / ExponentialLR
    args.lr_scheduler = args.lr_scheduler.lower()
    if args.lr_scheduler == "steplr":
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    elif args.lr_scheduler == "cosineannealinglr":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=0.0
        )
    elif args.lr_scheduler == "exponentiallr":
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
    else:
        raise RuntimeError(
            f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
            "are supported."
        )

    # 7. 如果设置了 warmup,就把 warmup scheduler 和主 scheduler 串联
    if args.lr_warmup_epochs > 0:
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        else:
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
            )
        # milestones 指在第 args.lr_warmup_epochs 次后切换到 main_lr_scheduler
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
        )
    else:
        lr_scheduler = main_lr_scheduler
        
    # 8. 开始训练
    print("Start training on synthetic dataset...")
    start_time = time.time()
    pbar = tqdm(range(args.epochs), ncols=100)
    for epoch in pbar:
        # 每个 epoch 都调用前面写好的 train_one_epoch(KL 蒸馏)
        train_one_epoch(model, teacher_model, criterion_kl, optimizer, train_loader, args.device, epoch, args)
        # 训练完一轮后,调度学习率
        lr_scheduler.step()
        # 只在最后 20% 的轮次做验证,节省时间
        if epoch > args.epochs * 0.8:
            acc1 = evaluate(model, criterion, test_loader, device=args.device)  # 这里 evaluate 用硬标签 loss & 准确率

            # 更新 best_acc1
            if acc1 > best_acc1:
                best_acc1 = acc1
            # 在进度条上显示当前/最佳准确率
            pbar.set_description(f"Epoch[{epoch}] Test Acc: {acc1:.2f}% Best Acc: {best_acc1:.2f}%")
    print(f"Best Accuracy {best_acc1:.2f}%")
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")
    
    return model, best_acc1

def coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=True):
    """
    对全量原始训练集 images_all 用 filter 模型做一次完整的推理:
    - 如果 get_correct=True,返回"预测正确"的样本索引列表;
      否则返回"预测错误"的样本索引列表。
    - 同时返回所有样本的 raw logits(未 softmax)。
    """
    true_labels = labels_all.cpu()
    filter.eval()   # 只做前向,不更新参数
    logits = None

    # 分批推理,防止一次OOM
    for select_times in range((len(images_all)+batch_size-1)//batch_size):
        # slice出当前batch的图像
        # detach 防止梯度追溯,再搬到device
        current_data_batch = images_all[batch_size*select_times : batch_size*(select_times+1)].detach().to(args.device)

        # 前向
        batch_logits = filter(current_data_batch)

        # concatenate 到一起
        if logits == None:
            logits = batch_logits.detach()
        else:
            logits = torch.cat((logits, batch_logits.detach()),0)
    # 取每行最大值的下标作为预测标签
    predicted_labels = torch.argmax(logits, dim=1).cpu()

    # 根据get_correct 选正确或错误的索引
    target_indices = torch.where(true_labels == predicted_labels)[0] if get_correct else torch.where(true_labels != predicted_labels)[0]
    target_indices = target_indices.tolist()
    print('Acc on training set: {:.2f}%'.format(100*len(target_indices)/len(images_all) if get_correct else 100*(1-len(target_indices)/len(images_all))))
    return target_indices, logits

def selection_logits(selected_idx, teacher_correct_idx, images_all, labels_all, filter, args):
    """
    用 filter 模型的 logits 做 fine 阶段的选样:
    - teacher_correct_idx: teacher 在原始训练集上预测正确的样本索引
    - selected_idx: 已经在前几轮中选过的样本索引,避免重复
    返回当前轮要新增的选样索引列表
    """
    batch_size = 512
    true_labels = labels_all.cpu()
    filter.eval()
    print('Coarse Filtering...')

    # --- Coarse 阶段:决定哪些样本进入fine阶段
    # 如果select_misclassified=True,就filter"预测错误"的样本
    if args.select_misclassified:
        target_indices, logits = coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=False)
    else:
        target_indices, logits = coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=True)

    # ------ 交叉过滤:只保留 teacher 也预测正确的样本,且去除已选过的
    # teacher_correct_idx 是 teacher 在原始训练集上预测正确的索引(论文里是"只有 teacher 也能正确的样本才考虑"这一类过滤)。
    if teacher_correct_idx is not None:
        # 取 teacher_correct_idx 与 target_indices 的交集,再减去 selected_idx
        target_indices = list(set(teacher_correct_idx) & set(target_indices) - set(selected_idx))
    else:
        target_indices = list(set(target_indices) - set(selected_idx))
    print('Fine Selection...')
    selection = []
    if args.balance:
        # 如果要 class-balance,每个类单独选 args.curpc 个
        target_idx_per_class = [[] for c in range(args.num_classes)]
        for idx in target_indices:
            target_idx_per_class[true_labels[idx]].append(idx)
        for c in range(args.num_classes):
            if args.select_method == 'random':
                # 随机抽样
                selection += random.sample(target_idx_per_class[c], args.curpc)
            elif args.select_method == 'hard':
                # 按 logits[c] 升序,logit 越低表示模型越"不自信" ⇒ "更难"
                selection += sorted(target_idx_per_class[c], key=lambda i: logits[i][c], reverse=False)[:args.curpc]
            elif args.select_method == 'simple':
                # 按 logits[c] 降序,logit 越高表示模型越"自信" ⇒ "简单"
                selection += sorted(target_idx_per_class[c], key=lambda i: logits[i][c], reverse=True)[:args.curpc]
    else:
        # 不做 class-balance,直接在所有 target_indices 中选总数 = curpc * num_classes
        if args.select_method == 'random':
            selection = random.sample(target_indices, args.curpc*args.num_classes)
        elif args.select_method == 'hard':
            selection = sorted(target_indices, key=lambda i: logits[i][true_labels[i]], reverse=False)[:args.curpc*args.num_classes]
        elif args.select_method == 'simple':
            selection = sorted(target_indices, key=lambda i: logits[i][true_labels[i]], reverse=True)[:args.curpc*args.num_classes]

    return selection

def selection_score(selected_idx, teacher_correct_idx, images_all, labels_all, filter, score, reverse, args):
    """
    用预先计算好的difficult score 做fine阶段的选样:
    - score: numpy array, score[i]表示样本i的难度分数
    - reverse: bool
    其余流程同 selection_logits,只是排序依据改为 score
    """

    batch_size = 512
    true_labels = labels_all.cpu()
    filter.eval()
    
    print('Coarse Filtering...')

    # Coarse 阶段同上
    if args.select_misclassified:
        target_indices, _ = coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=False)
    else:
        target_indices, _ = coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=True)
    
    # 交叉过滤 teacher_correct_idx & 去除已选
    if teacher_correct_idx is not None:
        target_indices = list(set(teacher_correct_idx) & set(target_indices) - set(selected_idx))
    else:
        target_indices = list(set(target_indices) - set(selected_idx))
    print('Fine Selection...')
    selection = []
    if args.balance:
        target_idx_per_class = [[] for c in range(args.num_classes)]
        for idx in target_indices:
            target_idx_per_class[true_labels[idx]].append(idx)
        for c in range(args.num_classes):
            if args.select_method == 'random':
                selection += random.sample(target_idx_per_class[c], min(args.curpc, len(target_idx_per_class[c])))
            elif args.select_method == 'hard':
                selection += sorted(target_idx_per_class[c], key=lambda i: score[i], reverse=reverse)[:args.curpc]
            elif args.select_method == 'simple':
                # 用外部 score(预先计算的 difficulty)
                selection += sorted(target_idx_per_class[c], key=lambda i: score[i], reverse=not reverse)[:args.curpc]
    else:
        if args.select_method == 'random':
            selection = random.sample(target_indices, min(args.curpc*args.num_classes, len(target_indices)))
        elif args.select_method == 'hard':
            selection = sorted(target_indices, key=lambda i: score[i], reverse=reverse)[:args.curpc*args.num_classes]
        elif args.select_method == 'simple':
            selection = sorted(target_indices, key=lambda i: score[i], reverse=not reverse)[:args.curpc*args.num_classes]
            
    return selection

def main(args):
    '''Preparation'''
    print('=> args.output_dir', args.output_dir)
    start_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    log_dir = os.path.join(args.output_dir, 'CIFAR-100', start_time)
    os.makedirs(log_dir, exist_ok=True)
    
    device = torch.device(args.device)
    if device.type == 'cuda':
        print('Using GPU')
        torch.backends.cudnn.benchmark = True

    # 计算cpc(合成每类)和spc(要从真实数据选的每类数)
    args.cpc = int(args.image_per_class * args.alpha)   # condensed images per class
    args.spc = args.image_per_class - args.cpc          # selected real images per class
    args.num_classes = 100
    print('Target IPC: {}, num_classes: {}, distillation portion: {}, distilled images per class: {}, real images to be selected per class: {}'
        .format(args.image_per_class, args.num_classes, args.alpha, args.cpc, args.spc))

    # 加载数据
    dataset_dis, images_og, labels_og, dataset_test, train_sampler, test_sampler = load_data(args)

    # 加载difficulty score
    if args.score == 'forgetting':
        score = np.load(args.score_path)
        reverse = True
    elif args.score == 'cscore':
        score = np.load(args.score_path)
        reverse = False
    curriculum_num = args.curriculum_num

    # 构造curriculum_arrangement分配:每轮要选多少个真实样本
    arrangement = curriculum_arrangement(args.spc, curriculum_num)

    # 加载测试集
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=512, sampler=test_sampler, num_workers=args.workers, pin_memory=True)

    # 加载教师模型
    teacher_model = create_model(args.teacher_model, device, args.num_classes, args.teacher_path)

    # 冻结教师模型参数
    for p in teacher_model.parameters():
        p.requires_grad = False
    teacher_model.eval()

    # 使用教师模型在原始数据集上做一次初筛,只有teacher预测对的样本才可能被选入
    teacher_correct_idx, _ = coarse_filtering(images_og, labels_og, teacher_model, 512, args, get_correct=True)
    print('teacher acc@1 on original training data: {:.2f}%'.format(100*len(teacher_correct_idx)/len(images_og)))

    '''Curriculum selection'''
    idx_selected = []
    dataset_sel = None
    dst_sel_transform = transforms.Compose([
        transforms.RandomResizedCrop(32),
        transforms.RandomHorizontalFlip(),
    ])
    print('Selected images per class arrangement in each curriculum: ', arrangement)

    # 开始课程学习
    for i in range(curriculum_num):
        print('----Curriculum [{}/{}]----'.format(i+1, curriculum_num))
        args.curpc = arrangement[i]
        # 第0轮以蒸馏合成集为起点
        if i == 0:
            print('Begin with distilled dataset')
            syn_dataset = dataset_dis
            dataset_sel = []
        
        print('Synthetic dataset size:', len(syn_dataset), "distilled data:", len(dataset_dis), "selected data:", len(dataset_sel))
        # 训练一个新的filter(每轮都从头开始训练)
        filter = create_model(args.filter_model, device, args.num_classes)
        # TODO:课程训练,教师模型打软标签
        filter, best_acc1 = curriculum_train(i, syn_dataset, test_loader, filter, teacher_model, args)

        print('Selecting real data...')
        if args.score == 'logits':
            selection = selection_logits(idx_selected, teacher_correct_idx, images_og, labels_og, filter, args)
        else:
            selection = selection_score(idx_selected, teacher_correct_idx, images_og, labels_og, filter, score, reverse, args)
        idx_selected += selection
        print('Selected {} in this curriculum'.format(len(selection)))
        imgs_select = images_og[idx_selected]
        labs_select = labels_og[idx_selected]
        dataset_sel = utils.TensorDataset(imgs_select, labs_select, dst_sel_transform)
        syn_dataset = torch.utils.data.ConcatDataset([dataset_dis, dataset_sel])
    print('----All curricula finished----')
    print('Final synthetic dataset size:', len(syn_dataset), "distilled data:", len(dataset_dis), "selected data:", len(dataset_sel))   

    print('Saving selected indices...')
    idx_file = os.path.join(log_dir, f'selected_indices.json')
    with open(idx_file, 'w') as f:
        json.dump({'ipc': args.image_per_class,
                   'alpha': args.alpha, 
                   'idx_selected': idx_selected}, f)
    f.close()

    '''Final evaluation'''
    num_eval = args.num_eval
    accs = []
    for i in range(num_eval):
        print(f'Evaluation {i+1}/{num_eval}')
        eval_model = create_model(args.eval_model, device, args.num_classes)
        _, best_acc1 = curriculum_train(0, syn_dataset, test_loader, eval_model, teacher_model, args)
        accs.append(best_acc1)
    acc_mean = np.mean(accs)
    acc_std = np.std(accs)
    print('----Evaluation Results----')
    print(f'Acc@1(mean): {acc_mean:.2f}%, std: {acc_std:.2f}')

    print('Saving results...')
    log_file = os.path.join(log_dir, f'exp_log.txt')
    with open(log_file, 'w') as f:
        f.write('EXP Settings: \n')
        f.write(f'IPC: {args.image_per_class},\tdistillation portion: {args.alpha},\tcurriculum_num: {args.curriculum_num}\n')
        f.write(f'filter model: {args.filter_model},\tteacher model: {args.teacher_model},\tbatch_size: {args.batch_size},\tepochs: {args.epochs}\n')
        f.write(f"coarse stage strategy: {'select misclassified' if args.select_misclassified else 'select correctly classified'}\n")
        f.write(f'fine stage strategy: {args.select_method},\tdifficulty score: {args.score},\tbalance: {args.balance}\n')
        f.write(f'eval model: {args.eval_model},\tAcc@1: {acc_mean:.2f}%,\tstd: {acc_std:.2f}\n')
    f.close()

if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)

算法整体逻辑

算法输入

  • 蒸馏合成集 D d i s t i l l D_{distill} Ddistill: 已经经过某种蒸馏算法生成的小规模"合成"数据集,每类包含 C P C = I P C × α CPC=IPC×\alpha CPC=IPC×α。
  • 原始训练集(image_og,labels_og):完整的CIFAR100训练样本,用于选样。
  • 教师模型 ϕ t e a c h e r ϕ_{teacher} ϕteacher: 在原始训练集上表现优秀的固定模型,用于提供"正确"的软标签参考。
  • 超参数: I P C IPC IPC(每类最终图数)、 α \alpha α(蒸馏比例)、课程轮数 J J J、粗筛策略、精筛策略、难度分数类型等。

整体流程

  1. 初始化

    • 解析命令行参数,计算: c p c = ⌊ I P C × α ⌋ cpc=⌊IPC×α⌋ cpc=⌊IPC×α⌋, s p c = I P C × c p c spc=IPC×cpc spc=IPC×cpc。
    • 加载 D d i s t i l l D_{distill} Ddistill、原始训练集和验证集。
    • 加载并冻结教师模型 ϕ t e a c h e r ϕ_{teacher} ϕteacher,在原始训练集上做一次推理,记录教师预测正确的索引集合 I t e a c h e r I_{teacher} Iteacher。
  2. 课程分配

    将每类总共要选的 s p c spc spc张真实图,均匀分配到 J J J轮: [ k 1 , k 2 , ... , k J ] , ∑ j k j = s p c [k_1,k_2,...,k_J],\sum_{j}{k_j} =spc [k1,k2,...,kJ],∑jkj=spc

  3. 多轮"粗-细"选样循环

    令当前集合 S 0 = D d i s t i l l S_0=D_{distill} S0=Ddistill,已选索引集合 I s e l = ∅ I_{sel}=∅ Isel=∅

    对每个课程阶段 j = 1... J j=1...J j=1...J:

    i. 蒸馏训练Filter:

    • 在 S j − 1 S_{j-1} Sj−1上,用教师模型的"软标签"蒸馏训练一个新的filter模型 ϕ j ϕ_j ϕj

    ii. Coarse(粗过滤)

    • 用 ϕ j ϕ_j ϕj在整个原始训练集上做推理,得到所有样本的logits和预测标签。
    • 根据 select_misclassified 决定保留"错分"样本索引,或保留"分对"样本索引,记为候选集 C C C。
    • 交叉过滤:仅保留既在 C C C 中、又在 I t e a c h e r I_{teacher} Iteacher中,且不在 I s e l I_{sel} Isel 中的索引。

    iii. Fine(精细选择)

    • 在上述候选索引里,依据"logits"或外部 pre-computed difficulty score,对每个索引排序,规则可以是simple、hard、random
    • 如选 --balance,则每类各取 k j k_j kj张;否则全体一并取总数 k j × n u m _ c l a s s e s k_j×num\_classes kj×num_classes。
    • 将本轮选中的新索引加入 I s e l I_{sel} Isel。

    iv. 更新合成集
    S j ← D d i s t i l l ∪ { 真实样本 i : i ∈ I s e l } . S_j←D_{distill}∪\{真实样本 i:i∈I_{sel}\}. Sj←Ddistill∪{真实样本i:i∈Isel}.

  4. 保存 & 最终评估

    • 将 I s e l I_{sel} Isel输出到Json以供后续复现
    • 用最终的混合集 S j S_j Sj训练一个新的evaluation模型,测多次Top-1准确率,取均值与标准差。

这个流程确保每轮都针对模型"真正没有学到"的部分,有序补充,最终合成集既覆盖常见知识,也涵盖关键难点

相关推荐
AKAMAI27 分钟前
运维逆袭志·第1期 | 数据黑洞吞噬一切 :自建系统的美丽陷阱
运维·人工智能·云计算
飞哥数智坊1 小时前
AI编程实战:AI要独立开发了?TRAE SOLO 后端生成能力深度实测
人工智能·trae
SamtecChina20232 小时前
应用科普 | 漫谈6G通信的未来
大数据·网络·人工智能·科技
Java与Android技术栈2 小时前
LLM + 图像处理的第一步:用自然语言驱动调色逻辑
图像处理·人工智能
F_D_Z2 小时前
计算机视觉的四项基本任务辨析
人工智能·计算机视觉
LetsonH2 小时前
⭐CVPR2025 MatAnyone:稳定且精细的视频抠图新框架
人工智能·python·深度学习·计算机视觉·音视频
奈斯。zs2 小时前
JavaWeb02——基础标签及样式(黑马视频笔记)
前端·笔记·html
dundunmm2 小时前
【论文阅读】ACE: Explaining cluster from an adversarial perspective
论文阅读
Olrookie2 小时前
若依前后端分离版学习笔记(五)——Spring Boot简介与Spring Security
笔记·后端·学习·spring·ruoyi
格林威2 小时前
Baumer相机如何通过YoloV8深度学习模型实现工厂自动化产线牛奶瓶盖实时装配的检测识别(C#代码UI界面版)
人工智能·深度学习·数码相机·yolo·机器学习·计算机视觉·c#