Skip to content

Commit 77ef3e0

Browse files
cjcchenryanjulian
authored andcommitted
Add TensorBoard Support
Adds TensorBoard support for basic key-value pairs. Anything logged via `logger.record_tabular()` is also available via TensorBoard.
1 parent b3a2899 commit 77ef3e0

File tree

3 files changed

+155
-42
lines changed

3 files changed

+155
-42
lines changed

environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,4 @@ dependencies:
6262
- pylru==1.0.9
6363
- hyperopt
6464
- polling
65+
- tensorboard

rllab/misc/logger.py

+62-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import json
1616
import pickle
1717
import base64
18+
import tensorflow as tf
1819

1920
_prefixes = []
2021
_prefix_str = ''
@@ -31,13 +32,17 @@
3132
_tabular_fds = {}
3233
_tabular_header_written = set()
3334

35+
_tensorboard_writer = None
3436
_snapshot_dir = None
3537
_snapshot_mode = 'all'
3638
_snapshot_gap = 1
3739

3840
_log_tabular_only = False
3941
_header_printed = False
4042

43+
_tensorboard_default_step = 0
44+
_tensorboard_step_key = None
45+
4146

4247
def _add_output(file_name, arr, fds, mode='a'):
4348
if file_name not in arr:
@@ -77,6 +82,20 @@ def remove_tabular_output(file_name):
7782
_remove_output(file_name, _tabular_outputs, _tabular_fds)
7883

7984

85+
def set_tensorboard_dir(dir_name):
86+
global _tensorboard_writer
87+
if not dir_name:
88+
if _tensorboard_writer:
89+
_tensorboard_writer.close()
90+
_tensorboard_writer = None
91+
else:
92+
mkdir_p(os.path.dirname(dir_name))
93+
_tensorboard_writer = tf.summary.FileWriter(dir_name)
94+
_tensorboard_default_step = 0
95+
assert _tensorboard_writer is not None
96+
print("tensorboard data will be logged into:", dir_name)
97+
98+
8099
def set_snapshot_dir(dir_name):
81100
global _snapshot_dir
82101
_snapshot_dir = dir_name
@@ -94,18 +113,26 @@ def set_snapshot_mode(mode):
94113
global _snapshot_mode
95114
_snapshot_mode = mode
96115

116+
97117
def get_snapshot_gap():
98118
return _snapshot_gap
99119

120+
100121
def set_snapshot_gap(gap):
101122
global _snapshot_gap
102123
_snapshot_gap = gap
103124

125+
104126
def set_log_tabular_only(log_tabular_only):
105127
global _log_tabular_only
106128
_log_tabular_only = log_tabular_only
107129

108130

131+
def set_tensorboard_step_key(key):
132+
global _tensorboard_step_key
133+
_tensorboard_step_key = key
134+
135+
109136
def get_log_tabular_only():
110137
return _log_tabular_only
111138

@@ -186,6 +213,23 @@ def refresh(self):
186213
table_printer = TerminalTablePrinter()
187214

188215

216+
def dump_tensorboard(*args, **kwargs):
217+
if len(_tabular) > 0 and _tensorboard_writer:
218+
tabular_dict = dict(_tabular)
219+
if _tensorboard_step_key and _tensorboard_step_key in tabular_dict:
220+
step = tabular_dict[_tensorboard_step_key]
221+
else:
222+
global _tensorboard_default_step
223+
step = _tensorboard_default_step
224+
_tensorboard_default_step += 1
225+
226+
summary = tf.Summary()
227+
for k, v in tabular_dict.items():
228+
summary.value.add(tag=k, simple_value=float(v))
229+
_tensorboard_writer.add_summary(summary, int(step))
230+
_tensorboard_writer.flush()
231+
232+
189233
def dump_tabular(*args, **kwargs):
190234
wh = kwargs.pop("write_header", None)
191235
if len(_tabular) > 0:
@@ -195,11 +239,18 @@ def dump_tabular(*args, **kwargs):
195239
for line in tabulate(_tabular).split('\n'):
196240
log(line, *args, **kwargs)
197241
tabular_dict = dict(_tabular)
242+
243+
# write to the tensorboard folder
244+
# This assumes that the keys in each iteration won't change!
245+
dump_tensorboard(args, kwargs)
246+
198247
# Also write to the csv files
199248
# This assumes that the keys in each iteration won't change!
200249
for tabular_fd in list(_tabular_fds.values()):
201-
writer = csv.DictWriter(tabular_fd, fieldnames=list(tabular_dict.keys()))
202-
if wh or (wh is None and tabular_fd not in _tabular_header_written):
250+
writer = csv.DictWriter(
251+
tabular_fd, fieldnames=list(tabular_dict.keys()))
252+
if wh or (wh is None
253+
and tabular_fd not in _tabular_header_written):
203254
writer.writeheader()
204255
_tabular_header_written.add(tabular_fd)
205256
writer.writerow(tabular_dict)
@@ -245,7 +296,8 @@ def log_parameters(log_file, args, classes):
245296
log_params[name] = params
246297
else:
247298
log_params[name] = getattr(cls, "__kwargs", dict())
248-
log_params[name]["_name"] = cls.__module__ + "." + cls.__class__.__name__
299+
log_params[name][
300+
"_name"] = cls.__module__ + "." + cls.__class__.__name__
249301
mkdir_p(os.path.dirname(log_file))
250302
with open(log_file, "w") as f:
251303
json.dump(log_params, f, indent=2, sort_keys=True)
@@ -258,13 +310,13 @@ def stub_to_json(stub_sth):
258310
data = dict()
259311
for k, v in stub_sth.kwargs.items():
260312
data[k] = stub_to_json(v)
261-
data["_name"] = stub_sth.proxy_class.__module__ + "." + stub_sth.proxy_class.__name__
313+
data[
314+
"_name"] = stub_sth.proxy_class.__module__ + "." + stub_sth.proxy_class.__name__
262315
return data
263316
elif isinstance(stub_sth, instrument.StubAttr):
264317
return dict(
265318
obj=stub_to_json(stub_sth.obj),
266-
attr=stub_to_json(stub_sth.attr_name)
267-
)
319+
attr=stub_to_json(stub_sth.attr_name))
268320
elif isinstance(stub_sth, instrument.StubMethodCall):
269321
return dict(
270322
obj=stub_to_json(stub_sth.obj),
@@ -294,7 +346,10 @@ def default(self, o):
294346
if isinstance(o, type):
295347
return {'$class': o.__module__ + "." + o.__name__}
296348
elif isinstance(o, Enum):
297-
return {'$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name}
349+
return {
350+
'$enum':
351+
o.__module__ + "." + o.__class__.__name__ + '.' + o.name
352+
}
298353
return json.JSONEncoder.default(self, o)
299354

300355

scripts/run_experiment_lite.py

+92-35
Original file line numberDiff line numberDiff line change
@@ -29,41 +29,95 @@ def run_experiment(argv):
2929

3030
default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
3131
parser = argparse.ArgumentParser()
32-
parser.add_argument('--n_parallel', type=int, default=1,
33-
help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers')
34-
parser.add_argument(
35-
'--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
36-
parser.add_argument('--log_dir', type=str, default=None,
37-
help='Path to save the log and iteration snapshot.')
38-
parser.add_argument('--snapshot_mode', type=str, default='all',
39-
help='Mode to save the snapshot. Can be either "all" '
40-
'(all iterations will be saved), "last" (only '
41-
'the last iteration will be saved), "gap" (every'
42-
'`snapshot_gap` iterations are saved), or "none" '
43-
'(do not save snapshots)')
44-
parser.add_argument('--snapshot_gap', type=int, default=1,
45-
help='Gap between snapshot iterations.')
46-
parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
47-
help='Name of the tabular log file (in csv).')
48-
parser.add_argument('--text_log_file', type=str, default='debug.log',
49-
help='Name of the text log file (in pure text).')
50-
parser.add_argument('--params_log_file', type=str, default='params.json',
51-
help='Name of the parameter log file (in json).')
52-
parser.add_argument('--variant_log_file', type=str, default='variant.json',
53-
help='Name of the variant log file (in json).')
54-
parser.add_argument('--resume_from', type=str, default=None,
55-
help='Name of the pickle file to resume experiment from.')
56-
parser.add_argument('--plot', type=ast.literal_eval, default=False,
57-
help='Whether to plot the iteration results')
58-
parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
59-
help='Whether to only print the tabular log information (in a horizontal format)')
60-
parser.add_argument('--seed', type=int,
61-
help='Random seed for numpy')
62-
parser.add_argument('--args_data', type=str,
63-
help='Pickled data for stub objects')
64-
parser.add_argument('--variant_data', type=str,
65-
help='Pickled data for variant configuration')
66-
parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)
32+
parser.add_argument(
33+
'--n_parallel',
34+
type=int,
35+
default=1,
36+
help=
37+
'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
38+
)
39+
parser.add_argument(
40+
'--exp_name',
41+
type=str,
42+
default=default_exp_name,
43+
help='Name of the experiment.')
44+
parser.add_argument(
45+
'--log_dir',
46+
type=str,
47+
default=None,
48+
help='Path to save the log and iteration snapshot.')
49+
parser.add_argument(
50+
'--snapshot_mode',
51+
type=str,
52+
default='all',
53+
help='Mode to save the snapshot. Can be either "all" '
54+
'(all iterations will be saved), "last" (only '
55+
'the last iteration will be saved), "gap" (every'
56+
'`snapshot_gap` iterations are saved), or "none" '
57+
'(do not save snapshots)')
58+
parser.add_argument(
59+
'--snapshot_gap',
60+
type=int,
61+
default=1,
62+
help='Gap between snapshot iterations.')
63+
parser.add_argument(
64+
'--tabular_log_file',
65+
type=str,
66+
default='progress.csv',
67+
help='Name of the tabular log file (in csv).')
68+
parser.add_argument(
69+
'--text_log_file',
70+
type=str,
71+
default='debug.log',
72+
help='Name of the text log file (in pure text).')
73+
parser.add_argument(
74+
'--tensorboard_log_dir',
75+
type=str,
76+
default='progress',
77+
help='Name of the folder for tensorboard_summary.')
78+
parser.add_argument(
79+
'--tensorboard_step_key',
80+
type=str,
81+
default=None,
82+
help=
83+
'Name of the step key in log data which shows the step in tensorboard_summary.'
84+
)
85+
parser.add_argument(
86+
'--params_log_file',
87+
type=str,
88+
default='params.json',
89+
help='Name of the parameter log file (in json).')
90+
parser.add_argument(
91+
'--variant_log_file',
92+
type=str,
93+
default='variant.json',
94+
help='Name of the variant log file (in json).')
95+
parser.add_argument(
96+
'--resume_from',
97+
type=str,
98+
default=None,
99+
help='Name of the pickle file to resume experiment from.')
100+
parser.add_argument(
101+
'--plot',
102+
type=ast.literal_eval,
103+
default=False,
104+
help='Whether to plot the iteration results')
105+
parser.add_argument(
106+
'--log_tabular_only',
107+
type=ast.literal_eval,
108+
default=False,
109+
help=
110+
'Whether to only print the tabular log information (in a horizontal format)'
111+
)
112+
parser.add_argument('--seed', type=int, help='Random seed for numpy')
113+
parser.add_argument(
114+
'--args_data', type=str, help='Pickled data for stub objects')
115+
parser.add_argument(
116+
'--variant_data',
117+
type=str,
118+
help='Pickled data for variant configuration')
119+
parser.add_argument(
120+
'--use_cloudpickle', type=ast.literal_eval, default=False)
67121

68122
args = parser.parse_args(argv[1:])
69123

@@ -87,6 +141,7 @@ def run_experiment(argv):
87141
tabular_log_file = osp.join(log_dir, args.tabular_log_file)
88142
text_log_file = osp.join(log_dir, args.text_log_file)
89143
params_log_file = osp.join(log_dir, args.params_log_file)
144+
tensorboard_log_dir = osp.join(log_dir, args.tensorboard_log_dir)
90145

91146
if args.variant_data is not None:
92147
variant_data = pickle.loads(base64.b64decode(args.variant_data))
@@ -100,12 +155,14 @@ def run_experiment(argv):
100155

101156
logger.add_text_output(text_log_file)
102157
logger.add_tabular_output(tabular_log_file)
158+
logger.set_tensorboard_dir(tensorboard_log_dir)
103159
prev_snapshot_dir = logger.get_snapshot_dir()
104160
prev_mode = logger.get_snapshot_mode()
105161
logger.set_snapshot_dir(log_dir)
106162
logger.set_snapshot_mode(args.snapshot_mode)
107163
logger.set_snapshot_gap(args.snapshot_gap)
108164
logger.set_log_tabular_only(args.log_tabular_only)
165+
logger.set_tensorboard_step_key(args.tensorboard_step_key)
109166
logger.push_prefix("[%s] " % args.exp_name)
110167

111168
if args.resume_from is not None:

0 commit comments

Comments
 (0)