DiVE长尾识别的虚拟实例蒸馏方法

一 原理解析

关于长尾识别中的 DiVE 方法,根据搜索结果,其核心思想是从知识蒸馏的视角出发,通过将教师模型的预测视为虚拟样本 ,并调整这些样本的分布来改善模型在尾类上的识别性能-1

为了让你能快速把握要点,我先用一个表格来梳理 DiVE 方法的核心思路与关键设计:

维度 核心思想 关键设计 解决的问题
基本思路 将教师模型对输入图像的预测视为"虚拟样本"-1 例如,一张狗的图片预测为(0.7狗, 0.3猫),则生成0.7个狗的虚拟样本和0.3个猫的虚拟样本 -1 尾部类别样本稀少,模型难以学习有效特征。
核心机制 从虚拟样本中蒸馏知识,这在一定约束下等效于标签分布学习 -1 通过知识蒸馏,将教师模型捕获的类别间关系迁移给学生模型-1 传统方法(如重采样、重加权)缺乏类别间交互。
分布调整 明确调整虚拟样本分布使其更平坦 ,降低头部类别权重,提升尾部类别影响-1 使虚拟样本分布比原始输入分布更平坦-1 长尾数据集中头部类别主导模型训练。

方法理解与启示

理解 DiVE 方法,可以注意以下几点:

  • 虚拟样本的本质 :它并非生成新的像素数据,而是利用教师模型预测的软标签 作为额外的监督信号。这些软标签包含了类别间的相关性(例如"狗"和"猫"在某些特征上的相似与差异),为学生模型提供了比原始独热编码更丰富的指导-1

  • 为何有效 :在长尾分布中,尾类样本少,模型从中学习到的特征往往不充分。DiVE 方法通过教师模型,让头部类别丰富的样本也能为识别尾类"贡献"一部分知识(如前文例子中狗的图片为识别猫贡献了0.3的虚拟样本)-1。通过平坦化的虚拟样本分布,间接提升了尾类在训练中的"话语权",从而缓解模型对头类的偏见。

  • 与相关方法对比 :一些研究也从不同角度改进蒸馏以应对长尾问题。例如,DeiT-LT 针对ViT模型,使用分布外图像 进行蒸馏,并让不同的标记分别专注于头类和尾类-5-7SSD 方法则引入了自监督学习 来辅助生成更好的蒸馏标签-6-9。DiVE 的核心区别在于其"虚拟样本"的构建与分布的显式平坦化调整。

二 代码实现

import torch

import torch.nn as nn

import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

import numpy as np

class DiVEDistillation(nn.Module):

"""

DiVE (Distillation with Virtual Examples) 方法的概念实现

核心思想:通过调整教师模型预测分布来生成平衡的虚拟样本

"""

def init(self, teacher_model, student_model, num_classes, alpha=0.7, temperature=3.0):

super(DiVEDistillation, self).init()

self.teacher = teacher_model

self.student = student_model

self.num_classes = num_classes

self.alpha = alpha # 蒸馏损失权重

self.temperature = temperature # 温度参数

类别权重 - 用于分布平坦化

self.class_weights = None

def compute_class_weights(self, data_loader):

"""计算类别权重以实现分布平坦化"""

class_counts = torch.zeros(self.num_classes)

统计每个类别的样本数

for _, targets in data_loader:

for class_idx in range(self.num_classes):

class_counts[class_idx] += (targets == class_idx).sum()

计算权重:样本数越少,权重越高

weights = 1.0 / (class_counts + 1e-8)

weights = weights / weights.sum() * self.num_classes # 归一化

self.class_weights = weights

print(f"Class counts: {class_counts}")

print(f"Class weights: {self.class_weights}")

def flatten_distribution(self, teacher_logits, targets):

"""

核心方法:平坦化教师模型的预测分布

增加尾部类别的权重,减少头部类别的影响

"""

batch_size = teacher_logits.size(0)

if self.class_weights is None:

return teacher_logits

应用温度调节

teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)

应用类别权重进行平坦化

weight_matrix = self.class_weights.unsqueeze(0).expand(batch_size, -1)

flattened_probs = teacher_probs * weight_matrix

flattened_probs = flattened_probs / flattened_probs.sum(dim=1, keepdim=True)

return flattened_probs

def forward(self, images, targets):

"""

前向传播:结合真实标签和虚拟样本进行训练

"""

教师模型预测(不更新梯度)

with torch.no_grad():

teacher_logits = self.teacher(images)

virtual_probs = self.flatten_distribution(teacher_logits, targets)

学生模型预测

student_logits = self.student(images)

计算交叉熵损失(真实标签)

ce_loss = F.cross_entropy(student_logits, targets)

计算蒸馏损失(虚拟样本)

distill_loss = F.kl_div(

F.log_softmax(student_logits / self.temperature, dim=1),

virtual_probs,

reduction='batchmean'

) * (self.temperature ** 2)

组合损失

total_loss = (1 - self.alpha) * ce_loss + self.alpha * distill_loss

return {

'total_loss': total_loss,

'ce_loss': ce_loss,

'distill_loss': distill_loss,

'virtual_probs': virtual_probs.detach()

}

示例:简单的CNN模型

class SimpleCNN(nn.Module):

def init(self, num_classes=10):

super(SimpleCNN, self).init()

self.features = nn.Sequential(

nn.Conv2d(3, 32, 3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2),

nn.Conv2d(32, 64, 3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2),

)

self.classifier = nn.Linear(64 * 8 * 8, num_classes)

def forward(self, x):

x = self.features(x)

x = x.view(x.size(0), -1)

return self.classifier(x)

模拟长尾数据集

class LongTailDataset(Dataset):

def init(self, num_samples=1000, num_classes=10):

self.num_classes = num_classes

创建长尾分布:第一个类别样本最多,最后一个类别样本最少

samples_per_class = []

for i in range(num_classes):

samples = int(num_samples * (0.5 ** i))

samples_per_class.append(max(samples, 10)) # 每个类别至少10个样本

self.data = []

self.targets = []

for class_idx, num_samples in enumerate(samples_per_class):

for _ in range(num_samples):

模拟图像数据 (3, 32, 32)

img = torch.randn(3, 32, 32)

self.data.append(img)

self.targets.append(class_idx)

print(f"Created long-tail dataset: {samples_per_class}")

def len(self):

return len(self.data)

def getitem(self, idx):

return self.data[idx], self.targets[idx]

训练示例

def train_dive_example():

初始化模型和数据

num_classes = 5

teacher_model = SimpleCNN(num_classes)

student_model = SimpleCNN(num_classes)

dataset = LongTailDataset(num_samples=1000, num_classes=num_classes)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

初始化DiVE

dive = DiVEDistillation(teacher_model, student_model, num_classes)

dive.compute_class_weights(dataloader)

optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)

训练循环

for epoch in range(3): # 示例中只训练3个epoch

for batch_idx, (images, targets) in enumerate(dataloader):

optimizer.zero_grad()

DiVE前向传播

losses = dive(images, targets)

反向传播

losses['total_loss'].backward()

optimizer.step()

if batch_idx % 10 == 0:

print(f'Epoch: {epoch}, Batch: {batch_idx}, '

f'Total Loss: {losses["total_loss"].item():.4f}, '

f'CE Loss: {losses["ce_loss"].item():.4f}, '

f'Distill Loss: {losses["distill_loss"].item():.4f}')

打印虚拟样本分布示例

if batch_idx == 0:

virtual_probs = losses['virtual_probs'][0]

print(f"Virtual probs example: {virtual_probs}")

if name == "main":

train_dive_example()

关键理解要点

  1. 虚拟样本本质:不是真实像素,而是软标签形式的概率分布

  2. 分布平坦化:通过类别权重调整,让尾部类别获得更多关注

  3. 知识蒸馏:教师模型的类别关系知识迁移到学生模型

  4. 损失组合:平衡真实标签监督和虚拟样本蒸馏

参考文献:

《Distilling Virtual Examples for Long-tailed Recognition》

相关推荐
是Dream呀几秒前
一个账号调用N个AI模型!从LLM到视频生成的丝滑解决方案
人工智能·大模型·aigc·音视频·deepseek
2301_797267344 分钟前
神经网络组植物分类学习规划与本周进展综述15
人工智能·神经网络·学习
xuehaikj5 分钟前
【实战案例】基于dino-4scale_r50_8xb2-36e_coco的棉田叶片病害识别与分类项目详解
人工智能·数据挖掘
算法与编程之美7 分钟前
探索不同的优化器、损失函数、batch_size对分类精度影响
人工智能·机器学习·计算机视觉·分类·batch
MicrosoftReactor15 分钟前
技术速递|GitHub Copilot 和 AI Agent 如何拯救传统系统
人工智能·github·copilot·agent
啊我不会诶16 分钟前
01BFS学习笔记
笔记·学习·算法
Ch_ty22 分钟前
leetcode解题思路分析(一百六十八)1452 - 1458 题
算法·leetcode·哈希算法
哼?~23 分钟前
算法学习--离散化
算法
only-code26 分钟前
SeqXGPT:Sentence-Level AI-Generated Text Detection —— 把大模型的“波形”变成测谎仪
人工智能·大语言模型·ai检测·文本检测