本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!
一、前言
长久以来,CNN一直是CV领域最受欢迎的网络,在NLP也有CNN的一席之地。限于CNN上下文的能力,RNN系列网络在长文本任务中要比CNN更受欢迎,但是RNN系列网络也一直存在性能问题。2017年的一篇论文《Attention is all you need》提出了Transformer架构,Transformer的出现打破了RNN的绝对优势,Transformer在NLP领域取得了不菲的成绩。
Transformer出来不久,就有许多关于Transformer与CNN相关的讨论。本文我们要讨论的就是Transformer在CV领域的应用,我们要实现论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》中提出的Vit(Vision Transformer)的网络。
二、Transformer
在开始介绍Vit之前,我们来简单看看Transformer网络。其结构如下图:

在Vit中我们只需要用到Encoder部分,此时和之前情感分析的例子,详见:基于Transformer的文本情感分析。此时Transformer可以分为三个部分:Input、Attention、分类网络,我们一一介绍。
2.1 Input
在Input的部分,原始的Transformer需要输入词的id,然后经过Embedding层,并加上Positional Encoding。此过程可以描述成下面的伪代码:
python
embedded = embedding(idxes)
embedded = embedded + positioal_encoding
而图像的处理则和句子不一样,在Vit中,图像会被分为多个patch,每个patch会被看作是原本的一个token,如下图所示:

而Positional Encoding也被替换成了Positional Embedding。原本的Positional Encoding是由三角函数生成的固定值,而Positional Embedding则是和普通Embedding类似的一种可供学习的位置嵌入,只不过在Positional Embedding中输入的id变为的固定值(第一个patch的位置id为0,第二个为1,...)。
另外,在原Transformer中是Embedding结果与Positional Encoding相加,而在Vit中是patch经过线性层后与Positional Embedding相加。
为了方便后续使用,我们可以封装一个PatchEncoder,用于处理输入(分patch并加上位置嵌入)。代码如下:
python
class Patches(layers.Layer):
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
在后续构建网络时需要使用到PatchEncoder。
2.2 Attention
Vit中的self-attention与Transformer中的self-attention是一样的,这里不详细赘述,其代码实现如下:
python
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
# 输入图片信息
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# 计算self-attention
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# 残差连接
x2 = layers.Add()([attention_output, encoded_patches])
# layernorm
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# 线性层
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# 残差连接
encoded_patches = layers.Add()([x3, x2])
在实际训练时,这部分会重复多次。在self-attention中,会提取图片的关键信息,将注意力集中在对分类起决定作用的图像区域。
2.3 分类网络
分类网络用来做最后的分类工作,由全连接和残差连接构成,其结构与attention部分非常类似,具体代码如下:
python
# 展开
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
# 全连接
features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
# 分类
logits = layers.Dense(num_classes)(features)
logits的结果就是类别分数,后续会与真实标签计算crossentropy。
本文假定读者已经对Transformer有些许了解,因此省去部分细节。下面就使用Vit网络来完成一个实际的任务。
三、使用Vit网络进行图像分类
本文使用cifar100数据集训练Vit网络,首先导入需要用的模块:
python
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
3.1 准备数据
这里使用keras直接加载cifar100的数据:
python
num_classes = 100
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
输出结果如下:
python
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)
共计60000张图片。
3.2 配置超参数
下面配置训练需要用到的一些参数:
ini
learning_rate = 0.001
# 权重衰减系数
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72
# 单个patch的尺寸
patch_size = 6
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
# attention的神经元数量
transformer_units = [
projection_dim * 2,
projection_dim,
]
transformer_layers = 8
# 分类网络的神经元数量
mlp_head_units = [2048, 1024]
3.3 数据增强
为了提高泛化能力,可以添加数据增强的操作,代码如下:
python
# 数据增强层
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.Resizing(image_size, image_size),
layers.RandomFlip("horizontal"),
layers.RandomRotation(factor=0.02),
layers.RandomZoom(
height_factor=0.2, width_factor=0.2
),
],
name="data_augmentation",
)
# 对训练数据进行数据正确
data_augmentation.layers[0].adapt(x_train)
3.4 构建Vit模型
下面我们使用前面三部分的代码创建Vit模型,代码如下:
python
def create_vit_classifier():
inputs = layers.Input(shape=input_shape)
augmented = data_augmentation(inputs)
patches = Patches(patch_size)(augmented)
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
# 重复多次attention
for _ in range(transformer_layers):
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
x2 = layers.Add()([attention_output, encoded_patches])
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
encoded_patches = layers.Add()([x3, x2])
# 全连接
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
logits = layers.Dense(num_classes)(features)
model = keras.Model(inputs=inputs, outputs=logits)
return model
3.5 训练Vit
训练的代码非常简单,代码如下:
python
vit = create_vit_classifier()
vit.compile(
'adam',
# 因为模型输出的结果没有经过softmax,因此需要设置参数from_logits=True
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['acc']
)
vit.fit(
x_train, y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_data=[x_test, y_test]
)
训练完成后准确率在70%左右,这比CNN表现稍差。
四、总结
在使用Vit实验后,发现结果并没有比CNN好。那么是不是Transformer就不适合应用在CV领域呢?答案是否定的。相比传统的CNN,vit的参数量更大,训练时间也更长。在数据量比较小时,Transformer会欠拟合,此时CNN依旧是最佳选择。而数据量较大时,CNN将到达性能瓶颈,此时可以考虑使用Vit网络,或许可以得到更好的结果。
今天的内容就分享到这里,感兴趣的读者可以继续深入研究Vit模型。本文参考自:keras.io/examples/vi...