面对上千种类别 的图像分类任务(例如1000类、10000类甚至更多),情况会变得更加复杂。这通常被称为大规模图像分类,主要挑战包括:
- 计算负担:最后的全连接层参数量巨大(例如,2048维特征×10000类 = 2.05亿参数)。
- 数据稀疏性:可能出现长尾分布,部分类别样本极少。
- 类别间相似性:类别越多,相似类别对越多,区分难度指数级增加。
针对这种场景,需要对之前的方案进行根本性的调整。以下是升级后的模型设计方案:
1. 核心思路转变:从"直推式"到"检索式"
当类别数达到数千甚至数万时,将分类问题视为度量学习 或特征检索问题往往比直接分类更有效。
- 训练阶段 :学习一个能将图像映射到特征空间的嵌入模型,使得同类样本距离近,异类样本距离远。
- 推理阶段 :不再直接输出类别ID,而是提取图像特征,然后在预先构建的类别特征库中进行最近邻检索,找到最相似的类别。
2. 模型架构设计
2.1 Backbone选择
需要更强的骨干网络来提取更有区分性的特征:
- ConvNeXt Large/XLarge:现代CNN,效率与性能平衡好。
- Swin Transformer Base/Large:适合捕捉全局关系,对相似类别区分有优势。
- EfficientNetV2-M/L:计算效率极高,适合资源敏感场景。
- RegNetY:为大规模任务设计的网络结构。
2.2 分类头设计(关键改动)
方案A:分层分类(Hierarchical Classification)
如果这上千个类别存在天然层级关系(如动物界->哺乳类->犬科->狗->哈士奇),可以构建分类器树:
- 一级分类器:预测粗粒度类别(如犬科、猫科、鸟类)。
- 二级分类器:在粗粒度类别内部进行细粒度分类。
- 优点:每个分类器处理的类别数减少,难度降低;可以利用层级先验知识。
方案B:Partial-FC(部分全连接)
在大规模分类中,最后一层全连接层的GPU显存占用是主要瓶颈。可以采用Partial-FC策略(如Meta在训练千万级分类器时使用的方法):
- 原理:每次迭代,只采样一小部分类别(如全体类别的5%-10%)参与当前batch的梯度计算。
- 实现:将全量类别向量存储在统一的库中,每次前向传播时,动态构建一个小型的全连接层用于当前batch的计算。
- 优点:显存占用从O(类别数)降低到O(采样类别数),使得训练万级分类成为可能。
方案C:转为度量学习
彻底移除分类头,改为输出一个低维特征向量(如512维),然后用特殊损失函数训练。
python
# 度量学习模型示例
class MetricLearningModel(nn.Module):
def __init__(self, backbone, embed_dim=512):
super().__init__()
self.backbone = backbone
# 移除原始的分类头,添加一个嵌入头
self.embedding_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(backbone_feat_dim, embed_dim),
nn.BatchNorm1d(embed_dim), # 重要:对嵌入向量进行归一化
# L2 Norm 通常在损失函数中处理
)
def forward(self, x):
# 返回特征向量,而非类别logits
return self.embedding_head(self.backbone(x))
3. 损失函数选择
在千级甚至万级分类中,损失函数的选择直接影响模型收敛和特征质量。
3.1 如果保留分类头(Partial-FC)
- ArcFace / CosFace:原本用于人脸识别(百万级身份),对大规模分类非常有效。它在特征和权重之间引入角度间隔,能让特征在超球面上分布得更均匀。
- Softmax + 大间隔:增加Margin,提高对相似类别的区分能力。
3.2 如果采用度量学习
- Multi-Similarity Loss:考虑正负样本对的多种相似度,是目前综合效果较好的损失函数。
- Circle Loss:对每个相似度进行不同的加权,灵活性更高。
- Contrastive Loss / Triplet Loss:经典方案,但需要精心设计样本挖掘策略。
python
# 伪代码:使用ArcFace损失(配合Partial-FC)
# 特征: [batch, 512]
# 权重矩阵: [num_classes, 512] (存储在库中)
# 采样: 选出当前batch中出现的类别 + 随机采样负类别 (共 5000个类别,远小于100000)
sampled_class_ids = sample_classes(current_batch_labels, total_classes, num_sampled=5000)
sampled_weights = weight_bank[sampled_class_ids] # [5000, 512]
# 计算ArcFace Logits
logits = cosine_similarity(features, sampled_weights) # [batch, 5000]
loss = arcface_loss(logits, current_batch_labels, sampled_class_ids)
4. 训练策略优化
4.1 动态类别采样
- 难点:上千类别,如果使用完整Softmax,单卡显存可能不够。
- 方案:使用混合并行。模型并行(Model Parallel)将分类头权重切分到多卡,或者使用上述Partial-FC。
4.2 难例挖掘
- 必要性:类别越多,简单的正负样本对越容易分开,模型提升有限。
- 做法:在batch内部或在每个epoch结束后,挖掘难分的负样本(即与anchor特征很接近的不同类别),重新喂给模型训练。
4.3 两阶段训练
- 预热阶段:先用少量采样类别训练骨干网络,使其具备基本特征提取能力。
- 对齐阶段:逐步增加采样类别数量,使用ArcFace/Circle Loss微调,让特征空间适应大规模类别划分。
5. 推理阶段设计
当类别成千上万时,直接计算特征与所有类别中心(或每个类别的代表样本)的距离仍然可能很慢。
- 建立索引(必须) :不能使用暴力遍历。
- 将所有类别特征(可以是每类多个原型)存入向量数据库。
- 工具:Faiss(Facebook)、HNSW(Hierarchical Navigable Small World)、Milvus。
- 检索流程 :
- 提取查询图像特征。
- 使用Faiss等工具进行快速ANN(Approximate Nearest Neighbor,近似最近邻)检索,召回Top-K个最相似的类别原型。
- 对Top-K结果进行重排序(如果精度要求高,可以回退到精确计算)。
6. 长尾分布处理
上千类数据往往是不平衡的,头部类别样本多,尾部类别样本极少。
- 数据层面 :
- 类平衡采样:确保每个batch中,每个类别(特别是尾部类别)被抽到的概率均衡,而非按数据量比例。
- 数据增强:对尾部类别进行更强烈的增强,甚至使用图像生成技术(如Diffusion模型)生成额外的训练样本。
- 算法层面 :
- Logit Adjustment:在训练或推理时,根据先验类别频率调整预测logits。
- Balanced Softmax:修改Softmax函数,使其不受数据不平衡的影响。
7. 总结建议方案
假设你有一个包含5000类、长尾分布的数据集:
- 骨干网络:ConvNeXt Large 或 Swin-L (ImageNet-22K预训练)。
- 训练范式 :度量学习 + 分类头微调两阶段结合。
- 阶段一 :使用Multi-Similarity Loss + 动态采样,训练一个嵌入模型,输出512维特征。
- 阶段二 :固定Backbone,添加一个Partial-FC 头,用ArcFace损失微调5-10个epoch,让特征更有判别性。
- 索引构建 :用训练集(或每个类别的原型)通过模型提取特征,构建Faiss索引。
- 推理:查询图像 -> 提特征 -> Faiss检索 -> 返回Top-1/5类别。