Skip to content

Commit

Permalink
Dev (#20)
Browse files Browse the repository at this point in the history
* fix: fix bugs of torch-version ckp loader

* refactor: add sync_timestep for hp loader
  • Loading branch information
xionghuichen authored Jul 13, 2022
1 parent cf30b18 commit 33e6aee
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 24 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,11 @@ PS:
2. An alternative way is building your own NFS for your physical machines and locate data_root to the NFS.

# TODO
- [ ] support sftp-based sync.
- [ ] support custom data structure saving and loading.
- [ ] support video visualization.
- [ ] add comments and documents to the functions.
- [ ] add an auto integration script.
- [ ] download / upload experiment logs through timestamp.
- [ ] add a document to the plot function.
- [ ] allow sync LOG only or ALL TYPE LOGS.
- [ ] support aim and smarter logger.
3 changes: 3 additions & 0 deletions RLA/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ class FRAMEWORK:
class FTP_PROTOCOL_NAME:
FTP = 'ftp'
SFTP = 'sftp'

class LOG_NAME_FORMAT_VERSION:
V1 = 'v1'
21 changes: 15 additions & 6 deletions RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import argparse
from typing import Optional, OrderedDict, Union, Dict, Any
from RLA.const import DEFAULT_X_NAME
from pprint import pprint

class ExperimentLoader(object):
"""
Expand Down Expand Up @@ -32,7 +33,9 @@ class ExperimentLoader(object):
def __init__(self):
self.task_name = exp_manager.hyper_param.get('loaded_task_name', None)
self.load_date = exp_manager.hyper_param.get('loaded_date', None)
self.data_root = getattr(exp_manager, 'root', None)
self.data_root = getattr(exp_manager, 'data_root', None)
if self.data_root is None:
self.data_root = getattr(exp_manager, 'root', None)
pass

def config(self, task_name, record_date, root):
Expand All @@ -51,17 +54,20 @@ def is_valid_config(self):
logger.warn("root", self.data_root)
return False

def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None):
def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None, sync_timestep=False):
if self.is_valid_config:
load_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
target_hp = copy.deepcopy(exp_manager.hyper_param)
target_hp.update(load_tester.hyper_param)
target_hp.update(loaded_tester.hyper_param)
if hp_to_overwrite is not None:
for v in hp_to_overwrite:
target_hp[v] = exp_manager.hyper_param[v]
args = argparse.Namespace(**target_hp)
args.load_date = self.load_date
args.load_task_name = self.task_name
if sync_timestep:
load_iter = loaded_tester.get_custom_data(DEFAULT_X_NAME)
exp_manager.time_step_holder.set_time(load_iter)
return args
else:
return argparse.Namespace(**exp_manager.hyper_param)
Expand All @@ -75,18 +81,21 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list:
"""
if self.is_valid_config:
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
print("attrs of the loaded tester")
pprint(loaded_tester.__dict__)
# load checkpoint
load_res = {}
if var_prefix is not None:
loaded_tester.new_saver(var_prefix=var_prefix, max_to_keep=1)
_, load_res = loaded_tester.load_checkpoint()
exp_manager.print_log_dir()
else:
loaded_tester.new_saver(max_to_keep=1)
_, load_res = loaded_tester.load_checkpoint()
hist_variables = {}
if variable_list is not None:
for v in variable_list:
hist_variables[v] = loaded_tester.get_custom_data(v)
load_iter = loaded_tester.get_custom_data(DEFAULT_X_NAME)
exp_manager.time_step_holder.set_time(load_iter)
return load_iter, load_res, hist_variables
else:
return 0, {}, {}
Expand Down
4 changes: 4 additions & 0 deletions RLA/easy_log/log_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def delete_small_timestep_log(self, skip_ask=False):
for res in self.small_timestep_regs:
print("[delete small-timestep log] reg: ", res[1])
self._delete_related_log(show=True, regex=res[0] + '*')
print("summarize:")
for count, res in enumerate(self.small_timestep_regs):
print(f"[delete small-timestep log] {count} reg: {res[1]}")

if skip_ask:
s = 'y'
else:
Expand Down
65 changes: 49 additions & 16 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import tensorboardX

from RLA.easy_log.const import *
from RLA.easy_log.time_step import time_step_holder
from RLA.easy_log import logger
from RLA.easy_log.const import *
from RLA.const import *
import yaml
import shutil
import argparse
Expand Down Expand Up @@ -107,6 +107,8 @@ def __init__(self):
self.code_dir = None
self.saver = None
self.dl_framework = None
self.checkpoint_keep_list = None
self.log_name_format_version = LOG_NAME_FORMAT_VERSION.V1

@deprecated_alias(task_name='task_table_name', private_config_path='rla_config', log_root='data_root')
def configure(self, task_table_name: str, rla_config: Union[str, dict], data_root: str,
Expand Down Expand Up @@ -205,13 +207,28 @@ def log_files_gen(self):
self._feed_hyper_params_to_tb()
self.print_log_dir()

def update_log_files_location(self, root):
def update_log_files_location(self, root:str):
"""
This function is designed for the requirement of using copied/moved experiment logs to other databases for downstream task.
The location of the experiment logs might have changed compared with their original location.
The function automatically update the attributes related to the data_root to the current location.
:param root: current data_root
:type root: str
"""
self.data_root = root
code_dir, _ = self.__create_file_directory(osp.join(self.data_root, CODE, self.task_table_name), '', is_file=False)
log_dir, _ = self.__create_file_directory(osp.join(self.data_root, LOG, self.task_table_name), '', is_file=False)
self.pkl_dir, self.pkl_file = self.__create_file_directory(osp.join(self.data_root, ARCHIVE_TESTER, self.task_table_name), '.pkl')
self.checkpoint_dir, _ = self.__create_file_directory(osp.join(self.data_root, CHECKPOINT, self.task_table_name), is_file=False)
self.results_dir, _ = self.__create_file_directory(osp.join(self.data_root, OTHER_RESULTS, self.task_table_name), is_file=False)

task_table_name = getattr(self, 'task_table_name', None)
if task_table_name is None:
task_table_name = getattr(self, 'task_name', None)
print("[WARN] you are using an old-version RLA. "
"Some attributes' name have been changed (task_name->task_table_name).")
else:
raise RuntimeError("invalid ExpManager: task_table_name cannot be found", )
code_dir, _ = self.__create_file_directory(osp.join(self.data_root, CODE, task_table_name), '', is_file=False)
log_dir, _ = self.__create_file_directory(osp.join(self.data_root, LOG, task_table_name), '', is_file=False)
self.pkl_dir, self.pkl_file = self.__create_file_directory(osp.join(self.data_root, ARCHIVE_TESTER, task_table_name), '.pkl')
self.checkpoint_dir, _ = self.__create_file_directory(osp.join(self.data_root, CHECKPOINT, task_table_name), is_file=False)
self.results_dir, _ = self.__create_file_directory(osp.join(self.data_root, OTHER_RESULTS, task_table_name), is_file=False)
self.log_dir = log_dir
self.code_dir = code_dir
self.print_log_dir()
Expand Down Expand Up @@ -487,15 +504,23 @@ def __create_file_directory(self, prefix, ext='', is_file=True, record_date=None
record_date = self.record_date
directory = str(record_date.strftime("%Y/%m/%d"))
directory = osp.join(prefix, directory)
version_num = getattr(self, 'log_name_format_version', None)

if version_num is None:
name_format = '{dir}/{timestep} {ip} {info}{ext}'
elif version_num == LOG_NAME_FORMAT_VERSION.V1:
name_format = '{dir}/{timestep}_{ip}_{info}{ext}'
else:
raise RuntimeError("unknown version name", version_num)

if is_file:
os.makedirs(directory, exist_ok=True)
file_name = '{dir}/{timestep}_{ip}_{info}{ext}'.format(dir=directory,
timestep=self.record_date_to_str(record_date),
file_name = name_format.format(dir=directory, timestep=self.record_date_to_str(record_date),
ip=str(self.ipaddr),
info=self.info,
ext=ext)
else:
directory = '{dir}/{timestep}_{ip}_{info}{ext}/'.format(dir=directory,
directory = (name_format + '/').format(dir=directory,
timestep=self.record_date_to_str(record_date),
ip=str(self.ipaddr),
info=self.info,
Expand Down Expand Up @@ -545,7 +570,6 @@ def new_saver(self, max_to_keep, var_prefix=None):
self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir, save_relative_paths=True)
elif self.dl_framework == FRAMEWORK.torch:
self.max_to_keep = max_to_keep
self.checkpoint_keep_list = []
else:
raise NotImplementedError

Expand All @@ -558,6 +582,8 @@ def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Opt
self.saver.save(tf.get_default_session(), cpt_name, global_step=iter)
elif self.dl_framework == FRAMEWORK.torch:
import torch
if self.checkpoint_keep_list is None:
self.checkpoint_keep_list = []
iter = self.time_step_holder.get_time()
torch.save(model_dict, f=tester.checkpoint_dir + "checkpoint-{}.pt".format(iter))
self.checkpoint_keep_list.append(iter)
Expand All @@ -574,20 +600,27 @@ def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Opt
self.add_custom_data(k, v, type(v), mode='replace')
self.add_custom_data(DEFAULT_X_NAME, time_step_holder.get_time(), int, mode='replace')

def load_checkpoint(self):
def load_checkpoint(self, ckp_index=None):
if self.dl_framework == FRAMEWORK.tensorflow:
# TODO: load with variable scope.
import tensorflow as tf
cpt_name = osp.join(self.checkpoint_dir)
logger.info("load checkpoint {}".format(cpt_name))
ckpt_path = tf.train.latest_checkpoint(cpt_name)
if ckp_index is None:
ckpt_path = tf.train.latest_checkpoint(cpt_name)
else:
ckpt_path = tf.train.latest_checkpoint(cpt_name, ckp_index)
self.saver.restore(tf.get_default_session(), ckpt_path)
max_iter = ckpt_path.split('-')[-1]
self.time_step_holder.set_time(max_iter)
return int(max_iter), None
elif self.dl_framework == FRAMEWORK.torch:
import torch
return self.checkpoint_keep_list[-1], torch.load(tester.checkpoint_dir + "checkpoint-{}.pt".format(self.checkpoint_keep_list[-1]))
all_ckps = sorted(os.listdir(self.checkpoint_dir))
print("all checkpoints:")
pprint.pprint(all_ckps)
if ckp_index is None:
ckp_index = all_ckps[-1].split('checkpoint-')[1].split('.pt')[0]
return ckp_index, torch.load(self.checkpoint_dir + "checkpoint-{}.pt".format(ckp_index))

def auto_parse_info(self):
return '&'.join(self.hyper_param_record)
Expand Down Expand Up @@ -648,7 +681,7 @@ def serialize_object_and_save(self):
saver = self.saver
self.saver = None
with open(self.pkl_file, 'wb') as f:
dill.dump(self, f)
dill.dump(self, f, recurse=True)
self.writer = writer
self.saver = saver

Expand Down
19 changes: 19 additions & 0 deletions RLA/rla_argparser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
import argparse


def boolean_flag(parser: argparse.ArgumentParser, name, default=False, help=None):
"""Add a boolean flag to argparse parser.
Parameters
----------
parser: argparse.Parser
parser to add the flag to
name: str
--<name> will enable the flag, while --no-<name> will disable it
default: bool or None
default value of the flag
help: str
help string for the flag
"""
dest = name.replace('-', '_')
parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help)
parser.add_argument("--no-" + name, action="store_false", dest=dest)


def arg_parser_postprocess(parser: argparse.ArgumentParser):
parser.add_argument('--loaded_task_name', default='', type=str)
parser.add_argument('--info', default='default exp info', type=str)
Expand Down
1 change: 1 addition & 0 deletions example/simplest_code/project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def get_param():
parser.add_argument('--env_id', help='environment ID', default='Test-v1')
parser.add_argument('--learning_rate', help='a hyperparameter', default=1e-3, type=float)
parser.add_argument('--input_size', help='a hyperparameter', default=16, type=int)
# NOTE: add some recommended hyper-parameters for RLA.
parser = arg_parser_postprocess(parser)
args = parser.parse_args()
kwargs = vars(args)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='RLA',
version="0.5.3",
version="0.6.0-pre",
description=(
'RL assistant'
),
Expand Down

0 comments on commit 33e6aee

Please sign in to comment.