Time-Series-Library外围代码理解

github.com/thuml/Time-... 很不错的一个时序库,看起来包含比较新的,基于Transformer一系列的时间序列衍生模型。前一段时间对这块做了一些调研,这篇文章主要是关于model外项目结构上的一些理解,包括数据格式,代码结构等。

scripts目录包含了一系列脚本,首先按任务方向分为5个folder,分别是anomaly_detection(异常检测),classification,imputation,long term forecast和short term forecast,long term指的是预测series在96-720之间,short term指的是预测series在6-48之间,对于每个方向,里面脚本按<模型名 + 数据集名>组织,比如Informer_M4.sh,是在M4 dataset上应用Informer模型的脚本。

脚本格式:

css 复制代码
python -u run.py \
  --task_name long_term_forecast \
  --is_training 1 \
  --root_path ./dataset/ETT-small/ \
  --data_path ETTh1.csv \
  --model_id ETTh1_96_96 \
  --model $model_name \
  --data ETTh1 \
  --features M \
  --seq_len 96 \
  --label_len 48 \
  --pred_len 96 \
  --e_layers 2 \
  --d_layers 1 \
  --factor 3 \
  --enc_in 7 \
  --dec_in 7 \
  --c_out 7 \
  --des 'Exp' \
  --itr 1

所有任务都会走run.py(run.py在后面解释)。

task_name就是上面说的5种任务方向之一

root_path和data_path共同决定了要处理任务的数据集来源,印象里这个库是需要使用者自己手动下载的,我尝试了ETT这个数据集 github.com/zhouhaoyi/E... , 后面的解释很多也是以这个数据集为准,所以稍微解释下:

ETT里每个数据点均包含8维特征,包括数据点的记录日期、预测值"油温"以及6个不同类型的外部负载值。数据格式如下:

bash 复制代码
date,HUFL,HULL,MUFL,MULL,LUFL,LULL,OT
2016-07-01 00:00:00,5.827000141143799,2.009000062942505,1.5989999771118164,0.4620000123977661,4.203000068664552,1.3400000333786009,30.5310001373291

这里面最后一位是油温,属于target。

model_id 里面ETTh1_96_96,第一项是数据集名称(这里h1是因为按小时,还有按分钟的)96_96应该有一个是pred_len,一个是seq_len,不过具体得再看下代码。

features是一个比较重要的参数,它有三种取值:M/MS/S,对应的含义可以看code,M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate,代码里单纯是S的任务很少,大部分都是M,这点我理解M会把features和target看作一个向量,对应transformer里每个token的embedding,然后时间序列本身对应token间的关系,再加上正常batch的概念,时序和transformer也就对应起来了。

seq_len,label_len和pred_len,这些参数影响了数据如何组织,稍后再说。

layers后面这些应该都是模型本身的参数。

回来说说代码格式。如刚才所说,无论什么任务,run.py都是入口,它会根据任务方向的类型,调用对应的Exp_xx,比如long term forecast,会调用Exp_Long_Term_Forecast,所有Exp_xx都继承自Exp_Basic。

对于Exp_xx,最重要的是它的train/test函数,不过在看train/test之前,需要先看看项目对数据组织(data_provider)所做的事情。

data_provider和数据集有关,每个数据集有自己的Data类,比如Dataset_ETT_hour/Dataset_ETT_minute,基本会提供两个函数__read_data__和__getitem__。

__read_data__在Data类初始化时调用,会产生data_x,data_y和seq_mark。seq_mark来自数据集里的timestamp

less 复制代码
1. df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
2. df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
3. df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
4. df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)

这部分对应模型用到的seq_mark部分。

data_x和data_y在初始化时是一致的,在__getitem__中组织。

ini 复制代码
1. self.data_x = data[border1:border2]
2. self.data_y = data[border1:border2]

getitem,它有一个输入参数index,对应原始ETT数据csv里读取的每一行(不一定是每行,但iteration时候是这个意思)

ini 复制代码
1. s_begin = index                               == 初始位置,比如index=1
2. s_end = s_begin + self.seq_len                == x的结束为止,比如index=1,seq_len=100,那么s_end=101
3. r_begin = s_end - self.label_len              == y往前的位置,label_len
4. r_end = r_begin + self.label_len + self.pred_len
5. seq_x = self.data_x[s_begin:s_end]            == 这里组织成一个二维数组
6. seq_y = self.data_y[r_begin:r_end]
7. seq_mark_x和seq_mark_y                        == 来自对应行timestamp的拆解

之后把数据组织成batch,比如

css 复制代码
1. batch_x: torch.Size([32, 96, 7])
2. batch_y: torch.Size([32, 144, 7])
3. batch_x_mark: torch.Size([32, 96, 4])
4. batch_y_mark: torch.Size([32, 144, 4])

其中batch_x和batch_x_mark会是一样的维度,batch_y和batch_y_mark会是一样的维度。

data_loader = DataLoader 这是一个pytorch自身提供的class,根据batch size要求组织dataset,在train/test时使用。

现在回来看看train部分,偷懒了,直接在代码里加了一些注释:

scss 复制代码
1. def train 
    1. train_data, train_loader = self._get_data(flag='train') == 这里会调用data_set, data_loader = data_provider(self.args, flag) 直接读取组织好的数据
    2. for epoch in range(self.args.train_epochs): == 训练多轮
        1. for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): == 读取数据
            1. 组织数据
                1. dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() 
                2. dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                3. == 这里先创建一个pred_len长度的全零向量,然后第二步在列维度和label_len长度的向量进行concat连接
                    1. batch_y,或者说seq_y本身的长度就是self.label_len + self.pred_len
            2. model预测,产生output
                1. outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) == 这里dec_inp相当于是把pred_len部分掩盖后的batch_y,预测这部分后,再和真实的batch_y比较error
            3. 把真值和预测值match上
                1. f_dim = -1 if self.args.features == 'MS' else 0 == 对于ms任务,只比较最后一项作为target,否则把每一项都作为target == 所以对ot问题,相当于不仅是比较OT项,也比较前面的feature项?
                2. loss = criterion(outputs, batch_y)
                3. loss.backward == 反向传播

再往下似乎应该说模型本身了,应该就不算外围代码了,所以这篇文章就到此为止了~

相关推荐
潘多编程1 小时前
Spring Boot微服务架构设计与实战
spring boot·后端·微服务
2402_857589361 小时前
新闻推荐系统:Spring Boot框架详解
java·spring boot·后端
2401_857622662 小时前
新闻推荐系统:Spring Boot的可扩展性
java·spring boot·后端
Amagi.3 小时前
Spring中Bean的作用域
java·后端·spring
侠客行03173 小时前
xxl-job调度平台之任务触发
java·后端·源码
2402_857589363 小时前
Spring Boot新闻推荐系统设计与实现
java·spring boot·后端
J老熊3 小时前
Spring Cloud Netflix Eureka 注册中心讲解和案例示范
java·后端·spring·spring cloud·面试·eureka·系统架构
Benaso4 小时前
Rust 快速入门(一)
开发语言·后端·rust
sco52824 小时前
SpringBoot 集成 Ehcache 实现本地缓存
java·spring boot·后端
原机小子4 小时前
在线教育的未来:SpringBoot技术实现
java·spring boot·后端