代码复现(三):TGRS-T4Fire

文章目录


analysis/i4_variation.py

处理两份CSV文件,绘制非火灾和火灾状态下波段 B a n d 4 Band4 Band4的变化图像。

data_processor/TokenizeProcessor.py

实现类TokenizeProcessor,实现将多光谱图像分块、标记化(Tokenize,本研究中指加上时间token)扁平化(flatten)。类中包含def tokenizing(self, data_path, window_size)def flatten_window(self, array, window_size)两个方法。

  • def tokenizing(self, data_path, window_size):加载.npy数据文件,对其维度进行转置,以便在后续的数据处理流程中可以更方便地访问和使用数据。
    • data_path:存放数据的.npy文件。
    • window_size:窗口参数 w w w(代码中未被使用)。
py 复制代码
def tokenizing(self, data_path, window_size):
	#将原本的维度(0,1,2,3,4)转为(0,3,4,1,2),即变为(batch_size,channels,time,height,width)格式
	array = np.load(data_path).transpose((0, 3, 4, 1, 2))
	return array
  • def flatten_window(self, array, window_size):将数组按窗口大小展平
    • array:四维numpy数组(时间序列数据)。
    • window_size:窗口参数 w w w。
py 复制代码
def flatten_window(self, array, window_size):
	#创建全零矩阵,形状为(array.shape[2],(array.shape[3]) * pow(window_size, 2))))
	##array.shape[3]) * pow(window_size, 2)即为嵌入(行)向量的维度w^2c
	##array.shape[2]即为嵌入向量的数量,对应时间标记
	output_array = np.zeros((array.shape[2], (array.shape[3])*pow(window_size,2))))
	#遍历每个时间步,将输入的四维数组array展平为大小为w^2c的行向量并存放到数组,作为模型的输入
	for time in range(array.shape[2]):
		#用1,2,3,...,作为时间token
		output_array[time, :] = array[:, :, time, :].flatten('F')
	#返回向量序列,形状为(时间,展平后的数据)
	return output_array

在论文中提到过,以每一个像素为中心提取大小为 w × w w×w w×w大小的图像片段,对其进行标记化(tokenizing()实现)、扁平化(flatten_window()实现)得到大小为 w 2 C × 1 w^2C×1 w2C×1的向量组成的序列作为Transformer模型的输入。在脚本下给出了使用案例:

py 复制代码
if __name__ == '__main__':
    if platform.system() == 'Darwin':
        root_path = '/Users/zhaoyu/PycharmProjects/T4Fire/data'
    else:
        root_path = '/geoinfo_vol1/zhao2/proj5_dataset'
    #定义窗口大小
    window_size = 1
    #定义时间序列长度
    ts_length = 10
    tokenize_processor = TokenizeProcessor()
    
    #将数据格式变为(batch_size,channels,time,height,width)
    tokenized_array = tokenize_processor.tokenizing(os.path.join(root_path, f'proj5_train_img_seqtoseq_l{ts_length}.npy'), window_size)
    tokenized_array = np.nan_to_num(tokenized_array)
    
    # 保存处理后的数据
    np.save(os.path.join(root_path, f'proj5_train_img_seqtoone_l{ts_length}_w{window_size}.npy'), tokenized_array.reshape(-1, tokenized_array.shape[-2], tokenized_array.shape[-1]))

data_processor/PreprocessingService.py

PreprocessingService中实现了多个用于处理遥感影像数据的功能。

  • def padding(self, coarse_arr, array_to_be_downsampled)
  • def down_sampling(self, input_arr):返回数组均值。
    • input_arr:输入数组。
  • def standardization(self, array):对每个通道的数组进行标准化处理。
    • array:输入数组,形状为 ( n c h a n n e l s , h e i g h t , w i d t h ) (n_channels, height, width) (nchannels,height,width)。
  • def normalization(self, array, channel_first=True):对数组进行归一化处理,将每个通道的值缩放到 [0, 1] 范围内。
    • array:待归一化的数组。
    • channel_first=True:布尔值,指示数组的通道维度在前还是在后。
  • def read_tiff(self, file_path):读取影像TIF文件,返回影像数据和元数据(profile)。
    • file_path:TIFF 文件的路径。
  • def write_tiff(self, file_path, arr, profile):将处理后的数组写入TIF文件。
    • file_path:保存 TIFF 文件的路径。
    • arr:要写入的数组,需转为浮点类型。
    • profile:TIFF 文件的元数据。

data_processor/DatesetProcessor.py

实现了类DatasetProcessor用于处理遥感影像数据,处理对象为TIF格式,最终生成用于模型训练的数据集。需先加载yml配置文件,该文件存放了不同地点的开始日期等信息,可用于时间序列数据的处理。

  • def reconstruct_tif_proj2(self, location, satellite='VIIRS_Day', image_size=(224, 224)) :将预处理过的npy格式的影像文件重新投影并保存为tif格式的文件。

    • location:影像数据存储的文件夹名称。
    • satellite :卫星名称,默认为'VIIRS_Day',可构建影像路径data/location/ satellite/
    • image_size:图像尺寸。
  • def dataset_generator_proj2_image(self, locations, file_name, image_size=(224, 224)) :处理多个位置的卫星数据(由locations指定,实际是按位置将数据划分为多个文件夹,locations列表中的位置信息可参与构建目录路径),将图像裁剪为224x224大小,并将生成的多维数组保存为.npy文件。其中,多维数组维度为 ( 10 , 7 , 224 , 224 ) (10,7,224,224) (10,7,224,224), 10 10 10表示时间序列的长度, 7 = 5 + 2 7=5+2 7=5+2表示图像的五个波段及掩码信息。

    • locations:包含位置信息的字符串列表。
    • file_name :生成的.npy文件的文件名。
    • image_size:裁剪生成的图像patch尺寸。

在函数内部使用save_path定义了处理后数据的路径,n_channels表示输出图像通道为5(包含了波段 B a n d 1 − B a n d 5 Band1-Band5 Band1−Band5)

  • def dataset_generator_proj2_image_test(self, location, file_name, image_size=(224, 224)) :生成一个基于给定位置和时间序列的卫星图像数据集,并将结果保存为.npy格式。

temporal_models/vit_keras/layers.py

实现ViT的Encoder模块。

  • class ClassToken(tf.keras.layers.Layer):向输入序列添加一个可训练的Class token,增强了模型对输入数据的表示能力,适合于分类任务中使用。
    • def build(self, input_shape):初始化class token,形状为 ( 1 , 1 , h i d d e n _ s i z e ) (1,1,hidden\_size) (1,1,hidden_size),其中 h i d d e n _ s i z e hidden\_size hidden_size是输入向量序列 i n p u t _ s h a p e input\_shape input_shape最后一个维度的大小(特征的维度)。class token初始化为零向量, f l o a t 32 float32 float32类型,可训练。
    • def call(self, inputs):前向传播,将Class token广播到与输入数据相同的批次大小,形状变为 ( b a t c h _ s i z e , 1 , h i d d e n _ s i z e ) (batch\_size,1,hidden\_size) (batch_size,1,hidden_size),并将其拼接到输入数据的最前面,拼接操作沿着第二个维度(序列维度)进行,得到张量形状为 ( b a t c h _ s i z e , s e q u e n c e _ l e n g t h + 1 , h i d d e n _ s i z e ) (batch\_size,sequence\_length+1,hidden\_size) (batch_size,sequence_length+1,hidden_size)(可理解为在行向量序列中增加了一个行向量)。
    • def get_config(self):保存层的配置,保证在模型保存与加载时可以恢复层的参数和行为。
    • def from_config(cls, config):从配置字典恢复该层实例,帮助实现模型的反序列化。
  • class AddPositionEmbs(tf.keras.layers.Layer):为输入序列添加可学习的位置嵌入(Positional Embeddings),将该嵌入与输入序列逐元素相加,帮助Transformer学习输入数据的位置信息。
    • def build(self, input_shape):初始化位置嵌入变量(可训练),形状为 ( 1 , i n p u t _ s h a p e [ 1 ] , i n p u t _ s h a p e [ 2 ] ) = ( 1 , s e q u e n c e _ l e n g t h , h i d d e n _ s i z e ) (1,input\_shape[1], input\_shape[2])=(1,sequence\_length,hidden\_size) (1,input_shape[1],input_shape[2])=(1,sequence_length,hidden_size),需对序列每一个向量进行位置嵌入。
    • def call(self, inputs):前向传播,通过相加操作,将位置编码嵌入到每一个向量上。
    • def get_config(self):返回层的配置,以便在保存和加载模型时保留该层的结构和参数。
    • def from_config(cls, config):从配置字典中恢复层的实例,帮助实现模型的反序列化。
  • class MultiHeadSelfAttention(tf.keras.layers.Layer):实现多头注意力机制,每个注意力头独立计算注意力并学习不同的特征表示,然后将多个头的结果拼接在一起。这个机制有助于模型捕捉输入序列中不同位置之间的依赖关系。
    • __init__(self, *args, is_masked, num_heads, **kwargs):通过is_mask确定是否使用掩膜,通过num_head确定多头注意力机制的头数,每个头独立计算注意力,并捕获不同的模式或关系。
    • def build(self, input_shape):根据 i n p u t _ s h a p e input\_shape input_shape的形状初始化参数。包括根据 h i d d e n _ s i z e hidden\_size hidden_size初始化 Q 、 V 、 K Q、V、K Q、V、K矩阵,判断 h i d d e n _ s i z e hidden\_size hidden_size是否可被 n u m _ h e a d s num\_heads num_heads整除等(多头自注意力机制需要将嵌入向量均分到每个头)。
    • def attention(self, query, key, value, attention_mask=None):定义注意力机制的计算方式,即计算 A t t e n t i o n ( Q , K , V ) Attention(Q,K,V) Attention(Q,K,V)。
    • def separate_heads(self, x, batch_size):将输入的最后一个维度 ( h i d d e n _ s i z e ) (hidden\_size) (hidden_size)分成多个头的维度,并改变张量的顺序。维度从 ( b a t c h _ s i z e , s e q u e n c e _ l e n g t h , h i d d e n _ s i z e ) (batch\_size, sequence\_length, hidden\_size) (batch_size,sequence_length,hidden_size)变为 ( b a t c h _ s i z e , n u m _ h e a d s , s e q u e n c e _ l e n g t h , p r o j e c t i o n _ d i m ) (batch\_size, num\_heads, sequence\_length, projection\_dim) (batch_size,num_heads,sequence_length,projection_dim)。
    • def call(self, inputs):前向传播,返回两个结果。
      • output:经过多头自注意力机制处理后的输出张量。形状为 ( b a t c h _ s i z e , s e q u e n c e _ l e n g t h , h i d d e n _ s i z e ) (batch\_size, sequence\_length, hidden\_size) (batch_size,sequence_length,hidden_size),这是通过将每个头的输出重新组合并线性变换得到的。
      • weights:注意力权重矩阵,形状为 ( b a t c h _ s i z e , n u m _ h e a d s , s e q u e n c e _ l e n g t h , s e q u e n c e _ l e n g t h ) (batch\_size,num\_heads, sequence\_length, sequence\_length) (batch_size,num_heads,sequence_length,sequence_length)。
    • def get_config(self):保存层的配置,保证在模型保存与加载时可以恢复层的参数和行为。
    • def from_config(cls, config):从配置字典恢复该层实例,帮助实现模型的反序列化。
  • class TransformerBlock(tf.keras.layers.Layer):实现了Transformer Encoder模块。
    • def __init__(self, *args, num_heads, mlp_dim, dropout, is_masked, **kwargs):初始化Transformer模块的参数,包括 n u m _ h e a d s num\_heads num_heads(多头注意力机制参数)、 m l p _ d i m mlp\_dim mlp_dim(MLP隐藏层维度)、 d r o p o u t dropout dropout(正则化参数)、 i s _ m a s k e d is\_masked is_masked(是否使用掩码,应用在解码器当中)。
    • def build(self, input_shape):定义了Transformer模块的各个层,包括 s e l f . a t t self.att self.att(多头注意力机制)、 s e l f . m l p b l o c k self.mlpblock self.mlpblock(MLP模块)、 s e l f . l a y e r n o r m 1 self.layernorm1 self.layernorm1(LayerNormalization层)、 s e l f . l a y e r n o r m 2 self.layernorm2 self.layernorm2(LayerNormalization层)、 s e l f . d r o p o u t _ l a y e r self.dropout\_layer self.dropout_layer(dropout层)。 s e l f . m l p b l o c k self.mlpblock self.mlpblock(MLP模块)内部结构:
    • def call(self, inputs, training):完成 T r a n s f o r m e r B l o c k Transformer Block TransformerBlock的前向传播过程。
    • def get_config(self):返回层的配置,以便在保存和加载模型时保留该层的结构和参数。
    • def from_config(cls, config):从配置字典中恢复层的实例,帮助实现模型的反序列化。

def call(self, inputs, training)代码:

py 复制代码
    def call(self, inputs, training):
    	#LayerNormalization归一化
        x = self.layernorm1(inputs)
        #计算注意力权重与输出
        x, weights = self.att(x)
        #dropout
        x = self.dropout_layer(x, training=training)
        #残差连接
        x = x + inputs
        #LayerNormalization归一化
        y = self.layernorm2(x)
        #MLP模块
        y = self.mlpblock(y)
        #残差连接,返回结果与注意力权重
        return x + y, weights

temporal_models/vit_keras/patch_encoder.py

实现类class PathEncoder(layers.Layer),为嵌入向量添加位置编码。

py 复制代码
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        #num_patches:输入的序列中,图像patch的数目(即为嵌入向量的数目)
        self.num_patches = num_patches
        #定义一个嵌入层,将每个位置编码为一个大小为projection_dim的(行)向量
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def get_config(self):
        config = super().get_config()

        config.update({
            "num_patches":self.num_patches,
            "position_embedding": self.position_embedding,
        })

        return config

    def call(self, patch):
    	#生成[0,...,num_patches-1]的整数序列作为位置编码
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        #将整数序列输入位置嵌入层,生成位置嵌入向量
        return self.position_embedding(positions)

temporal_models/vit_keras/utils.py

定义读取和预处理图像数据的方法,且能将来自Flax训练的预训练模型权重加载到Keras模型中。

  • def get_imagenet_classes():在vit_keras包中找到imagenet2012.txt文件(保存类别信息的文本文件),返回包含ImageNet 2012数据集所有类别的字符串列表。
  • def read(filepath_or_buffer: ImageInputType, size, timeout=None):给出要读取的图像源( f i l e p a t h _ o r _ b u f f e r filepath\_or\_buffer filepath_or_buffer),将图像转换为指定大小并转换为RGB格式并返回。
    • filepath_or_buffer:指定要读取的图像源,可以是PIL对象、io.BytesIO、client.HTTPResponse、URL、文件路径。
    • size:指定返回图像的尺寸。
  • def apply_embedding_weights(target_layer, source_weights, num_x_patches, num_y_patches):将嵌入层的权重(source_weights)应用到目标层,并在需要时对权重进行调整。
  • def load_weights_numpy(model, params_path):从numpy文件中加载模型参数,并将其应用到指定的keras模型中。
    • model:要加载权重的 Keras 模型。
    • params_path:记录模型参数的numpy文件。

temporal_models/vit_keras/vit.py

构建ViT模型框架,包括不同版本的预定义配置以及允许自定义的功能。模型的构建主要通过堆叠 Transformer 块和其他 Keras 层来实现,适用于多种输入形状和任务设置。

  • ConfigDict = tx.TypedDict({...}):定义Vit模型配置参数的类型字典。
  • CONFIG_Ti: ConfigDict = {...}CONFIG_S: ConfigDict = {...}CONFIG_B: ConfigDict = {...}:定义ViT的Tiny、Small、Base、Large的配置参数,包括 d r o p o u t 、 m l p _ d i m 、 n u m _ h e a d s 、 n u m _ l a y e r s 、 h i d d e n _ s i z e dropout、mlp\_dim、num\_heads、num\_layers、hidden\_size dropout、mlp_dim、num_heads、num_layers、hidden_size。
  • def build_model(input_shape: tuple,num_layers: int,hidden_size: int,num_heads: int,name: str,mlp_dim: int,classes: int,dropout=0.1,activation="linear",include_top=True,representation_size=None,return_sequence=True,is_masked=True):用于构建模型的函数。
参数名 功能
input_size 输入图像的形状。
num_layers Transformer模型中Encoder Block的数量。
hidden_size 隐藏层的大小。
num_heads 多头注意力机制参数。
name 模型名称。
mlp_dim MLP层的输出维度。
classes 分类任务的类别数量。
dropout Dropout 比率。
activation 最终层的激活函数。
include_top 是否包含分类层。
representation_size 分类层之前的表示层大小。
return_sequence 是否返回序列。
is_masked 是否使用掩码。
py 复制代码
def build_model(
    input_shape: tuple,
    num_layers: int,
    hidden_size: int,
    num_heads: int,
    name: str,
    mlp_dim: int,
    classes: int,
    dropout=0.1,
    activation="linear",
    include_top=True,
    representation_size=None,
    return_sequence=True,
    is_masked=True
):
	#创建输入层,形状为 input_shape
    x = tf.keras.layers.Input(shape=input_shape)
    #Linear projection of Spectral values of pixel at each Timestamp
    proj = tf.keras.layers.Dense(units=hidden_size)(x)
    #使用 PatchEncoder 对输入图像进行编码,生成补丁。
    y = PatchEncoder(input_shape[0], hidden_size)(x)
    #残差连接
    y = y + proj
    #按循环构建指定数量的Encoder Block模块
    for n in range(num_layers):
        y, _ = layers.TransformerBlock(
            num_heads=num_heads,
            mlp_dim=mlp_dim,
            dropout=dropout,
            is_masked=is_masked,
            name=f"Transformer/encoderblock_{n}",
        )(y)
    #LayerNormalization
    y = tf.keras.layers.LayerNormalization(
        epsilon=1e-6, name="Transformer/encoder_norm"
    )(y)
    #MLP-Head模块:Pre-Logits
    if representation_size is not None:
        y = tf.keras.layers.Dense(
            representation_size, name="pre_logits", activation="tanh"
        )(y)
    #展平
    if not return_sequence:
        y = tf.keras.layers.Flatten()(y)
    #MLP-Head模块:Linear层
    if include_top:
        y = tf.keras.layers.Dense(classes, name="head", activation=activation)(y)
    return tf.keras.models.Model(inputs=x, outputs=y, name=name)
  • def vit_base(input_shape=(10, 45),classes=2,activation="linear",include_top=True,weights="imagenet21k+imagenet2012"):创建一个Base版本的ViT模型并返回。

run_seq_model.py

时间序列模型(ViT、GRU、LSTM、TCN)的训练。

  • def get_dateset(window_size, batch_size):加载训练和测试数据集。
    • window_size:设置 w w w的值。
    • batch_size:批次大小。
py 复制代码
def get_dateset(window_size, batch_size):
	#加载训练数据集(已经过w处理),文件格式如'.npy',如'proj3_train_v2_w1.npy'
    x_dataset = np.load(os.path.join(root_path, 'proj3_train_v2_w' + str(window_size) + '.npy'))
    #初始化标签y_dataset,x_dataset.shape[0]表样本数,x_dataset.shape[1]表时间步数,第三个维度大小为2,表二分类任务
    y_dataset = np.zeros((x_dataset.shape[0],x_dataset.shape[1],2))

    #
    y_dataset[: ,:, 0] = x_dataset[:, :, pow(window_size,2)*5] == 0
    y_dataset[:, :, 1] = x_dataset[:, :, pow(window_size,2)*5] > 0
    
    x_dataset_val1 = np.load(os.path.join(root_path, 'proj3_walker_fire_w'+str(window_size)+'.npy'))
    x_dataset_val2 = np.load(os.path.join(root_path, 'proj3_hanceville_fire_w'+str(window_size)+'.npy'))

    x_dataset_val = np.concatenate((x_dataset_val1, x_dataset_val2), axis=0)
    y_dataset_val = np.zeros((x_dataset_val.shape[0],x_dataset_val.shape[1],2))
    y_dataset_val[: ,:, 0] = x_dataset_val[:, :, pow(window_size,2)*5] == 0
    y_dataset_val[:, :, 1] = x_dataset_val[:, :, pow(window_size,2)*5] > 0

    x_train, x_val, y_train, y_val = x_dataset[:,:,:pow(window_size,2)*5], x_dataset_val[:,:,:pow(window_size,2)*5], y_dataset, y_dataset_val
    def make_generator(inputs, labels):
        def _generator():
            for input, label in zip(inputs, labels):
                yield input, label

        return _generator
    #定义Tensorflow训练、测试数据集加载器
    train_dataset = tf.data.Dataset.from_generator(make_generator(x_train, y_train),
                                                   (tf.float32, tf.int16))
    val_dataset = tf.data.Dataset.from_generator(make_generator(x_val, y_val),
                                                 (tf.float32, tf.int16))
    #将训练、测试数据洗牌(shuffle),并按batch_size分批处理.之后重复数据集以便用于多个训练周期(epochs)
    train_dataset = train_dataset.shuffle(batch_size).repeat(MAX_EPOCHS).batch(batch_size)
    val_dataset = val_dataset.repeat(MAX_EPOCHS).batch(224*224)
    #计算每个训练周期需要的步骤数 steps_per_epoch,和每次验证的步骤数 validation_steps。
    steps_per_epoch = x_train.shape[0]//batch_size
    validation_steps = x_val.shape[0]//(224*224)
    #返回训练集、验证集,以及每个周期的训练和验证步数。
    return train_dataset, val_dataset, steps_per_epoch, validation_steps
  • def wandb_config(window_size, model_name, run, num_heads, num_layers, mlp_dim, hidden_size):Wandb 是一个常用于深度学习训练监控的平台,能够记录训练过程中的各种参数和指标。wandb_config()可用于配置和初始化该平台。
    • window_size:窗口大小,用于数据处理。
    • model_name:模型名称,确定选择哪种模型进行训练。
    • run:当前运行标识符。
    • num_heads:多头注意力机制Heads数量。
    • num_layers:模型层数。
    • mlp_dim:多层感知机隐藏层维度大小。
    • hidden_size:嵌入向量维度大小。
py 复制代码
def wandb_config(window_size, model_name, run, num_heads, num_layers, mlp_dim, hidden_size):
    wandb.login()
    wandb.init(project="proj3_"+model_name+"_grid_search", entity="zhaoyutim")
    wandb.run.name = 'num_heads_' + str(num_heads) + 'num_layers_'+ str(num_layers)+ 'mlp_dim_'+str(mlp_dim)+'hidden_size_'+str(hidden_size)+'batchsize_'+str(batch_size)
    #配置wandb记录的超参数信息
    wandb.config = {
        "learning_rate": learning_rate,
        "weight_decay": weight_decay,
        "epochs": MAX_EPOCHS,
        "batch_size": batch_size,
        "num_heads": num_heads,
        "num_layers": num_layers,
        "mlp_dim": mlp_dim,
        "embed_dim": hidden_size
    }
  • main:训练多种深度学习模型(ViT、GRU、LSTM等),并通过wandb记录训练过程。
py 复制代码
if __name__=='__main__':
	#配置命令行参数
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('-m', type=str, help='Model to be executed')
    parser.add_argument('-w', type=int, help='Window size')
    parser.add_argument('-p', type=str, help='Load trained weights')
    parser.add_argument('-b', type=int, help='batch size')
    parser.add_argument('-r', type=int, help='run')
    parser.add_argument('-lr', type=float, help='learning rate')

    parser.add_argument('-nh', type=int, help='number-of-head')
    parser.add_argument('-md', type=int, help='mlp-dimension')
    parser.add_argument('-ed', type=int, help='embedding dimension')
    parser.add_argument('-nl', type=int, help='num_layers')
    #解析配置的参数
    args = parser.parse_args()
    model_name = args.m
    load_weights = args.p
    window_size = args.w
    batch_size=args.b
    num_heads=args.nh
    mlp_dim=args.md
    num_layers=args.nl
    hidden_size=args.ed
    is_masked=False

    run = args.r
    lr = args.lr
    MAX_EPOCHS = 50
    learning_rate = lr
    weight_decay = lr / 10
    num_classes=2

    input_shape=(10,pow(window_size,2)*5)
	#初始化wandb平台连接
    wandb_config(window_size, model_name, run, num_heads, mlp_dim, num_layers, hidden_size)
	#分布式训练策略,使用多GPU平台进行训练
    strategy = tf.distribute.MirroredStrategy()
    #strategy.scope()表示以下代码采用分布式策略,可在多GPU平台训练模型
    with strategy.scope():
    	#根据命令行参数加载不同模型
        if model_name == 'vit_small':
            model = vit.vit_small(
                input_shape=input_shape,
                classes=num_classes,
                activation='sigmoid',
                include_top=True,
            )
        elif model_name=='vit_tiny':
            model = vit.vit_tiny(
                input_shape=input_shape,
                classes=num_classes,
                activation='sigmoid',
                include_top=True,
            )
        elif model_name=='vit_tiny_custom':
            model = vit.vit_tiny_custom(
                input_shape=input_shape,
                classes=num_classes,
                activation='sigmoid',
                include_top=True,
                num_heads=num_heads,
                mlp_dim=mlp_dim,
                num_layers=num_layers,
                hidden_size=hidden_size,
                is_masked=is_masked
            )
        elif model_name == 'gru_custom':
            gru = GRUModel(input_shape, num_classes)
            model = gru.get_model_custom(input_shape, num_classes, num_layers, hidden_size)
        elif model_name == 'lstm_custom':
            lstm = LSTMModel(input_shape, num_classes)
            model = lstm.get_model_custom(input_shape, num_classes, num_layers, hidden_size)
        elif model_name=='vit_base':
            model = vit.vit_base(
                input_shape=input_shape,
                classes=num_classes,
                activation='sigmoid',
                include_top=True,
            )
        elif model_name=='tcn':
            model = compiled_tcn(return_sequences=True,
                                 num_feat=input_shape[-1],
                                 num_classes=num_classes,
                                 nb_filters=mlp_dim,
                                 kernel_size=hidden_size,
                                 dilations=[2 ** i for i in range(9)],
                                 nb_stacks=num_layers,
                                 max_len=input_shape[0],
                                 use_weight_norm=True,
                                 use_skip_connections=True)
        else:
            raise('no suport model')
		#打印模型概述
        model.summary()
        #设置Adam优化器,并使用SigmoidFocalCrossEntropy损失函数、使用CategoricalAccuracy评估分类准确度
        optimizer = tfa.optimizers.AdamW(
            learning_rate=learning_rate, weight_decay=weight_decay
        )
        train_dataset, val_dataset, steps_per_epoch, validation_steps = get_dateset(window_size, batch_size)
		#不同模型生成不同检查点文件
        if model_name == 'vit_tiny_custom':
            checkpoint = ModelCheckpoint(os.path.join(root_path, 'proj3_'+model_name+'w' + str(window_size) + '_nopretrained'+'_run'+str(run)+'_'+str(num_heads)+'_'+str(mlp_dim)+'_'+str(hidden_size)+'_'+str(num_layers)+'_'+str(batch_size)), monitor="val_loss", mode="min", save_best_only=True, verbose=1)
        elif model_name == 'gru_custom' or model_name == 'lstm_custom':
            checkpoint = ModelCheckpoint(os.path.join(root_path, 'proj3_'+model_name+'w' + str(window_size) + '_nopretrained'+'_run'+str(run)+'_'+str(hidden_size)+'_'+str(num_layers)), monitor="val_loss", mode="min", save_best_only=True, verbose=1)
        else:
            checkpoint = ModelCheckpoint(os.path.join(root_path, 'proj3_'+model_name+'w' + str(window_size) + '_nopretrained'+'_run'+str(run)), monitor="val_loss", mode="min", save_best_only=True, verbose=1)
		#模型编译并设置优化器、损失函数、评价指标
        model.compile(
            optimizer=optimizer,
            loss=tfa.losses.SigmoidFocalCrossEntropy(from_logits=False),
            metrics=[
                tf.keras.metrics.CategoricalAccuracy(name="accuracy")
            ],
        )
	#创建tf.data.Options对象,Tensorflow中专门用于处理数据集,如并行处理、数据分片等
    options = tf.data.Options()
    #关闭数据集在分布式训练环境下的自动分片策略(Auto Sharding,指TensorFlow 在分布式训练中,将输入数据集自动分割为多份,以供不同的计算设备(如多GPU)并行处理)
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
    #将options应用到训练集、测试集上
    train_dataset = train_dataset.with_options(options)
    val_dataset = val_dataset.with_options(options)
    #直接加载模型权重或训练模型,并使用wandb记录训练过程
    if load_weights== 'yes':
        model.load_weights(os.path.join(root_path, 'proj3_' + model_name + 'w' + str(window_size) + '_nopretrained'+'_run'+str(run)))
    else:
        print('training in progress')
        history = model.fit(
            x=train_dataset,
            steps_per_epoch=steps_per_epoch,
            validation_data=val_dataset,
            validation_steps=validation_steps,
            epochs=MAX_EPOCHS,
            callbacks=[WandbCallback()],
        )
        #保存训练的模型
        if model_name == 'vit_tiny_custom':
            model.save(os.path.join(root_path, 'proj3_'+model_name+'w' + str(window_size) + '_nopretrained'+'_run'+str(run)+'_'+str(num_heads)+'_'+str(mlp_dim)+'_'+str(hidden_size)+'_'+str(num_layers)+'_'+str(batch_size)+'_'+str(is_masked)))
        elif model_name == 'gru_custom' or model_name == 'lstm_custom':
            model.save(os.path.join(root_path, 'proj3_'+model_name+'w' + str(window_size) + '_nopretrained'+'_run'+str(run)+'_'+str(hidden_size)+'_'+str(num_layers)+'_'+str(is_masked)))
        else:
            model.save(os.path.join(root_path, 'proj3_'+model_name+'w' + str(window_size) + '_nopretrained'+'_run'+str(run)+'_'+str(is_masked)))
相关推荐
chengxuyuan666667 分钟前
python基础语句整理
java·windows·python
一只会飞的猪_9 分钟前
国密加密golang加密,java解密
java·开发语言·golang
四念处茫茫24 分钟前
【C语言系列】深入理解指针(2)
c语言·开发语言·visual studio
LucianaiB30 分钟前
C语言之图像文件的属性
c语言·开发语言·microsoft·c语言之图像文件的属性
向着开发进攻30 分钟前
深入理解 Java 并发编程中的锁机制
java·开发语言
清弦墨客32 分钟前
【蓝桥杯】43691.拉马车
python·蓝桥杯·程序算法
CURRY30_HJH35 分钟前
JAVA 使用反射比较对象属性的变化,记录修改日志。使用注解【策略模式】,来进行不同属性枚举值到中英文描述的切换,支持前端国际化。
java·开发语言
千千道38 分钟前
QT 中 UDP 的使用
开发语言·qt·udp
C++小厨神2 小时前
MATLAB语言的编程范式
开发语言·后端·golang