Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 13, 2024
1 parent 8260544 commit 470b831
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def make_storage():
rb_trainer2.register(trainer2)
if re_init:
trainer2._process_batch_hook(td.to_tensordict().zero_())
trainer2.load_from_file(file)
trainer2.load_from_file(file, weights_only=False)
assert state_dict_has_been_called[0]
assert load_state_dict_has_been_called[0]
assert state_dict_has_been_called_td[0]
Expand Down
9 changes: 7 additions & 2 deletions torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,17 @@ def save_trainer(self, force_save: bool = False) -> None:
if _save and self.save_trainer_file:
self._save_trainer()

def load_from_file(self, file: Union[str, pathlib.Path]) -> Trainer:
def load_from_file(self, file: Union[str, pathlib.Path], **kwargs) -> Trainer:
"""Loads a file and its state-dict in the trainer.
Keyword arguments are passed to the :func:`~torch.load` function.
"""
if _CKPT_BACKEND == "torchsnapshot":
snapshot = Snapshot(path=file)
snapshot.restore(app_state=self.app_state)
elif _CKPT_BACKEND == "torch":
loaded_dict: OrderedDict = torch.load(file)
loaded_dict: OrderedDict = torch.load(file, **kwargs)
self.load_state_dict(loaded_dict)
return self

Expand Down

0 comments on commit 470b831

Please sign in to comment.