From 6427212f00dba6d8cae5f47e9c97489b946af3fc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 23 Apr 2024 21:22:04 +0100 Subject: [PATCH] amend --- sota-implementations/redq/redq.py | 18 +++++++++--------- torchrl/data/replay_buffers/replay_buffers.py | 5 ++++- torchrl/data/replay_buffers/storages.py | 8 +++++++- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/sota-implementations/redq/redq.py b/sota-implementations/redq/redq.py index d6b1668aadf..c6b96db9292 100644 --- a/sota-implementations/redq/redq.py +++ b/sota-implementations/redq/redq.py @@ -76,7 +76,7 @@ def main(cfg: "DictConfig"): # noqa: F821 }, ) else: - logger = "" + logger = None key, init_env_steps, stats = None, None, None if not cfg.env.vecnorm and cfg.env.norm_stats: @@ -174,14 +174,14 @@ def main(cfg: "DictConfig"): # noqa: F821 t.loc.fill_(0.0) trainer = make_trainer( - collector, - loss_module, - recorder, - target_net_updater, - actor_model_explore, - replay_buffer, - logger, - cfg, + collector=collector, + loss_module=loss_module, + recorder=recorder, + target_net_updater=target_net_updater, + policy_exploration=actor_model_explore, + replay_buffer=replay_buffer, + logger=logger, + cfg=cfg, ) trainer.train() diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 68835e8dc3c..f75eaaeedcf 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -39,6 +39,7 @@ ) from torchrl.data.replay_buffers.storages import ( _get_default_collate, + _stack_anything, ListStorage, Storage, StorageEnsemble, @@ -1541,8 +1542,10 @@ def __init__( num_buffer_sampled: int | None = None, **kwargs, ): + if collate_fn is None: - collate_fn = LazyStackedTensorDict.maybe_dense_stack + collate_fn = _stack_anything + if rbs: if storages is not None or samplers is not None or writers is not None: raise RuntimeError diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 3cb809814b3..a1ada2eb72e 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1323,10 +1323,16 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor: def _collate_list_tensordict(x): - out = LazyStackedTensorDict.maybe_dense_stack(x, 0) + out = torch.stack(x, 0) return out +def _stack_anything(x): + if is_tensor_collection(x[0]): + return LazyStackedTensorDict.maybe_dense_stack(x) + return torch.stack(x) + + def _collate_id(x): return x