diff --git a/RLA/easy_log/exp_loader.py b/RLA/easy_log/exp_loader.py index ac4806b..a8ee0a6 100644 --- a/RLA/easy_log/exp_loader.py +++ b/RLA/easy_log/exp_loader.py @@ -32,8 +32,7 @@ class ExperimentLoader(object): def __init__(self): self.task_name = exp_manager.hyper_param.get('loaded_task_name', None) self.load_date = exp_manager.hyper_param.get('loaded_date', None) - self.root = getattr(exp_manager, 'root', None) - self.data_root = None + self.data_root = getattr(exp_manager, 'root', None) pass def config(self, task_name, record_date, root): @@ -49,12 +48,12 @@ def is_valid_config(self): logger.warn("meet invalid loader config when use it") logger.warn("load_date", self.load_date) logger.warn("task_name", self.task_name) - logger.warn("root", self.root) + logger.warn("root", self.data_root) return False def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None): if self.is_valid_config: - load_tester = Tester.load_tester(self.load_date, self.task_name, self.root) + load_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) target_hp = copy.deepcopy(exp_manager.hyper_param) target_hp.update(load_tester.hyper_param) if hp_to_overwrite is not None: @@ -75,7 +74,7 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: :return: """ if self.is_valid_config: - loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.root) + loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) # load checkpoint load_res = {} if var_prefix is not None: @@ -100,7 +99,7 @@ def fork_log_files(self): if self.is_valid_config: global exp_manager assert isinstance(exp_manager, Tester) - loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.root) + loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) # copy log file exp_manager.log_file_copy(loaded_tester) # copy attribute @@ -109,4 +108,4 @@ def fork_log_files(self): exp_manager.private_config = loaded_tester.private_config -exp_loader = experimental_loader = ExperimentLoader() \ No newline at end of file +exp_loader = experimental_loader = ExperimentLoader()