diff --git a/RLA/easy_log/tester.py b/RLA/easy_log/tester.py index 66f7807..ac36bc4 100644 --- a/RLA/easy_log/tester.py +++ b/RLA/easy_log/tester.py @@ -113,7 +113,7 @@ def __init__(self): @deprecated_alias(task_name='task_table_name', private_config_path='rla_config', log_root='data_root') def configure(self, task_table_name: str, rla_config: Union[str, dict], data_root: str, - ignore_file_path: Optional[str] = None, run_file: Optional[str] = None, + ignore_file_path: Optional[str] = None, run_file: Union[str, List[str]] = None, is_master_node: bool = False, code_root: Optional[str] = None): """ The function to configure your exp_manager, which should be run before your experiments. @@ -131,7 +131,7 @@ def configure(self, task_table_name: str, rla_config: Union[str, dict], data_roo :type ignore_file_path: str :param run_file: If you have extra files out of your codebase (e.g., some scripts to run the code), you can pass it to the run_file. Then we will backup the run_file too. - :type run_file: str + :type run_file: str or list :param is_master_node: In "distributed training & centralized logs" mode (By set SEND_LOG_FILE in rla_config.yaml to True), you should mark the master node (is_master_node=True) to collect logs of the slave nodes (is_master_node=False). :type is_master_node: bool @@ -502,18 +502,25 @@ def get_ignore_files(self, src, names): def __copy_source_code(self, run_file, code_dir): import shutil + def _copy_run_file(run_file, code_dir): + if type(run_file) == list: + for file_name in run_file: + shutil.copy(file_name, code_dir) + else: + shutil.copy(run_file, code_dir) if self.private_config["PROJECT_TYPE"]["backup_code_by"] == 'lib': assert os.listdir(code_dir) == [] os.removedirs(code_dir) shutil.copytree(osp.join(self.project_root, self.private_config["BACKUP_CONFIG"]["lib_dir"]), code_dir) assert run_file is not None, "you should define the run_file in lib backup mode." - shutil.copy(run_file, code_dir) + _copy_run_file(run_file, code_dir) elif self.private_config["PROJECT_TYPE"]["backup_code_by"] == 'source': - for dir_name in self.private_config["BACKUP_CONFIG"]["backup_code_dir"]: - shutil.copytree(osp.join(self.project_root, dir_name), osp.join(code_dir, dir_name), - ignore=self.get_ignore_files) + if self.private_config["BACKUP_CONFIG"].get("backup_code_dir"): + for dir_name in self.private_config["BACKUP_CONFIG"]["backup_code_dir"]: + shutil.copytree(osp.join(self.project_root, dir_name), osp.join(code_dir, dir_name), + ignore=self.get_ignore_files) if run_file is not None: - shutil.copy(run_file, code_dir) + _copy_run_file(run_file, code_dir) else: raise NotImplementedError