引言
在计算机视觉领域,Vision Transformer(ViT)的出现标志着深度学习架构的重大转变。其中,BEiT(Bidirectional Encoder Representations from Image Transformers) 作为微软研究院提出的自监督视觉表示学习模型,通过借鉴NLP领域BERT的掩码语言建模(MLM)思想,创新性地引入了**掩码图像建模(Masked Image Modeling, MIM)**预训练策略,在ImageNet等基准数据集上取得了卓越性能。
本文将详细介绍如何利用BEiT-base模型,通过迁移学习在CIFAR-100数据集上实现高精度的图像分类。CIFAR-100包含100个细粒度类别(分为20个超类),每类600张32×32像素的彩色图像,是评估图像分类算法的重要基准。
一、BEiT模型架构与核心原理
1.1 模型架构特点
BEiT-base模型采用与ViT-base相似的架构配置:
-
12层Transformer编码器
-
768维隐藏层维度
-
12个注意力头
-
16×16像素的Patch分割
-
224×224像素的输入分辨率
与标准ViT的关键区别在于,BEiT使用相对位置编码 (类似T5模型)替代绝对位置编码,并通过对所有Patch的最终隐藏状态进行平均池化(mean-pooling)来进行图像分类,而非仅使用[CLS]标记。
1.2 掩码图像建模(MIM)预训练
BEiT的核心创新在于其预训练策略:
-
双视图输入:原始图像被分割为Patch序列,同时通过DALL-E的VQ-VAE编码器转换为离散视觉Token
-
随机掩码:随机遮蔽部分图像Patch(通常40%)
-
Token预测:模型学习预测被遮蔽区域的视觉Token,而非原始像素值
这种方法迫使模型学习图像的高级语义抽象,而非低级纹理细节,从而在下游任务中表现更佳。
二、项目架构与实现
2.1 环境准备与模型下载
首先通过ModelScope下载预训练的BEiT模型:
python
from modelscope import snapshot_download
import os
# 下载模型到本地
model_dir = snapshot_download(
'microsoft/beit-base-patch16-224',
cache_dir='./models',
revision='master'
)
模型选择说明 :我们选用microsoft/beit-base-patch16-224,该模型先在ImageNet-21k(1400万张图像,21,841个类别)上进行自监督预训练,再在ImageNet-1k(1000类)上进行微调。这种经过完整预训练流程的模型具有强大的视觉特征提取能力。
2.2 数据预处理与增强策略
由于BEiT期望224×224像素的输入,而CIFAR-100原始图像仅为32×32像素,我们需要精心设计数据预处理流程:
python
# 训练集数据增强
transform_train = transforms.Compose([
transforms.Resize((224, 224)), # 上采样至224×224
transforms.RandomCrop(224, padding=28), # 随机裁剪
transforms.RandomHorizontalFlip(p=0.5), # 水平翻转
transforms.RandomRotation(15), # 随机旋转
transforms.ColorJitter( # 颜色抖动
brightness=0.2,
contrast=0.2,
saturation=0.2
),
transforms.ToTensor(),
transforms.Normalize(
mean=feature_extractor.image_mean, # BEiT特定的归一化参数
std=feature_extractor.image_std
),
])
# 测试集预处理
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=feature_extractor.image_mean,
std=feature_extractor.image_std
),
])
关键设计决策:
-
Resize vs. Upsample:直接将32×32图像插值到224×224会损失细节,但这是适配预训练模型的必要妥协
-
强数据增强:针对小图像的上采样特性,采用 aggressive 的数据增强(大角度旋转、颜色抖动)防止过拟合
-
归一化对齐:使用BEiT特征提取器提供的均值和标准差,确保与预训练分布一致
2.3 模型适配与分类头修改
加载模型时需要修改分类头以适应100类输出:
python
def load_pretrained_model(num_classes=100, model_path="./models/microsoft/beit-base-patch16-224"):
# 加载特征提取器
feature_extractor = BeitFeatureExtractor.from_pretrained(model_path)
# 加载模型并修改分类头
model = BeitForImageClassification.from_pretrained(
model_path,
num_labels=num_classes,
ignore_mismatched_sizes=True # 关键参数:允许分类头尺寸不匹配
)
return model, feature_extractor
权重加载机制 :由于预训练模型输出1000类(ImageNet-1k),而我们只需要100类(CIFAR-100),ignore_mismatched_sizes=True参数会自动重新初始化分类层(classifier.weight和classifier.bias),同时保留Transformer编码器的预训练权重。
2.4 分层学习率优化策略
针对迁移学习场景,采用分层学习率(Layer-wise Learning Rate Decay):
python
optimizer = optim.AdamW([
{'params': model.beit.parameters(), 'lr': 5e-5}, # 预训练部分:小学习率
{'params': model.classifier.parameters(), 'lr': 1e-3} # 新分类头:大学习率
], weight_decay=0.05)
# 余弦退火学习率调度
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
策略原理:
-
骨干网络(5e-5):预训练的BEiT编码器已具备强大的特征提取能力,仅需微调
-
分类头(1e-3):随机初始化的分类层需要更激进的更新
-
AdamW优化器:结合权重衰减(weight decay=0.05)防止过拟合
-
余弦退火:平滑降低学习率,帮助收敛到更优局部最小值
2.5 损失函数与标签平滑
采用带标签平滑的交叉熵损失:
python
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
标签平滑(Label Smoothing)将硬标签(0/1)转换为软标签(如0.1/0.9),防止模型过度自信,提升泛化能力。
三、训练过程与结果分析
训练过程显卡显存占用情况👇:
bash
Sun Feb 15 10:03:02 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A10 Off | 00000000:00:07.0 Off | Off |
| 0% 59C P0 150W / 150W | 11323MiB / 24564MiB | 99% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+
3.1 训练日志解读
🍎完整的训练日志如下:
bash
CUDA is available. GPU: NVIDIA A10
Memory allocated: 0.00 MB
Using device: cuda
Loading pre-trained BEiT model...
Some weights of BeitForImageClassification were not initialized from the model checkpoint at ./models/microsoft/beit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([100, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Preparing CIFAR-100 dataset...
Starting training for 10 epochs...
Epoch 1: Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [09:38<00:00, 1.35it/s, Loss=1.305, Acc=83.99%]
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:42<00:00, 3.71it/s, Acc=89.33%]
Epoch 1/10: Train Loss: 1.3054, Train Acc: 83.99%, Test Acc: 89.33%, Best Acc: 89.33%, Time: 623.31s
Epoch 2: Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [09:38<00:00, 1.35it/s, Loss=1.055, Acc=91.34%]
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:42<00:00, 3.71it/s, Acc=90.30%]
Epoch 2/10: Train Loss: 1.0548, Train Acc: 91.34%, Test Acc: 90.30%, Best Acc: 90.30%, Time: 624.62s
Epoch 3: Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [09:38<00:00, 1.35it/s, Loss=0.984, Acc=93.78%]
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:42<00:00, 3.71it/s, Acc=90.95%]
Epoch 3/10: Train Loss: 0.9837, Train Acc: 93.78%, Test Acc: 90.95%, Best Acc: 90.95%, Time: 623.37s
Epoch 4: Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [09:38<00:00, 1.35it/s, Loss=0.930, Acc=95.55%]
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:42<00:00, 3.71it/s, Acc=91.36%]
Epoch 4/10: Train Loss: 0.9297, Train Acc: 95.55%, Test Acc: 91.36%, Best Acc: 91.36%, Time: 623.35s
Epoch 5: Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [09:38<00:00, 1.35it/s, Loss=0.884, Acc=97.06%]
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:42<00:00, 3.71it/s, Acc=92.11%]
Epoch 5/10: Train Loss: 0.8837, Train Acc: 97.06%, Test Acc: 92.11%, Best Acc: 92.11%, Time: 622.95s
Epoch 6: Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [09:38<00:00, 1.35it/s, Loss=0.855, Acc=98.12%]
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:42<00:00, 3.72it/s, Acc=92.55%]
Epoch 6/10: Train Loss: 0.8551, Train Acc: 98.12%, Test Acc: 92.55%, Best Acc: 92.55%, Time: 622.69s
Epoch 7: Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [09:38<00:00, 1.35it/s, Loss=0.832, Acc=98.87%]
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:42<00:00, 3.71it/s, Acc=92.88%]
Epoch 7/10: Train Loss: 0.8319, Train Acc: 98.87%, Test Acc: 92.88%, Best Acc: 92.88%, Time: 622.77s
Epoch 8: Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [09:38<00:00, 1.35it/s, Loss=0.817, Acc=99.34%]
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:42<00:00, 3.71it/s, Acc=93.30%]
Epoch 8/10: Train Loss: 0.8172, Train Acc: 99.34%, Test Acc: 93.30%, Best Acc: 93.30%, Time: 623.30s
Epoch 9: Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [09:38<00:00, 1.35it/s, Loss=0.811, Acc=99.51%]
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:42<00:00, 3.71it/s, Acc=93.52%]
Epoch 9/10: Train Loss: 0.8106, Train Acc: 99.51%, Test Acc: 93.52%, Best Acc: 93.52%, Time: 623.31s
Epoch 10: Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [09:38<00:00, 1.35it/s, Loss=0.808, Acc=99.56%]
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:42<00:00, 3.70it/s, Acc=93.61%]
Epoch 10/10: Train Loss: 0.8078, Train Acc: 99.56%, Test Acc: 93.61%, Best Acc: 93.61%, Time: 623.23s
Training completed. Best accuracy: 93.61%
Loading best model for final evaluation...
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:42<00:00, 3.70it/s, Acc=93.61%]
Final test accuracy: 93.61%
Per-class accuracy:
apple : 99.00%
aquarium_fish : 97.00%
baby : 86.00%
bear : 97.00%
beaver : 91.00%
bed : 96.00%
bee : 100.00%
beetle : 94.00%
bicycle : 96.00%
bottle : 99.00%
bowl : 82.00%
boy : 85.00%
bridge : 95.00%
bus : 91.00%
butterfly : 99.00%
camel : 96.00%
can : 94.00%
castle : 94.00%
caterpillar : 92.00%
cattle : 97.00%
chair : 96.00%
chimpanzee : 98.00%
clock : 99.00%
cloud : 87.00%
cockroach : 97.00%
couch : 88.00%
crab : 92.00%
crocodile : 96.00%
cup : 90.00%
dinosaur : 95.00%
dolphin : 92.00%
elephant : 97.00%
flatfish : 95.00%
forest : 81.00%
fox : 96.00%
girl : 81.00%
hamster : 99.00%
house : 91.00%
kangaroo : 94.00%
keyboard : 98.00%
lamp : 96.00%
lawn_mower : 99.00%
leopard : 91.00%
lion : 98.00%
lizard : 98.00%
lobster : 90.00%
man : 90.00%
maple_tree : 76.00%
motorcycle : 98.00%
mountain : 98.00%
mouse : 90.00%
mushroom : 99.00%
oak_tree : 72.00%
orange : 99.00%
orchid : 97.00%
otter : 84.00%
palm_tree : 98.00%
pear : 97.00%
pickup_truck : 100.00%
pine_tree : 82.00%
plain : 89.00%
plate : 88.00%
poppy : 93.00%
porcupine : 88.00%
possum : 83.00%
rabbit : 93.00%
raccoon : 97.00%
ray : 96.00%
road : 98.00%
rocket : 99.00%
rose : 93.00%
sea : 93.00%
seal : 85.00%
shark : 93.00%
shrew : 86.00%
skunk : 100.00%
skyscraper : 97.00%
snail : 99.00%
snake : 95.00%
spider : 98.00%
squirrel : 94.00%
streetcar : 91.00%
sunflower : 99.00%
sweet_pepper : 94.00%
table : 93.00%
tank : 99.00%
telephone : 98.00%
television : 98.00%
tiger : 96.00%
tractor : 100.00%
train : 94.00%
trout : 97.00%
tulip : 90.00%
turtle : 97.00%
wardrobe : 99.00%
whale : 91.00%
willow_tree : 86.00%
wolf : 99.00%
woman : 94.00%
worm : 95.00%
Training History Summary:
Best training accuracy: 99.56%
Best test accuracy: 93.61%
Generating inference visualization...
可视化结果已保存为 beit_cifar100_inference_results.png
在NVIDIA A10 GPU上训练10个epoch的完整日志如下:
bash
Epoch 1/10: Train Loss: 1.3054, Train Acc: 83.99%, Test Acc: 89.33%
Epoch 2/10: Train Loss: 1.0548, Train Acc: 91.34%, Test Acc: 90.30%
Epoch 3/10: Train Loss: 0.9837, Train Acc: 93.78%, Test Acc: 90.95%
Epoch 4/10: Train Loss: 0.9297, Train Acc: 95.55%, Test Acc: 91.36%
Epoch 5/10: Train Loss: 0.8837, Train Acc: 97.06%, Test Acc: 92.11%
Epoch 6/10: Train Loss: 0.8551, Train Acc: 98.12%, Test Acc: 92.55%
Epoch 7/10: Train Loss: 0.8319, Train Acc: 98.87%, Test Acc: 92.88%
Epoch 8/10: Train Loss: 0.8172, Train Acc: 99.34%, Test Acc: 93.30%
Epoch 9/10: Train Loss: 0.8106, Train Acc: 99.51%, Test Acc: 93.52%
Epoch 10/10: Train Loss: 0.8078, Train Acc: 99.56%, Test Acc: 93.61%
关键观察:
-
快速收敛:第1个epoch即达到89.33%的测试准确率,体现迁移学习的优势
-
持续改进:从Epoch 1到Epoch 10,测试准确率稳步提升4.28%
-
过拟合控制:训练准确率(99.56%)与测试准确率(93.61%)存在差距,但在合理范围内
-
训练效率:每个epoch约623秒(10分23秒),单卡训练总时长约1小时45分钟
3.2 细粒度类别性能分析
模型在100个类别上的表现差异显著:
表现最佳类别(准确率≥99%):
- bee, butterfly, hamster, mushroom, orange, pickup_truck, skunk, snail, sunflower, tank, tractor, wardrobe, wolf
表现较弱类别(准确率≤82%):
- maple_tree (76.00%), oak_tree (72.00%), bowl (82.00%), forest (81.00%), girl (81.00%), possum (83.00%)
性能差异原因分析:
-
类别固有难度:树木类(maple_tree, oak_tree, pine_tree)在32×32低分辨率下难以区分细粒度特征
-
语义混淆:bowl与cup、plate等容器类容易混淆
-
数据分布:部分类别(如skunk, hamster)具有独特的视觉特征,易于识别
3.3 可视化推理结果
从可视化结果可以观察到:
-
高置信度正确预测:house (93.73%), butterfly (92.46%), keyboard (91.14%)
-
挑战性样本:couch (73.57%)等家具类在特定角度下识别困难
四、关键技术要点总结
4.1 迁移学习最佳实践
-
输入尺寸适配:CIFAR-100的32×32图像需上采样至224×224,虽然会引入插值伪影,但预训练模型的强大特征提取能力可以克服这一限制
-
数据增强强度:针对小图像上采样场景,采用更强的几何变换(RandomRotation(15°))和颜色抖动,有效扩充数据多样性
-
分层微调策略:预训练层使用较小学习率(5e-5),新分类层使用较大学习率(1e-3),平衡稳定性与收敛速度
-
标签平滑:0.1的标签平滑系数有效防止过拟合,特别是在细粒度分类任务中
4.2 性能优化技巧
-
混合精度训练 :虽然日志中未明确使用,但在A10 GPU上启用
torch.cuda.amp可进一步提升训练速度 -
梯度累积:若显存受限,可采用梯度累积模拟大批量训练
-
早停机制:当前实现保存最佳模型(基于验证准确率),防止过拟合
五、扩展与改进方向
-
更大模型尝试:可尝试BEiT-large(24层,1024维)进一步提升性能
-
自蒸馏策略:参考BEiT v2的矢量量化知识蒸馏,提升小模型性能
-
多尺度训练:结合CIFAR-100的原始分辨率与上采样分辨率进行多尺度训练
-
超类辅助学习:利用CIFAR-100的20个超类标签进行层次化分类
本文详细介绍了如何使用BEiT-base模型通过迁移学习在CIFAR-100数据集上实现93.61%的分类准确率。通过精心设计的数据预处理、分层学习率调度和标签平滑等技术,我们充分发挥了预训练视觉Transformer的强大能力。实验结果表明,即使面对32×32像素的低分辨率输入,经过ImageNet-21k大规模预训练的BEiT模型依然能够提取有效的语义特征,实现优异的细粒度分类性能。
完整的代码实现展示了从模型下载、数据预处理、训练到推理的完整流程,为基于Transformer的图像分类任务提供了可复用的技术方案。