From 6b87184c0dcfb42c6b5a898701c3469572da7b84 Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Thu, 18 Apr 2024 17:04:24 +0200 Subject: [PATCH] [Feature] Extend TensorDictPrimer default_value options (#2071) Co-authored-by: Vincent Moens --- test/test_transforms.py | 89 +++++++++++++++++---- torchrl/envs/transforms/transforms.py | 104 +++++++++++++++++++------ torchrl/objectives/value/functional.py | 1 - 3 files changed, 154 insertions(+), 40 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 6e978d0ab5b..40c4f76e539 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -6423,17 +6423,11 @@ def test_trans_parallel_env_check(self): finally: env.close() - def test_trans_serial_env_check(self): - with pytest.raises(RuntimeError, match="The leading shape of the primer specs"): - env = TransformedEnv( - SerialEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])), - ) - _ = env.observation_spec - + @pytest.mark.parametrize("spec_shape", [[4], [2, 4]]) + def test_trans_serial_env_check(self, spec_shape): env = TransformedEnv( SerialEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), + TensorDictPrimer(mykey=UnboundedContinuousTensorSpec(spec_shape)), ) check_env_specs(env) assert "mykey" in env.reset().keys() @@ -6533,6 +6527,72 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done): r1 = env.rollout(100, break_when_any_done=break_when_any_done) tensordict.tensordict.assert_allclose_td(r0, r1) + def test_callable_default_value(self): + def create_tensor(): + return torch.ones(3) + + env = TransformedEnv( + ContinuousActionVecMockEnv(), + TensorDictPrimer( + mykey=UnboundedContinuousTensorSpec([3]), default_value=create_tensor + ), + ) + check_env_specs(env) + assert "mykey" in env.reset().keys() + assert ("next", "mykey") in env.rollout(3).keys(True) + + def test_dict_default_value(self): + + # Test with a dict of float default values + key1_spec = UnboundedContinuousTensorSpec([3]) + key2_spec = UnboundedContinuousTensorSpec([3]) + env = TransformedEnv( + ContinuousActionVecMockEnv(), + TensorDictPrimer( + mykey1=key1_spec, + mykey2=key2_spec, + default_value={ + "mykey1": 1.0, + "mykey2": 2.0, + }, + ), + ) + check_env_specs(env) + reset_td = env.reset() + assert "mykey1" in reset_td.keys() + assert "mykey2" in reset_td.keys() + rollout_td = env.rollout(3) + assert ("next", "mykey1") in rollout_td.keys(True) + assert ("next", "mykey2") in rollout_td.keys(True) + assert (rollout_td.get(("next", "mykey1")) == 1.0).all() + assert (rollout_td.get(("next", "mykey2")) == 2.0).all() + + # Test with a dict of callable default values + key1_spec = UnboundedContinuousTensorSpec([3]) + key2_spec = DiscreteTensorSpec(3, dtype=torch.int64) + env = TransformedEnv( + ContinuousActionVecMockEnv(), + TensorDictPrimer( + mykey1=key1_spec, + mykey2=key2_spec, + default_value={ + "mykey1": lambda: torch.ones(3), + "mykey2": lambda: torch.tensor(1, dtype=torch.int64), + }, + ), + ) + check_env_specs(env) + reset_td = env.reset() + assert "mykey1" in reset_td.keys() + assert "mykey2" in reset_td.keys() + rollout_td = env.rollout(3) + assert ("next", "mykey1") in rollout_td.keys(True) + assert ("next", "mykey2") in rollout_td.keys(True) + assert (rollout_td.get(("next", "mykey1")) == torch.ones(3)).all + assert ( + rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64) + ).all + class TestTimeMaxPool(TransformBase): @pytest.mark.parametrize("T", [2, 4]) @@ -6813,18 +6873,13 @@ def make_env(): finally: env.close() - def test_trans_serial_env_check(self): + @pytest.mark.parametrize("shape", [(), (2,)]) + def test_trans_serial_env_check(self, shape): state_dim = 7 action_dim = 7 - with pytest.raises(RuntimeError, match="The leading shape of the primer"): - env = TransformedEnv( - SerialEnv(2, ContinuousActionVecMockEnv), - gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=()), - ) - check_env_specs(env) env = TransformedEnv( SerialEnv(2, ContinuousActionVecMockEnv), - gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=(2,)), + gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=shape), ) try: check_env_specs(env) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 128c0fe51f8..6aeea0529ce 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4435,8 +4435,12 @@ class TensorDictPrimer(Transform): random (bool, optional): if ``True``, the values will be drawn randomly from the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed. Defaults to `False`. - default_value (float, optional): if non-random filling is chosen, this - value will be used to populate the tensors. Defaults to `0.0`. + default_value (float, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random + filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float, + all elements of the tensors will be set to that value. If it is a callable, this callable is expected to + return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value` + is a dictionary of tensors or a dictionary of callables with keys matching those of the specs, these will + be used to generate the corresponding tensors. Defaults to `0.0`. reset_key (NestedKey, optional): the reset key to be used as partial reset indicator. Must be unique. If not provided, defaults to the only reset key of the parent environment (if it has only one) @@ -4493,8 +4497,11 @@ class TensorDictPrimer(Transform): def __init__( self, primers: dict | CompositeSpec = None, - random: bool = False, - default_value: float = 0.0, + random: bool | None = None, + default_value: float + | Callable + | Dict[NestedKey, float] + | Dict[NestedKey, Callable] = None, reset_key: NestedKey | None = None, **kwargs, ): @@ -4509,8 +4516,31 @@ def __init__( if not isinstance(kwargs, CompositeSpec): kwargs = CompositeSpec(kwargs) self.primers = kwargs + if random and default_value: + raise ValueError( + "Setting random to True and providing a default_value are incompatible." + ) + default_value = ( + default_value or 0.0 + ) # if not random and no default value, use 0.0 self.random = random + if isinstance(default_value, dict): + default_value = TensorDict(default_value, []) + default_value_keys = default_value.keys( + True, + True, + is_leaf=lambda x: issubclass(x, (NonTensorData, torch.Tensor)), + ) + if set(default_value_keys) != set(self.primers.keys(True, True)): + raise ValueError( + "If a default_value dictionary is provided, it must match the primers keys." + ) + else: + default_value = { + key: default_value for key in self.primers.keys(True, True) + } self.default_value = default_value + self._validated = False self.reset_key = reset_key # sanity check @@ -4563,6 +4593,9 @@ def to(self, *args, **kwargs): self.primers = self.primers.to(device) return super().to(*args, **kwargs) + def _expand_shape(self, spec): + return spec.expand((*self.parent.batch_size, *spec.shape)) + def transform_observation_spec( self, observation_spec: CompositeSpec ) -> CompositeSpec: @@ -4572,15 +4605,13 @@ def transform_observation_spec( ) for key, spec in self.primers.items(): if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: - raise RuntimeError( - f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " - f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}." - ) + expanded_spec = self._expand_shape(spec) + spec = expanded_spec try: device = observation_spec.device except RuntimeError: device = self.device - observation_spec[key] = spec.to(device) + observation_spec[key] = self.primers[key] = spec.to(device) return observation_spec def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: @@ -4593,8 +4624,13 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: def _batch_size(self): return self.parent.batch_size + def _validate_value_tensor(self, value, spec): + if not spec.is_in(value): + raise RuntimeError(f"Value ({value}) is not in the spec domain ({spec}).") + return True + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - for key, spec in self.primers.items(): + for key, spec in self.primers.items(True, True): if spec.shape[: len(tensordict.shape)] != tensordict.shape: raise RuntimeError( "The leading shape of the spec must match the tensordict's, " @@ -4605,11 +4641,21 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.random: value = spec.rand() else: - value = torch.full_like( - spec.zero(), - self.default_value, - ) + value = self.default_value[key] + if callable(value): + value = value() + if not self._validated: + self._validate_value_tensor(value, spec) + else: + value = torch.full( + spec.shape, + value, + device=spec.device, + ) + tensordict.set(key, value) + if not self._validated: + self._validated = True return tensordict def _step( @@ -4638,22 +4684,36 @@ def _reset( ) _reset = _get_reset(self.reset_key, tensordict) if _reset.any(): - for key, spec in self.primers.items(): + for key, spec in self.primers.items(True, True): if self.random: value = spec.rand(shape) else: - value = torch.full_like( - spec.zero(shape), - self.default_value, - ) - prev_val = tensordict.get(key, 0.0) - value = torch.where(expand_as_right(_reset, value), value, prev_val) + value = self.default_value[key] + if callable(value): + value = value() + if not self._validated: + self._validate_value_tensor(value, spec) + else: + value = torch.full( + spec.shape, + value, + device=spec.device, + ) + prev_val = tensordict.get(key, 0.0) + value = torch.where( + expand_as_right(_reset, value), value, prev_val + ) tensordict_reset.set(key, value) + self._validated = True return tensordict_reset def __repr__(self) -> str: class_name = self.__class__.__name__ - return f"{class_name}(primers={self.primers}, default_value={self.default_value}, random={self.random})" + default_value = { + key: value if isinstance(value, float) else "Callable" + for key, value in self.default_value.items() + } + return f"{class_name}(primers={self.primers}, default_value={default_value}, random={self.random})" class PinMemoryTransform(Transform): diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index e93386b34ef..082c0ae9e9a 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import functools import math import warnings