import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import layers
from keras import ops
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
num_classes = 10
input_shape = (32,32,3)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
learning_rate = 1e-3
weight_decay = 1e-4
batch_size = 64
num_epochs = 30
image_size = 72 # 重设图片大小
patch_size = 6 #每个小块的大小
num_patches = (image_size // patch_size) ** 2#被分成了多少小块
projection_dim = 64
num_heads = 4
transformer_units = [
projection_dim *4,
projection_dim,
]
transformer_layers = 8
mlp_head_units = [
2048,
1024,
]
data_augmentation = keras.Sequential(
[
layers.Rescaling(scale=1.0 / 255),
layers.Resizing(image_size, image_size),
layers.RandomFlip("horizontal"),
layers.RandomRotation(factor=0.02),
layers.RandomZoom(height_factor=0.1, width_factor=0.1),
],
name="data_augmentation",
)
val_split = 0.1
val_indices = int(len(x_train) * val_split)
new_x_train, new_y_train = x_train[val_indices:], y_train[val_indices:]
x_val, y_val = x_train[:val_indices], y_train[:val_indices]
auto = tf.data.AUTOTUNE
def make_datasets(images, labels, is_train=False):
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
if is_train:
dataset = dataset.shuffle(batch_size * 20)
dataset = dataset.batch(batch_size)
return dataset.prefetch(auto)
train_dataset = make_datasets(new_x_train, new_y_train, is_train=True)
val_dataset = make_datasets(x_val, y_val)
test_dataset = make_datasets(x_test, y_test)
#线性转换层
def mlp(x, hidden_units,dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=keras.activations.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
class Patches(layers.Layer):#把图片分块
def init(self, patch_size):
super().init()
self.patch_size = patch_size
def call(self, images):
input_shape = ops.shape(images)
batch_size = input_shape[0]#批次大小
height = input_shape[1]#高
width = input_shape[2]#宽
channels =input_shape[-1]#通道
num_patches_h = height // self.patch_size#12个小块
num_patches_w = width // self.patch_size
小块数据
patches = keras.ops.image.extract_patches(images, size=self.patch_size)
print(patches.shape)#(1,12,12,6*6*3)
#1指一个样本,7是行,列被分成7小区间,16是1*4*4,是每个小块的像素值
patches = ops.reshape(
patches,
(
batch_size,
num_patches_h * num_patches_w,
self.patch_size * self.patch_size * channels,
),
)
#变形之后,(1,144,6*6*3),样本,小块,每个小块的像素值
print(patches.shape) #(1,144,6*6*3)
return patches
def get_config(self):#获取层配置信息
config = super().get_config()
config.update({"patch_size": self.patch_size})
return config
plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]#随机选取一张图片
print(image.shape,image.max(),image.min())
plt.imshow(image.astype("uint8"))
plt.axis("off")
resized_image = ops.image.resize(
ops.convert_to_tensor([image.astype('float32')]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
n = int(np.sqrt(patches.shape[1]))#n=12
print(n,patches[0].shape)
plt.figure(figsize=(4,4))
for i, patch in enumerate(patches[0]):#(49,16)
print(patch.shape)#(108,)
ax = plt.subplot(n,n, i + 1)
patch_img = ops.reshape(patch, (patch_size, patch_size,3))#变形成(6,6,3)
print(patch_img.numpy().max(),type(patch_img))#每个小块图的像素大小都不一样
plt.imshow(patch_img.numpy().astype("uint8"))
plt.axis("off")#不显示轴
plt.show()
#图块的编码层(加了位置信息)
class PatchEncoder(layers.Layer):
def init(self, num_patches, projection_dim):
super().init()
self.num_patches = num_patches#多少块
self.projection = layers.Dense(projection_dim)#线性转换
self.position_embedding = layers.Embedding(#位置编码
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
在0轴增加维度,变成2维(1,144)
positions = ops.expand_dims(
ops.arange(start=0, stop=self.num_patches, step=1), axis=0
)
projected_patches = self.projection(patch)#(None,144,108)-->(None,144,64)
print(projected_patches.shape)
#(1,144,64)+ (1,144)--->(1,144,64)
encoded = projected_patches + self.position_embedding(positions)
print(encoded.shape)
return encoded
def get_config(self):
config = super().get_config()
config.update({"num_patches": self.num_patches})
return config
def create_vit_classifier():#构建vit model
inputs = keras.Input(shape=input_shape)#(32,32,3)
数据增强
augmented = data_augmentation(inputs)
创建小区域图块(1,144,108)#原图(32,32,3),path_size=6,144是指144个小块,108是
6*6*3,3通道
patches = Patches(patch_size)(augmented)
对图像区块编码,编码后形状(1,144,108)
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
print(f'{encoded_patches.shape=}')
for _ in range(transformer_layers):#迭代4次,相当于经过了4个编码器层
标准化到1
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
#多头自注意力层,num_heads:4,projection_dim:64,attention_output(None, 49, 64)
x1最后一维是64,所以分给每个头16维,这样才是利用了多头,key_dim=16
attention_output = layers.MultiHeadAttention(
num_heads=num_heads,key_dim=projection_dim//num_heads,dropout=0.1 )(x1, x1)
print(f'{attention_output.shape=}')
#残差连接
x2 = layers.Add()([attention_output, encoded_patches])
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
前馈全连接(线性转换层)
x3 = mlp(x3, hidden_units=transformer_units,dropout_rate=0.1)
(64)残差连接
encoded_patches = layers.Add()([x3, x2])
print(f'{encoded_patches=}')#((None, 144, 64)
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
#用全局平均池化可以减少参数量
representation=layers.GlobalAveragePooling1D()(representation)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
线性转换层
features = mlp(representation, hidden_units=mlp_head_units,dropout_rate=0.2)
10分类
logits = layers.Dense(num_classes)(representation)
Create the Keras model.
model = keras.Model(inputs=inputs, outputs=logits)
return model
vit_classifier = create_vit_classifier()
vit_classifier.summary()# 49*64+16*64+64
#weight_decay:权重衰减系数
optimizer = keras.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
)
vit_classifier.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="acc")
],
)
from tensorflow.keras.callbacks import ReduceLROnPlateau
checkpoint_filepath = "./checkpoint/cifar10_transformer_best_1.keras"
callbacks = [ keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,monitor='val_loss',\
verbose=1,save_best_only=True,mode='min'),
ReduceLROnPlateau(monitor='val_loss', factor=0.5,
patience=2, min_lr=8e-6)]
history = vit_classifier.fit(
train_dataset,
epochs=30,
validation_data=val_dataset,
callbacks=[callbacks]
)
vit_classifier.load_weights(checkpoint_filepath)
_, acc = vit_classifier.evaluate(test_dataset)
print("Test accuracy: %.2f %%" %(acc*100))
def plot_history(item):
plt.plot(history.history[item], label=item)
plt.plot(history.history["val_" + item], label="val_" + item)
plt.xlabel("Epochs")
plt.ylabel(item)
plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
plt.legend()
plt.grid()
plt.show()
plot_history("loss")