Skip to content

Commit

Permalink
Merge pull request #13 from zezhishao/dev/refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao authored Sep 30, 2022
2 parents dc8ba47 + 14a12ab commit b7f2840
Show file tree
Hide file tree
Showing 21 changed files with 1,745 additions and 220 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Configuration file `step/STEP_$DATASET.py` describes the forecasting configurati
Edit `BATCH_SIZE` and `GPU_NUM` in the configuration file and `--gpu` in the command line to run on your own hardware.
Note that different GPU number leads to different real batch sizes, affecting the learning rate setting and the forecasting accuracy.

Our training logs are shown in `train_logs/Backend_metr.log`, `train_logs/Backend_pems04.log`, and `train_logs/Backend_pemsbay.log`.
Our training logs are shown in `train_logs/STEP_METR-LA.log`, `train_logs/STEP_PEMS04.log`, and `train_logs/STEP_PEMS-BAY.log`.

## ⚒ Train STEP from Scratch

Expand All @@ -110,6 +110,8 @@ Edit the `BATCH_SIZE` and `GPU_NUM` in the configuration file and `--gpu` in the
All the training logs, including the config file, training log, and checkpoints, will be saved in `checkpoints/MODEL_EPOCH/MD5_of_config_file`.
For example, `checkpoints/TSFormer_100/5afe80b3e7a3dc055158bcfe99afbd7f`.

Our training logs are shown in `train_logs/TSFormer_METR-LA.log`, `train_logs/TSFormer_PEMS04.log`, and `train_logs/TSFormer_PEMS-BAY.log`, and the pre-trained TSFormers for each datasets are placed in `tsformer_ckpt` folder.

### **Forecasting Stage**

After pre-training TSFormer, move your pre-trained best checkpoint to `tsformer_ckpt/`.
Expand All @@ -124,7 +126,7 @@ Replace `$DATASET_NAME` with one of `METR-LA`, `PEMS-BAY`, `PEMS04`.
Then train the downstream STGNN (Graph WaveNet) like in section 4.

## 📈 Performance and Visualization
<!-- <img src="figures/Table3.png" alt="Table3" style="zoom:60.22%;" /><img src="figures/Table4.png" alt="Table4" style="zoom:51%;" /> -->

<img src="figure/MainResults.png" alt="TheTable" style="zoom:49.4%;" />

<img src="figure/Inspecting.jpg" alt="Visualization" style="zoom:25%;" />
Expand Down
64 changes: 63 additions & 1 deletion basicts/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def standard_transform(data: np.array, output_dir: str, train_index: list, histo
np.array: normalized raw time series data.
"""

# data: L, N, C
# data: L, N, C, C=1
data_train = data[:train_index[-1][1], ...]

mean, std = data_train[..., 0].mean(), data_train[..., 0].std()
Expand Down Expand Up @@ -57,3 +57,65 @@ def re_standard_transform(data: torch.Tensor, **kwargs) -> torch.Tensor:
data = data * std
data = data + mean
return data


@SCALER_REGISTRY.register()
def min_max_transform(data: np.array, output_dir: str, train_index: list, history_seq_len: int, future_seq_len: int) -> np.array:
"""Min-max normalization.
Args:
data (np.array): raw time series data.
output_dir (str): output dir path.
train_index (list): train index.
history_seq_len (int): historical sequence length.
future_seq_len (int): future sequence length.
Returns:
np.array: normalized raw time series data.
"""

# L, N, C, C=1
data_train = data[:train_index[-1][1], ...]

min_value = data_train.min(axis=(0, 1), keepdims=False)[0]
max_value = data_train.max(axis=(0, 1), keepdims=False)[0]

print("min: (training data)", min_value)
print("max: (training data)", max_value)
scaler = {}
scaler["func"] = re_min_max_transform.__name__
scaler["args"] = {"min_value": min_value, "max_value": max_value}
# label to identify the scaler for different settings.
# To be fair, only one transformation can be implemented per dataset.
# TODO: Therefore we (for now) do not distinguish between the data produced by the different transformation methods.
with open(output_dir + "/scaler_in{0}_out{1}.pkl".format(history_seq_len, future_seq_len), "wb") as f:
pickle.dump(scaler, f)

def normalize(x):
# ref:
# https://github.com/guoshnBJTU/ASTGNN/blob/f0f8c2f42f76cc3a03ea26f233de5961c79c9037/lib/utils.py#L17
x = 1. * (x - min_value) / (max_value - min_value)
x = 2. * x - 1.
return x

data_norm = normalize(data)
return data_norm


@SCALER_REGISTRY.register()
def re_min_max_transform(data: torch.Tensor, **kwargs) -> torch.Tensor:
"""Standard re-min-max transform.
Args:
data (torch.Tensor): input data.
Returns:
torch.Tensor: re-scaled data.
"""

min_value, max_value = kwargs["min_value"], kwargs["max_value"]
# ref:
# https://github.com/guoshnBJTU/ASTGNN/blob/f0f8c2f42f76cc3a03ea26f233de5961c79c9037/lib/utils.py#L23
data = (data + 1.) / 2.
data = 1. * data * (max_value - min_value) + min_value
return data
9 changes: 5 additions & 4 deletions basicts/runners/base_tsf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Tuple, Union, Optional

import torch
import numpy as np
from easytorch.utils.dist import master_only

from .base_runner import BaseRunner
Expand All @@ -26,7 +27,7 @@ def __init__(self, cfg: dict):
super().__init__(cfg)
self.dataset_name = cfg["DATASET_NAME"]
# different datasets have different null_values, e.g., 0.0 or np.nan.
self.null_val = cfg["TRAIN"].get("NULL_VAL", 0)
self.null_val = cfg["TRAIN"].get("NULL_VAL", np.nan) # consist with metric functions
self.dataset_type = cfg["DATASET_TYPE"]

# read scaler for re-normalization
Expand Down Expand Up @@ -99,7 +100,7 @@ def build_train_dataset(self, cfg: dict):
index_file_path = "{0}/index_in{1}_out{2}.pkl".format(cfg["TRAIN"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"])

# build dataset args
dataset_args = cfg.get("DATASET_ARGS", dict())
dataset_args = cfg.get("DATASET_ARGS", {})
# three necessary arguments, data file path, corresponding index file path, and mode (train, valid, or test)
dataset_args["data_file_path"] = data_file_path
dataset_args["index_file_path"] = index_file_path
Expand Down Expand Up @@ -127,7 +128,7 @@ def build_val_dataset(cfg: dict):
index_file_path = "{0}/index_in{1}_out{2}.pkl".format(cfg["VAL"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"])

# build dataset args
dataset_args = cfg.get("DATASET_ARGS", dict())
dataset_args = cfg.get("DATASET_ARGS", {})
# three necessary arguments, data file path, corresponding index file path, and mode (train, valid, or test)
dataset_args["data_file_path"] = data_file_path
dataset_args["index_file_path"] = index_file_path
Expand All @@ -153,7 +154,7 @@ def build_test_dataset(cfg: dict):
index_file_path = "{0}/index_in{1}_out{2}.pkl".format(cfg["TEST"]["DATA"]["DIR"], cfg["DATASET_INPUT_LEN"], cfg["DATASET_OUTPUT_LEN"])

# build dataset args
dataset_args = cfg.get("DATASET_ARGS", dict())
dataset_args = cfg.get("DATASET_ARGS", {})
# three necessary arguments, data file path, corresponding index file path, and mode (train, valid, or test)
dataset_args["data_file_path"] = data_file_path
dataset_args["index_file_path"] = index_file_path
Expand Down
4 changes: 2 additions & 2 deletions scripts/data_preparation/PEMS04/generate_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from generate_adj_mx import generate_adj_pems04
# TODO: remove it when basicts can be installed by pip
sys.path.append(os.path.abspath(__file__ + "/../../../.."))
from basicts.data.transform import standard_transform
from basicts.data.transform import min_max_transform


def generate_data(args: argparse.Namespace):
Expand Down Expand Up @@ -65,7 +65,7 @@ def generate_data(args: argparse.Namespace):
test_index = index_list[train_num_short +
valid_num_short: train_num_short + valid_num_short + test_num_short]

scaler = standard_transform
scaler = min_max_transform
data_norm = scaler(data, output_dir, train_index, history_seq_len, future_seq_len)

# add external feature
Expand Down
10 changes: 5 additions & 5 deletions step/STEP_METR-LA.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@
"mode":"forecasting"
},
"backend_args": {
"num_nodes" : 207,
"num_nodes" : 207,
"supports" :[torch.tensor(i) for i in adj_mx], # the supports are not used
"dropout" : 0.3,
"gcn_bool" : True,
"addaptadj" : True,
"aptinit" : None,
"dropout" : 0.3,
"gcn_bool" : True,
"addaptadj" : True,
"aptinit" : None,
"in_dim" : 2,
"out_dim" : 12,
"residual_channels" : 32,
Expand Down
4 changes: 3 additions & 1 deletion step/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

def parse_args():
parser = ArgumentParser(description="Run time series forecasting model in BasicTS framework!")
parser.add_argument("-c", "--cfg", default="step/STEP_METR-LA.py", help="training config")
# parser.add_argument("-c", "--cfg", default="step/TSFormer_METR-LA.py", help="training config")
# parser.add_argument("-c", "--cfg", default="step/STEP_METR-LA.py", help="training config")

parser.add_argument("-c", "--cfg", default="step/TSFormer_PEMS04.py", help="training config")
parser.add_argument("--gpus", default="0", help="visible gpus")
return parser.parse_args()

Expand Down
2 changes: 1 addition & 1 deletion step/step_arch/discrete_graph_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class DiscreteGraphLearning(nn.Module):

def __init__(self, dataset_name, k, input_seq_len, output_seq_len):
super().__init__()

self.k = k # the "k" of knn graph
self.num_nodes = {"METR-LA": 207, "PEMS04": 307, "PEMS-BAY": 325}[dataset_name]
self.train_length = {"METR-LA": 23990, "PEMS04": 13599, "PEMS-BAY": 36482}[dataset_name]
Expand Down
4 changes: 2 additions & 2 deletions step/step_arch/graphwavenet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class GraphWaveNet(nn.Module):

def __init__(self, num_nodes, supports, dropout=0.3, gcn_bool=True, addaptadj=True, aptinit=None, in_dim=2,out_dim=12,residual_channels=32,dilation_channels=32,skip_channels=256,end_channels=512,kernel_size=2,blocks=4,layers=2, **kwargs):
"""
kindly note that although there is a 'supports' parameter, we will not use the prior graph if there is a learned dependency graph.
kindly note that although there is a 'supports' parameter, we will not use the prior graph if there is a learned dependency graph.
Details can be found in the feed forward function.
"""
super(GraphWaveNet, self).__init__()
Expand Down Expand Up @@ -228,7 +228,7 @@ def forward(self, input, hidden_states, sampled_adj):
x = F.relu(skip)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x)

# reshape output: [B, P, N, 1] -> [B, N, P]
x = x.squeeze(-1).transpose(1, 2)
return x
2 changes: 1 addition & 1 deletion step/step_arch/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def batch_cosine_similarity(x, y):
l2_m = torch.matmul(l2_x.unsqueeze(dim=2), l2_y.unsqueeze(dim=2).transpose(1, 2))
# 计算分子
l2_z = torch.matmul(x, y.transpose(1, 2))
# cos similarity affinity matrix
# cos similarity affinity matrix
cos_affnity = l2_z / l2_m
adj = cos_affnity
return adj
Expand Down
3 changes: 0 additions & 3 deletions step/step_arch/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ def __init__(self, dataset_name, pre_trained_tsformer_path, tsformer_args, backe
self.dataset_name = dataset_name
self.pre_trained_tsformer_path = pre_trained_tsformer_path

# tsformer and backend model args
tsformer_args = tsformer_args
backend_args = backend_args
# iniitalize the tsformer and backend models
self.tsformer = TSFormer(**tsformer_args)
self.backend = GraphWaveNet(**backend_args)
Expand Down
6 changes: 3 additions & 3 deletions step/step_arch/tsformer/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def __init__(self, patch_size, in_channel, embed_dim, norm_layer):
self.input_channel = in_channel
self.output_channel = embed_dim
self.input_embedding = nn.Conv2d(
in_channel,
embed_dim,
kernel_size=(self.len_patch, 1),
in_channel,
embed_dim,
kernel_size=(self.len_patch, 1),
stride=(self.len_patch, 1))
self.norm_layer = norm_layer if norm_layer is not None else nn.Identity()

Expand Down
2 changes: 1 addition & 1 deletion step/step_arch/tsformer/tsformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def decoding(self, hidden_states_unmasked, masked_token_index):

# add mask tokens
hidden_states_masked = self.positional_encoding(
self.mask_token.expand(batch_size, num_nodes, len(masked_token_index), hidden_states_unmasked.shape[-1]),
self.mask_token.expand(batch_size, num_nodes, len(masked_token_index), hidden_states_unmasked.shape[-1]),
index=masked_token_index
)
hidden_states_full = torch.cat([hidden_states_unmasked, hidden_states_masked], dim=-2) # B, N, P, d
Expand Down
2 changes: 1 addition & 1 deletion step/step_data/forecasting_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class ForecastingDataset(Dataset):
"""Time series forecasting dataset."""

def __init__(self, data_file_path: str, index_file_path: str, mode: str, seq_len:int) -> None:
"""Init the dataset in the forecasting stage.
"""Init the dataset in the forecasting stage.
Args:
data_file_path (str): data file path.
Expand Down
1 change: 0 additions & 1 deletion step/step_loss/step_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
import numpy as np
from torch import nn
from basicts.losses import masked_mae
Expand Down
Loading

0 comments on commit b7f2840

Please sign in to comment.