Skip to content

Commit

Permalink
Merge pull request #10 from xionghuichen/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
xionghuichen authored May 29, 2022
2 parents 3c06803 + b3deb67 commit 1724634
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class ExperimentLoader(object):
def __init__(self):
self.task_name = exp_manager.hyper_param.get('loaded_task_name', None)
self.load_date = exp_manager.hyper_param.get('loaded_date', None)
self.root = getattr(exp_manager, 'root', None)
self.data_root = None
self.data_root = getattr(exp_manager, 'root', None)
pass

def config(self, task_name, record_date, root):
Expand All @@ -49,12 +48,12 @@ def is_valid_config(self):
logger.warn("meet invalid loader config when use it")
logger.warn("load_date", self.load_date)
logger.warn("task_name", self.task_name)
logger.warn("root", self.root)
logger.warn("root", self.data_root)
return False

def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None):
if self.is_valid_config:
load_tester = Tester.load_tester(self.load_date, self.task_name, self.root)
load_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
target_hp = copy.deepcopy(exp_manager.hyper_param)
target_hp.update(load_tester.hyper_param)
if hp_to_overwrite is not None:
Expand All @@ -75,7 +74,7 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list:
:return:
"""
if self.is_valid_config:
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.root)
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
# load checkpoint
load_res = {}
if var_prefix is not None:
Expand All @@ -100,7 +99,7 @@ def fork_log_files(self):
if self.is_valid_config:
global exp_manager
assert isinstance(exp_manager, Tester)
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.root)
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
# copy log file
exp_manager.log_file_copy(loaded_tester)
# copy attribute
Expand All @@ -109,4 +108,4 @@ def fork_log_files(self):
exp_manager.private_config = loaded_tester.private_config


exp_loader = experimental_loader = ExperimentLoader()
exp_loader = experimental_loader = ExperimentLoader()

0 comments on commit 1724634

Please sign in to comment.