最近在研究Vision Transformer(ViT),发现这个把Transformer用到图像分类上的想法真的很巧妙。正好MindSpore有个完整的教程,就跟着跑了一遍,记录下整个过程。
如果你对MindSpore感兴趣,可以关注昇思MindSpore社区


下图展示了ViT的完整架构:从输入图像分割成patches,到Transformer编码器处理,最后通过分类头输出结果。整个流程清晰明了,接下来我们一步步来实现。

1 环境搭建和数据准备
1.1 环境配置
首先确保本地装好了Python和MindSpore。这个教程建议用GPU跑,CPU会慢得让人怀疑人生。
数据集用的是ImageNet的子集,第一次运行会自动下载:
python
from download import download
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
path = download(dataset_url, path, kind="zip", replace=True)
下载完后数据结构是这样的:
.dataset/
├── ILSVRC2012_devkit_t12.tar.gz
├── train/
├── infer/
└── val/
1.2 数据预处理
数据预处理这块比较标准,主要是resize、随机裁剪、归一化这些操作:
python
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
trans_train = [
transforms.RandomCropDecodeResize(size=224,
scale=(0.08, 1.0),
ratio=(0.75, 1.333)),
transforms.RandomHorizontalFlip(prob=0.5),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
这里的mean和std是ImageNet的标准值,乘以255是因为MindSpore的数据格式。
2 ViT模型原理解析
2.1 Transformer的核心:Self-Attention
要理解ViT,得先搞懂Transformer的核心机制------Self-Attention。简单来说,就是让模型学会关注输入序列中不同位置之间的关系。
Self-Attention的计算过程:
- 输入向量通过三个不同的线性变换得到Q(Query)、K(Key)、V(Value)
- 计算Q和K的点积,得到注意力权重
- 用这些权重对V进行加权求和
数学公式是这样的:
{ q i = W q ⋅ x i k i = W k ⋅ x i v i = W v ⋅ x i \begin{cases} q_i = W_q \cdot x_i \\ k_i = W_k \cdot x_i \\ v_i = W_v \cdot x_i \end{cases} ⎩ ⎨ ⎧qi=Wq⋅xiki=Wk⋅xivi=Wv⋅xi
然后计算注意力分数:
a i , j = q i ⋅ k j d a_{i,j} = \frac{q_i \cdot k_j}{\sqrt{d}} ai,j=d qi⋅kj
经过Softmax归一化后,得到最终输出:
o u t p u t i = ∑ j softmax ( a i , j ) ⋅ v j output_i = \sum_j \text{softmax}(a_{i,j}) \cdot v_j outputi=j∑softmax(ai,j)⋅vj

上图详细展示了Self-Attention的计算过程:从输入序列X通过线性变换得到Q、K、V矩阵,然后计算注意力分数,经过Softmax得到权重,最后加权求和得到输出。这个机制让模型能够动态地关注输入序列中的不同部分。
2.2 Multi-Head Attention实现
多头注意力就是把输入分成多个"头",每个头独立计算注意力,最后拼接起来。这样能让模型从不同角度理解输入:
python
from mindspore import nn, ops
class Attention(nn.Cell):
def __init__(self,
dim: int,
num_heads: int = 8,
keep_prob: float = 1.0,
attention_keep_prob: float = 1.0):
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = ms.Tensor(head_dim ** -0.5)
self.qkv = nn.Dense(dim, dim * 3)
self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
self.out = nn.Dense(dim, dim)
self.out_drop = nn.Dropout(p=1.0-keep_prob)
self.attn_matmul_v = ops.BatchMatMul()
self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
self.softmax = nn.Softmax(axis=-1)
def construct(self, x):
b, n, c = x.shape
qkv = self.qkv(x)
qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
q, k, v = ops.unstack(qkv, axis=0)
attn = self.q_matmul_k(q, k)
attn = ops.mul(attn, self.scale)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
out = self.attn_matmul_v(attn, v)
out = ops.transpose(out, (0, 2, 1, 3))
out = ops.reshape(out, (b, n, c))
out = self.out(out)
out = self.out_drop(out)
return out
这段代码的关键在于:
qkv = self.qkv(x)一次性生成Q、K、V三个矩阵- reshape和transpose操作把数据重新组织成多头的形式
- 最后把多个头的结果拼接回去
2.3 Feed Forward和残差连接
除了注意力机制,Transformer还需要Feed Forward网络和残差连接:
python
from typing import Optional
class FeedForward(nn.Cell):
def __init__(self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
activation: nn.Cell = nn.GELU,
keep_prob: float = 1.0):
super(FeedForward, self).__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.dense1 = nn.Dense(in_features, hidden_features)
self.activation = activation()
self.dense2 = nn.Dense(hidden_features, out_features)
self.dropout = nn.Dropout(p=1.0-keep_prob)
def construct(self, x):
x = self.dense1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.dense2(x)
x = self.dropout(x)
return x
class ResidualCell(nn.Cell):
def __init__(self, cell):
super(ResidualCell, self).__init__()
self.cell = cell
def construct(self, x):
return self.cell(x) + x
残差连接很简单,就是把输入直接加到输出上,这样能避免深层网络的梯度消失问题。
2.4 TransformerEncoder的完整实现
把注意力机制、Feed Forward和残差连接组合起来,就是TransformerEncoder:
python
class TransformerEncoder(nn.Cell):
def __init__(self,
dim: int,
num_layers: int,
num_heads: int,
mlp_dim: int,
keep_prob: float = 1.,
attention_keep_prob: float = 1.0,
drop_path_keep_prob: float = 1.0,
activation: nn.Cell = nn.GELU,
norm: nn.Cell = nn.LayerNorm):
super(TransformerEncoder, self).__init__()
layers = []
for _ in range(num_layers):
normalization1 = norm((dim,))
normalization2 = norm((dim,))
attention = Attention(dim=dim,
num_heads=num_heads,
keep_prob=keep_prob,
attention_keep_prob=attention_keep_prob)
feedforward = FeedForward(in_features=dim,
hidden_features=mlp_dim,
activation=activation,
keep_prob=keep_prob)
layers.append(
nn.SequentialCell([
ResidualCell(nn.SequentialCell([normalization1, attention])),
ResidualCell(nn.SequentialCell([normalization2, feedforward]))
])
)
self.layers = nn.SequentialCell(layers)
def construct(self, x):
return self.layers(x)
这里有个细节:ViT把LayerNorm放在了注意力和Feed Forward之前,这和标准Transformer不太一样,但实验证明这样效果更好。
3 ViT的关键创新:图像转序列

上图展示了ViT处理图像的完整流程:从原始图像分割成patches,经过embedding转换,添加位置编码和CLS token,通过Transformer编码器处理,最后提取CLS token进行分类预测。
3.1 Patch Embedding
ViT最巧妙的地方就是把图像转换成序列。具体做法是把图像切成一个个小块(patch),然后把每个patch拉成一维向量:
python
class PatchEmbedding(nn.Cell):
MIN_NUM_PATCHES = 4
def __init__(self,
image_size: int = 224,
patch_size: int = 16,
embed_dim: int = 768,
input_channels: int = 3):
super(PatchEmbedding, self).__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.conv = nn.Conv2d(input_channels, embed_dim,
kernel_size=patch_size, stride=patch_size, has_bias=True)
def construct(self, x):
x = self.conv(x)
b, c, h, w = x.shape
x = ops.reshape(x, (b, c, h * w))
x = ops.transpose(x, (0, 2, 1))
return x
这里用卷积来实现patch切分,比手工切分更高效。对于224×224的图像,用16×16的patch,最终得到14×14=196个patch。
3.2 位置编码和分类token
图像切成patch后,还需要加上位置信息和分类token:
python
# 在ViT类的__init__中
self.cls_token = init(init_type=Normal(sigma=1.0),
shape=(1, 1, embed_dim),
dtype=ms.float32,
name='cls',
requires_grad=True)
self.pos_embedding = init(init_type=Normal(sigma=1.0),
shape=(1, num_patches + 1, embed_dim),
dtype=ms.float32,
name='pos_embedding',
requires_grad=True)
分类token借鉴了BERT的思路,在序列开头加一个特殊token,最后用这个token的输出来做分类。位置编码则告诉模型每个patch在图像中的位置。
3.3 完整的ViT模型
把所有组件组合起来,就是完整的ViT模型:
python
from mindspore.common.initializer import Normal, initializer
from mindspore import Parameter
def init(init_type, shape, dtype, name, requires_grad):
initial = initializer(init_type, shape, dtype).init_data()
return Parameter(initial, name=name, requires_grad=requires_grad)
class ViT(nn.Cell):
def __init__(self,
image_size: int = 224,
input_channels: int = 3,
patch_size: int = 16,
embed_dim: int = 768,
num_layers: int = 12,
num_heads: int = 12,
mlp_dim: int = 3072,
keep_prob: float = 1.0,
attention_keep_prob: float = 1.0,
drop_path_keep_prob: float = 1.0,
activation: nn.Cell = nn.GELU,
norm: Optional[nn.Cell] = nn.LayerNorm,
pool: str = 'cls') -> None:
super(ViT, self).__init__()
self.patch_embedding = PatchEmbedding(image_size=image_size,
patch_size=patch_size,
embed_dim=embed_dim,
input_channels=input_channels)
num_patches = self.patch_embedding.num_patches
self.cls_token = init(init_type=Normal(sigma=1.0),
shape=(1, 1, embed_dim),
dtype=ms.float32,
name='cls',
requires_grad=True)
self.pos_embedding = init(init_type=Normal(sigma=1.0),
shape=(1, num_patches + 1, embed_dim),
dtype=ms.float32,
name='pos_embedding',
requires_grad=True)
self.pool = pool
self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
self.norm = norm((embed_dim,))
self.transformer = TransformerEncoder(dim=embed_dim,
num_layers=num_layers,
num_heads=num_heads,
mlp_dim=mlp_dim,
keep_prob=keep_prob,
attention_keep_prob=attention_keep_prob,
drop_path_keep_prob=drop_path_keep_prob,
activation=activation,
norm=norm)
self.dropout = nn.Dropout(p=1.0-keep_prob)
self.dense = nn.Dense(embed_dim, num_classes)
def construct(self, x):
x = self.patch_embedding(x)
cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
x = ops.concat((cls_tokens, x), axis=1)
x += self.pos_embedding
x = self.pos_dropout(x)
x = self.transformer(x)
x = self.norm(x)
x = x[:, 0] # 取分类token的输出
if self.training:
x = self.dropout(x)
x = self.dense(x)
return x
整个流程就是:图像 → patch embedding → 加上cls token和位置编码 → Transformer编码器 → 分类头。
4 训练和验证实战
4.1 训练配置
训练前需要设置损失函数、优化器等:
python
from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
# 超参数设置
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
# 构建模型
network = ViT()
# 加载预训练权重
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
# 学习率调度
lr = nn.cosine_decay_lr(min_lr=float(0),
max_lr=0.00005,
total_step=epoch_size * step_size,
step_per_epoch=step_size,
decay_epoch=10)
# 优化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
这里用了预训练模型,所以学习率设得比较小。余弦退火调度能让训练更稳定。
4.2 损失函数
用了带标签平滑的交叉熵损失:
python
class CrossEntropySmooth(LossBase):
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = ops.OneHot()
self.sparse = sparse
self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss
network_loss = CrossEntropySmooth(sparse=True,
reduction="mean",
smooth_factor=0.1,
num_classes=num_classes)
标签平滑能防止模型过拟合,提高泛化能力。
4.3 开始训练
python
# 设置检查点
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
# 初始化模型
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt,
metrics={"acc"}, amp_level="O2")
else:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt,
metrics={"acc"}, amp_level="O0")
# 开始训练
model.train(epoch_size,
dataset_train,
callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
dataset_sink_mode=False)
训练过程中会看到这样的输出:
epoch: 1 step: 125, loss is 1.903618335723877
Train epoch time: 99857.517 ms, per step time: 798.860 ms
epoch: 2 step: 125, loss is 1.448015570640564
Train epoch time: 95555.111 ms, per step time: 764.441 ms
loss在逐渐下降,说明训练正常进行。

上图展示了ViT模型的训练过程:左侧是损失函数的下降趋势,右侧是准确率的提升曲线,下方表格总结了训练配置和最终结果。可以看到模型在训练过程中稳定收敛,最终达到了不错的性能。
4.4 模型验证
训练完后验证一下效果:
python
# 验证数据预处理
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
trans_val = [
transforms.Decode(),
transforms.Resize(224 + 32),
transforms.CenterCrop(224),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
# 评估指标
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
if ascend_target:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt,
metrics=eval_metrics, amp_level="O2")
else:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt,
metrics=eval_metrics, amp_level="O0")
# 开始验证
result = model.eval(dataset_val)
print(result)
结果显示:
{'Top_1_Accuracy': 0.75, 'Top_5_Accuracy': 0.928}
Top-1准确率75%,Top-5准确率92.8%,效果还不错。
5 推理测试
5.1 推理数据准备
python
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
trans_infer = [
transforms.Decode(),
transforms.Resize([224, 224]),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_infer = dataset_infer.map(operations=trans_infer,
input_columns=["image"],
num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)
5.2 推理和结果可视化
python
import cv2
import numpy as np
from PIL import Image
from scipy import io
def index2label():
"""获取ImageNet类别标签"""
metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
meta = io.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
_, wnids, classes = list(zip(*meta))[:3]
clssname = [tuple(clss.split(', ')) for clss in classes]
wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
mapping = {}
for index, (_, class_name) in enumerate(wind2class_name):
mapping[index] = class_name[0]
return mapping
# 推理
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
image = image["image"]
image = ms.Tensor(image)
prob = model.predict(image)
label = np.argmax(prob.asnumpy(), axis=1)
mapping = index2label()
output = {int(label): mapping[int(label)]}
print(output)
推理结果:
{236: 'Doberman'}

模型正确识别出了杜宾犬,说明推理效果不错。
6 总结和思考
6.1 ViT的优势
通过这次实践,感受到ViT的几个优势:
- 架构简洁:相比CNN的复杂卷积层设计,ViT的架构更加统一和简洁
- 可扩展性强:Transformer的并行计算能力让模型可以轻松扩展到更大规模
- 迁移能力好:在大数据集上预训练后,可以很好地迁移到下游任务
6.2 实践中的坑
- 计算资源要求高:ViT对GPU内存要求比较大,batch size不能设太大
- 需要大量数据:相比CNN,ViT更依赖大规模预训练数据
- 位置编码很重要:去掉位置编码后性能会明显下降
6.3 代码实现的亮点
MindSpore的实现有几个不错的地方:
- 模块化设计:每个组件都封装得很好,便于理解和修改
- 自动混合精度:通过amp_level参数可以轻松开启混合精度训练
- 灵活的数据处理:数据预处理管道设计得很灵活
整个跑通过程还是比较顺利的,代码质量不错,注释也比较清楚。对于想了解ViT原理和实现的同学来说,这个教程是个不错的起点。
当然,要真正掌握ViT,还需要多读论文,多做实验。这次只是个开始,后面可以尝试在自己的数据集上微调,或者实现一些ViT的变种模型。