github.com/thuml/Time-... 很不错的一个时序库,看起来包含比较新的,基于Transformer一系列的时间序列衍生模型。前一段时间对这块做了一些调研,这篇文章主要是关于model外项目结构上的一些理解,包括数据格式,代码结构等。
scripts目录包含了一系列脚本,首先按任务方向分为5个folder,分别是anomaly_detection(异常检测),classification,imputation,long term forecast和short term forecast,long term指的是预测series在96-720之间,short term指的是预测series在6-48之间,对于每个方向,里面脚本按<模型名 + 数据集名>组织,比如Informer_M4.sh,是在M4 dataset上应用Informer模型的脚本。
脚本格式:
css
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_96_96 \
--model $model_name \
--data ETTh1 \
--features M \
--seq_len 96 \
--label_len 48 \
--pred_len 96 \
--e_layers 2 \
--d_layers 1 \
--factor 3 \
--enc_in 7 \
--dec_in 7 \
--c_out 7 \
--des 'Exp' \
--itr 1
所有任务都会走run.py(run.py在后面解释)。
task_name就是上面说的5种任务方向之一
root_path和data_path共同决定了要处理任务的数据集来源,印象里这个库是需要使用者自己手动下载的,我尝试了ETT这个数据集 github.com/zhouhaoyi/E... , 后面的解释很多也是以这个数据集为准,所以稍微解释下:
ETT里每个数据点均包含8维特征,包括数据点的记录日期、预测值"油温"以及6个不同类型的外部负载值。数据格式如下:
bash
date,HUFL,HULL,MUFL,MULL,LUFL,LULL,OT
2016-07-01 00:00:00,5.827000141143799,2.009000062942505,1.5989999771118164,0.4620000123977661,4.203000068664552,1.3400000333786009,30.5310001373291
这里面最后一位是油温,属于target。
model_id 里面ETTh1_96_96,第一项是数据集名称(这里h1是因为按小时,还有按分钟的)96_96应该有一个是pred_len,一个是seq_len,不过具体得再看下代码。
features是一个比较重要的参数,它有三种取值:M/MS/S,对应的含义可以看code,M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate,代码里单纯是S的任务很少,大部分都是M,这点我理解M会把features和target看作一个向量,对应transformer里每个token的embedding,然后时间序列本身对应token间的关系,再加上正常batch的概念,时序和transformer也就对应起来了。
seq_len,label_len和pred_len,这些参数影响了数据如何组织,稍后再说。
layers后面这些应该都是模型本身的参数。
回来说说代码格式。如刚才所说,无论什么任务,run.py都是入口,它会根据任务方向的类型,调用对应的Exp_xx,比如long term forecast,会调用Exp_Long_Term_Forecast,所有Exp_xx都继承自Exp_Basic。
对于Exp_xx,最重要的是它的train/test函数,不过在看train/test之前,需要先看看项目对数据组织(data_provider)所做的事情。
data_provider和数据集有关,每个数据集有自己的Data类,比如Dataset_ETT_hour/Dataset_ETT_minute,基本会提供两个函数__read_data__和__getitem__。
__read_data__在Data类初始化时调用,会产生data_x,data_y和seq_mark。seq_mark来自数据集里的timestamp
less
1. df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
2. df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
3. df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
4. df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
这部分对应模型用到的seq_mark部分。
data_x和data_y在初始化时是一致的,在__getitem__中组织。
ini
1. self.data_x = data[border1:border2]
2. self.data_y = data[border1:border2]
getitem,它有一个输入参数index,对应原始ETT数据csv里读取的每一行(不一定是每行,但iteration时候是这个意思)
ini
1. s_begin = index == 初始位置,比如index=1
2. s_end = s_begin + self.seq_len == x的结束为止,比如index=1,seq_len=100,那么s_end=101
3. r_begin = s_end - self.label_len == y往前的位置,label_len
4. r_end = r_begin + self.label_len + self.pred_len
5. seq_x = self.data_x[s_begin:s_end] == 这里组织成一个二维数组
6. seq_y = self.data_y[r_begin:r_end]
7. seq_mark_x和seq_mark_y == 来自对应行timestamp的拆解
之后把数据组织成batch,比如
css
1. batch_x: torch.Size([32, 96, 7])
2. batch_y: torch.Size([32, 144, 7])
3. batch_x_mark: torch.Size([32, 96, 4])
4. batch_y_mark: torch.Size([32, 144, 4])
其中batch_x和batch_x_mark会是一样的维度,batch_y和batch_y_mark会是一样的维度。
data_loader = DataLoader 这是一个pytorch自身提供的class,根据batch size要求组织dataset,在train/test时使用。
现在回来看看train部分,偷懒了,直接在代码里加了一些注释:
scss
1. def train
1. train_data, train_loader = self._get_data(flag='train') == 这里会调用data_set, data_loader = data_provider(self.args, flag) 直接读取组织好的数据
2. for epoch in range(self.args.train_epochs): == 训练多轮
1. for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): == 读取数据
1. 组织数据
1. dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
2. dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
3. == 这里先创建一个pred_len长度的全零向量,然后第二步在列维度和label_len长度的向量进行concat连接
1. batch_y,或者说seq_y本身的长度就是self.label_len + self.pred_len
2. model预测,产生output
1. outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) == 这里dec_inp相当于是把pred_len部分掩盖后的batch_y,预测这部分后,再和真实的batch_y比较error
3. 把真值和预测值match上
1. f_dim = -1 if self.args.features == 'MS' else 0 == 对于ms任务,只比较最后一项作为target,否则把每一项都作为target == 所以对ot问题,相当于不仅是比较OT项,也比较前面的feature项?
2. loss = criterion(outputs, batch_y)
3. loss.backward == 反向传播
再往下似乎应该说模型本身了,应该就不算外围代码了,所以这篇文章就到此为止了~