结构化知识蒸馏(特征分布+关系知识)

在长尾图像分类中,结构化知识蒸馏的核心思想是:它不仅让学生模型模仿教师模型的最终输出(标签),更重要的是模仿其内部学到的结构化知识,例如特征空间的几何关系、类间相似性、数据分布的形状等。这对于长尾问题至关重要,因为尾类样本稀少,难以从数据中直接学习鲁棒特征,但一个预训练好的教师模型(即使在均衡数据上训练)其特征空间已蕴含了丰富的、可迁移的类别间结构化关系。

一、核心原理:超越标签的"关系"学习 与标准知识蒸馏(KD)仅匹配教师和学生的输出概率分布不同,结构化知识蒸馏旨在迁移更深层的知识表示。在长尾场景下,其优势尤为突出: 缓解尾类表征学习不足:尾类样本少,其统计特征(如均值、方差)估计不准。结构化知识让学生模型从教师处学习一个更稳健、更具判别性的特征空间结构,而不仅仅记忆少数尾类样本。 迁移类别语义关系:教师模型能捕捉到"猫和老虎更相似,猫和汽车不相似"这样的语义关系。结构化蒸馏让学生继承这种关系先验,即使尾类"猎豹"样本很少,学生也能因其与"猫"的相似性而将其正确归类。 增强特征判别性:通过匹配特征分布的形状(如类内紧凑、类间分离),学生模型能学习到更具判别力的特征,减少将尾类误判为视觉相似的头部类别。

二、代码示例:结构化知识蒸馏(特征分布+关系知识) 以下是一个结合了特征分布对齐和类间关系对齐的PyTorch实现。我们使用一种在长尾数据上有效的策略:类平衡采样来构建批次,并结合特征分布损失与结构化关系损失(三元组损失)。

import torch

import torch.nn as nn

import torch.nn.functional as F

import torchvision.models as models

from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import numpy as np

class StructuredKnowledgeDistillationLoss(nn.Module):

"""

结构化知识蒸馏损失模块。

结合:1. 特征分布对齐损失 (L2/MMD) 2. 类间关系对齐损失 (三元组损失)

"""

def init(self, feat_dim, num_classes, temperature=3.0, alpha=1.0, beta=1.0, margin=1.0):

super().init()

self.temperature = temperature

self.alpha = alpha # 传统KD损失权重

self.beta = beta # 特征分布损失权重

self.margin = margin # 三元组损失边际

self.num_classes = num_classes

可学习的投影头,将学生特征映射到与教师特征可比的空间

self.projection = nn.Sequential(

nn.Linear(feat_dim, 512),

nn.ReLU(),

nn.Linear(512, feat_dim)

)

用于存储类别原型(均值),动量更新

self.register_buffer('teacher_prototypes', torch.zeros(num_classes, feat_dim))

self.register_buffer('student_prototypes', torch.zeros(num_classes, feat_dim))

def forward(self, student_logits, teacher_logits, student_feat, teacher_feat, labels):

"""

计算总的结构化蒸馏损失。

Args:

student_logits: 学生模型分类头输出 [B, C]

teacher_logits: 教师模型分类头输出 [B, C]

student_feat: 学生模型backbone输出的特征 [B, D]

teacher_feat: 教师模型backbone输出的特征 [B, D]

labels: 真实标签 [B]

Returns:

total_loss: 总损失

loss_dict: 各损失分量字典

"""

loss_dict = {}

1. 传统知识蒸馏损失 (软化标签)

soft_targets = F.softmax(teacher_logits / self.temperature, dim=1)

soft_prob = F.log_softmax(student_logits / self.temperature, dim=1)

kd_loss = F.kl_div(soft_prob, soft_targets.detach(), reduction='batchmean') * (self.temperature ** 2)

loss_dict['kd_loss'] = kd_loss

2. 特征分布结构化损失 (使用投影后的特征)

projected_student_feat = self.projection(student_feat)

2.1 基于L2的特征匹配损失

feat_l2_loss = F.mse_loss(F.normalize(projected_student_feat, dim=1),

F.normalize(teacher_feat, dim=1))

loss_dict['feat_l2_loss'] = feat_l2_loss

2.2 可选:基于MMD的特征分布匹配损失(对长尾更鲁棒)

mmd_loss = self.compute_mmd_loss(projected_student_feat, teacher_feat)

loss_dict['feat_mmd_loss'] = mmd_loss

feat_dist_loss = feat_l2_loss + 0.1 * mmd_loss

feat_dist_loss = feat_l2_loss

loss_dict['feat_dist_loss'] = feat_dist_loss

3. 类间关系结构化损失 (使用三元组损失对齐结构)

目标:让学生特征空间中,样本与同类原型、异类原型的关系,与教师特征空间一致

triplet_loss = self.structured_triplet_loss(projected_student_feat, teacher_feat, labels)

loss_dict['triplet_loss'] = triplet_loss

4. 更新类别原型 (动量更新)

self.update_prototypes(teacher_feat, student_feat, labels)

组合总损失

total_loss = self.alpha * kd_loss + self.beta * feat_dist_loss + triplet_loss

loss_dict['total_loss'] = total_loss

return total_loss, loss_dict

def structured_triplet_loss(self, student_feat, teacher_feat, labels):

"""

结构化三元组损失:让学生特征间的相对距离关系模仿教师特征间的距离关系。

这里实现一个简化版本:为每个样本寻找教师空间中的最难正/负样本对,然后让学生模仿这个关系。

"""

batch_size = student_feat.shape[0]

if batch_size < 2:

return torch.tensor(0.0, device=student_feat.device)

归一化特征

teacher_feat_norm = F.normalize(teacher_feat, dim=1)

student_feat_norm = F.normalize(student_feat, dim=1)

计算教师和学生的特征距离矩阵

t_dist = torch.cdist(teacher_feat_norm, teacher_feat_norm, p=2) # [B, B]

s_dist = torch.cdist(student_feat_norm, student_feat_norm, p=2) # [B, B]

创建标签掩码

labels_expanded = labels.unsqueeze(1)

same_class_mask = (labels_expanded == labels_expanded.T).float()

diff_class_mask = 1 - same_class_mask

忽略自身

same_class_mask.fill_diagonal_(0)

loss = 0.0

valid_count = 0

for i in range(batch_size):

在教师空间中找到最难正样本和最难负样本

最难正样本:同类中距离最大的

pos_distances = t_dist[i] * same_class_mask[i]

hardest_pos_idx = torch.argmax(pos_distances)

hardest_pos_dist_t = t_dist[i, hardest_pos_idx]

最难负样本:异类中距离最小的

neg_distances = t_dist[i] * diff_class_mask[i]

避免全零(如果没有负样本)

if neg_distances.sum() > 0:

neg_distances[neg_distances == 0] = float('inf')

hardest_neg_idx = torch.argmin(neg_distances)

hardest_neg_dist_t = t_dist[i, hardest_neg_idx]

获取学生在对应配对上的距离

hardest_pos_dist_s = s_dist[i, hardest_pos_idx]

hardest_neg_dist_s = s_dist[i, hardest_neg_idx]

计算三元组损失,目标是让学生距离比与教师一致

我们鼓励: (d_pos_s / d_neg_s) ≈ (d_pos_t / d_neg_t)

转换为鼓励: d_pos_s * d_neg_t - d_neg_s * d_pos_t 接近0

relation_loss = torch.abs(hardest_pos_dist_s * hardest_neg_dist_t -

hardest_neg_dist_s * hardest_pos_dist_t)

loss += relation_loss

valid_count += 1

if valid_count > 0:

loss = loss / valid_count

return loss

@torch.no_grad()

def update_prototypes(self, teacher_feat, student_feat, labels, momentum=0.99):

"""

使用动量更新方式,维护教师和学生的类别原型(类中心)。

"""

for idx in range(self.num_classes):

mask = (labels == idx)

if mask.sum() > 0:

计算当前批次中该类特征的均值

teacher_mean = teacher_feat[mask].mean(dim=0)

student_mean = student_feat[mask].mean(dim=0)

动量更新

self.teacher_prototypes[idx] = momentum * self.teacher_prototypes[idx] + (1 - momentum) * teacher_mean

self.student_prototypes[idx] = momentum * self.student_prototypes[idx] + (1 - momentum) * student_mean

def compute_mmd_loss(self, x, y, kernel_mul=2.0, kernel_num=5):

"""

计算最大均值差异(MMD)损失,用于匹配两个特征分布的整体形状。

这对于对齐长尾数据中头部和尾类的整体分布尤其有效。

"""

batch_size = min(x.size(0), y.size(0))

x, y = x[:batch_size], y[:batch_size]

xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())

rx = (xx.diag().unsqueeze(0).expand_as(xx))

ry = (yy.diag().unsqueeze(0).expand_as(yy))

dxx = rx.t() + rx - 2 * xx

dyy = ry.t() + ry - 2 * yy

dxy = rx.t() + ry - 2 * zz

XX, YY, XY = (torch.zeros_like(xx) for _ in range(3))

for sigma in [kernel_mul ** i for i in range(kernel_num)]:

gamma = 1.0 / (2 * sigma)

XX += torch.exp(-gamma * dxx)

YY += torch.exp(-gamma * dyy)

XY += torch.exp(-gamma * dxy)

mmd = XX.mean() + YY.mean() - 2 * XY.mean()

return torch.clamp(mmd, min=0.0)

==================== 主训练流程集成示例 ====================

def train_one_epoch(student, teacher, train_loader, criterion, skd_criterion, optimizer, device):

"""

一个完整训练周期的示例。

"""

student.train()

teacher.eval() # 教师模型固定

total_loss = 0.0

for batch_idx, (images, labels) in enumerate(train_loader):

images, labels = images.to(device), labels.to(device)

1. 前向传播

假设我们的模型返回 (logits, features)

student_logits, student_feat = student(images)

with torch.no_grad():

teacher_logits, teacher_feat = teacher(images)

2. 计算损失

2.1 标准交叉熵损失

ce_loss = F.cross_entropy(student_logits, labels)

2.2 结构化知识蒸馏损失

skd_loss, loss_dict = skd_criterion(student_logits, teacher_logits,

student_feat, teacher_feat, labels)

2.3 组合损失

lambda_skd = 0.7 # 结构化蒸馏损失权重

total_batch_loss = ce_loss + lambda_skd * skd_loss

3. 反向传播

optimizer.zero_grad()

total_batch_loss.backward()

optimizer.step()

total_loss += total_batch_loss.item()

打印部分信息

if batch_idx % 50 == 0:

print(f'Batch {batch_idx}: CE={ce_loss.item():.4f}, '

f'SKD={skd_loss.item():.4f}, '

f'Feat={loss_dict["feat_dist_loss"].item():.4f}, '

f'Triplet={loss_dict.get("triplet_loss", 0).item():.4f}')

return total_loss / len(train_loader)

==================== 构建长尾数据采样器 ====================

def create_balanced_sampler(dataset):

"""

为长尾数据集创建一个平衡采样器。

在结构化蒸馏中,平衡采样有助于在批次中看到更多尾类,

从而更有效地对齐尾部特征的结构。

"""

假设dataset.targets包含所有样本的标签

if hasattr(dataset, 'targets'):

labels = dataset.targets

else:

否则需要遍历数据集收集标签

labels = [label for _, label in dataset]

class_counts = np.bincount(labels)

num_samples = len(labels)

为每个样本分配权重:该样本所属类别的倒数

class_weights = 1.0 / class_counts

sample_weights = [class_weights[label] for label in labels]

sampler = WeightedRandomSampler(sample_weights, num_samples=num_samples, replacement=True)

return sampler

==================== 模型定义示例 ====================

class SimpleCNN(nn.Module):

"""一个简单的CNN示例,返回logits和特征"""

def init(self, num_classes=10):

super().init()

self.features = nn.Sequential(

nn.Conv2d(3, 32, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2),

nn.Conv2d(32, 64, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2),

nn.AdaptiveAvgPool2d((1, 1))

)

self.fc = nn.Linear(64, num_classes)

self.feat_dim = 64

def forward(self, x):

feat = self.features(x).view(x.size(0), -1)

logits = self.fc(feat)

return logits, feat

if name == 'main':

设备配置

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'Using device: {device}')

初始化模型

num_classes = 100 # 假设有100个类别,呈长尾分布

student_model = SimpleCNN(num_classes=num_classes).to(device)

teacher_model = SimpleCNN(num_classes=num_classes).to(device)

加载预训练的教师模型权重(此处假设已有)

teacher_model.load_state_dict(torch.load('teacher.pth'))

初始化结构化知识蒸馏损失

skd_loss = StructuredKnowledgeDistillationLoss(

feat_dim=student_model.feat_dim,

num_classes=num_classes,

temperature=3.0,

alpha=1.0, # 传统KD权重

beta=0.5, # 特征分布损失权重

margin=1.0

).to(device)

优化器

optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)

假设已有长尾数据集 `long_tail_dataset`

train_dataset = YourLongTailDataset(...)

balanced_sampler = create_balanced_sampler(train_dataset)

train_loader = DataLoader(train_dataset, batch_size=64, sampler=balanced_sampler)

print("模型与损失函数初始化完成。")

print("提示:在实际使用前,请准备好长尾数据集并配置好数据加载器。")

相关推荐
自然语1 小时前
数字生已经进化到一个分水岭面临选择?先实现“动态识别“还是先实现“特征信息归纳分类“,文中给出以给出答案,大家选哪个方向?
人工智能·分类·数据挖掘
围炉聊科技1 小时前
国内的大模型访问能访问墙外内容吗?
人工智能
fantasy_arch1 小时前
LSTM和DenseNet区别
人工智能·rnn·lstm
Deepoch1 小时前
发动机设计迎突破!Deepoc-M低幻觉模型重塑研发逻辑
大数据·人工智能·deepoc
Pocker_Spades_A1 小时前
智能时代的操作系统范式:openEuler的AI就绪度深度评估
人工智能
Sirius Wu1 小时前
智能体开发框架选型
人工智能·aigc
人工智能技术咨询.1 小时前
【无标题】卷积神经网络(CNN)详细介绍及其原理详解(2)
人工智能
sendnews1 小时前
红松亮相首届厦门银博会,以一站式社区平台展示退休生活新图景
大数据·人工智能
有Li1 小时前
一种交互式可解释人工智能方法,用于改进数字细胞病理学癌症亚型分类中的人机协作|文献速递-文献分享
大数据·论文阅读·人工智能·文献