From 790bef602b6224bf1a729a325aee1dfee3c55c06 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 21 Jan 2025 09:43:36 +0000 Subject: [PATCH] [Feature] flexible return type when indexing prob sequences ghstack-source-id: 74d28ee84d965c11c527c60b20d9123ef30007f6 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1189 --- tensordict/_contextlib.py | 6 +++++- tensordict/nn/probabilistic.py | 21 +++++++++++++++++++++ test/test_nn.py | 19 +++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/tensordict/_contextlib.py b/tensordict/_contextlib.py index ac5f74b18..db31e907a 100644 --- a/tensordict/_contextlib.py +++ b/tensordict/_contextlib.py @@ -330,7 +330,11 @@ def _reverse_squeeze(self, args, kwargs, out): def _reverse_to_module(self, args, kwargs, out): try: - with out.unlock_() if not is_compiling() else contextlib.nullcontext(): + with ( + out.unlock_() + if not is_compiling() and out is not None + else contextlib.nullcontext() + ): return self.to_module(*args, **kwargs, swap_dest=out) except AttributeError: # This is a bit unsafe but we assume that out won't have an unlock_() if it's not a TD diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 9e043554b..1633b01ad 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -990,6 +990,27 @@ def __init__( super().__init__(*modules, partial_tolerant=partial_tolerant) self.return_composite = return_composite + def __getitem__(self, index: int | slice | str) -> TensorDictModuleBase: + if isinstance(index, (int, str)): + return self.module.__getitem__(index) + else: + mods = self.module.__getitem__(index) + if self.return_composite and any( + isinstance( + item, + (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential), + ) + for item in mods + ): + return type(self)(*mods, return_composite=self.return_composite) + elif isinstance( + mods[-1], + (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential), + ): + return type(self)(*mods) + else: + return TensorDictSequential(*mods) + _dist_sample = ProbabilisticTensorDictModule._dist_sample @property diff --git a/test/test_nn.py b/test/test_nn.py index 46a84001a..a2c3597b7 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2251,6 +2251,25 @@ def test_nested_keys_probabilistic_normal(self, log_prob_key): else: assert td_out[module.log_prob_key].shape == (3, 4, 1) + def test_index_prob_seq(self): + m0 = ProbabilisticTensorDictModule( + in_keys=["loc"], out_keys=["sample"], distribution_class=Normal + ) + m1 = TensorDictModule(lambda x: x, in_keys=["other"], out_keys=["something"]) + m2 = ProbabilisticTensorDictModule( + in_keys=["scale"], out_keys=["sample2"], distribution_class=Normal + ) + seq = ProbabilisticTensorDictSequential(m0, m1, m2) + assert isinstance(seq[0], ProbabilisticTensorDictModule) + assert isinstance(seq[:2], TensorDictSequential) + assert not isinstance(seq[:2], ProbabilisticTensorDictSequential) + assert isinstance(seq[-2:], ProbabilisticTensorDictSequential) + + seq = ProbabilisticTensorDictSequential(m0, m1, m2, return_composite=True) + assert isinstance(seq[0], ProbabilisticTensorDictModule) + assert isinstance(seq[:2], ProbabilisticTensorDictSequential) + assert isinstance(seq[-2:], ProbabilisticTensorDictSequential) + class TestEnsembleModule: def test_init(self):