Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Saving and loading a ReplayBuffer to disk #1588

Closed
daniloml opened this issue Oct 1, 2023 · 6 comments · Fixed by #1733
Closed

[Question] Saving and loading a ReplayBuffer to disk #1588

daniloml opened this issue Oct 1, 2023 · 6 comments · Fixed by #1733

Comments

@daniloml
Copy link

daniloml commented Oct 1, 2023

Hello

I'm using a TensorDictReplayBuffer with a LazyTensorStorage for training a model. After the training, I need to save to disk the replay buffer for future use. I expected torch.save to be capable to pickle it, but the object is RLocked.

What would be the correct approach with torchrl?

I'm using torchrl==0.1.1

@vmoens
Copy link
Contributor

vmoens commented Oct 4, 2023

Hello!
This is something we don't document well (sigh!) but it's much easier than that (yet buried in the code, we should make that much more apparent and user facing)

from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from tensordict import TensorDict
s = LazyMemmapStorage(100, scratch_dir="./_dump")
td = TensorDict({"a": 0}, [])
s._init(td)
print("saved", s._storage)
s2 = LazyMemmapStorage(100, scratch_dir="./_dump")
s2._storage = TensorDict.load_memmap("./_dump")
print("loaded", s2._storage)

which will print the same result.

I will make a PR to have this functionality available at a high level!
Stay tuned!

@bmanczak
Copy link

bmanczak commented Dec 4, 2023

hey, I am also interested in this. @vmoens in your example you do not use TensorDictReplayBuffer.
How would you go about saving the buffer if there are many GB's of TensorDicts stored in the buffer (i.e. will not fit into the memory at once)?

@vmoens
Copy link
Contributor

vmoens commented Dec 4, 2023

It's part of the roadmap for the next release!

Here's what I'm envisioning:

  • the concept of state_dict is a bit blurry with tensordict, we don't know if we're talking about a regular dict or a state-dict as in PT so I would avoid using that name
  • there are various scenarios we need to account for:
    1. Easiest: the replay buffer is on disk already and uses a memmap storage
    2. Middle: the replay buffer is not on disk but it's still something that inherits from TensorStorage
    3. Hard: the replay buffer storage is something else (eg, a list)

For 3. I think we should simply have a state_dict for the storage and call it a day. It'll be a regular storage.load_state_dict(storage.state_dict()) kind of thing.

For 2. we could do

rb = TensorDictReplayBuffer(storage=storage)
...
rb.dumps(path)  # saves the storage (and some other stuff, see below)
rb.loads(path) # loads the storage from the path

For 3. dumps would copy the memmap tensors somewhere else, but if no path is provided (or the path matches the rb path) nothing will be done. Then we can simply do rb.loads(path) and get the rb back.

In my current view, rb.dumps is mainyly a wrapper around rb._storage.dumps but in some cases we'll need to save extra stuff, so all modules will need a dumps or equivalent:

  • prioritized replay buffers will need a way to save the segment trees.
  • TensorDictMaxValueWriter will need to save the heap.
  • stateful transforms will need to save their state too.

I'm not 100% sure of what shape dumps should take, ideally we would like for pickle not to be used at all (since we can avoid it for the storage). For the pure numerical data (TensorDictMaxValueWriter's heap and PRB's segment trees) a mmap file will do, but for statefu transforms this may not always be the case. On top of that, ensuring that all transforms have a proper way of being serialized that does not depend on state-dict can be challenging, so the first approach would be:

  • storage, writer, sampler -> mmap
  • transforms -> state-dict + pickle (plain torch.save / torch.load) until we find a better solution

Thoughts?

@vmoens vmoens linked a pull request Dec 5, 2023 that will close this issue
2 tasks
@vmoens
Copy link
Contributor

vmoens commented Dec 5, 2023

Have a look at #1733 which implements this feature!

@bmanczak
Copy link

bmanczak commented Dec 5, 2023

Wow, thanks for a swift response! The proposed solution does exactly what I need: will make restarting experiments a breeze and enable my team to use TorchRL buffers for our use-case.

When do you expect to merge?

@vmoens
Copy link
Contributor

vmoens commented Dec 5, 2023

Tonight or tomorrow, there are some loose ends with the max writer in the tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants