Skip to content

Commit

Permalink
Fix specs for BinarizeReward and CatFrames transforms (#86)
Browse files Browse the repository at this point in the history
Failing tests for SSL reason
  • Loading branch information
vmoens committed Apr 21, 2022
1 parent 308ed49 commit 60b59b1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 21 deletions.
4 changes: 4 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,10 @@ def test_noop_reset_env(self, random):
def test_binerized_reward(self, device):
pass

@pytest.mark.parametrize("device", get_available_devices())
def test_reward_scaling(self, device):
pass

@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda device found")
@pytest.mark.parametrize("device", get_available_devices())
def test_pin_mem(self, device):
Expand Down
36 changes: 15 additions & 21 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
NdUnboundedContinuousTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
BinaryDiscreteTensorSpec,
)
from torchrl.data.tensordict.tensordict import _TensorDict, TensorDict
from torchrl.envs.common import _EnvClass, make_tensordict
Expand All @@ -50,7 +51,7 @@
"DoubleToFloat",
"CatTensors",
"NoopResetEnv",
"BinerizeReward",
"BinarizeReward",
"PinMemoryTransform",
"VecNorm",
"gSDENoise",
Expand Down Expand Up @@ -576,7 +577,7 @@ def __repr__(self) -> str:
)


class BinerizeReward(Transform):
class BinarizeReward(Transform):
"""
Maps the reward to a binary value (0 or 1) if the reward is null or
non-null, respectively.
Expand All @@ -591,19 +592,14 @@ def __init__(self, keys: Optional[Sequence[str]] = None):
super().__init__(keys=keys)

def _apply(self, reward: torch.Tensor) -> torch.Tensor:
return (reward != 0.0).to(reward.dtype)
if not reward.shape or reward.shape[-1] != 1:
raise RuntimeError(
f"Reward shape last dimension must be singleton, got reward of shape {reward.shape}"
)
return (reward > 0.0).to(torch.long)

def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
if isinstance(reward_spec, UnboundedContinuousTensorSpec):
return BoundedTensorSpec(
0.0, 1.0, device=reward_spec.device, dtype=reward_spec.dtype
)
else:
raise NotImplementedError(
f"{self.__class__.__name__}.transform_reward_spec not "
f"implemented for tensor spec of type "
f"{type(reward_spec).__name__}"
)
return BinaryDiscreteTensorSpec(n=1, device=reward_spec.device)


class Resize(ObservationTransform):
Expand Down Expand Up @@ -860,14 +856,12 @@ def reset(self, tensordict: _TensorDict) -> _TensorDict:

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
if isinstance(observation_spec, CompositeSpec):
return CompositeSpec(
**{
key: self.transform_observation_spec(_obs_spec)
if key in self.keys
else _obs_spec
for key, _obs_spec in observation_spec._specs.items()
}
)
keys = [key for key in observation_spec.keys() if key in self.keys]
for key in keys:
observation_spec[key] = self.transform_observation_spec(
observation_spec[key]
)
return observation_spec
else:
_observation_spec = observation_spec
space = _observation_spec.space
Expand Down

0 comments on commit 60b59b1

Please sign in to comment.