Skip to content

Commit

Permalink
[tune] [rllib] Allow checkpointing to object store instead of local d…
Browse files Browse the repository at this point in the history
…isk (#1212)

* wip

* use normal pickle

* fix checkpoint test

* comment

* Comment

* fix test

* fix lint

* fix py 3.5

* Update agent.py

* fix lint
  • Loading branch information
ericl authored Nov 19, 2017
1 parent d986294 commit ae4e1dd
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 12 deletions.
51 changes: 51 additions & 0 deletions python/ray/rllib/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

import logging
import numpy as np
import io
import os
import gzip
import pickle
import shutil
import tempfile
import time
import uuid
Expand Down Expand Up @@ -147,6 +150,35 @@ def save(self):
open(checkpoint_path + ".rllib_metadata", "wb"))
return checkpoint_path

def save_to_object(self):
"""Saves the current model state to a Python object. It also
saves to disk but does not return the checkpoint path.
Returns:
Object holding checkpoint data.
"""

checkpoint_prefix = self.save()

data = {}
base_dir = os.path.dirname(checkpoint_prefix)
for path in os.listdir(base_dir):
path = os.path.join(base_dir, path)
if path.startswith(checkpoint_prefix):
data[os.path.basename(path)] = open(path, "rb").read()

out = io.BytesIO()
with gzip.GzipFile(fileobj=out, mode="wb") as f:
compressed = pickle.dumps({
"checkpoint_name": os.path.basename(checkpoint_prefix),
"data": data,
})
print("Saving checkpoint to object store, {} bytes".format(
len(compressed)))
f.write(compressed)

return out.getvalue()

def restore(self, checkpoint_path):
"""Restores training state from a given model checkpoint.
Expand All @@ -160,6 +192,25 @@ def restore(self, checkpoint_path):
self._timesteps_total = metadata[2]
self._time_total = metadata[3]

def restore_from_object(self, obj):
"""Restores training state from a checkpoint object.
These checkpoints are returned from calls to save_to_object().
"""

out = io.BytesIO(obj)
info = pickle.loads(gzip.GzipFile(fileobj=out, mode="rb").read())
data = info["data"]
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])

for file_name, file_contents in data.items():
with open(os.path.join(tmpdir, file_name), "wb") as f:
f.write(file_contents)

self.restore(checkpoint_path)
shutil.rmtree(tmpdir)

def stop(self):
"""Releases all resources used by this agent."""

Expand Down
21 changes: 17 additions & 4 deletions python/ray/rllib/test/test_checkpoint_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ def get_mean_action(alg, obs):
"A3C": {"use_lstm": False},
}

for name in ["ES", "DQN", "PPO", "A3C"]:
cls = get_agent_class(name)

def test(use_object_store, alg_name):
cls = get_agent_class(alg_name)
alg1 = cls("CartPole-v0", CONFIGS[name])
alg2 = cls("CartPole-v0", CONFIGS[name])

Expand All @@ -36,11 +37,23 @@ def get_mean_action(alg, obs):
print("current status: " + str(res))

# Sync the models
alg2.restore(alg1.save())
if use_object_store:
alg2.restore_from_object(alg1.save_to_object())
else:
alg2.restore(alg1.save())

for _ in range(10):
obs = np.random.uniform(size=4)
a1 = get_mean_action(alg1, obs)
a2 = get_mean_action(alg2, obs)
print("Checking computed actions", alg1, obs, a1, a2)
assert abs(a1-a2) < .1, (a1, a2)
assert abs(a1 - a2) < .1, (a1, a2)


if __name__ == "__main__":
# https://github.com/ray-project/ray/issues/1062 for enabling ES test too
for use_object_store in [False, True]:
for name in ["ES", "DQN", "PPO", "A3C"]:
test(use_object_store, name)

print("All checkpoint restore tests passed!")
40 changes: 32 additions & 8 deletions python/ray/tune/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
# Local trial state that is updated during the run
self.last_result = None
self._checkpoint_path = restore_path
self._checkpoint_obj = None
self.agent = None
self.status = Trial.PENDING
self.location = None
Expand All @@ -106,7 +107,9 @@ def start(self):

self._setup_agent()
if self._checkpoint_path:
self.restore_from_path(path=self._checkpoint_path)
self.restore_from_path(self._checkpoint_path)
elif self._checkpoint_obj:
self.restore_from_obj(self._checkpoint_obj)

def stop(self, error=False, stop_logger=True):
"""Stops this trial.
Expand Down Expand Up @@ -152,7 +155,7 @@ def pause(self):

assert self.status == Trial.RUNNING, self.status
try:
self.checkpoint()
self.checkpoint(to_object_store=True)
self.stop(stop_logger=False)
self.status = Trial.PAUSED
except Exception:
Expand Down Expand Up @@ -226,16 +229,25 @@ def location_string(hostname, pid):

return ', '.join(pieces)

def checkpoint(self):
"""Synchronously checkpoints the state of this trial.
def checkpoint(self, to_object_store=False):
"""Checkpoints the state of this trial.
TODO(ekl): we should support a PAUSED state based on checkpointing.
Args:
to_object_store (bool): Whether to save to the Ray object store
(async) vs a path on local disk (sync).
"""

path = ray.get(self.agent.save.remote())
obj = None
path = None
if to_object_store:
obj = self.agent.save_to_object.remote()
else:
path = ray.get(self.agent.save.remote())
self._checkpoint_path = path
print("Saved checkpoint to:", path)
return path
self._checkpoint_obj = obj

print("Saved checkpoint to:", path or obj)
return path or obj

def restore_from_path(self, path):
"""Restores agent state from specified path.
Expand All @@ -253,6 +265,18 @@ def restore_from_path(self, path):
print("Error restoring agent:", traceback.format_exc())
self.status = Trial.ERROR

def restore_from_obj(self, obj):
"""Restores agent state from the specified object."""

if self.agent is None:
print("Unable to restore - no agent")
else:
try:
ray.get(self.agent.restore_from_object.remote(obj))
except Exception:
print("Error restoring agent:", traceback.format_exc())
self.status = Trial.ERROR

def _setup_agent(self):
self.status = Trial.RUNNING
agent_cls = get_agent_class(self.alg)
Expand Down

0 comments on commit ae4e1dd

Please sign in to comment.