Time-Series-Library外围代码理解

github.com/thuml/Time-... 很不错的一个时序库,看起来包含比较新的,基于Transformer一系列的时间序列衍生模型。前一段时间对这块做了一些调研,这篇文章主要是关于model外项目结构上的一些理解,包括数据格式,代码结构等。

scripts目录包含了一系列脚本,首先按任务方向分为5个folder,分别是anomaly_detection(异常检测),classification,imputation,long term forecast和short term forecast,long term指的是预测series在96-720之间,short term指的是预测series在6-48之间,对于每个方向,里面脚本按<模型名 + 数据集名>组织,比如Informer_M4.sh,是在M4 dataset上应用Informer模型的脚本。

脚本格式:

css 复制代码
python -u run.py \
  --task_name long_term_forecast \
  --is_training 1 \
  --root_path ./dataset/ETT-small/ \
  --data_path ETTh1.csv \
  --model_id ETTh1_96_96 \
  --model $model_name \
  --data ETTh1 \
  --features M \
  --seq_len 96 \
  --label_len 48 \
  --pred_len 96 \
  --e_layers 2 \
  --d_layers 1 \
  --factor 3 \
  --enc_in 7 \
  --dec_in 7 \
  --c_out 7 \
  --des 'Exp' \
  --itr 1

所有任务都会走run.py(run.py在后面解释)。

task_name就是上面说的5种任务方向之一

root_path和data_path共同决定了要处理任务的数据集来源,印象里这个库是需要使用者自己手动下载的,我尝试了ETT这个数据集 github.com/zhouhaoyi/E... , 后面的解释很多也是以这个数据集为准,所以稍微解释下:

ETT里每个数据点均包含8维特征,包括数据点的记录日期、预测值"油温"以及6个不同类型的外部负载值。数据格式如下:

bash 复制代码
date,HUFL,HULL,MUFL,MULL,LUFL,LULL,OT
2016-07-01 00:00:00,5.827000141143799,2.009000062942505,1.5989999771118164,0.4620000123977661,4.203000068664552,1.3400000333786009,30.5310001373291

这里面最后一位是油温,属于target。

model_id 里面ETTh1_96_96,第一项是数据集名称(这里h1是因为按小时,还有按分钟的)96_96应该有一个是pred_len,一个是seq_len,不过具体得再看下代码。

features是一个比较重要的参数,它有三种取值:M/MS/S,对应的含义可以看code,M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate,代码里单纯是S的任务很少,大部分都是M,这点我理解M会把features和target看作一个向量,对应transformer里每个token的embedding,然后时间序列本身对应token间的关系,再加上正常batch的概念,时序和transformer也就对应起来了。

seq_len,label_len和pred_len,这些参数影响了数据如何组织,稍后再说。

layers后面这些应该都是模型本身的参数。

回来说说代码格式。如刚才所说,无论什么任务,run.py都是入口,它会根据任务方向的类型,调用对应的Exp_xx,比如long term forecast,会调用Exp_Long_Term_Forecast,所有Exp_xx都继承自Exp_Basic。

对于Exp_xx,最重要的是它的train/test函数,不过在看train/test之前,需要先看看项目对数据组织(data_provider)所做的事情。

data_provider和数据集有关,每个数据集有自己的Data类,比如Dataset_ETT_hour/Dataset_ETT_minute,基本会提供两个函数__read_data__和__getitem__。

__read_data__在Data类初始化时调用,会产生data_x,data_y和seq_mark。seq_mark来自数据集里的timestamp

less 复制代码
1. df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
2. df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
3. df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
4. df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)

这部分对应模型用到的seq_mark部分。

data_x和data_y在初始化时是一致的,在__getitem__中组织。

ini 复制代码
1. self.data_x = data[border1:border2]
2. self.data_y = data[border1:border2]

getitem,它有一个输入参数index,对应原始ETT数据csv里读取的每一行(不一定是每行,但iteration时候是这个意思)

ini 复制代码
1. s_begin = index                               == 初始位置,比如index=1
2. s_end = s_begin + self.seq_len                == x的结束为止,比如index=1,seq_len=100,那么s_end=101
3. r_begin = s_end - self.label_len              == y往前的位置,label_len
4. r_end = r_begin + self.label_len + self.pred_len
5. seq_x = self.data_x[s_begin:s_end]            == 这里组织成一个二维数组
6. seq_y = self.data_y[r_begin:r_end]
7. seq_mark_x和seq_mark_y                        == 来自对应行timestamp的拆解

之后把数据组织成batch,比如

css 复制代码
1. batch_x: torch.Size([32, 96, 7])
2. batch_y: torch.Size([32, 144, 7])
3. batch_x_mark: torch.Size([32, 96, 4])
4. batch_y_mark: torch.Size([32, 144, 4])

其中batch_x和batch_x_mark会是一样的维度,batch_y和batch_y_mark会是一样的维度。

data_loader = DataLoader 这是一个pytorch自身提供的class,根据batch size要求组织dataset,在train/test时使用。

现在回来看看train部分,偷懒了,直接在代码里加了一些注释:

scss 复制代码
1. def train 
    1. train_data, train_loader = self._get_data(flag='train') == 这里会调用data_set, data_loader = data_provider(self.args, flag) 直接读取组织好的数据
    2. for epoch in range(self.args.train_epochs): == 训练多轮
        1. for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): == 读取数据
            1. 组织数据
                1. dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() 
                2. dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                3. == 这里先创建一个pred_len长度的全零向量,然后第二步在列维度和label_len长度的向量进行concat连接
                    1. batch_y,或者说seq_y本身的长度就是self.label_len + self.pred_len
            2. model预测,产生output
                1. outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) == 这里dec_inp相当于是把pred_len部分掩盖后的batch_y,预测这部分后,再和真实的batch_y比较error
            3. 把真值和预测值match上
                1. f_dim = -1 if self.args.features == 'MS' else 0 == 对于ms任务,只比较最后一项作为target,否则把每一项都作为target == 所以对ot问题,相当于不仅是比较OT项,也比较前面的feature项?
                2. loss = criterion(outputs, batch_y)
                3. loss.backward == 反向传播

再往下似乎应该说模型本身了,应该就不算外围代码了,所以这篇文章就到此为止了~

相关推荐
Vitalia2 分钟前
从零开始学 Rust:基本概念——变量、数据类型、函数、控制流
开发语言·后端·rust
猎人everest3 小时前
SpringBoot应用开发入门
java·spring boot·后端
孤雪心殇8 小时前
简单易懂,解析Go语言中的Map
开发语言·数据结构·后端·golang·go
小突突突10 小时前
模拟实现Java中的计时器
java·开发语言·后端·java-ee
web1376560764310 小时前
Scala的宝藏库:探索常用的第三方库及其应用
开发语言·后端·scala
闲猫10 小时前
go 反射 interface{} 判断类型 获取值 设置值 指针才可以设置值
开发语言·后端·golang·反射
LUCIAZZZ11 小时前
EasyExcel快速入门
java·数据库·后端·mysql·spring·spring cloud·easyexcel
Asthenia041211 小时前
依托IOC容器提供的Bean生命周期,我们能在Bean中做些什么?又能测些什么?
后端
Ase5gqe12 小时前
Spring中的IOC详解
java·后端·spring
小万编程12 小时前
基于SpringBoot+Vue奖学金评比系统(高质量源码,可定制,提供文档,免费部署到本地)
java·spring boot·后端·毕业设计·计算机毕业设计·项目源码