华为开源自研AI框架昇思MindSpore应用案例:跑通Vision Transformer图像分类

最近在研究Vision Transformer(ViT),发现这个把Transformer用到图像分类上的想法真的很巧妙。正好MindSpore有个完整的教程,就跟着跑了一遍,记录下整个过程。

如果你对MindSpore感兴趣,可以关注昇思MindSpore社区

下图展示了ViT的完整架构:从输入图像分割成patches,到Transformer编码器处理,最后通过分类头输出结果。整个流程清晰明了,接下来我们一步步来实现。

1 环境搭建和数据准备

1.1 环境配置

首先确保本地装好了Python和MindSpore。这个教程建议用GPU跑,CPU会慢得让人怀疑人生。

数据集用的是ImageNet的子集,第一次运行会自动下载:

python 复制代码
from download import download

dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"

path = download(dataset_url, path, kind="zip", replace=True)

下载完后数据结构是这样的:

复制代码
.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/

1.2 数据预处理

数据预处理这块比较标准,主要是resize、随机裁剪、归一化这些操作:

python 复制代码
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms

data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)

trans_train = [
    transforms.RandomCropDecodeResize(size=224,
                                      scale=(0.08, 1.0),
                                      ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

这里的mean和std是ImageNet的标准值,乘以255是因为MindSpore的数据格式。

2 ViT模型原理解析

2.1 Transformer的核心:Self-Attention

要理解ViT,得先搞懂Transformer的核心机制------Self-Attention。简单来说,就是让模型学会关注输入序列中不同位置之间的关系。

Self-Attention的计算过程:

  1. 输入向量通过三个不同的线性变换得到Q(Query)、K(Key)、V(Value)
  2. 计算Q和K的点积,得到注意力权重
  3. 用这些权重对V进行加权求和

数学公式是这样的:
{ q i = W q ⋅ x i k i = W k ⋅ x i v i = W v ⋅ x i \begin{cases} q_i = W_q \cdot x_i \\ k_i = W_k \cdot x_i \\ v_i = W_v \cdot x_i \end{cases} ⎩ ⎨ ⎧qi=Wq⋅xiki=Wk⋅xivi=Wv⋅xi

然后计算注意力分数:
a i , j = q i ⋅ k j d a_{i,j} = \frac{q_i \cdot k_j}{\sqrt{d}} ai,j=d qi⋅kj

经过Softmax归一化后,得到最终输出:
o u t p u t i = ∑ j softmax ( a i , j ) ⋅ v j output_i = \sum_j \text{softmax}(a_{i,j}) \cdot v_j outputi=j∑softmax(ai,j)⋅vj

上图详细展示了Self-Attention的计算过程:从输入序列X通过线性变换得到Q、K、V矩阵,然后计算注意力分数,经过Softmax得到权重,最后加权求和得到输出。这个机制让模型能够动态地关注输入序列中的不同部分。

2.2 Multi-Head Attention实现

多头注意力就是把输入分成多个"头",每个头独立计算注意力,最后拼接起来。这样能让模型从不同角度理解输入:

python 复制代码
from mindspore import nn, ops

class Attention(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()

        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = ms.Tensor(head_dim ** -0.5)

        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(p=1.0-keep_prob)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)

    def construct(self, x):
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = ops.unstack(qkv, axis=0)
        
        attn = self.q_matmul_k(q, k)
        attn = ops.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        
        out = self.attn_matmul_v(attn, v)
        out = ops.transpose(out, (0, 2, 1, 3))
        out = ops.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)

        return out

这段代码的关键在于:

  • qkv = self.qkv(x) 一次性生成Q、K、V三个矩阵
  • reshape和transpose操作把数据重新组织成多头的形式
  • 最后把多个头的结果拼接回去

2.3 Feed Forward和残差连接

除了注意力机制,Transformer还需要Feed Forward网络和残差连接:

python 复制代码
from typing import Optional

class FeedForward(nn.Cell):
    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 activation: nn.Cell = nn.GELU,
                 keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(p=1.0-keep_prob)

    def construct(self, x):
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
        return x

class ResidualCell(nn.Cell):
    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell

    def construct(self, x):
        return self.cell(x) + x

残差连接很简单,就是把输入直接加到输出上,这样能避免深层网络的梯度消失问题。

2.4 TransformerEncoder的完整实现

把注意力机制、Feed Forward和残差连接组合起来,就是TransformerEncoder:

python 复制代码
class TransformerEncoder(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_layers: int,
                 num_heads: int,
                 mlp_dim: int,
                 keep_prob: float = 1.,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []

        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)

            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)

            layers.append(
                nn.SequentialCell([
                    ResidualCell(nn.SequentialCell([normalization1, attention])),
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))
                ])
            )
        self.layers = nn.SequentialCell(layers)

    def construct(self, x):
        return self.layers(x)

这里有个细节:ViT把LayerNorm放在了注意力和Feed Forward之前,这和标准Transformer不太一样,但实验证明这样效果更好。

3 ViT的关键创新:图像转序列

上图展示了ViT处理图像的完整流程:从原始图像分割成patches,经过embedding转换,添加位置编码和CLS token,通过Transformer编码器处理,最后提取CLS token进行分类预测。

3.1 Patch Embedding

ViT最巧妙的地方就是把图像转换成序列。具体做法是把图像切成一个个小块(patch),然后把每个patch拉成一维向量:

python 复制代码
class PatchEmbedding(nn.Cell):
    MIN_NUM_PATCHES = 4

    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 input_channels: int = 3):
        super(PatchEmbedding, self).__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(input_channels, embed_dim, 
                             kernel_size=patch_size, stride=patch_size, has_bias=True)

    def construct(self, x):
        x = self.conv(x)
        b, c, h, w = x.shape
        x = ops.reshape(x, (b, c, h * w))
        x = ops.transpose(x, (0, 2, 1))
        return x

这里用卷积来实现patch切分,比手工切分更高效。对于224×224的图像,用16×16的patch,最终得到14×14=196个patch。

3.2 位置编码和分类token

图像切成patch后,还需要加上位置信息和分类token:

python 复制代码
# 在ViT类的__init__中
self.cls_token = init(init_type=Normal(sigma=1.0),
                      shape=(1, 1, embed_dim),
                      dtype=ms.float32,
                      name='cls',
                      requires_grad=True)

self.pos_embedding = init(init_type=Normal(sigma=1.0),
                          shape=(1, num_patches + 1, embed_dim),
                          dtype=ms.float32,
                          name='pos_embedding',
                          requires_grad=True)

分类token借鉴了BERT的思路,在序列开头加一个特殊token,最后用这个token的输出来做分类。位置编码则告诉模型每个patch在图像中的位置。

3.3 完整的ViT模型

把所有组件组合起来,就是完整的ViT模型:

python 复制代码
from mindspore.common.initializer import Normal, initializer
from mindspore import Parameter

def init(init_type, shape, dtype, name, requires_grad):
    initial = initializer(init_type, shape, dtype).init_data()
    return Parameter(initial, name=name, requires_grad=requires_grad)

class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()

        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches

        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)

        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)

        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)

    def construct(self, x):
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding

        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]  # 取分类token的输出
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)

        return x

整个流程就是:图像 → patch embedding → 加上cls token和位置编码 → Transformer编码器 → 分类头。

4 训练和验证实战

4.1 训练配置

训练前需要设置损失函数、优化器等:

python 复制代码
from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train

# 超参数设置
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()

# 构建模型
network = ViT()

# 加载预训练权重
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

# 学习率调度
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.00005,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=10)

# 优化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)

这里用了预训练模型,所以学习率设得比较小。余弦退火调度能让训练更稳定。

4.2 损失函数

用了带标签平滑的交叉熵损失:

python 复制代码
class CrossEntropySmooth(LossBase):
    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss

network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)

标签平滑能防止模型过拟合,提高泛化能力。

4.3 开始训练

python 复制代码
# 设置检查点
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)

# 初始化模型
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics={"acc"}, amp_level="O0")

# 开始训练
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
            dataset_sink_mode=False)

训练过程中会看到这样的输出:

复制代码
epoch: 1 step: 125, loss is 1.903618335723877
Train epoch time: 99857.517 ms, per step time: 798.860 ms
epoch: 2 step: 125, loss is 1.448015570640564
Train epoch time: 95555.111 ms, per step time: 764.441 ms

loss在逐渐下降,说明训练正常进行。

上图展示了ViT模型的训练过程:左侧是损失函数的下降趋势,右侧是准确率的提升曲线,下方表格总结了训练配置和最终结果。可以看到模型在训练过程中稳定收敛,最终达到了不错的性能。

4.4 模型验证

训练完后验证一下效果:

python 复制代码
# 验证数据预处理
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)

trans_val = [
    transforms.Decode(),
    transforms.Resize(224 + 32),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)

# 评估指标
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
                'Top_5_Accuracy': train.Top5CategoricalAccuracy()}

if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics=eval_metrics, amp_level="O0")

# 开始验证
result = model.eval(dataset_val)
print(result)

结果显示:

复制代码
{'Top_1_Accuracy': 0.75, 'Top_5_Accuracy': 0.928}

Top-1准确率75%,Top-5准确率92.8%,效果还不错。

5 推理测试

5.1 推理数据准备

python 复制代码
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)

trans_infer = [
    transforms.Decode(),
    transforms.Resize([224, 224]),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_infer = dataset_infer.map(operations=trans_infer,
                                  input_columns=["image"],
                                  num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

5.2 推理和结果可视化

python 复制代码
import cv2
import numpy as np
from PIL import Image
from scipy import io

def index2label():
    """获取ImageNet类别标签"""
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']
    
    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
    
    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
    
    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping

# 推理
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)
    mapping = index2label()
    output = {int(label): mapping[int(label)]}
    print(output)

推理结果:

复制代码
{236: 'Doberman'}

模型正确识别出了杜宾犬,说明推理效果不错。

6 总结和思考

6.1 ViT的优势

通过这次实践,感受到ViT的几个优势:

  1. 架构简洁:相比CNN的复杂卷积层设计,ViT的架构更加统一和简洁
  2. 可扩展性强:Transformer的并行计算能力让模型可以轻松扩展到更大规模
  3. 迁移能力好:在大数据集上预训练后,可以很好地迁移到下游任务

6.2 实践中的坑

  1. 计算资源要求高:ViT对GPU内存要求比较大,batch size不能设太大
  2. 需要大量数据:相比CNN,ViT更依赖大规模预训练数据
  3. 位置编码很重要:去掉位置编码后性能会明显下降

6.3 代码实现的亮点

MindSpore的实现有几个不错的地方:

  1. 模块化设计:每个组件都封装得很好,便于理解和修改
  2. 自动混合精度:通过amp_level参数可以轻松开启混合精度训练
  3. 灵活的数据处理:数据预处理管道设计得很灵活

整个跑通过程还是比较顺利的,代码质量不错,注释也比较清楚。对于想了解ViT原理和实现的同学来说,这个教程是个不错的起点。

当然,要真正掌握ViT,还需要多读论文,多做实验。这次只是个开始,后面可以尝试在自己的数据集上微调,或者实现一些ViT的变种模型。

相关推荐
见山是山-见水是水26 分钟前
鸿蒙flutter第三方库适配 - 读书笔记
flutter·华为·harmonyos
Utopia^1 小时前
鸿蒙flutter第三方库适配 - 图片压缩工具
flutter·华为·harmonyos
SoraLuna3 小时前
「鸿蒙智能体实战记录 11」年俗文化展示卡片开发与多段内容结构化呈现实现
华为·harmonyos
网教盟人才服务平台3 小时前
“方班预备班盾立方人才培养计划”正式启动!
大数据·人工智能
芯智工坊3 小时前
第15章 Mosquitto生产环境部署实践
人工智能·mqtt·开源
菜菜艾3 小时前
基于llama.cpp部署私有大模型
linux·运维·服务器·人工智能·ai·云计算·ai编程
TDengine (老段)3 小时前
TDengine IDMP 可视化 —— 分享
大数据·数据库·人工智能·时序数据库·tdengine·涛思数据·时序数据
小真zzz3 小时前
搜极星:第三方多平台中立GEO洞察专家全面解析
人工智能·搜索引擎·seo·geo·中立·第三方平台
GreenTea4 小时前
从 Claw-Code 看 AI 驱动的大型项目开发:2 人 + 10 个自治 Agent 如何产出 48K 行 Rust 代码
前端·人工智能·后端
火山引擎开发者社区4 小时前
秒级创建实例,火山引擎 Milvus Serverless 让 AI Agent 开发更快更省
人工智能