From c9f39809fae7d56a1efc01b9652ff1318620ab23 Mon Sep 17 00:00:00 2001 From: Janice Lan Date: Mon, 17 Jul 2023 17:14:54 -0700 Subject: [PATCH] default for get task metrics --- ocpmodels/trainers/base_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index 3b7929341d..69dc16f2d0 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -179,7 +179,7 @@ def __init__( self.evaluator = Evaluator( task=name, eval_metrics=self.config["task"].get( - "evaluation_metrics", Evaluator.task_metrics[name] + "evaluation_metrics", Evaluator.task_metrics.get(name, {}) ), ) @@ -952,7 +952,7 @@ def validate(self, split="val", disable_tqdm=False): evaluator = Evaluator( task=self.name, eval_metrics=self.config["task"].get( - "evaluation_metrics", Evaluator.task_metrics[self.name] + "evaluation_metrics", Evaluator.task_metrics.get(self.name, {}) ), )