Skip to content

Commit

Permalink
Merge branch 'dev' of github.com:xionghuichen/RLAssistant into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
DrZero0 committed Jun 12, 2022
2 parents b557e9c + 9fdb615 commit d6a785f
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 30 deletions.
24 changes: 14 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,18 @@ We build an example project for integrating RLA, which can be seen in ./example/
We record scalars by `RLA.easy_log.logger`:
```python
from RLA.easy_log import logger

# scalar variable
value = 1
logger.record_tabular("k", value)

# tensorflow summary
import tensorflow as tf
summary = tf.Summary()
logger.log_from_tf_summary(summary)
from RLA.easy_log.time_step import time_step_holder

for i in range(1000):
# time-steps (iterations)
time_step_holder.set_time(i)
# scalar variable
value = 1
logger.record_tabular("k", value)
# tensorflow summary
summary = tf.Summary()
logger.log_from_tf_summary(summary)

```
**Record checkpoints**
Expand Down Expand Up @@ -242,8 +245,9 @@ PS:
2. An alternative way is building your own NFS for your physical machines and locate data_root to the NFS.

# TODO
- [ ] video visualization.
- [ ] support sftp-based sync.
- [ ] support custom data structure saving and loading.
- [ ] support video visualization.
- [ ] add comments and documents to the functions.
- [ ] add an auto integration script.
- [ ] download / upload experiment logs through timestamp.

13 changes: 6 additions & 7 deletions RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class ExperimentLoader(object):
def __init__(self):
self.task_name = exp_manager.hyper_param.get('loaded_task_name', None)
self.load_date = exp_manager.hyper_param.get('loaded_date', None)
self.root = getattr(exp_manager, 'root', None)
self.data_root = None
self.data_root = getattr(exp_manager, 'root', None)
pass

def config(self, task_name, record_date, root):
Expand All @@ -49,12 +48,12 @@ def is_valid_config(self):
logger.warn("meet invalid loader config when use it")
logger.warn("load_date", self.load_date)
logger.warn("task_name", self.task_name)
logger.warn("root", self.root)
logger.warn("root", self.data_root)
return False

def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None):
if self.is_valid_config:
load_tester = Tester.load_tester(self.load_date, self.task_name, self.root)
load_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
target_hp = copy.deepcopy(exp_manager.hyper_param)
target_hp.update(load_tester.hyper_param)
if hp_to_overwrite is not None:
Expand All @@ -75,7 +74,7 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list:
:return:
"""
if self.is_valid_config:
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.root)
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
# load checkpoint
load_res = {}
if var_prefix is not None:
Expand All @@ -100,7 +99,7 @@ def fork_log_files(self):
if self.is_valid_config:
global exp_manager
assert isinstance(exp_manager, Tester)
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.root)
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
# copy log file
exp_manager.log_file_copy(loaded_tester)
# copy attribute
Expand All @@ -109,4 +108,4 @@ def fork_log_files(self):
exp_manager.private_config = loaded_tester.private_config


exp_loader = experimental_loader = ExperimentLoader()
exp_loader = experimental_loader = ExperimentLoader()
2 changes: 1 addition & 1 deletion RLA/easy_log/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def configure(dir=None, format_strs=None, comm=None, framework='tensorflow'):
if format_strs is None:
format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',')
format_strs = filter(None, format_strs)
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
output_formats = [make_output_format(f, dir, log_suffix, framework) for f in format_strs]
warn_output_formats = make_output_format('warn', dir, log_suffix, framework)
backup_output_formats = make_output_format('backup', dir, log_suffix, framework)

Expand Down
16 changes: 8 additions & 8 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import shutil
import argparse
from typing import Optional, Union, Dict, Any
from RLA.const import DEFAULT_X_NAME
from RLA.const import DEFAULT_X_NAME, FRAMEWORK
import pathspec

def import_hyper_parameters(task_name, record_date):
Expand Down Expand Up @@ -187,7 +187,7 @@ def _init_logger(self):
self.writer = None
# logger configure
logger.info("store file %s" % self.pkl_file)
logger.configure(self.log_dir, self.private_config["LOG_USED"])
logger.configure(self.log_dir, self.private_config["LOG_USED"], framework=self.private_config["DL_FRAMEWORK"])
for fmt in logger.Logger.CURRENT.output_formats:
if isinstance(fmt, logger.TensorBoardOutputFormat):
self.writer = fmt.writer
Expand Down Expand Up @@ -481,7 +481,7 @@ def new_saver(self, max_to_keep, var_prefix=None):
:param max_to_keep:
:return:
"""
if self.dl_framework == 'tensorflow':
if self.dl_framework == FRAMEWORK.tensorflow:
import tensorflow as tf
if var_prefix is None:
var_prefix = ''
Expand All @@ -490,20 +490,20 @@ def new_saver(self, max_to_keep, var_prefix=None):
for v in var_list:
logger.info(v)
self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir, save_relative_paths=True)
elif self.dl_framework == 'pytorch':
elif self.dl_framework == FRAMEWORK.torch:
self.max_to_keep = max_to_keep
self.checkpoint_keep_list = []
else:
raise NotImplementedError

def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Optional[dict]=None):
if self.dl_framework == 'tensorflow':
if self.dl_framework == FRAMEWORK.tensorflow:
import tensorflow as tf
iter = self.time_step_holder.get_time()
cpt_name = osp.join(self.checkpoint_dir, 'checkpoint')
logger.info("save checkpoint to ", cpt_name, iter)
self.saver.save(tf.get_default_session(), cpt_name, global_step=iter)
elif self.dl_framework == 'pytorch':
elif self.dl_framework == FRAMEWORK.torch:
import torch
iter = self.time_step_holder.get_time()
torch.save(model_dict, f=tester.checkpoint_dir + "checkpoint-{}.pt".format(iter))
Expand All @@ -522,7 +522,7 @@ def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Opt
self.add_custom_data(DEFAULT_X_NAME, time_step_holder.get_time(), int, mode='replace')

def load_checkpoint(self):
if self.dl_framework == 'tensorflow':
if self.dl_framework == FRAMEWORK.tensorflow:
# TODO: load with variable scope.
import tensorflow as tf
cpt_name = osp.join(self.checkpoint_dir)
Expand All @@ -532,7 +532,7 @@ def load_checkpoint(self):
max_iter = ckpt_path.split('-')[-1]
self.time_step_holder.set_time(max_iter)
return int(max_iter), None
elif self.dl_framework == 'pytorch':
elif self.dl_framework == FRAMEWORK.torch:
import torch
return self.checkpoint_keep_list[-1], torch.load(tester.checkpoint_dir + "checkpoint-{}.pt".format(self.checkpoint_keep_list[-1]))

Expand Down
4 changes: 2 additions & 2 deletions example/sb3_ppo_example/rla_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ LOG_USED:
- 'tensorboard'
- 'csv'

# select a DL framework: tensorflow or pytorch.
DL_FRAMEWORK: 'pytorch'
# select a DL framework: tensorflow or torch.
DL_FRAMEWORK: 'torch'

SEND_LOG_FILE: False
REMOTE_SETTING:
Expand Down
2 changes: 1 addition & 1 deletion example/sb_ppo_example/rla_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ LOG_USED:
- 'tensorboard'
- 'csv'

# select a DL framework: tensorflow or pytorch.
# select a DL framework: tensorflow or torch.
DL_FRAMEWORK: 'tensorflow'

SEND_LOG_FILE: False
Expand Down
2 changes: 1 addition & 1 deletion example/simplest_code/rla_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ LOG_USED:
- 'tensorboard'
- 'csv'

# select a DL framework: tensorflow or pytorch.
# select a DL framework: "tensorflow" or "torch".
DL_FRAMEWORK: 'tensorflow'

SEND_LOG_FILE: False
Expand Down

0 comments on commit d6a785f

Please sign in to comment.