论文阅读笔记:《Dataset Condensation with Distribution Matching》

论文阅读笔记:《Dataset Condensation with Distribution Matching》

WACV23 github

核心思想一句话总结:

用少量可学习的合成图像,通过多组随机网络上的分布匹配(MMD),高效地"蒸馏"出与原始大数据集等价的训练集。


1.解决了什么问题?(Motivation)

训练大型数据集耗时且昂贵,现有"核心集"只能数据、"蒸馏"常需双层优化都各有局限。本工作旨在:

  • 用少量合成图像(每类几十到几百)
  • 保持模型在测试集上的性能
  • 且避免繁重的bi-level优化

2.关键方法与创新点(Key Method & Innovation)

  • 分布匹配视角 :首次用最大均值差异(MMD)在特征空间对齐合成与真实数据分布,而非仅作子集选择或梯度匹配。
  • 随机网络嵌入 :不用预训练模型,随机初始化多个同构网络 ψ θ ψ_θ ψθ作为多种"看法",增强合成数据集的泛化。
  • 单层优化 :只对合成图像本身求梯度、SGD更新网络权重固定,省去双层优化开销。
  • 可微分西雅姆增强 (DSA):对真实和合成样本做相同随机变换,提升分布估计稳定性。

3.实验结果与贡献 (Experiments & Contributions)

  • 在 CIFAR-10/100、TinyImageNet、ImageNet-1K 上:
    • 每类仅 10--50 张合成图即可训练出接近原始数据的模型精度(如 CIFAR-10 10 张时 ≈70%+)。
    • 合成速度比 Gradient Matching 提升 ∼45×。
  • 下游任务验证:
    • 持续学习:更小的记忆库即可保持准确率。
    • 神经架构搜索:用代理合成集显著加速搜索且不损失性能。
  • 开源代码与可视化结果:每隔若干迭代保存合成图像演化,便于直观对比。

4.个人思考与启发

  • 高效vs. 代表性 :只匹配特征均值简单有效,但或许忽略高阶统计和类内多样性
  • 生成质量 vs. 训练效果:无需最求"图像好看",只要"训练有用";但在某些任务中是否要兼顾真实的视觉特征。

主体代码

python 复制代码
''' initialize the synthetic data '''
image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, deviceargs=.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

if args.init == 'real':
    print('initialize synthetic data from random real images')
    for c in range(num_classes):
        image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
        else:
            print('initialize synthetic data from random noise')


            ''' training '''
            # 只更新image_syn
            optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
            optimizer_img.zero_grad()
            print('%s training begins'%get_time())

            for it in range(args.Iteration+1):

                ''' Evaluate synthetic data '''
                if it in eval_it_pool:
                    for model_eval in model_eval_pool:
                        print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))

                        print('DSA augmentation strategy: \n', args.dsa_strategy)
                        print('DSA augmentation parameters: \n', args.dsa_param.__dict__)

                        accs = []
                        for it_eval in range(args.num_eval):
                            # 每次用新的随机初始化网络来评估合成集
                            net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
                            # 深拷贝合成数据集
                            image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
                            # 测试与评估
                            _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                            accs.append(acc_test)
                            print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))

                            if it == args.Iteration: # record the final results
                                accs_all_exps[model_eval] += accs

                                ''' visualize and save '''
                                save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
                                image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                                for ch in range(channel):
                                    image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                                    image_syn_vis[image_syn_vis<0] = 0.0
                                    image_syn_vis[image_syn_vis>1] = 1.0
                                    save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.



                                    ''' Train synthetic data '''
                                    # --- 用当前合成数据计算损失并更新(核心:分布匹配) ---

                                    # 新的随机网络(视角embedding)
                                    net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
                                    net.train()
                                    # 合成数据训练时冻结网络参数(只优化合成图像)
                                    for param in list(net.parameters()):
                                        param.requires_grad = False

                                        # 多GPU支持,如果使用了DataParallel,embed在module下面
                                        embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel

                                        loss_avg = 0  # 记录各类 loss 平均(后面除法)


                                        ''' update synthetic data '''
                                        # --- 计算合成图像和真实图像在embedding space 上的均值差(即MMD的简化版本)---
                                        if 'BN' not in args.model: # for ConvNet 没有batch norm的网络
                                            loss = torch.tensor(0.0).to(args.device)
                                            for c in range(num_classes):
                                                # 每类分别取真实图和合成图
                                                img_real = get_images(c, args.batch_real)
                                                img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

                                                # 可微分增强(DSA):对real/syn做同样的随机变换以稳定分布估计
                                                if args.dsa:
                                                    seed = int(time.time() * 1000) % 100000  # 保证 real 和 syn 用同样的 seed
                                                    img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                                                    img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                                                    # 投影到embedding空间
                                                    output_real = embed(img_real).detach() # 真实特征不方向传播
                                                    output_syn = embed(img_syn)  # 合成特征是要参与梯度的

                                                    # 均值匹配(特征均值之差平方和)
                                                    loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)

                                                    else: # for ConvNetBN   BatchNorm 的 ConvNe
                                                        images_real_all = []
                                                        images_syn_all = []
                                                        loss = torch.tensor(0.0).to(args.device)
                                                        for c in range(num_classes):
                                                            img_real = get_images(c, args.batch_real)
                                                            img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

                                                            if args.dsa:
                                                                seed = int(time.time() * 1000) % 100000
                                                                img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                                                                img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                                                                images_real_all.append(img_real)
                                                                images_syn_all.append(img_syn)

                                                                # 把每类真实/合成拼成一个大 batch,送进 embedding 一次得到所有类的特征
                                                                images_real_all = torch.cat(images_real_all, dim=0)
                                                                images_syn_all = torch.cat(images_syn_all, dim=0)

                                                                output_real = embed(images_real_all).detach()
                                                                output_syn = embed(images_syn_all)

                                                                # reshape 以便按类计算均值,再做平方差累加
                                                                loss += torch.sum((torch.mean(output_real.reshape(num_classes, args.batch_real, -1), dim=1) - torch.mean(output_syn.reshape(num_classes, args.ipc, -1), dim=1))**2)


                                                                # 梯度累积与更新 synthetic images
                                                                optimizer_img.zero_grad()
                                                                loss.backward()
                                                                optimizer_img.step()
                                                                loss_avg += loss.item()


                                                                loss_avg /= (num_classes)        # 梯度累积与更新 synthetic images


                                                                if it%10 == 0:
                                                                    print('%s iter = %05d, loss = %.4f' % (get_time(), it, loss_avg))

                                                                    if it == args.Iteration: # only record the final results
                                                                        data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
                                                                        torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))

算法逻辑总结

  1. 准备:假设要蒸馏一个3类数据集,每类只想保留5张合成图。
  2. 多次"看法" :每次随机初始化一个小网络,把真实图和合成图都送进去提取特征。
  3. 测差异:对每个类别,计算真实图和和冲突在该网络特征空间的平均差距。
  4. 更新合成图 :把所有类别的平均差距累加成一个损失,反向梯度作用到图像像素上,轻微调整它们,让下次"看"更像真实图。
  5. 重复:多次切换网络、多次迭代,合成图不断逼近真实数据的"分布"。
  6. 评估:在最终合成图训练几个随机网络,验证它们在测试集上的准确率,确认蒸馏效果
相关推荐
Godspeed Zhao10 分钟前
自动驾驶中的传感器技术7——概述(7)-IMU
人工智能·机器学习·自动驾驶·传感器·imu·惯性导航
数据智研11 分钟前
【数据分享】各省粮食外贸依存度、粮食波动率等粮食相关数据合集(2011-2022)(获取方式看文末)
大数据·人工智能
W.KN20 分钟前
Spring 学习笔记
笔记·学习·spring
Blossom.11840 分钟前
基于深度学习的医学图像分析:使用PixelRNN实现医学图像超分辨率
c语言·人工智能·python·深度学习·yolo·目标检测·机器学习
小小洋洋43 分钟前
笔记:C语言中指向指针的指针作用
c语言·开发语言·笔记
摘星编程1 小时前
MCP革命:Anthropic如何重新定义AI与外部世界的连接标准
人工智能·ai·anthropic·mcp·ai连接标准
陈敬雷-充电了么-CEO兼CTO1 小时前
从游戏NPC到手术助手:Agent AI重构多模态交互,具身智能打开AGI新大门
人工智能·深度学习·算法·chatgpt·重构·transformer·agi
测试者家园1 小时前
Browser-Use在UI自动化测试中的应用
自动化测试·软件测试·人工智能·llm·ui自动化测试
deephub1 小时前
NSA稀疏注意力深度解析:DeepSeek如何将Transformer复杂度从O(N²)降至线性,实现9倍训练加速
人工智能·深度学习·transformer·deepseek·稀疏注意力
Virgil1391 小时前
【DL学习笔记】计算图与自动求导
笔记·学习