Skip to content

Commit

Permalink
Dev (#29)
Browse files Browse the repository at this point in the history
* Update README.md

* fix: fix bugs of torch-version ckp loader

* refactor: add sync_timestep for hp loader

* fix: minor changes for version compatibility

* Dev (#10)

* Update README.md

* Dev (#20)

* fix: fix bugs of torch-version ckp loader

* refactor: add sync_timestep for hp loader

* fix: minor changes for version compatibility

* fix: a bug of sorting in torch-version checkpoint loading

* Dev (#11)

* Update README.md

* Dev (#20)

* fix: fix bugs of torch-version ckp loader

* refactor: add sync_timestep for hp loader

* fix: minor changes for version compatibility

* fix: a bug of sorting in torch-version checkpoint loading

* refactor: robust multi-key plot implementation

* feat: supoort pretty plotter

* refactor(log plotter): record scores of the log plotter

* fix(exp_loader): add parameter ckp_index

* update readme

* rm unsolved merge

* feat: tf-v2 compatible

* refactor: add timestep recorder. refactor on exp_loader

* test: add test data

* feat(plot): track the hyper-parameter from the exp_manager instead of the experiment name. refactor the plot_func for better readability

* Dev (#12)

* Update README.md

* Dev (#20)

* fix: fix bugs of torch-version ckp loader

* refactor: add sync_timestep for hp loader

* fix: minor changes for version compatibility

* Dev (#21)

* fix: fix bugs of torch-version ckp loader

* refactor: add sync_timestep for hp loader

* fix: minor changes for version compatibility

* fix: a bug of sorting in torch-version checkpoint loading

* Dev (#22)

* fix: fix bugs of torch-version ckp loader

* refactor: add sync_timestep for hp loader

* fix: minor changes for version compatibility

* fix: a bug of sorting in torch-version checkpoint loading

* refactor: robust multi-key plot implementation

* feat: supoort pretty plotter

* refactor(log plotter): record scores of the log plotter

* fix(exp_loader): add parameter ckp_index

* update readme

* Dev (#23)

* fix: fix bugs of torch-version ckp loader

* refactor: add sync_timestep for hp loader

* fix: minor changes for version compatibility

* fix: a bug of sorting in torch-version checkpoint loading

* refactor: robust multi-key plot implementation

* feat: supoort pretty plotter

* refactor(log plotter): record scores of the log plotter

* fix(exp_loader): add parameter ckp_index

* update readme

* rm unsolved merge

* feat: tf-v2 compatible

* refactor: add timestep recorder. refactor on exp_loader

* test: add test data

* feat(plot): track the hyper-parameter from the exp_manager instead of the experiment name. refactor the plot_func for better readability

* test(plot): add user cases and documents

* test(plot): add user cases

* simplify codes

* refactor: more robust freq print implementation

* update readme
  • Loading branch information
xionghuichen authored Oct 4, 2022
1 parent 4efee65 commit 9bd402e
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 26 deletions.
36 changes: 26 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,17 @@ Modifying:
Querying:
1. tensorboard: the recorded variables are added to tensorboard events and can be loaded via standard tensorboard tools.
![img.png](resource/tb-img.png)
2. easy_plot: We give some APIs to load and visualize the data in CSV files. The results will be something like this:
2. easy_plot: We give some APIs to load and visualize the data in CSV files. The results will be something like this:
```python
from RLA.easy_plot.plot_func_v2 import plot_func
data_root='your_project'
task = 'sac_test'
regs = [
'2022/03/01/21-[12]*'
]
_ = plot_func(data_root=data_root, task_table_name=task,
regs=regs , split_keys=['info', 'van_sac', 'alpha_max'], metrics=['perf/rewards'])
```
![](resource/sample-plot.png)


Expand All @@ -86,7 +96,7 @@ We also list the RL research projects using RLA as follows:
```angular2html
git clone https://github.com/xionghuichen/RLAssistant.git
cd RLAssistant
python setup.py install
pip install -e .
```


Expand All @@ -99,13 +109,19 @@ We build an example project for integrating RLA, which can be seen in ./example/
1. We define the property of the database in `rla_config.yaml`. You can construct your YAML file based on the template in ./example/simplest_code/rla_config.yaml.
2. We define the property of the table in exp_manager.config. Before starting your experiment, you should configure the global object RLA.easy_log.tester.exp_manager like this.
```python
from RLA.easy_log.tester import exp_manager
from RLA import exp_manager
kwargs = {'env_id': 'Hopper-v2', 'lr': 1e-3}
exp_manager.set_hyper_param(**kwargs) # kwargs are the hyper-parameters for your experiment
exp_manager.add_record_param(["env_id"]) # add parts of hyper-parameters to name the index of data items for better readability.
task_name = 'demo_task' # define your task
rla_data_root = '../' # the place to store the data items.
exp_manager.configure(task_name, private_config_path='../../../rla_config.yaml', data_root=rla_data_root)

def get_package_path():
return os.path.dirname(os.path.abspath(__file__))

rla_data_root = get_package_path() # the place to store the data items.

rla_config = os.path.join(get_package_path(), 'rla_config.yaml')
exp_manager.configure(task_table_name=task_name, rla_config=rla_config, data_root=rla_data_root)
exp_manager.log_files_gen() # initialize the data items.
exp_manager.print_args()
```
Expand All @@ -124,9 +140,9 @@ We build an example project for integrating RLA, which can be seen in ./example/

We record scalars by `RLA.easy_log.logger`:
```python
from RLA.easy_log import logger
from RLA import logger
import tensorflow as tf
from RLA.easy_log.time_step import time_step_holder
from RLA import time_step_holder

for i in range(1000):
# time-steps (iterations)
Expand All @@ -143,7 +159,7 @@ for i in range(1000):

We save checkpoints of neural networks by `exp_manager.save_checkpoint`.
```python
from RLA.easy_log.tester import exp_manager
from RLA import exp_manager
exp_manager.new_saver()

for i in range(1000):
Expand All @@ -157,7 +173,7 @@ Currently we can record complex-structure data based on tensorboard:
```python
# from tensorflow summary
import tensorflow as tf
from RLA.easy_log import logger
from RLA import logger
summary = tf.Summary()
logger.log_from_tf_summary(summary)
# from tensorboardX writer
Expand All @@ -169,7 +185,7 @@ We will develop APIs to record common-used complex-structure data in RLA.easy_lo
Now we give a MatplotlibRecorder tool to manage your figures generated by matplotlib:

```python
from RLA.easy_log.complex_data_recorder import MatplotlibRecorder as mpr
from RLA import MatplotlibRecorder as mpr
def plot_func():
import matplotlib.pyplot as plt
plt.plot([1,1,1], [2,2,2])
Expand Down
4 changes: 3 additions & 1 deletion RLA/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from RLA.easy_log.tester import exp_manager
from RLA.easy_log import logger
from RLA.easy_plot.plot_func_v2 import plot_func
from RLA.easy_log.time_step import time_step_holder
from RLA.easy_plot.plot_func_v2 import plot_func
from RLA.easy_log.complex_data_recorder import MatplotlibRecorder
2 changes: 1 addition & 1 deletion RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from RLA.easy_log.tester import exp_manager, Tester
import copy
import argparse
from typing import Optional, OrderedDict, Union, Dict, Any
from typing import Optional
from RLA.const import DEFAULT_X_NAME
from pprint import pprint

Expand Down
16 changes: 11 additions & 5 deletions RLA/easy_log/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def timestep():
ma_dict = {}


def ma_record_tabular(key, val, record_len, ignore_nan=False, exclude:Optional[Union[str, Tuple[str, ...]]]=None):
def ma_record_tabular(key, val, record_len, ignore_nan=False, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Optional[int]=None):
if key not in ma_dict:
ma_dict[key] = deque(maxlen=record_len)
if ignore_nan:
Expand All @@ -415,7 +415,10 @@ def ma_record_tabular(key, val, record_len, ignore_nan=False, exclude:Optional[U
else:
ma_dict[key].append(val)
if len(ma_dict[key]) == record_len:
record_tabular(key, np.mean(ma_dict[key]), exclude)
record_tabular(key, np.mean(ma_dict[key]), exclude, freq)


lst_print_dict = {}

def logkv(key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Optional[int]=None):
"""
Expand All @@ -426,8 +429,11 @@ def logkv(key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Opt
:param key: (Any) save to log this key
:param val: (Any) save to log this value
"""
if freq is None or timestep() % freq == 0:
if key not in lst_print_dict:
lst_print_dict[key] = -np.inf
if freq is None or timestep() - lst_print_dict[key] >= freq:
get_current().logkv(key, val, exclude)
lst_print_dict[key] = timestep()


def log_from_tf_summary(summary):
Expand Down Expand Up @@ -463,12 +469,12 @@ def logkv_mean(key, val):
"""
get_current().logkv_mean(key, val)

def logkvs(d, exclude:Optional[Union[str, Tuple[str, ...]]]=None):
def logkvs(d, prefix:Optional[str]='', exclude:Optional[Union[str, Tuple[str, ...]]]=None):
"""
Log a dictionary of key-value pairs
"""
for (k, v) in d.items():
logkv(k, v, exclude)
logkv(prefix+k, v, exclude)


def log_key_value(keys, values, prefix_name=''):
Expand Down
3 changes: 1 addition & 2 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import os.path as osp
import pprint

import tensorboardX

from RLA.easy_log.time_step import time_step_holder
from RLA.easy_log import logger
Expand Down Expand Up @@ -134,7 +133,7 @@ def configure(self, task_table_name: str, rla_config: Union[str, dict], data_roo
:param is_master_node: In "distributed training & centralized logs" mode (By set SEND_LOG_FILE in rla_config.yaml to True),
you should mark the master node (is_master_node=True) to collect logs of the slave nodes (is_master_node=False).
:type is_master_node: bool
: param code_root: Define the root of your codebase (for backup) explicitly. It will be in the same location as rla_config.yaml by default.
:param code_root: Define the root of your codebase (for backup) explicitly. It will be in the same location as rla_config.yaml by default.
"""
if isinstance(rla_config, str):
self.private_config = load_yaml(rla_config)
Expand Down
6 changes: 3 additions & 3 deletions RLA/easy_plot/plot_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,17 +523,17 @@ def allequal(qs):
if shaded_err:
res = g2lf[original_legend_keys[index] + '-se']
res[0].update(props={"color": colors[index % len(colors)]})
print("{}-err : ({:.2f} \pm {:.2f})".format(legend_keys[index], res[1][-1], res[2][-1]))
print("{}-err : ({:.3f} $\pm$ {:.3f})".format(legend_keys[index], res[1][-1], res[2][-1]))
score_results[legend_keys[index]+'-err'] = [res[1][-1], res[2][-1]]
if shaded_std:
res = g2lf[original_legend_keys[index] + '-ss']
res[0].update(props={"color": colors[index % len(colors)]})
print("{}-std :({:.2f} \pm {:.2f})".format(legend_keys[index], res[1][-1], res[2][-1]))
print("{}-std :({:.3f} $\pm$ {:.3f})".format(legend_keys[index], res[1][-1], res[2][-1]))
score_results[legend_keys[index]+'-std'] = [res[1][-1], res[2][-1]]
if shaded_range:
res = g2lf[original_legend_keys[index] + '-sr']
res[0].update(props={"color": colors[index % len(colors)]})
print("{}-range : ({:.2f}, {:.2f})".format(legend_keys[index], res[1][-1], res[2][-1]))
print("{}-range : ({:.3f}, {:.3f})".format(legend_keys[index], res[1][-1], res[2][-1]))
score_results[legend_keys[index]+'-range'] = [res[1][-1], res[2][-1]]

if bound_line is not None:
Expand Down
2 changes: 1 addition & 1 deletion example/sb3_ppo_example/ppo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def mujoco_arg_parser():
task_name = 'demo_task'
exp_manager.set_hyper_param(**vars(args))
exp_manager.add_record_param(["info", "seed", 'env'])
exp_manager.configure(task_name, private_config_path='../rla_config.yaml', data_root='../')
exp_manager.configure(task_name, rla_config='../rla_config.yaml', data_root='../')
exp_manager.log_files_gen()
exp_manager.print_args()

Expand Down
2 changes: 1 addition & 1 deletion example/sb_ppo_example/ppo2/run_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def main():
task_name = 'demo_task'
exp_manager.set_hyper_param(**vars(args))
exp_manager.add_record_param(["info", "seed", 'env'])
exp_manager.configure(task_name, private_config_path='../rla_config.yaml', data_root='../')
exp_manager.configure(task_name, rla_config='../rla_config.yaml', data_root='../')
exp_manager.log_files_gen()
exp_manager.print_args()
# [RLA] optional:
Expand Down
2 changes: 1 addition & 1 deletion example/simplest_code/project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_param():

task_name = 'demo_task'
rla_data_root = '../'
exp_manager.configure(task_name, private_config_path='../rla_config.yaml', data_root=rla_data_root)
exp_manager.configure(task_name, rla_config='../rla_config.yaml', data_root=rla_data_root)
exp_manager.log_files_gen()
exp_manager.print_args()

Expand Down
6 changes: 5 additions & 1 deletion test/test_proj/proj/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _init_proj(self, config_yaml, **kwargs):
task_name = 'test_manger_demo_task'
rla_data_root = os.path.join(DATABASE_ROOT, 'test_data_root')
config_yaml['BACKUP_CONFIG']['backup_code_dir'] = ['proj']
exp_manager.configure(task_name, private_config_path=config_yaml, data_root=rla_data_root,
exp_manager.configure(task_name, rla_config=config_yaml, data_root=rla_data_root,
code_root=CODE_ROOT, **kwargs)
exp_manager.log_files_gen()
exp_manager.print_args()
Expand Down Expand Up @@ -66,6 +66,8 @@ def test_log_tf(self):
if i % 20 == 0:
exp_manager.save_checkpoint()
if i % 10 == 0:
logger.ma_record_tabular("perf/mse-long", np.mean(mse_loss.detach().cpu().numpy()), 10, freq=25)
logger.record_tabular("y_out-long", np.mean(y), freq=25)
def plot_func():
import matplotlib.pyplot as plt
testX = np.repeat(np.expand_dims(np.arange(-10, 10, 0.1), axis=-1), repeats=kwargs["input_size"], axis=-1)
Expand Down Expand Up @@ -109,6 +111,8 @@ def test_log_torch(self):
logger.ma_record_tabular("perf/mse", np.mean(mse_loss.detach().cpu().numpy()), 10)
logger.record_tabular("y_out", np.mean(y))
if i % 10 == 0:
logger.ma_record_tabular("perf/mse-long", np.mean(mse_loss.detach().cpu().numpy()), 10, freq=25)
logger.record_tabular("y_out-long", np.mean(y), freq=25)
def plot_func():
import matplotlib.pyplot as plt
testX = np.repeat(np.expand_dims(np.arange(-10, 10, 0.1), axis=-1), repeats=kwargs["input_size"], axis=-1)
Expand Down

0 comments on commit 9bd402e

Please sign in to comment.