Keras 3 神经网络紧凑型卷积转换器(Transformers)
作者: Sayak Paul
创建日期: 2021/06/30
最后修改时间: 2023/08/07
(i) 此示例使用 Keras 3
正如在视觉Transformer(ViT)论文中所讨论的,基于Transformer的视觉架构通常需要比常规更大的数据集,以及更长的预训练时间。对于ViT而言,ImageNet - 1k(包含约一百万张图像)被认为属于中等规模的数据范畴。这主要是因为,与卷积神经网络(CNNs)不同,ViT(或典型的基于Transformer的架构)不具备充分的归纳偏置(比如用于处理图像的卷积操作)。这就引出了一个问题:我们能否将卷积的优势与Transformer的优势结合到单一的网络架构中呢?这些优势包括参数效率,以及能够处理长距离和全局依赖关系(图像中不同区域之间的相互作用)的自注意力机制 。
在《用紧凑Transformer逃离大数据范式》一文中,哈萨尼等人提出了一种正是实现上述想法的方法。他们提出了紧凑卷积Transformer(CCT)架构。在本示例中,我们将对CCT进行实现,并观察它在CIFAR - 10数据集上的表现如何。
如果你不熟悉自我注意或 Transformer 的概念,你可以阅读 François Chollet 的书 Deep Learning with Python 中的这一章。此示例使用 来自另一个示例的代码片段 Image classification with Vision Transformer.
from` `keras` `import` `layers`
`import` `keras`
`import` `matplotlib.pyplot` `as` `plt`
`import` `numpy` `as` `np`
positional_emb` `=` `True`
`conv_layers` `=` `2`
`projection_dim` `=` `128`
`num_heads` `=` `2`
`transformer_units` `=` `[`
`transformer_layers` `=` `2`
`stochastic_depth_rate` `=` `0.1`
`learning_rate` `=` `0.001`
`weight_decay` `=` `0.0001`
`batch_size` `=` `128`
`num_epochs` `=` `30`
`image_size` `=` `32`
加载 CIFAR-10 数据集
num_classes` `=` `10`
`input_shape` `=` `(32,` `32,` `3)`
`(x_train,` `y_train),` `(x_test,` `y_test)` `=` `keras.datasets.cifar10.load_data()`
`y_train` `=` `keras.utils.to_categorical(y_train,` `num_classes)`
`y_test` `=` `keras.utils.to_categorical(y_test,` `num_classes)`
`print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")`
`print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")`
`x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 10) x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 10) `
CCT 分词器
CCT 作者介绍的第一个配方是用于处理 图像。在标准 ViT 中,图像被组织成均匀的非重叠块。 这消除了不同补丁之间存在的边界级信息。这 对于神经网络有效利用局部信息非常重要。这 下图展示了如何将图像组织成补丁。
我们已经知道卷积非常擅长利用位置信息。所以 基于此,作者引入了一种全卷积微型网络来生成图像 补丁。
class` `CCTTokenizer(layers.Layer):`
`def` `__init__(`
`num_output_channels=[64,` `128],`
`# This is our tokenizer.`
`self.conv_model` `=` `keras.Sequential()`
`for` `i` `in` `range(num_conv_layers):`
`layers.MaxPooling2D(pooling_kernel_size,` `pooling_stride,` `"same")`
`self.positional_emb` `=` `positional_emb`
`def` `call(self,` `images):`
`outputs` `=` `self.conv_model(images)`
`# After passing the images through our mini-network the spatial dimensions`
`# are flattened to form sequences.`
`reshaped` `=` `keras.ops.reshape(`
`keras.ops.shape(outputs)[1]` `*` `keras.ops.shape(outputs)[2],`
`return` `reshaped`
位置嵌入在 CCT 中是可选的。如果我们想使用它们,我们可以使用 下面定义的 Layer。
class` `PositionEmbedding(keras.layers.Layer):`
`def` `__init__(`
`if` `sequence_length` `is` `None:`
`raise` `ValueError("`sequence_length` must be an Integer, received `None`.")`
`self.sequence_length` `=` `int(sequence_length)`
`self.initializer` `=` `keras.initializers.get(initializer)`
`def` `get_config(self):`
`config` `=` `super().get_config()`
`"sequence_length":` `self.sequence_length,`
`"initializer":` `keras.initializers.serialize(self.initializer),`
`return` `config`
`def` `build(self,` `input_shape):`
`feature_size` `=` `input_shape[-1]`
`self.position_embeddings` `=` `self.add_weight(`
`shape=[self.sequence_length,` `feature_size],`
`def` `call(self,` `inputs,` `start_index=0):`
`shape` `=` `keras.ops.shape(inputs)`
`feature_length` `=` `shape[-1]`
`sequence_length` `=` `shape[-2]`
`# trim to match the length of the input sequence, which might be less`
`# than the sequence_length of the layer.`
`position_embeddings` `=` `keras.ops.convert_to_tensor(self.position_embeddings)`
`position_embeddings` `=` `keras.ops.slice(`
`(start_index,` `0),`
`(sequence_length,` `feature_length),`
`return` `keras.ops.broadcast_to(position_embeddings,` `shape)`
`def` `compute_output_shape(self,` `input_shape):`
`return` `input_shape`
CCT 中引入的另一个方法是注意力池或序列池。在 ViT 中,只有 与类 Token 对应的特征映射被池化,然后用于 后续分类任务(或任何其他下游任务)。
class` `SequencePooling(layers.Layer):`
`def` `__init__(self):`
`self.attention` `=` `layers.Dense(1)`
`def` `call(self,` `x):`
`attention_weights` `=` `keras.ops.softmax(self.attention(x),` `axis=1)`
`attention_weights` `=` `keras.ops.transpose(attention_weights,` `axes=(0,` `2,` `1))`
`weighted_representation` `=` `keras.ops.matmul(attention_weights,` `x)`
`return` `keras.ops.squeeze(weighted_representation,` `-2)`
随机深度是一种正则化技术,它 随机放置一组图层。在推理过程中,各层保持原样。是的 与 Dropout 非常相似,但仅 它在层块上运行,而不是在 层。在 CCT 中,随机深度在 Transformer 的残差块之前使用 编码器。
# Referred from:`
`class` `StochasticDepth(layers.Layer):`
`def` `__init__(self,` `drop_prop,` `**kwargs):`
`self.drop_prob` `=` `drop_prop`
`self.seed_generator` `=` `keras.random.SeedGenerator(1337)`
`def` `call(self,` `x,` `training=None):`
`if` `training:`
`keep_prob` `=` `1` `-` `self.drop_prob`
`shape` `=` `(keras.ops.shape(x)[0],)` `+` `(1,)` `*` `(len(x.shape)` `-` `1)`
`random_tensor` `=` `keep_prob` `+` `keras.random.uniform(`
`shape,` `0,` `1,` `seed=self.seed_generator`
`random_tensor` `=` `keras.ops.floor(random_tensor)`
`return` `(x` `/` `keep_prob)` `*` `random_tensor`
`return` `x`
用于 Transformers 编码器的 MLP
def` `mlp(x,` `hidden_units,` `dropout_rate):`
`for` `units` `in` `hidden_units:`
`x` `=` `layers.Dense(units,` `activation=keras.ops.gelu)(x)`
`x` `=` `layers.Dropout(dropout_rate)(x)`
`return` `x`
在原始论文中,作者使用 AutoAugment 来诱导更强的正则化。为 在这个例子中,我们将使用标准的几何增强,如随机裁剪 和翻转。
# Note the rescaling layer. These layers have pre-defined inference behavior.`
`data_augmentation` `=` `keras.Sequential(`
`layers.Rescaling(scale=1.0` `/` `255),`
`layers.RandomCrop(image_size,` `image_size),`
最终的 CCT 模型
在 CCT 中,来自 Transformers 编码器的输出被加权,然后传递到最终的任务特定层(在 这个例子,我们进行分类)。
def` `create_cct_model(`
`inputs` `=` `layers.Input(input_shape)`
`# Augment data.`
`augmented` `=` `data_augmentation(inputs)`
`# Encode patches.`
`cct_tokenizer` `=` `CCTTokenizer()`
`encoded_patches` `=` `cct_tokenizer(augmented)`
`# Apply positional embedding.`
`if` `positional_emb:`
`sequence_length` `=` `encoded_patches.shape[1]`
`encoded_patches` `+=` `PositionEmbedding(sequence_length=sequence_length)(`
`# Calculate Stochastic Depth probabilities.`
`dpr` `=` `[x` `for` `x` `in` `np.linspace(0,` `stochastic_depth_rate,` `transformer_layers)]`
`# Create multiple layers of the Transformer block.`
`for` `i` `in` `range(transformer_layers):`
`# Layer normalization 1.`
`x1` `=` `layers.LayerNormalization(epsilon=1e-5)(encoded_patches)`
`# Create a multi-head attention layer.`
`attention_output` `=` `layers.MultiHeadAttention(`
`num_heads=num_heads,` `key_dim=projection_dim,` `dropout=0.1`
`)(x1,` `x1)`
`# Skip connection 1.`
`attention_output` `=` `StochasticDepth(dpr[i])(attention_output)`
`x2` `=` `layers.Add()([attention_output,` `encoded_patches])`
`# Layer normalization 2.`
`x3` `=` `layers.LayerNormalization(epsilon=1e-5)(x2)`
`# MLP.`
`x3` `=` `mlp(x3,` `hidden_units=transformer_units,` `dropout_rate=0.1)`
`# Skip connection 2.`
`x3` `=` `StochasticDepth(dpr[i])(x3)`
`encoded_patches` `=` `layers.Add()([x3,` `x2])`
`# Apply sequence pooling.`
`representation` `=` `layers.LayerNormalization(epsilon=1e-5)(encoded_patches)`
`weighted_representation` `=` `SequencePooling()(representation)`
`# Classify outputs.`
`logits` `=` `layers.Dense(num_classes)(weighted_representation)`
`# Create the Keras model.`
`model` `=` `keras.Model(inputs=inputs,` `outputs=logits)`
`return` `model`
def` `run_experiment(model):`
`optimizer` `=` `keras.optimizers.AdamW(learning_rate=0.001,` `weight_decay=0.0001)`
`from_logits=True,` `label_smoothing=0.1`
`keras.metrics.TopKCategoricalAccuracy(5,` `name="top-5-accuracy"),`
`checkpoint_filepath` `=` `"/tmp/checkpoint.weights.h5"`
`checkpoint_callback` `=` `keras.callbacks.ModelCheckpoint(`
`history` `=` ``
`_,` `accuracy,` `top_5_accuracy` `=` `model.evaluate(x_test,` `y_test)`
`print(f"Test accuracy: {round(accuracy` `*` `100,` `2)}%")`
`print(f"Test top 5 accuracy: {round(top_5_accuracy` `*` `100,` `2)}%")`
`return` `history`
`cct_model` `=` `create_cct_model()`
`history` `=` `run_experiment(cct_model)`
`Epoch 1/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 90s 248ms/step - accuracy: 0.2578 - loss: 2.0882 - top-5-accuracy: 0.7553 - val_accuracy: 0.4438 - val_loss: 1.6872 - val_top-5-accuracy: 0.9046 Epoch 2/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 91s 258ms/step - accuracy: 0.4779 - loss: 1.6074 - top-5-accuracy: 0.9261 - val_accuracy: 0.5730 - val_loss: 1.4462 - val_top-5-accuracy: 0.9562 Epoch 3/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 92s 260ms/step - accuracy: 0.5655 - loss: 1.4371 - top-5-accuracy: 0.9501 - val_accuracy: 0.6178 - val_loss: 1.3458 - val_top-5-accuracy: 0.9626 Epoch 4/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 92s 261ms/step - accuracy: 0.6166 - loss: 1.3343 - top-5-accuracy: 0.9613 - val_accuracy: 0.6610 - val_loss: 1.2695 - val_top-5-accuracy: 0.9706 Epoch 5/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 92s 261ms/step - accuracy: 0.6468 - loss: 1.2814 - top-5-accuracy: 0.9672 - val_accuracy: 0.6834 - val_loss: 1.2231 - val_top-5-accuracy: 0.9716 Epoch 6/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 92s 261ms/step - accuracy: 0.6619 - loss: 1.2412 - top-5-accuracy: 0.9708 - val_accuracy: 0.6842 - val_loss: 1.2018 - val_top-5-accuracy: 0.9744 Epoch 7/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 93s 263ms/step - accuracy: 0.6976 - loss: 1.1775 - top-5-accuracy: 0.9752 - val_accuracy: 0.6988 - val_loss: 1.1988 - val_top-5-accuracy: 0.9752 Epoch 8/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 93s 263ms/step - accuracy: 0.7070 - loss: 1.1579 - top-5-accuracy: 0.9774 - val_accuracy: 0.7010 - val_loss: 1.1780 - val_top-5-accuracy: 0.9732 Epoch 9/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 95s 269ms/step - accuracy: 0.7219 - loss: 1.1255 - top-5-accuracy: 0.9795 - val_accuracy: 0.7166 - val_loss: 1.1375 - val_top-5-accuracy: 0.9784 Epoch 10/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 93s 264ms/step - accuracy: 0.7273 - loss: 1.1087 - top-5-accuracy: 0.9801 - val_accuracy: 0.7258 - val_loss: 1.1286 - val_top-5-accuracy: 0.9814 Epoch 11/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 93s 265ms/step - accuracy: 0.7361 - loss: 1.0863 - top-5-accuracy: 0.9828 - val_accuracy: 0.7222 - val_loss: 1.1412 - val_top-5-accuracy: 0.9766 Epoch 12/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 93s 264ms/step - accuracy: 0.7504 - loss: 1.0644 - top-5-accuracy: 0.9834 - val_accuracy: 0.7418 - val_loss: 1.0943 - val_top-5-accuracy: 0.9812 Epoch 13/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 94s 266ms/step - accuracy: 0.7593 - loss: 1.0422 - top-5-accuracy: 0.9856 - val_accuracy: 0.7468 - val_loss: 1.0834 - val_top-5-accuracy: 0.9818 Epoch 14/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 93s 265ms/step - accuracy: 0.7647 - loss: 1.0307 - top-5-accuracy: 0.9868 - val_accuracy: 0.7526 - val_loss: 1.0863 - val_top-5-accuracy: 0.9822 Epoch 15/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 93s 263ms/step - accuracy: 0.7684 - loss: 1.0231 - top-5-accuracy: 0.9863 - val_accuracy: 0.7666 - val_loss: 1.0454 - val_top-5-accuracy: 0.9834 Epoch 16/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 94s 268ms/step - accuracy: 0.7809 - loss: 1.0007 - top-5-accuracy: 0.9859 - val_accuracy: 0.7670 - val_loss: 1.0469 - val_top-5-accuracy: 0.9838 Epoch 17/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 94s 268ms/step - accuracy: 0.7902 - loss: 0.9795 - top-5-accuracy: 0.9895 - val_accuracy: 0.7676 - val_loss: 1.0396 - val_top-5-accuracy: 0.9836 Epoch 18/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 106s 301ms/step - accuracy: 0.7920 - loss: 0.9693 - top-5-accuracy: 0.9889 - val_accuracy: 0.7616 - val_loss: 1.0791 - val_top-5-accuracy: 0.9828 Epoch 19/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 93s 264ms/step - accuracy: 0.7965 - loss: 0.9631 - top-5-accuracy: 0.9893 - val_accuracy: 0.7850 - val_loss: 1.0149 - val_top-5-accuracy: 0.9842 Epoch 20/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 93s 265ms/step - accuracy: 0.8030 - loss: 0.9529 - top-5-accuracy: 0.9899 - val_accuracy: 0.7898 - val_loss: 1.0029 - val_top-5-accuracy: 0.9852 Epoch 21/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 92s 261ms/step - accuracy: 0.8118 - loss: 0.9322 - top-5-accuracy: 0.9903 - val_accuracy: 0.7728 - val_loss: 1.0529 - val_top-5-accuracy: 0.9850 Epoch 22/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 91s 259ms/step - accuracy: 0.8104 - loss: 0.9308 - top-5-accuracy: 0.9906 - val_accuracy: 0.7874 - val_loss: 1.0090 - val_top-5-accuracy: 0.9876 Epoch 23/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 92s 263ms/step - accuracy: 0.8164 - loss: 0.9193 - top-5-accuracy: 0.9911 - val_accuracy: 0.7800 - val_loss: 1.0091 - val_top-5-accuracy: 0.9844 Epoch 24/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 94s 268ms/step - accuracy: 0.8147 - loss: 0.9184 - top-5-accuracy: 0.9919 - val_accuracy: 0.7854 - val_loss: 1.0260 - val_top-5-accuracy: 0.9856 Epoch 25/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 92s 262ms/step - accuracy: 0.8255 - loss: 0.9000 - top-5-accuracy: 0.9914 - val_accuracy: 0.7918 - val_loss: 1.0014 - val_top-5-accuracy: 0.9842 Epoch 26/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 90s 257ms/step - accuracy: 0.8297 - loss: 0.8865 - top-5-accuracy: 0.9933 - val_accuracy: 0.7924 - val_loss: 1.0065 - val_top-5-accuracy: 0.9834 Epoch 27/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 92s 262ms/step - accuracy: 0.8339 - loss: 0.8837 - top-5-accuracy: 0.9931 - val_accuracy: 0.7906 - val_loss: 1.0035 - val_top-5-accuracy: 0.9870 Epoch 28/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 92s 260ms/step - accuracy: 0.8362 - loss: 0.8781 - top-5-accuracy: 0.9934 - val_accuracy: 0.7878 - val_loss: 1.0041 - val_top-5-accuracy: 0.9850 Epoch 29/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 92s 260ms/step - accuracy: 0.8398 - loss: 0.8707 - top-5-accuracy: 0.9942 - val_accuracy: 0.7854 - val_loss: 1.0186 - val_top-5-accuracy: 0.9858 Epoch 30/30 352/352 ━━━━━━━━━━━━━━━━━━━━ 92s 263ms/step - accuracy: 0.8438 - loss: 0.8614 - top-5-accuracy: 0.9933 - val_accuracy: 0.7892 - val_loss: 1.0123 - val_top-5-accuracy: 0.9846 313/313 ━━━━━━━━━━━━━━━━━━━━ 14s 44ms/step - accuracy: 0.7752 - loss: 1.0370 - top-5-accuracy: 0.9824 Test accuracy: 77.82% Test top 5 accuracy: 98.42% `
plt.plot(history.history["loss"],` `label="train_loss")`
`plt.plot(history.history["val_loss"],` `label="val_loss")`
`plt.title("Train and Validation Losses Over Epochs",` `fontsize=14)`

我们刚刚训练的 CCT 模型只有 40 万 个参数,它让我们 在 30 个 epoch 内达到 ~79% top-1 的准确率。上图显示没有过拟合的迹象,因为 井。这意味着我们可以训练这个网络更长时间(也许需要更多一点 正则化),并且可能会获得更好的性能。此性能可以进一步 通过其他方法进行改进,例如 cosine decay learning rate schedule、其他数据增强 AutoAugment、MixUp 或 Cutmix 等技术。通过这些修改,作者提出了 CIFAR-10 数据集上 95.1% 的 top-1 准确率。作者还介绍了一些 实验来研究卷积块、Transformers 层等的数量。 影响 CCT 的最终性能。
相比之下,ViT 模型大约需要 470 万 个参数和 100 个 在 CIFAR-10 数据集上达到 78.22% 的 top-1 准确率的训练纪元。您可以 请参阅此笔记本以了解有关实验设置的信息。
作者还演示了 Compact Convolutional Transformers 在 NLP 任务,他们在那里报告有竞争力的结果。