Skip to content

Commit

Permalink
docs: ✏️ add training logs and ckpt for PEMS-BAY dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Oct 10, 2022
1 parent b3b4caa commit 63b5478
Show file tree
Hide file tree
Showing 7 changed files with 2,967 additions and 3 deletions.
157 changes: 157 additions & 0 deletions step/STEP_PEMS-BAY.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import os
import sys


# TODO: remove it when basicts can be installed by pip
sys.path.append(os.path.abspath(__file__ + "/../../.."))
import torch
from easydict import EasyDict
from basicts.utils.serialization import load_adj

from .step_arch import STEP
from .step_runner import STEPRunner
from .step_loss import step_loss
from .step_data import ForecastingDataset


CFG = EasyDict()

# ================= general ================= #
CFG.DESCRIPTION = "STEP(PEMS-BAY) configuration"
CFG.RUNNER = STEPRunner
CFG.DATASET_CLS = ForecastingDataset
CFG.DATASET_NAME = "PEMS-BAY"
CFG.DATASET_TYPE = "Traffic speed"
CFG.DATASET_INPUT_LEN = 12
CFG.DATASET_OUTPUT_LEN = 12
CFG.DATASET_ARGS = {
"seq_len": 288 * 7
}
CFG.GPU_NUM = 2

# ================= environment ================= #
CFG.ENV = EasyDict()
CFG.ENV.SEED = 0
CFG.ENV.CUDNN = EasyDict()
CFG.ENV.CUDNN.ENABLED = True

# ================= model ================= #
CFG.MODEL = EasyDict()
CFG.MODEL.NAME = "STEP"
CFG.MODEL.ARCH = STEP
adj_mx, _ = load_adj("datasets/" + CFG.DATASET_NAME + "/adj_mx.pkl", "doubletransition")
CFG.MODEL.PARAM = {
"dataset_name": CFG.DATASET_NAME,
"pre_trained_tsformer_path": "tsformer_ckpt/TSFormer_PEMS-BAY.pt",
"tsformer_args": {
"patch_size":12,
"in_channel":1,
"embed_dim":96,
"num_heads":4,
"mlp_ratio":4,
"dropout":0.1,
"num_token":288 * 7 / 12,
"mask_ratio":0.75,
"encoder_depth":4,
"decoder_depth":1,
"mode":"forecasting"
},
"backend_args": {
"num_nodes" : 325,
"supports" :[torch.tensor(i) for i in adj_mx], # the supports are not used
"dropout" : 0.3,
"gcn_bool" : True,
"addaptadj" : True,
"aptinit" : None,
"in_dim" : 2,
"out_dim" : 12,
"residual_channels" : 32,
"dilation_channels" : 32,
"skip_channels" : 256,
"end_channels" : 512,
"kernel_size" : 2,
"blocks" : 4,
"layers" : 2
},
"dgl_args": {
"dataset_name": CFG.DATASET_NAME,
"k": 10,
"input_seq_len": CFG.DATASET_INPUT_LEN,
"output_seq_len": CFG.DATASET_OUTPUT_LEN
}
}
CFG.MODEL.FROWARD_FEATURES = [0, 1, 2]
CFG.MODEL.TARGET_FEATURES = [0]
CFG.MODEL.DDP_FIND_UNUSED_PARAMETERS = True

# ================= optim ================= #
CFG.TRAIN = EasyDict()
CFG.TRAIN.LOSS = step_loss
CFG.TRAIN.OPTIM = EasyDict()
CFG.TRAIN.OPTIM.TYPE = "Adam"
CFG.TRAIN.OPTIM.PARAM= {
"lr":0.001,
"weight_decay":1.0e-5,
"eps":1.0e-8,
}
CFG.TRAIN.LR_SCHEDULER = EasyDict()
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
CFG.TRAIN.LR_SCHEDULER.PARAM= {
"milestones":[1, 18, 36, 54, 72],
"gamma":0.5
}

# ================= train ================= #
CFG.TRAIN.CLIP_GRAD_PARAM = {
"max_norm": 3.0
}
CFG.TRAIN.NUM_EPOCHS = 100
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
"checkpoints",
"_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)])
)
# train data
CFG.TRAIN.DATA = EasyDict()
CFG.TRAIN.NULL_VAL = 0.0
# read data
CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME
# dataloader args, optional
CFG.TRAIN.DATA.BATCH_SIZE = 32
CFG.TRAIN.DATA.PREFETCH = False
CFG.TRAIN.DATA.SHUFFLE = True
CFG.TRAIN.DATA.NUM_WORKERS = 2
CFG.TRAIN.DATA.PIN_MEMORY = True
# curriculum learning
CFG.TRAIN.CL = EasyDict()
CFG.TRAIN.CL.WARM_EPOCHS = 30
CFG.TRAIN.CL.CL_EPOCHS = 3
CFG.TRAIN.CL.PREDICTION_LENGTH = 12

# ================= validate ================= #
CFG.VAL = EasyDict()
CFG.VAL.INTERVAL = 1
# validating data
CFG.VAL.DATA = EasyDict()
# read data
CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME
# dataloader args, optional
CFG.VAL.DATA.BATCH_SIZE = 32
CFG.VAL.DATA.PREFETCH = False
CFG.VAL.DATA.SHUFFLE = False
CFG.VAL.DATA.NUM_WORKERS = 2
CFG.VAL.DATA.PIN_MEMORY = True

# ================= test ================= #
CFG.TEST = EasyDict()
CFG.TEST.INTERVAL = 1
# evluation
# test data
CFG.TEST.DATA = EasyDict()
# read data
CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME
# dataloader args, optional
CFG.TEST.DATA.BATCH_SIZE = 32
CFG.TEST.DATA.PREFETCH = False
CFG.TEST.DATA.SHUFFLE = False
CFG.TEST.DATA.NUM_WORKERS = 2
CFG.TEST.DATA.PIN_MEMORY = True
118 changes: 118 additions & 0 deletions step/TSFormer_PEMS-BAY.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
import sys

# TODO: remove it when basicts can be installed by pip
sys.path.append(os.path.abspath(__file__ + "/../../.."))
from easydict import EasyDict
from basicts.losses import masked_mae

from .step_arch import TSFormer
from .step_runner import TSFormerRunner
from .step_data import PretrainingDataset


CFG = EasyDict()

# ================= general ================= #
CFG.DESCRIPTION = "TSFormer(PEMS-BAY) configuration"
CFG.RUNNER = TSFormerRunner
CFG.DATASET_CLS = PretrainingDataset
CFG.DATASET_NAME = "PEMS-BAY"
CFG.DATASET_TYPE = "Traffic speed"
CFG.DATASET_INPUT_LEN = 288 * 7
CFG.DATASET_OUTPUT_LEN = 12
CFG.GPU_NUM = 2

# ================= environment ================= #
CFG.ENV = EasyDict()
CFG.ENV.SEED = 0
CFG.ENV.CUDNN = EasyDict()
CFG.ENV.CUDNN.ENABLED = True

# ================= model ================= #
CFG.MODEL = EasyDict()
CFG.MODEL.NAME = "TSFormer"
CFG.MODEL.ARCH = TSFormer
CFG.MODEL.PARAM = {
"patch_size":12,
"in_channel":1,
"embed_dim":96,
"num_heads":4,
"mlp_ratio":4,
"dropout":0.1,
"num_token":288 * 7 / 12,
"mask_ratio":0.75,
"encoder_depth":4,
"decoder_depth":1,
"mode":"pre-train"
}
CFG.MODEL.FROWARD_FEATURES = [0]
CFG.MODEL.TARGET_FEATURES = [0]

# ================= optim ================= #
CFG.TRAIN = EasyDict()
CFG.TRAIN.LOSS = masked_mae
CFG.TRAIN.OPTIM = EasyDict()
CFG.TRAIN.OPTIM.TYPE = "Adam"
CFG.TRAIN.OPTIM.PARAM= {
"lr":0.001,
"weight_decay":0,
"eps":1.0e-8,
"betas":(0.9, 0.95)
}
CFG.TRAIN.LR_SCHEDULER = EasyDict()
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
CFG.TRAIN.LR_SCHEDULER.PARAM= {
"milestones":[50],
"gamma":0.5
}

# ================= train ================= #
CFG.TRAIN.CLIP_GRAD_PARAM = {
"max_norm": 5.0
}
CFG.TRAIN.NUM_EPOCHS = 100
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
"checkpoints",
"_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)])
)
# train data
CFG.TRAIN.DATA = EasyDict()
CFG.TRAIN.NULL_VAL = 0.0
# read data
CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME
# dataloader args, optional
CFG.TRAIN.DATA.BATCH_SIZE = 16
CFG.TRAIN.DATA.PREFETCH = False
CFG.TRAIN.DATA.SHUFFLE = True
CFG.TRAIN.DATA.NUM_WORKERS = 2
CFG.TRAIN.DATA.PIN_MEMORY = True

# ================= validate ================= #
CFG.VAL = EasyDict()
CFG.VAL.INTERVAL = 1
# validating data
CFG.VAL.DATA = EasyDict()
# read data
CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME
# dataloader args, optional
CFG.VAL.DATA.BATCH_SIZE = 16
CFG.VAL.DATA.PREFETCH = False
CFG.VAL.DATA.SHUFFLE = False
CFG.VAL.DATA.NUM_WORKERS = 2
CFG.VAL.DATA.PIN_MEMORY = True

# ================= test ================= #
CFG.TEST = EasyDict()
CFG.TEST.INTERVAL = 1
# evluation
# test data
CFG.TEST.DATA = EasyDict()
# read data
CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME
# dataloader args, optional
CFG.TEST.DATA.BATCH_SIZE = 16
CFG.TEST.DATA.PREFETCH = False
CFG.TEST.DATA.SHUFFLE = False
CFG.TEST.DATA.NUM_WORKERS = 2
CFG.TEST.DATA.PIN_MEMORY = True
7 changes: 5 additions & 2 deletions step/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ def parse_args():
# parser.add_argument("-c", "--cfg", default="step/STEP_METR-LA.py", help="training config")

# parser.add_argument("-c", "--cfg", default="step/TSFormer_PEMS04.py", help="training config")
parser.add_argument("-c", "--cfg", default="step/STEP_PEMS04.py", help="training config")
parser.add_argument("--gpus", default="0, 1", help="visible gpus")
# parser.add_argument("-c", "--cfg", default="step/STEP_PEMS04.py", help="training config")

# parser.add_argument("-c", "--cfg", default="step/TSFormer_PEMS-BAY.py", help="training config")
parser.add_argument("-c", "--cfg", default="step/STEP_PEMS-BAY.py", help="training config")
parser.add_argument("--gpus", default="0", help="visible gpus")
return parser.parse_args()

if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion step/step_arch/discrete_graph_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, dataset_name, k, input_seq_len, output_seq_len):

# CNN for global feature extraction
## for the dimension, see https://github.com/zezhishao/STEP/issues/1#issuecomment-1191640023
self.dim_fc = {"METR-LA": 383552, "PEMS04": 217296, "PEMS-BAY": 217296}[dataset_name]
self.dim_fc = {"METR-LA": 383552, "PEMS04": 217296, "PEMS-BAY": 583424}[dataset_name]
self.embedding_dim = 100
## network structure
self.conv1 = torch.nn.Conv1d(1, 8, 10, stride=1) # .to(device)
Expand Down
Loading

0 comments on commit 63b5478

Please sign in to comment.