Skip to content

Commit

Permalink
Dev (#10)
Browse files Browse the repository at this point in the history
* Update README.md

* Dev (#20)

* fix: fix bugs of torch-version ckp loader

* refactor: add sync_timestep for hp loader

* fix: minor changes for version compatibility
  • Loading branch information
xionghuichen authored Jul 14, 2022
1 parent 560d15a commit 8d4c905
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
7 changes: 4 additions & 3 deletions RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None, sync_t
else:
return argparse.Namespace(**exp_manager.hyper_param)

def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: Optional[list]=None):
def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: Optional[list]=None, verbose=True):
"""
:param var_prefix: the prefix of namescope (for tf) to load. Set to '' to load all of the parameters.
Expand All @@ -81,8 +81,9 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list:
"""
if self.is_valid_config:
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
print("attrs of the loaded tester")
pprint(loaded_tester.__dict__)
if verbose:
print("attrs of the loaded tester")
pprint(loaded_tester.__dict__)
# load checkpoint
load_res = {}
if var_prefix is not None:
Expand Down
10 changes: 6 additions & 4 deletions RLA/easy_log/log_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,12 @@ def __init__(self, proj_root, task_table_name, regex, filter, *args, **kwargs):
self.small_timestep_regs = []
super(DeleteLogTool, self).__init__(*args, **kwargs)

def _delete_related_log(self, regex, show=False):
def _delete_related_log(self, regex, show=False, delete_log_types=None):
log_found = 0
for log_type in self.log_types:
print(f"--- search {log_type} ---")
if delete_log_types is not None and log_type not in delete_log_types:
continue
root_dir_regex = osp.join(self.proj_root, log_type, self.task_table_name, regex)
empty = True
for root_dir in glob.glob(root_dir_regex):
Expand Down Expand Up @@ -144,15 +146,15 @@ def _delete_related_log(self, regex, show=False):
if empty: print("empty regex {}".format(root_dir_regex))
return log_found

def delete_related_log(self, skip_ask=False):
self._delete_related_log(show=True, regex=self.regex)
def delete_related_log(self, skip_ask=False, delete_log_types=None):
self._delete_related_log(show=True, regex=self.regex, delete_log_types=delete_log_types)
if skip_ask:
s = 'y'
else:
s = input("delete these files? (y/n)")
if s == 'y':
print("do delete ...")
return self._delete_related_log(show=False, regex=self.regex)
return self._delete_related_log(show=False, regex=self.regex, delete_log_types=delete_log_types)
else:
return 0

Expand Down
16 changes: 11 additions & 5 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ def update_log_files_location(self, root:str):
task_table_name = getattr(self, 'task_name', None)
print("[WARN] you are using an old-version RLA. "
"Some attributes' name have been changed (task_name->task_table_name).")
else:
raise RuntimeError("invalid ExpManager: task_table_name cannot be found", )
if task_table_name is None:
raise RuntimeError("invalid ExpManager: task_table_name cannot be found", )
code_dir, _ = self.__create_file_directory(osp.join(self.data_root, CODE, task_table_name), '', is_file=False)
log_dir, _ = self.__create_file_directory(osp.join(self.data_root, LOG, task_table_name), '', is_file=False)
self.pkl_dir, self.pkl_file = self.__create_file_directory(osp.join(self.data_root, ARCHIVE_TESTER, task_table_name), '.pkl')
Expand Down Expand Up @@ -430,9 +430,15 @@ def log_file_finder(cls, record_date, task_table_name='train', file_root='../che
raise NotImplementedError
for search_item in search_list:
if search_item.startswith(str(record_date.strftime("%H-%M-%S-%f"))):
split_dir = search_item.split(' ')
# self.__ipaddr = split_dir[1]
info = " ".join(split_dir[2:])
try:
split_dir = search_item.split('_')
assert len(split_dir) >= 2
info = " ".join(split_dir[2:])
except AssertionError as e:
split_dir = search_item.split(' ')
# self.__ipaddr = split_dir[1]
info = "_".join(split_dir[2:])
print("[WARN] We find an old-version experiment data.")
logger.info("load data: \n ts {}, \n ip {}, \n info {}".format(split_dir[0], split_dir[1], info))
file_found = search_item
break
Expand Down

0 comments on commit 8d4c905

Please sign in to comment.