From 60b59b11074e453276b89eb4ccbdbec0165bfa20 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 21 Apr 2022 10:53:36 +0100 Subject: [PATCH] Fix specs for BinarizeReward and CatFrames transforms (#86) Failing tests for SSL reason --- test/test_transforms.py | 4 +++ torchrl/envs/transforms/transforms.py | 36 +++++++++++---------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index c9d7b56f82e..1c8819897eb 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -663,6 +663,10 @@ def test_noop_reset_env(self, random): def test_binerized_reward(self, device): pass + @pytest.mark.parametrize("device", get_available_devices()) + def test_reward_scaling(self, device): + pass + @pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda device found") @pytest.mark.parametrize("device", get_available_devices()) def test_pin_mem(self, device): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index bc2dc69419f..6ae56a90410 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -27,6 +27,7 @@ NdUnboundedContinuousTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, + BinaryDiscreteTensorSpec, ) from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict from torchrl.envs.common import _EnvClass, make_tensordict @@ -50,7 +51,7 @@ "DoubleToFloat", "CatTensors", "NoopResetEnv", - "BinerizeReward", + "BinarizeReward", "PinMemoryTransform", "VecNorm", "gSDENoise", @@ -576,7 +577,7 @@ def __repr__(self) -> str: ) -class BinerizeReward(Transform): +class BinarizeReward(Transform): """ Maps the reward to a binary value (0 or 1) if the reward is null or non-null, respectively. @@ -591,19 +592,14 @@ def __init__(self, keys: Optional[Sequence[str]] = None): super().__init__(keys=keys) def _apply(self, reward: torch.Tensor) -> torch.Tensor: - return (reward != 0.0).to(reward.dtype) + if not reward.shape or reward.shape[-1] != 1: + raise RuntimeError( + f"Reward shape last dimension must be singleton, got reward of shape {reward.shape}" + ) + return (reward > 0.0).to(torch.long) def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: - if isinstance(reward_spec, UnboundedContinuousTensorSpec): - return BoundedTensorSpec( - 0.0, 1.0, device=reward_spec.device, dtype=reward_spec.dtype - ) - else: - raise NotImplementedError( - f"{self.__class__.__name__}.transform_reward_spec not " - f"implemented for tensor spec of type " - f"{type(reward_spec).__name__}" - ) + return BinaryDiscreteTensorSpec(n=1, device=reward_spec.device) class Resize(ObservationTransform): @@ -860,14 +856,12 @@ def reset(self, tensordict: _TensorDict) -> _TensorDict: def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: if isinstance(observation_spec, CompositeSpec): - return CompositeSpec( - **{ - key: self.transform_observation_spec(_obs_spec) - if key in self.keys - else _obs_spec - for key, _obs_spec in observation_spec._specs.items() - } - ) + keys = [key for key in observation_spec.keys() if key in self.keys] + for key in keys: + observation_spec[key] = self.transform_observation_spec( + observation_spec[key] + ) + return observation_spec else: _observation_spec = observation_spec space = _observation_spec.space