在长尾图像分类中,结构化知识蒸馏的核心思想是:它不仅让学生模型模仿教师模型的最终输出(标签),更重要的是模仿其内部学到的结构化知识,例如特征空间的几何关系、类间相似性、数据分布的形状等。这对于长尾问题至关重要,因为尾类样本稀少,难以从数据中直接学习鲁棒特征,但一个预训练好的教师模型(即使在均衡数据上训练)其特征空间已蕴含了丰富的、可迁移的类别间结构化关系。
一、核心原理:超越标签的"关系"学习 与标准知识蒸馏(KD)仅匹配教师和学生的输出概率分布不同,结构化知识蒸馏旨在迁移更深层的知识表示。在长尾场景下,其优势尤为突出: 缓解尾类表征学习不足:尾类样本少,其统计特征(如均值、方差)估计不准。结构化知识让学生模型从教师处学习一个更稳健、更具判别性的特征空间结构,而不仅仅记忆少数尾类样本。 迁移类别语义关系:教师模型能捕捉到"猫和老虎更相似,猫和汽车不相似"这样的语义关系。结构化蒸馏让学生继承这种关系先验,即使尾类"猎豹"样本很少,学生也能因其与"猫"的相似性而将其正确归类。 增强特征判别性:通过匹配特征分布的形状(如类内紧凑、类间分离),学生模型能学习到更具判别力的特征,减少将尾类误判为视觉相似的头部类别。
二、代码示例:结构化知识蒸馏(特征分布+关系知识) 以下是一个结合了特征分布对齐和类间关系对齐的PyTorch实现。我们使用一种在长尾数据上有效的策略:类平衡采样来构建批次,并结合特征分布损失与结构化关系损失(三元组损失)。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import numpy as np
class StructuredKnowledgeDistillationLoss(nn.Module):
"""
结构化知识蒸馏损失模块。
结合:1. 特征分布对齐损失 (L2/MMD) 2. 类间关系对齐损失 (三元组损失)
"""
def init(self, feat_dim, num_classes, temperature=3.0, alpha=1.0, beta=1.0, margin=1.0):
super().init()
self.temperature = temperature
self.alpha = alpha # 传统KD损失权重
self.beta = beta # 特征分布损失权重
self.margin = margin # 三元组损失边际
self.num_classes = num_classes
可学习的投影头,将学生特征映射到与教师特征可比的空间
self.projection = nn.Sequential(
nn.Linear(feat_dim, 512),
nn.ReLU(),
nn.Linear(512, feat_dim)
)
用于存储类别原型(均值),动量更新
self.register_buffer('teacher_prototypes', torch.zeros(num_classes, feat_dim))
self.register_buffer('student_prototypes', torch.zeros(num_classes, feat_dim))
def forward(self, student_logits, teacher_logits, student_feat, teacher_feat, labels):
"""
计算总的结构化蒸馏损失。
Args:
student_logits: 学生模型分类头输出 [B, C]
teacher_logits: 教师模型分类头输出 [B, C]
student_feat: 学生模型backbone输出的特征 [B, D]
teacher_feat: 教师模型backbone输出的特征 [B, D]
labels: 真实标签 [B]
Returns:
total_loss: 总损失
loss_dict: 各损失分量字典
"""
loss_dict = {}
1. 传统知识蒸馏损失 (软化标签)
soft_targets = F.softmax(teacher_logits / self.temperature, dim=1)
soft_prob = F.log_softmax(student_logits / self.temperature, dim=1)
kd_loss = F.kl_div(soft_prob, soft_targets.detach(), reduction='batchmean') * (self.temperature ** 2)
loss_dict['kd_loss'] = kd_loss
2. 特征分布结构化损失 (使用投影后的特征)
projected_student_feat = self.projection(student_feat)
2.1 基于L2的特征匹配损失
feat_l2_loss = F.mse_loss(F.normalize(projected_student_feat, dim=1),
F.normalize(teacher_feat, dim=1))
loss_dict['feat_l2_loss'] = feat_l2_loss
2.2 可选:基于MMD的特征分布匹配损失(对长尾更鲁棒)
mmd_loss = self.compute_mmd_loss(projected_student_feat, teacher_feat)
loss_dict['feat_mmd_loss'] = mmd_loss
feat_dist_loss = feat_l2_loss + 0.1 * mmd_loss
feat_dist_loss = feat_l2_loss
loss_dict['feat_dist_loss'] = feat_dist_loss
3. 类间关系结构化损失 (使用三元组损失对齐结构)
目标:让学生特征空间中,样本与同类原型、异类原型的关系,与教师特征空间一致
triplet_loss = self.structured_triplet_loss(projected_student_feat, teacher_feat, labels)
loss_dict['triplet_loss'] = triplet_loss
4. 更新类别原型 (动量更新)
self.update_prototypes(teacher_feat, student_feat, labels)
组合总损失
total_loss = self.alpha * kd_loss + self.beta * feat_dist_loss + triplet_loss
loss_dict['total_loss'] = total_loss
return total_loss, loss_dict
def structured_triplet_loss(self, student_feat, teacher_feat, labels):
"""
结构化三元组损失:让学生特征间的相对距离关系模仿教师特征间的距离关系。
这里实现一个简化版本:为每个样本寻找教师空间中的最难正/负样本对,然后让学生模仿这个关系。
"""
batch_size = student_feat.shape[0]
if batch_size < 2:
return torch.tensor(0.0, device=student_feat.device)
归一化特征
teacher_feat_norm = F.normalize(teacher_feat, dim=1)
student_feat_norm = F.normalize(student_feat, dim=1)
计算教师和学生的特征距离矩阵
t_dist = torch.cdist(teacher_feat_norm, teacher_feat_norm, p=2) # [B, B]
s_dist = torch.cdist(student_feat_norm, student_feat_norm, p=2) # [B, B]
创建标签掩码
labels_expanded = labels.unsqueeze(1)
same_class_mask = (labels_expanded == labels_expanded.T).float()
diff_class_mask = 1 - same_class_mask
忽略自身
same_class_mask.fill_diagonal_(0)
loss = 0.0
valid_count = 0
for i in range(batch_size):
在教师空间中找到最难正样本和最难负样本
最难正样本:同类中距离最大的
pos_distances = t_dist[i] * same_class_mask[i]
hardest_pos_idx = torch.argmax(pos_distances)
hardest_pos_dist_t = t_dist[i, hardest_pos_idx]
最难负样本:异类中距离最小的
neg_distances = t_dist[i] * diff_class_mask[i]
避免全零(如果没有负样本)
if neg_distances.sum() > 0:
neg_distances[neg_distances == 0] = float('inf')
hardest_neg_idx = torch.argmin(neg_distances)
hardest_neg_dist_t = t_dist[i, hardest_neg_idx]
获取学生在对应配对上的距离
hardest_pos_dist_s = s_dist[i, hardest_pos_idx]
hardest_neg_dist_s = s_dist[i, hardest_neg_idx]
计算三元组损失,目标是让学生距离比与教师一致
我们鼓励: (d_pos_s / d_neg_s) ≈ (d_pos_t / d_neg_t)
转换为鼓励: d_pos_s * d_neg_t - d_neg_s * d_pos_t 接近0
relation_loss = torch.abs(hardest_pos_dist_s * hardest_neg_dist_t -
hardest_neg_dist_s * hardest_pos_dist_t)
loss += relation_loss
valid_count += 1
if valid_count > 0:
loss = loss / valid_count
return loss
@torch.no_grad()
def update_prototypes(self, teacher_feat, student_feat, labels, momentum=0.99):
"""
使用动量更新方式,维护教师和学生的类别原型(类中心)。
"""
for idx in range(self.num_classes):
mask = (labels == idx)
if mask.sum() > 0:
计算当前批次中该类特征的均值
teacher_mean = teacher_feat[mask].mean(dim=0)
student_mean = student_feat[mask].mean(dim=0)
动量更新
self.teacher_prototypes[idx] = momentum * self.teacher_prototypes[idx] + (1 - momentum) * teacher_mean
self.student_prototypes[idx] = momentum * self.student_prototypes[idx] + (1 - momentum) * student_mean
def compute_mmd_loss(self, x, y, kernel_mul=2.0, kernel_num=5):
"""
计算最大均值差异(MMD)损失,用于匹配两个特征分布的整体形状。
这对于对齐长尾数据中头部和尾类的整体分布尤其有效。
"""
batch_size = min(x.size(0), y.size(0))
x, y = x[:batch_size], y[:batch_size]
xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
rx = (xx.diag().unsqueeze(0).expand_as(xx))
ry = (yy.diag().unsqueeze(0).expand_as(yy))
dxx = rx.t() + rx - 2 * xx
dyy = ry.t() + ry - 2 * yy
dxy = rx.t() + ry - 2 * zz
XX, YY, XY = (torch.zeros_like(xx) for _ in range(3))
for sigma in [kernel_mul ** i for i in range(kernel_num)]:
gamma = 1.0 / (2 * sigma)
XX += torch.exp(-gamma * dxx)
YY += torch.exp(-gamma * dyy)
XY += torch.exp(-gamma * dxy)
mmd = XX.mean() + YY.mean() - 2 * XY.mean()
return torch.clamp(mmd, min=0.0)
==================== 主训练流程集成示例 ====================
def train_one_epoch(student, teacher, train_loader, criterion, skd_criterion, optimizer, device):
"""
一个完整训练周期的示例。
"""
student.train()
teacher.eval() # 教师模型固定
total_loss = 0.0
for batch_idx, (images, labels) in enumerate(train_loader):
images, labels = images.to(device), labels.to(device)
1. 前向传播
假设我们的模型返回 (logits, features)
student_logits, student_feat = student(images)
with torch.no_grad():
teacher_logits, teacher_feat = teacher(images)
2. 计算损失
2.1 标准交叉熵损失
ce_loss = F.cross_entropy(student_logits, labels)
2.2 结构化知识蒸馏损失
skd_loss, loss_dict = skd_criterion(student_logits, teacher_logits,
student_feat, teacher_feat, labels)
2.3 组合损失
lambda_skd = 0.7 # 结构化蒸馏损失权重
total_batch_loss = ce_loss + lambda_skd * skd_loss
3. 反向传播
optimizer.zero_grad()
total_batch_loss.backward()
optimizer.step()
total_loss += total_batch_loss.item()
打印部分信息
if batch_idx % 50 == 0:
print(f'Batch {batch_idx}: CE={ce_loss.item():.4f}, '
f'SKD={skd_loss.item():.4f}, '
f'Feat={loss_dict["feat_dist_loss"].item():.4f}, '
f'Triplet={loss_dict.get("triplet_loss", 0).item():.4f}')
return total_loss / len(train_loader)
==================== 构建长尾数据采样器 ====================
def create_balanced_sampler(dataset):
"""
为长尾数据集创建一个平衡采样器。
在结构化蒸馏中,平衡采样有助于在批次中看到更多尾类,
从而更有效地对齐尾部特征的结构。
"""
假设dataset.targets包含所有样本的标签
if hasattr(dataset, 'targets'):
labels = dataset.targets
else:
否则需要遍历数据集收集标签
labels = [label for _, label in dataset]
class_counts = np.bincount(labels)
num_samples = len(labels)
为每个样本分配权重:该样本所属类别的倒数
class_weights = 1.0 / class_counts
sample_weights = [class_weights[label] for label in labels]
sampler = WeightedRandomSampler(sample_weights, num_samples=num_samples, replacement=True)
return sampler
==================== 模型定义示例 ====================
class SimpleCNN(nn.Module):
"""一个简单的CNN示例,返回logits和特征"""
def init(self, num_classes=10):
super().init()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.AdaptiveAvgPool2d((1, 1))
)
self.fc = nn.Linear(64, num_classes)
self.feat_dim = 64
def forward(self, x):
feat = self.features(x).view(x.size(0), -1)
logits = self.fc(feat)
return logits, feat
if name == 'main':
设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
初始化模型
num_classes = 100 # 假设有100个类别,呈长尾分布
student_model = SimpleCNN(num_classes=num_classes).to(device)
teacher_model = SimpleCNN(num_classes=num_classes).to(device)
加载预训练的教师模型权重(此处假设已有)
teacher_model.load_state_dict(torch.load('teacher.pth'))
初始化结构化知识蒸馏损失
skd_loss = StructuredKnowledgeDistillationLoss(
feat_dim=student_model.feat_dim,
num_classes=num_classes,
temperature=3.0,
alpha=1.0, # 传统KD权重
beta=0.5, # 特征分布损失权重
margin=1.0
).to(device)
优化器
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)
假设已有长尾数据集 `long_tail_dataset`
train_dataset = YourLongTailDataset(...)
balanced_sampler = create_balanced_sampler(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=64, sampler=balanced_sampler)
print("模型与损失函数初始化完成。")
print("提示:在实际使用前,请准备好长尾数据集并配置好数据加载器。")