Skip to content

Commit

Permalink
add custmize metrics.
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Oct 1, 2022
1 parent 407bdd4 commit ec6578a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 2 additions & 0 deletions step/step_runner/step_runner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import torch

from basicts.runners import BaseTimeSeriesForecastingRunner
from basicts.metrics import masked_mae, masked_rmse, masked_mape


class STEPRunner(BaseTimeSeriesForecastingRunner):
def __init__(self, cfg: dict):
super().__init__(cfg)
self.metrics = cfg.get("METRICS", {"MAE": masked_mae, "RMSE": masked_rmse, "MAPE": masked_mape})
self.forward_features = cfg["MODEL"].get("FROWARD_FEATURES", None)
self.target_features = cfg["MODEL"].get("TARGET_FEATURES", None)

Expand Down
1 change: 0 additions & 1 deletion step/step_runner/tsformer_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import time
import torch

from easytorch.utils.dist import master_only
Expand Down

0 comments on commit ec6578a

Please sign in to comment.