基于Transformer的图像分类网络Vit

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

一、前言

长久以来,CNN一直是CV领域最受欢迎的网络,在NLP也有CNN的一席之地。限于CNN上下文的能力,RNN系列网络在长文本任务中要比CNN更受欢迎,但是RNN系列网络也一直存在性能问题。2017年的一篇论文《Attention is all you need》提出了Transformer架构,Transformer的出现打破了RNN的绝对优势,Transformer在NLP领域取得了不菲的成绩。

Transformer出来不久,就有许多关于Transformer与CNN相关的讨论。本文我们要讨论的就是Transformer在CV领域的应用,我们要实现论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》中提出的Vit(Vision Transformer)的网络。

二、Transformer

在开始介绍Vit之前,我们来简单看看Transformer网络。其结构如下图:

在Vit中我们只需要用到Encoder部分,此时和之前情感分析的例子,详见:基于Transformer的文本情感分析。此时Transformer可以分为三个部分:Input、Attention、分类网络,我们一一介绍。

2.1 Input

在Input的部分,原始的Transformer需要输入词的id,然后经过Embedding层,并加上Positional Encoding。此过程可以描述成下面的伪代码:

python 复制代码
embedded = embedding(idxes)
embedded = embedded + positioal_encoding

而图像的处理则和句子不一样,在Vit中,图像会被分为多个patch,每个patch会被看作是原本的一个token,如下图所示:

而Positional Encoding也被替换成了Positional Embedding。原本的Positional Encoding是由三角函数生成的固定值,而Positional Embedding则是和普通Embedding类似的一种可供学习的位置嵌入,只不过在Positional Embedding中输入的id变为的固定值(第一个patch的位置id为0,第二个为1,...)。

另外,在原Transformer中是Embedding结果与Positional Encoding相加,而在Vit中是patch经过线性层后与Positional Embedding相加。

为了方便后续使用,我们可以封装一个PatchEncoder,用于处理输入(分patch并加上位置嵌入)。代码如下:

python 复制代码
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

在后续构建网络时需要使用到PatchEncoder。

2.2 Attention

Vit中的self-attention与Transformer中的self-attention是一样的,这里不详细赘述,其代码实现如下:

python 复制代码
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

# 输入图片信息
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# 计算self-attention
attention_output = layers.MultiHeadAttention(
    num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# 残差连接
x2 = layers.Add()([attention_output, encoded_patches])
# layernorm
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# 线性层
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# 残差连接
encoded_patches = layers.Add()([x3, x2])

在实际训练时,这部分会重复多次。在self-attention中,会提取图片的关键信息,将注意力集中在对分类起决定作用的图像区域。

2.3 分类网络

分类网络用来做最后的分类工作,由全连接和残差连接构成,其结构与attention部分非常类似,具体代码如下:

python 复制代码
# 展开
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
# 全连接
features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
# 分类
logits = layers.Dense(num_classes)(features)

logits的结果就是类别分数,后续会与真实标签计算crossentropy。

本文假定读者已经对Transformer有些许了解,因此省去部分细节。下面就使用Vit网络来完成一个实际的任务。

三、使用Vit网络进行图像分类

本文使用cifar100数据集训练Vit网络,首先导入需要用的模块:

python 复制代码
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

3.1 准备数据

这里使用keras直接加载cifar100的数据:

python 复制代码
num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

输出结果如下:

python 复制代码
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1) 
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)

共计60000张图片。

3.2 配置超参数

下面配置训练需要用到的一些参数:

ini 复制代码
learning_rate = 0.001
# 权重衰减系数
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72 
# 单个patch的尺寸
patch_size = 6
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
# attention的神经元数量
transformer_units = [
    projection_dim * 2,
    projection_dim,
]
transformer_layers = 8
# 分类网络的神经元数量
mlp_head_units = [2048, 1024]

3.3 数据增强

为了提高泛化能力,可以添加数据增强的操作,代码如下:

python 复制代码
# 数据增强层
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# 对训练数据进行数据正确
data_augmentation.layers[0].adapt(x_train)

3.4 构建Vit模型

下面我们使用前面三部分的代码创建Vit模型,代码如下:

python 复制代码
def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    augmented = data_augmentation(inputs)
    patches = Patches(patch_size)(augmented)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
    # 重复多次attention
    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        x2 = layers.Add()([attention_output, encoded_patches])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        encoded_patches = layers.Add()([x3, x2])
    # 全连接
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    logits = layers.Dense(num_classes)(features)
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

3.5 训练Vit

训练的代码非常简单,代码如下:

python 复制代码
vit = create_vit_classifier()
vit.compile(
    'adam', 
    # 因为模型输出的结果没有经过softmax,因此需要设置参数from_logits=True
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['acc']
)
vit.fit(
    x_train, y_train, 
    batch_size=batch_size,
    epochs=num_epochs,
    validation_data=[x_test, y_test]
)

训练完成后准确率在70%左右,这比CNN表现稍差。

四、总结

在使用Vit实验后,发现结果并没有比CNN好。那么是不是Transformer就不适合应用在CV领域呢?答案是否定的。相比传统的CNN,vit的参数量更大,训练时间也更长。在数据量比较小时,Transformer会欠拟合,此时CNN依旧是最佳选择。而数据量较大时,CNN将到达性能瓶颈,此时可以考虑使用Vit网络,或许可以得到更好的结果。

今天的内容就分享到这里,感兴趣的读者可以继续深入研究Vit模型。本文参考自:keras.io/examples/vi...

相关推荐
尼尔森系4 小时前
排序与算法:希尔排序
c语言·算法·排序算法
AC使者4 小时前
A. C05.L08.贪心算法入门
算法·贪心算法
冠位观测者4 小时前
【Leetcode 每日一题】624. 数组列表中的最大距离
数据结构·算法·leetcode
yadanuof5 小时前
leetcode hot100 滑动窗口&子串
算法·leetcode
可爱de艺艺5 小时前
Go入门之函数
算法
武乐乐~5 小时前
欢乐力扣:旋转图像
算法·leetcode·职场和发展
a_j586 小时前
算法与数据结构(子集)
数据结构·算法·leetcode
清水加冰6 小时前
【算法精练】背包问题(01背包问题)
c++·算法
慢一点会很快8 小时前
FRRouting配置与OSPF介绍,配置,命令,bfd算法:
算法·智能路由器·php·ospf
88号技师9 小时前
2024年中科院一区SCI-雪雁优化算法Snow Geese Algorithm-附Matlab免费代码
开发语言·人工智能·算法·matlab·优化算法