《轻量化 Transformers:开启计算机视觉新篇》

一、Compact Convolutional Transformers

ViT在小型数据集上性能不够好的问题,这个问题非常实际,现实情况下如果确实没有大量数据集,同时也没有合适的预训练模型需要从头训练的时候,ViT架构性能是不如CNN架构的。这篇文章实际上并没有引入大量的卷积操作,通过修改patch size,以及使用SeqPool的方法就可以取得不错的成绩。

CCT 的核心设计理念是将卷积神经网络强大的局部特征提取能力与 Transformer 卓越的全局建模能力深度融合,兼收并蓄两者之长。通过精心设计的卷积模块,CCT 能够高效地提取图像中的局部细节信息,精准捕捉图像中物体的边缘、纹理等特征;而 Transformer 模块则负责在更大的范围内建模,捕捉不同局部区域之间的长距离依赖关系,从而对图像的整体结构和语义有更全面、深入的理解。

核心贡献如下:

  • 通过引入ViT-Lite能够有效从头开始在小型数据集上实现更高精度,打破Transformer需要大量数据的神话。
  • 引入新型序列池化策略(sequence pooling)的CVT(Compact Vision Transformer),从而让Transformer无需class token
  • 引入CCT(Compact Convolutional Transformer)来提升模型性能,同时可以让图片输入尺寸更加灵活。

(1) ViT-Lite

(2) CVT

(3)CCT

与传统 ViT 中简单粗暴地将图像分割成均匀、非重叠的补丁不同,CCTTokenizer 引入了一个全卷积迷你网络.每一层卷积都像是在对图像进行一次深度的 "扫描",提取出不同层次的特征;零填充层巧妙地保持了图像的空间尺寸,确保信息的完整性;而最大池化层则在降低数据维度的同时,突出了图像中的关键特征。

CCT 模型构建与训练

(一)数据预处理与增强

在 CCT 模型的训练征程中,数据预处理与增强是至关重要的起始步骤,如同精心准备食材是烹饪出美味佳肴的基础。本次实验选用的 CIFAR - 10 数据集,宛如一座丰富的图像宝库,其中包含了 10 个不同类别的 60000 张彩色图像,每类图像各有 6000 张,这些图像如同繁星般璀璨,涵盖了飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车等众多物体,为模型的训练提供了丰富多样的样本。

(二)模型的搭建(keras)

python 复制代码
import tensorflow as tf
from keras import layers
import keras

import matplotlib.pyplot as plt
import numpy as np

'''导入需要的模块'''

'''超参数设置'''
positional_emb = True
conv_layers = 2
projection_dim = 128

num_heads = 2
transformer_units = [
    projection_dim,
    projection_dim,
]
transformer_layers = 2
stochastic_depth_rate = 0.1

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 128
num_epochs = 30
image_size = 32

'''数据集下载---cifar10'''
num_classes = 10
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

'''CCT token'''


class CCTTokenizer(layers.Layer):
    def __init__(
            self,
            kernel_size=3,
            stride=1,
            padding=1,
            pooling_kernel_size=3,
            pooling_stride=2,
            num_conv_layers=conv_layers,
            num_output_channels=[64, 128],
            positional_emb=positional_emb,
            **kwargs,
    ):
        super().__init__(**kwargs)

        # This is our tokenizer.
        self.conv_model = keras.Sequential()
        for i in range(num_conv_layers):
            self.conv_model.add(
                layers.Conv2D(
                    num_output_channels[i],
                    kernel_size,
                    stride,
                    padding="valid",
                    use_bias=False,
                    activation="relu",
                    kernel_initializer="he_normal",
                )
            )
            self.conv_model.add(layers.ZeroPadding2D(padding))
            self.conv_model.add(
                layers.MaxPooling2D(pooling_kernel_size, pooling_stride, "same")
            )

        self.positional_emb = positional_emb

    def call(self, images):
        outputs = self.conv_model(images)
        # After passing the images through our mini-network the spatial dimensions
        # are flattened to form sequences.
        batch_size = tf.shape(outputs)[0]
        h = tf.shape(outputs)[1]
        w = tf.shape(outputs)[2]
        c = tf.shape(outputs)[3]

        reshaped = tf.reshape(
            outputs,
            (batch_size, h * w, c)
        )
        return reshaped


'''位置编码'''


class PositionEmbedding(keras.layers.Layer):
    def __init__(
            self,
            sequence_length,
            initializer="glorot_uniform",
            **kwargs,
    ):
        super().__init__(**kwargs)
        if sequence_length is None:
            raise ValueError("`sequence_length` must be an Integer, received `None`.")
        self.sequence_length = int(sequence_length)
        self.initializer = keras.initializers.get(initializer)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "sequence_length": self.sequence_length,
                "initializer": keras.initializers.serialize(self.initializer),
            }
        )
        return config

    def build(self, input_shape):
        feature_size = input_shape[-1]
        self.position_embeddings = self.add_weight(
            name="embeddings",
            shape=[self.sequence_length, feature_size],
            initializer=self.initializer,
            trainable=True,
        )

        super().build(input_shape)

    def call(self, inputs, start_index=0):
        shape = tf.shape(inputs)
        feature_length = shape[-1]
        sequence_length = shape[-2]
        # trim to match the length of the input sequence, which might be less
        # than the sequence_length of the layer.
        position_embeddings = tf.convert_to_tensor(self.position_embeddings)
        position_embeddings = tf.slice(
            position_embeddings,
            [start_index, 0],
            [sequence_length, feature_length],
        )
        return tf.broadcast_to(position_embeddings, shape)

    def compute_output_shape(self, input_shape):
        return input_shape


class SequencePooling(layers.Layer):
    def __init__(self):
        super().__init__()
        self.attention = layers.Dense(1)

    def call(self, x):
        attention_weights = tf.nn.softmax(self.attention(x), axis=1)
        attention_weights = tf.transpose(attention_weights, perm=(0, 2, 1))
        weighted_representation = tf.matmul(attention_weights, x)
        return tf.squeeze(weighted_representation, -2)


class StochasticDepth(layers.Layer):
    def __init__(self, drop_prop, **kwargs):
        super().__init__(**kwargs)
        self.drop_prob = drop_prop

    def call(self, x, training=None):
        if training:
            keep_prob = 1 - self.drop_prob
            shape = (tf.shape(x)[0],) + (1,) * (len(x.shape) - 1)
            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
            random_tensor = tf.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x


def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


data_augmentation = keras.Sequential(
    [
        layers.Rescaling(scale=1.0 / 255),
        layers.RandomCrop(image_size, image_size),
        layers.RandomFlip("horizontal"),
    ],
    name="data_augmentation",
)


def create_cct_model(
        image_size=image_size,
        input_shape=input_shape,
        num_heads=num_heads,
        projection_dim=projection_dim,
        transformer_units=transformer_units,
):
    inputs = layers.Input(input_shape)

    # Augment data.
    augmented = data_augmentation(inputs)

    # Encode patches.
    cct_tokenizer = CCTTokenizer()
    encoded_patches = cct_tokenizer(augmented)

    # Apply positional embedding.
    if positional_emb:
        sequence_length = encoded_patches.shape[1]
        encoded_patches += PositionEmbedding(sequence_length=sequence_length)(
            encoded_patches
        )

    # Calculate Stochastic Depth probabilities.
    dpr = [x for x in np.linspace(0, stochastic_depth_rate, transformer_layers)]

    # Create multiple layers of the Transformer block.
    for i in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)

        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)

        # Skip connection 1.
        attention_output = StochasticDepth(dpr[i])(attention_output)
        x2 = layers.Add()([attention_output, encoded_patches])

        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-5)(x2)

        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)

        # Skip connection 2.
        x3 = StochasticDepth(dpr[i])(x3)
        encoded_patches = layers.Add()([x3, x2])

    # Apply sequence pooling.
    representation = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
    weighted_representation = SequencePooling()(representation)

    # Classify outputs.
    logits = layers.Dense(num_classes)(weighted_representation)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model


def run_experiment(model):
    optimizer = keras.optimizers.AdamW(learning_rate=0.001, weight_decay=0.0001)

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.CategoricalCrossentropy(
            from_logits=True, label_smoothing=0.1
        ),
        metrics=[
            keras.metrics.CategoricalAccuracy(name="accuracy"),
            keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/tmp/checkpoint.weights.h5"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


cct_model = create_cct_model()
history = run_experiment(cct_model)

模型性能

根据 中的实现和注释,CCT 模型在 CIFAR-10 数据集上展现出了高效的性能:

参数量:约 40 万参数,仅为标准 ViT 的 9%

准确率:在 30 个 epoch 训练后达到约 79% 的 top-1 准确率

训练速度:由于参数量少,训练速度快,适合在有限计算资源上运行

需要注意的是,CCT 模型的设计重点是参数效率,而不是绝对性能。在资源受限的环境中,CCT 提供了很好的性能和资源消耗平衡。

参考

https://arxiv.org/abs/2104.05704

https://github.com/SHI-Labs/Compact-Transformers

https://zhuanlan.zhihu.com/p/364589899


二 MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer

论文地址:https://arxiv.org/pdf/2110.02178

它可以有效地将局部和全局信息进行编码。与ViT及其变体不同,MobileViT从不同的角度学习全局表示。标准卷积涉及三个操作:展开(unfloading)局部处理(local processing)展开(folding) 。MobileViT块使用Transformer将卷积中的局部建模替换为全局建模。这使得MobileViT块具有CNN和ViT的性质,有助于它用更少的参数和简单的训练方式学习更好的表示。

相关推荐
搞科研的小刘选手4 小时前
【同济大学主办】第十一届能源资源与环境工程研究进展国际学术会议(ICAESEE 2025)
大数据·人工智能·能源·材质·材料工程·地理信息
MARS_AI_4 小时前
云蝠智能 VoiceAgent 2.0:全栈语音交互能力升级
人工智能·自然语言处理·交互·信息与通信·agi
top_designer4 小时前
Substance 3D Stager:电商“虚拟摄影”工作流
人工智能·3d·设计模式·prompt·技术美术·教育电商·游戏美术
雷神大青椒4 小时前
离别的十字路口: 是否还记得曾经追求的梦想
人工智能·程序人生·职场和发展·玩游戏
m0_650108245 小时前
多模态大模型 VS. 图像视频生成模型浅析
人工智能·技术边界与协同·mllm与生成模型·技术浅谈
ai_xiaogui5 小时前
Mac苹果版Krita AI一键安装教程:AIStarter+ComfyUI零基础部署全流程(X86/ARM双架构)
arm开发·人工智能·macos·comfyui·一键部署·ai绘画教程·kritaai
lapiii3586 小时前
[智能体设计模式] 第11章:目标设定与监控模式
人工智能·设计模式
这张生成的图像能检测吗6 小时前
(论文速读)WFF-Net:用于表面缺陷检测的可训练权重特征融合卷积神经网络
人工智能·深度学习·神经网络·缺陷检测·图像分割
shayudiandian6 小时前
RNN与LSTM详解:AI是如何“记住”信息的?
人工智能·rnn·lstm