From 1773116c49ab53143e98cf306ba78f22c163f036 Mon Sep 17 00:00:00 2001 From: Chenghe Wang <33591044+DrZero0@users.noreply.github.com> Date: Thu, 12 May 2022 11:41:44 +0800 Subject: [PATCH 1/5] Update exp_loader.py change self.root to self.data_root --- RLA/easy_log/exp_loader.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/RLA/easy_log/exp_loader.py b/RLA/easy_log/exp_loader.py index ac4806b..a8ee0a6 100644 --- a/RLA/easy_log/exp_loader.py +++ b/RLA/easy_log/exp_loader.py @@ -32,8 +32,7 @@ class ExperimentLoader(object): def __init__(self): self.task_name = exp_manager.hyper_param.get('loaded_task_name', None) self.load_date = exp_manager.hyper_param.get('loaded_date', None) - self.root = getattr(exp_manager, 'root', None) - self.data_root = None + self.data_root = getattr(exp_manager, 'root', None) pass def config(self, task_name, record_date, root): @@ -49,12 +48,12 @@ def is_valid_config(self): logger.warn("meet invalid loader config when use it") logger.warn("load_date", self.load_date) logger.warn("task_name", self.task_name) - logger.warn("root", self.root) + logger.warn("root", self.data_root) return False def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None): if self.is_valid_config: - load_tester = Tester.load_tester(self.load_date, self.task_name, self.root) + load_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) target_hp = copy.deepcopy(exp_manager.hyper_param) target_hp.update(load_tester.hyper_param) if hp_to_overwrite is not None: @@ -75,7 +74,7 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: :return: """ if self.is_valid_config: - loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.root) + loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) # load checkpoint load_res = {} if var_prefix is not None: @@ -100,7 +99,7 @@ def fork_log_files(self): if self.is_valid_config: global exp_manager assert isinstance(exp_manager, Tester) - loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.root) + loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) # copy log file exp_manager.log_file_copy(loaded_tester) # copy attribute @@ -109,4 +108,4 @@ def fork_log_files(self): exp_manager.private_config = loaded_tester.private_config -exp_loader = experimental_loader = ExperimentLoader() \ No newline at end of file +exp_loader = experimental_loader = ExperimentLoader() From f8aad62d3fc12ad1f205cd9c53ab56fd734c697f Mon Sep 17 00:00:00 2001 From: Xiong-Hui Chen Date: Mon, 30 May 2022 07:42:04 +0800 Subject: [PATCH 2/5] update req --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4c73f48..ee394c6 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,9 @@ PS: 2. An alternative way is building your own NFS for your physical machines and locate data_root to the NFS. # TODO -- [ ] video visualization. +- [ ] support sftp-based sync. +- [ ] support custom data structure saving and loading. +- [ ] support video visualization. - [ ] add comments and documents to the functions. - [ ] add an auto integration script. - [ ] download / upload experiment logs through timestamp. From a9ffd2d7de66bd48796cf752e3c7062c3b246dab Mon Sep 17 00:00:00 2001 From: Yihao-Sun <1778826780@qq.com> Date: Fri, 3 Jun 2022 19:19:14 +0800 Subject: [PATCH 3/5] fix the tensorboard bug --- RLA/easy_log/logger.py | 2 +- RLA/easy_log/tester.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/RLA/easy_log/logger.py b/RLA/easy_log/logger.py index 6ac17c2..7f6ed4d 100644 --- a/RLA/easy_log/logger.py +++ b/RLA/easy_log/logger.py @@ -688,7 +688,7 @@ def configure(dir=None, format_strs=None, comm=None, framework='tensorflow'): if format_strs is None: format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',') format_strs = filter(None, format_strs) - output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] + output_formats = [make_output_format(f, dir, log_suffix, framework) for f in format_strs] warn_output_formats = make_output_format('warn', dir, log_suffix, framework) backup_output_formats = make_output_format('backup', dir, log_suffix, framework) diff --git a/RLA/easy_log/tester.py b/RLA/easy_log/tester.py index 9fce214..e7ba846 100644 --- a/RLA/easy_log/tester.py +++ b/RLA/easy_log/tester.py @@ -187,7 +187,7 @@ def _init_logger(self): self.writer = None # logger configure logger.info("store file %s" % self.pkl_file) - logger.configure(self.log_dir, self.private_config["LOG_USED"]) + logger.configure(self.log_dir, self.private_config["LOG_USED"], framework=self.private_config["DL_FRAMEWORK"]) for fmt in logger.Logger.CURRENT.output_formats: if isinstance(fmt, logger.TensorBoardOutputFormat): self.writer = fmt.writer From 530a1eb5ca245da152579e95528d217a5a6bd97d Mon Sep 17 00:00:00 2001 From: Xiong-Hui Chen Date: Sat, 4 Jun 2022 23:50:32 +0800 Subject: [PATCH 4/5] update readme. fix typos in yaml --- README.md | 20 +++++++++++--------- example/sb3_ppo_example/rla_config.yaml | 4 ++-- example/sb_ppo_example/rla_config.yaml | 2 +- example/simplest_code/rla_config.yaml | 2 +- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index ee394c6..6b715e0 100644 --- a/README.md +++ b/README.md @@ -125,15 +125,18 @@ We build an example project for integrating RLA, which can be seen in ./example/ We record scalars by `RLA.easy_log.logger`: ```python from RLA.easy_log import logger - -# scalar variable -value = 1 -logger.record_tabular("k", value) - -# tensorflow summary import tensorflow as tf -summary = tf.Summary() -logger.log_from_tf_summary(summary) +from RLA.easy_log.time_step import time_step_holder + +for i in range(1000): + # time-steps (iterations) + time_step_holder.set_time(i) + # scalar variable + value = 1 + logger.record_tabular("k", value) + # tensorflow summary + summary = tf.Summary() + logger.log_from_tf_summary(summary) ``` **Record checkpoints** @@ -248,4 +251,3 @@ PS: - [ ] add comments and documents to the functions. - [ ] add an auto integration script. - [ ] download / upload experiment logs through timestamp. - diff --git a/example/sb3_ppo_example/rla_config.yaml b/example/sb3_ppo_example/rla_config.yaml index 51573cf..2a8e9fc 100644 --- a/example/sb3_ppo_example/rla_config.yaml +++ b/example/sb3_ppo_example/rla_config.yaml @@ -18,8 +18,8 @@ LOG_USED: - 'tensorboard' - 'csv' -# select a DL framework: tensorflow or pytorch. -DL_FRAMEWORK: 'pytorch' +# select a DL framework: tensorflow or torch. +DL_FRAMEWORK: 'torch' SEND_LOG_FILE: False REMOTE_SETTING: diff --git a/example/sb_ppo_example/rla_config.yaml b/example/sb_ppo_example/rla_config.yaml index e9472e6..a8224ed 100644 --- a/example/sb_ppo_example/rla_config.yaml +++ b/example/sb_ppo_example/rla_config.yaml @@ -18,7 +18,7 @@ LOG_USED: - 'tensorboard' - 'csv' -# select a DL framework: tensorflow or pytorch. +# select a DL framework: tensorflow or torch. DL_FRAMEWORK: 'tensorflow' SEND_LOG_FILE: False diff --git a/example/simplest_code/rla_config.yaml b/example/simplest_code/rla_config.yaml index dc464f3..65714b3 100644 --- a/example/simplest_code/rla_config.yaml +++ b/example/simplest_code/rla_config.yaml @@ -18,7 +18,7 @@ LOG_USED: - 'tensorboard' - 'csv' -# select a DL framework: tensorflow or pytorch. +# select a DL framework: "tensorflow" or "torch". DL_FRAMEWORK: 'tensorflow' SEND_LOG_FILE: False From 9fdb615d34afec045bbb48fcd64ec2f6ae2a10bc Mon Sep 17 00:00:00 2001 From: Xiong-Hui Chen Date: Tue, 7 Jun 2022 11:03:28 +0800 Subject: [PATCH 5/5] fix: fix a bug of dl_framework name typos --- RLA/easy_log/tester.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/RLA/easy_log/tester.py b/RLA/easy_log/tester.py index e7ba846..38cfc10 100644 --- a/RLA/easy_log/tester.py +++ b/RLA/easy_log/tester.py @@ -25,7 +25,7 @@ import shutil import argparse from typing import Optional, Union, Dict, Any -from RLA.const import DEFAULT_X_NAME +from RLA.const import DEFAULT_X_NAME, FRAMEWORK import pathspec def import_hyper_parameters(task_name, record_date): @@ -473,7 +473,7 @@ def new_saver(self, max_to_keep, var_prefix=None): :param max_to_keep: :return: """ - if self.dl_framework == 'tensorflow': + if self.dl_framework == FRAMEWORK.tensorflow: import tensorflow as tf if var_prefix is None: var_prefix = '' @@ -482,20 +482,20 @@ def new_saver(self, max_to_keep, var_prefix=None): for v in var_list: logger.info(v) self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir, save_relative_paths=True) - elif self.dl_framework == 'pytorch': + elif self.dl_framework == FRAMEWORK.torch: self.max_to_keep = max_to_keep self.checkpoint_keep_list = [] else: raise NotImplementedError def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Optional[dict]=None): - if self.dl_framework == 'tensorflow': + if self.dl_framework == FRAMEWORK.tensorflow: import tensorflow as tf iter = self.time_step_holder.get_time() cpt_name = osp.join(self.checkpoint_dir, 'checkpoint') logger.info("save checkpoint to ", cpt_name, iter) self.saver.save(tf.get_default_session(), cpt_name, global_step=iter) - elif self.dl_framework == 'pytorch': + elif self.dl_framework == FRAMEWORK.torch: import torch iter = self.time_step_holder.get_time() torch.save(model_dict, f=tester.checkpoint_dir + "checkpoint-{}.pt".format(iter)) @@ -514,7 +514,7 @@ def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Opt self.add_custom_data(DEFAULT_X_NAME, time_step_holder.get_time(), int, mode='replace') def load_checkpoint(self): - if self.dl_framework == 'tensorflow': + if self.dl_framework == FRAMEWORK.tensorflow: # TODO: load with variable scope. import tensorflow as tf cpt_name = osp.join(self.checkpoint_dir) @@ -524,7 +524,7 @@ def load_checkpoint(self): max_iter = ckpt_path.split('-')[-1] self.time_step_holder.set_time(max_iter) return int(max_iter), None - elif self.dl_framework == 'pytorch': + elif self.dl_framework == FRAMEWORK.torch: import torch return self.checkpoint_keep_list[-1], torch.load(tester.checkpoint_dir + "checkpoint-{}.pt".format(self.checkpoint_keep_list[-1]))