Skip to content

Commit

Permalink
[BugFix] Fix tests failing because of pytorch/pytorch#137602
Browse files Browse the repository at this point in the history
cc mikaylagawarecki albanD

ghstack-source-id: 6fc7434a259f92b0fca8875b20ac22624ecf1a03
Pull Request resolved: #2558
  • Loading branch information
vmoens committed Nov 13, 2024
1 parent 50a35f6 commit 165163a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def make_storage():
rb_trainer2.register(trainer2)
if re_init:
trainer2._process_batch_hook(td.to_tensordict().zero_())
trainer2.load_from_file(file)
trainer2.load_from_file(file, weights_only=False)
assert state_dict_has_been_called[0]
assert load_state_dict_has_been_called[0]
assert state_dict_has_been_called_td[0]
Expand Down
9 changes: 7 additions & 2 deletions torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,17 @@ def save_trainer(self, force_save: bool = False) -> None:
if _save and self.save_trainer_file:
self._save_trainer()

def load_from_file(self, file: Union[str, pathlib.Path]) -> Trainer:
def load_from_file(self, file: Union[str, pathlib.Path], **kwargs) -> Trainer:
"""Loads a file and its state-dict in the trainer.
Keyword arguments are passed to the :func:`~torch.load` function.
"""
if _CKPT_BACKEND == "torchsnapshot":
snapshot = Snapshot(path=file)
snapshot.restore(app_state=self.app_state)
elif _CKPT_BACKEND == "torch":
loaded_dict: OrderedDict = torch.load(file)
loaded_dict: OrderedDict = torch.load(file, **kwargs)
self.load_state_dict(loaded_dict)
return self

Expand Down

1 comment on commit 165163a

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 165163a Previous: 50a35f6 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 37.82997442983486 iter/sec (stddev: 0.15613410023354232) 225.78334844105714 iter/sec (stddev: 0.0009734842916720349) 5.97

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.