From 3667290e43c50fbc94232eff3964bba5c34f8c46 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 13 Jan 2023 17:07:31 +0000 Subject: [PATCH 01/30] init --- test/mocking_classes.py | 6 +- test/test_cost.py | 15 +- test/{test_tensor_spec.py => test_specs.py} | 326 +++++++++++++++++++- torchrl/data/__init__.py | 2 +- torchrl/data/tensor_specs.py | 291 ++++++++++++++++- torchrl/envs/common.py | 20 +- torchrl/envs/libs/brax.py | 6 +- torchrl/envs/libs/gym.py | 4 +- torchrl/envs/model_based/common.py | 4 +- torchrl/envs/model_based/dreamer.py | 1 + torchrl/envs/vec_env.py | 20 +- 11 files changed, 644 insertions(+), 51 deletions(-) rename test/{test_tensor_spec.py => test_specs.py} (77%) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 39525fec36f..62a6a5dcaac 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -12,7 +12,7 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, - MultOneHotDiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, ) @@ -25,7 +25,7 @@ "categorical": DiscreteTensorSpec, "unbounded": UnboundedContinuousTensorSpec, "binary": BinaryDiscreteTensorSpec, - "mult_one_hot": MultOneHotDiscreteTensorSpec, + "mult_one_hot": MultiOneHotDiscreteTensorSpec, "composite": CompositeSpec, } @@ -39,7 +39,7 @@ ] }, BinaryDiscreteTensorSpec: {"n": 7}, - MultOneHotDiscreteTensorSpec: {"nvec": [7, 3, 5]}, + MultiOneHotDiscreteTensorSpec: {"nvec": [7, 3, 5]}, CompositeSpec: {}, } diff --git a/test/test_cost.py b/test/test_cost.py index b84c8ed1f95..d1d5ac2bf36 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -31,7 +31,7 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, - MultOneHotDiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, ) @@ -141,7 +141,10 @@ def _create_mock_actor( return module.to(device) actor = QValueActor( spec=CompositeSpec( - action=action_spec, action_value=None, chosen_action_value=None + action=action_spec, + action_value=None, + chosen_action_value=None, + shape=[], ), module=module, ).to(device) @@ -160,7 +163,7 @@ def _create_mock_distributional_actor( ): # Actor if action_spec_type == "mult_one_hot": - action_spec = MultOneHotDiscreteTensorSpec([atoms] * action_dim) + action_spec = MultiOneHotDiscreteTensorSpec([atoms] * action_dim) elif action_spec_type == "one_hot": action_spec = OneHotDiscreteTensorSpec(action_dim) elif action_spec_type == "categorical": @@ -175,7 +178,11 @@ def _create_mock_distributional_actor( # if is_nn_module: # return module actor = DistributionalQValueActor( - spec=CompositeSpec(action=action_spec, action_value=None), + spec=CompositeSpec( + action=action_spec, + action_value=None, + shape=[], + ), module=module, support=support, action_space="categorical" diff --git a/test/test_tensor_spec.py b/test/test_specs.py similarity index 77% rename from test/test_tensor_spec.py rename to test/test_specs.py index 8d0732cbeae..96e5494ecfa 100644 --- a/test/test_tensor_spec.py +++ b/test/test_specs.py @@ -17,9 +17,10 @@ CompositeSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, - MultOneHotDiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, ) @@ -215,7 +216,7 @@ def test_binary(n, shape): def test_mult_onehot(shape, ns): torch.manual_seed(0) np.random.seed(0) - ts = MultOneHotDiscreteTensorSpec(nvec=ns) + ts = MultiOneHotDiscreteTensorSpec(nvec=ns) for _ in range(100): r = ts.rand(shape) assert r.shape == torch.Size( @@ -322,7 +323,7 @@ def test_discrete_conversion(n, device): @pytest.mark.parametrize("device", get_available_devices()) def test_multi_discrete_conversion(ns, device): categorical = MultiDiscreteTensorSpec(ns, device=device) - one_hot = MultOneHotDiscreteTensorSpec(ns, device=device) + one_hot = MultiOneHotDiscreteTensorSpec(ns, device=device) assert categorical != one_hot assert categorical.to_onehot() == one_hot @@ -818,33 +819,33 @@ def test_equality_multi_onehot(self, nvec): device = "cpu" dtype = torch.float16 - ts = MultOneHotDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + ts = MultiOneHotDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) - ts_same = MultOneHotDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + ts_same = MultiOneHotDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) assert ts == ts_same other_nvec = np.array(nvec) + 3 - ts_other = MultOneHotDiscreteTensorSpec( + ts_other = MultiOneHotDiscreteTensorSpec( nvec=other_nvec, device=device, dtype=dtype ) assert ts != ts_other other_nvec = [12] - ts_other = MultOneHotDiscreteTensorSpec( + ts_other = MultiOneHotDiscreteTensorSpec( nvec=other_nvec, device=device, dtype=dtype ) assert ts != ts_other other_nvec = [12, 13] - ts_other = MultOneHotDiscreteTensorSpec( + ts_other = MultiOneHotDiscreteTensorSpec( nvec=other_nvec, device=device, dtype=dtype ) assert ts != ts_other - ts_other = MultOneHotDiscreteTensorSpec(nvec=nvec, device="cpu:0", dtype=dtype) + ts_other = MultiOneHotDiscreteTensorSpec(nvec=nvec, device="cpu:0", dtype=dtype) assert ts != ts_other - ts_other = MultOneHotDiscreteTensorSpec( + ts_other = MultiOneHotDiscreteTensorSpec( nvec=nvec, device=device, dtype=torch.float64 ) assert ts != ts_other @@ -983,7 +984,7 @@ def test_discrete_action_spec_reconstruct(self, action_spec_cls): def test_mult_discrete_action_spec_reconstruct(self): torch.manual_seed(0) - action_spec = MultOneHotDiscreteTensorSpec((10, 5)) + action_spec = MultiOneHotDiscreteTensorSpec((10, 5)) actions_tensors = [action_spec.rand() for _ in range(10)] actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors] @@ -1034,7 +1035,7 @@ def test_mult_discrete_action_spec_rand(self): torch.manual_seed(0) ns = (10, 5) N = 100000 - action_spec = MultOneHotDiscreteTensorSpec((10, 5)) + action_spec = MultiOneHotDiscreteTensorSpec((10, 5)) actions_tensors = [action_spec.rand() for _ in range(10)] actions_numpy = [action_spec.to_numpy(a) for a in actions_tensors] @@ -1085,3 +1086,304 @@ def test_ndbounded_shape(self): sample = torch.stack([spec.rand() for _ in range(100)], 0) assert (-3 <= sample).all() and (3 >= sample).all() assert sample.shape == torch.Size([100, 10, 5]) + + +class TestExpand: + @pytest.mark.parametrize( + "shape1", + [ + None, + (4,), + (5, 4), + ], + ) + @pytest.mark.parametrize("shape2", [(), (10,)]) + def test_binary(self, shape1, shape2): + spec = BinaryDiscreteTensorSpec( + n=4, shape=shape1, device="cpu", dtype=torch.bool + ) + if shape1 is not None: + shape2_real = (*shape2, *shape1) + else: + shape2_real = (*shape2, 4) + + spec2 = spec.expand(shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + spec2 = spec.expand(*shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + + @pytest.mark.parametrize("shape2", [(), (5,)]) + @pytest.mark.parametrize( + "shape1,mini,maxi", + [ + [(10,), -torch.ones([]), torch.ones([])], + [None, -torch.ones([10]), torch.ones([])], + [None, -torch.ones([]), torch.ones([10])], + [(10,), -torch.ones([]), torch.ones([10])], + [(10,), -torch.ones([10]), torch.ones([])], + [(10,), -torch.ones([10]), torch.ones([10])], + ], + ) + def test_bounded(self, shape1, shape2, mini, maxi): + spec = BoundedTensorSpec( + mini, maxi, shape=shape1, device="cpu", dtype=torch.bool + ) + shape1 = spec.shape + assert shape1 == torch.Size([10]) + shape2_real = (*shape2, *shape1) + + spec2 = spec.expand(shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + spec2 = spec.expand(*shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + + def test_composite(self): + batch_size = (5,) + spec1 = BoundedTensorSpec( + -torch.ones([*batch_size, 10]), + torch.ones([*batch_size, 10]), + shape=( + *batch_size, + 10, + ), + device="cpu", + dtype=torch.bool, + ) + spec2 = BinaryDiscreteTensorSpec( + n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool + ) + spec3 = DiscreteTensorSpec( + n=4, shape=batch_size, device="cpu", dtype=torch.long + ) + spec4 = MultiDiscreteTensorSpec( + nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long + ) + spec5 = MultiOneHotDiscreteTensorSpec( + nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long + ) + spec6 = OneHotDiscreteTensorSpec( + n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long + ) + spec7 = UnboundedContinuousTensorSpec( + shape=(*batch_size, 9), + device="cpu", + dtype=torch.float64, + ) + spec8 = UnboundedDiscreteTensorSpec( + shape=(*batch_size, 9), + device="cpu", + dtype=torch.long, + ) + spec = CompositeSpec( + spec1=spec1, + spec2=spec2, + spec3=spec3, + spec4=spec4, + spec5=spec5, + spec6=spec6, + spec7=spec7, + spec8=spec8, + shape=batch_size, + ) + for new_spec in (spec.expand((4, *batch_size)), spec.expand(4, *batch_size)): + assert new_spec is not spec + assert new_spec.shape == torch.Size([4, *batch_size]) + assert new_spec["spec1"].shape == torch.Size([4, *batch_size, 10]) + assert new_spec["spec2"].shape == torch.Size([4, *batch_size, 4]) + assert new_spec["spec3"].shape == torch.Size( + [ + 4, + *batch_size, + ] + ) + assert new_spec["spec4"].shape == torch.Size([4, *batch_size, 3]) + assert new_spec["spec5"].shape == torch.Size([4, *batch_size, 15]) + assert new_spec["spec6"].shape == torch.Size([4, *batch_size, 15]) + assert new_spec["spec7"].shape == torch.Size([4, *batch_size, 9]) + assert new_spec["spec8"].shape == torch.Size([4, *batch_size, 9]) + + @pytest.mark.parametrize( + "shape1", + [ + None, + (), + (5,), + ], + ) + @pytest.mark.parametrize("shape2", [(), (10,)]) + def test_discrete(self, shape1, shape2): + spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) + if shape1 is not None: + shape2_real = (*shape2, *shape1) + else: + shape2_real = shape2 + + spec2 = spec.expand(shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + spec2 = spec.expand(*shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + + @pytest.mark.parametrize( + "shape1", + [ + None, + (), + (5,), + ], + ) + @pytest.mark.parametrize("shape2", [(), (10,)]) + def test_multidiscrete(self, shape1, shape2): + if shape1 is None: + shape1 = (3,) + else: + shape1 = (*shape1, 3) + spec = MultiDiscreteTensorSpec( + nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long + ) + if shape1 is not None: + shape2_real = (*shape2, *shape1) + else: + shape2_real = shape2 + + spec2 = spec.expand(shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + spec2 = spec.expand(*shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + + @pytest.mark.parametrize( + "shape1", + [ + None, + (), + (5,), + ], + ) + @pytest.mark.parametrize("shape2", [(), (10,)]) + def test_multionehot(self, shape1, shape2): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = MultiOneHotDiscreteTensorSpec( + nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long + ) + if shape1 is not None: + shape2_real = (*shape2, *shape1) + else: + shape2_real = shape2 + + spec2 = spec.expand(shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + spec2 = spec.expand(*shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + + @pytest.mark.parametrize( + "shape1", + [ + None, + (), + (5,), + ], + ) + @pytest.mark.parametrize("shape2", [(), (10,)]) + def test_onehot(self, shape1, shape2): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = OneHotDiscreteTensorSpec( + n=15, shape=shape1, device="cpu", dtype=torch.long + ) + if shape1 is not None: + shape2_real = (*shape2, *shape1) + else: + shape2_real = shape2 + + spec2 = spec.expand(shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + spec2 = spec.expand(*shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + + @pytest.mark.parametrize( + "shape1", + [ + None, + (), + (5,), + ], + ) + @pytest.mark.parametrize("shape2", [(), (10,)]) + def test_unbounded(self, shape1, shape2): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = UnboundedContinuousTensorSpec( + shape=shape1, device="cpu", dtype=torch.float64 + ) + if shape1 is not None: + shape2_real = (*shape2, *shape1) + else: + shape2_real = shape2 + + spec2 = spec.expand(shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + spec2 = spec.expand(*shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + + @pytest.mark.parametrize( + "shape1", + [ + None, + (), + (5,), + ], + ) + @pytest.mark.parametrize("shape2", [(), (10,)]) + def test_unboundeddiscrete(self, shape1, shape2): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = UnboundedDiscreteTensorSpec(shape=shape1, device="cpu", dtype=torch.long) + if shape1 is not None: + shape2_real = (*shape2, *shape1) + else: + shape2_real = shape2 + + spec2 = spec.expand(shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() + spec2 = spec.expand(*shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert (spec2.zero() == spec.zero()).all() diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 9510cca0309..4a0ac554218 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -22,7 +22,7 @@ DEVICE_TYPING, DiscreteTensorSpec, MultiDiscreteTensorSpec, - MultOneHotDiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 8fc25cf7598..981464e2623 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -265,6 +265,13 @@ def to_numpy(self, val: torch.Tensor, safe: bool = True) -> np.ndarray: self.assert_is_in(val) return val.detach().cpu().numpy() + @property + def ndim(self): + return self.ndimension() + + def ndimension(self): + return len(self.shape) + @abc.abstractmethod def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor: """Indexes the input tensor. @@ -279,6 +286,19 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten """ raise NotImplementedError + @abc.abstractmethod + def expand(self, *shape): + """Returns a new Spec with the extended shape. + + Args: + *shape (tuple or iterable of int): the new shape of the Spec. Must comply with the current shape: + its length must be at least as long as the current shape length, + and its last values must be complient too; ie they can only differ + from it if the current dimension is a singleton. + + """ + raise NotImplementedError + def _project(self, val: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -407,6 +427,8 @@ class OneHotDiscreteTensorSpec(TensorSpec): Args: n (int): number of possible outcomes. + shape (torch.Size, optional): total shape of the sampled tensors. + If provided, the last dimension must match n. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. user_register (bool): experimental feature. If True, every integer @@ -428,6 +450,7 @@ class OneHotDiscreteTensorSpec(TensorSpec): def __init__( self, n: int, + shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.long, use_register: bool = False, @@ -438,9 +461,33 @@ def __init__( space = DiscreteBox( n, ) - shape = torch.Size((space.n,)) + if shape is None: + shape = torch.Size((space.n,)) + else: + shape = torch.Size(shape) + if shape[-1] != space.n: + raise ValueError( + f"The last value of the shape must match n for transform of type {self.__class__}. " + f"Got n={space.n} and shape={shape}." + ) super().__init__(shape, space, device, dtype, "discrete") + def expand(self, *shape): + if len(shape) == 1 and isinstance(shape[0], tuple): + shape = shape[0] + if any(val < 0 for val in shape): + raise ValueError( + f"{self.__class__.__name__}.extend does not support negative shapes." + ) + if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): + raise ValueError( + f"The last {self.ndim} of the extended shape must match the" + f"shape of the CompositeSpec in CompositeSpec.extend." + ) + return self.__class__( + n=shape[-1], shape=shape, device=self.device, dtype=self.dtype + ) + def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) @@ -611,6 +658,26 @@ def __init__( shape, ContinuousBox(minimum, maximum), device, dtype, "continuous" ) + def expand(self, *shape): + if len(shape) == 1 and isinstance(shape[0], tuple): + shape = shape[0] + if any(val < 0 for val in shape): + raise ValueError( + f"{self.__class__.__name__}.extend does not support negative shapes." + ) + if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): + raise ValueError( + f"The last {self.ndim} of the extended shape must match the" + f"shape of the CompositeSpec in CompositeSpec.extend." + ) + return self.__class__( + minimum=self.space.minimum.expand(shape), + maximum=self.space.maximum.expand(shape), + shape=shape, + device=self.device, + dtype=self.dtype, + ) + def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) @@ -702,6 +769,20 @@ def rand(self, shape=None) -> torch.Tensor: def is_in(self, val: torch.Tensor) -> bool: return True + def expand(self, *shape): + if len(shape) == 1 and isinstance(shape[0], tuple): + shape = shape[0] + if any(val < 0 for val in shape): + raise ValueError( + f"{self.__class__.__name__}.extend does not support negative shapes." + ) + if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): + raise ValueError( + f"The last {self.ndim} of the extended shape must match the" + f"shape of the CompositeSpec in CompositeSpec.extend." + ) + return self.__class__(shape=shape, device=self.device, dtype=self.dtype) + @dataclass(repr=False) class UnboundedDiscreteTensorSpec(TensorSpec): @@ -755,6 +836,20 @@ def rand(self, shape=None) -> torch.Tensor: def is_in(self, val: torch.Tensor) -> bool: return True + def expand(self, *shape): + if len(shape) == 1 and isinstance(shape[0], tuple): + shape = shape[0] + if any(val < 0 for val in shape): + raise ValueError( + f"{self.__class__.__name__}.extend does not support negative shapes." + ) + if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): + raise ValueError( + f"The last {self.ndim} of the extended shape must match the" + f"shape of the CompositeSpec in CompositeSpec.extend." + ) + return self.__class__(shape=shape, device=self.device, dtype=self.dtype) + @dataclass(repr=False) class BinaryDiscreteTensorSpec(TensorSpec): @@ -762,9 +857,14 @@ class BinaryDiscreteTensorSpec(TensorSpec): Args: n (int): length of the binary vector. + shape (torch.Size, optional): total shape of the sampled tensors. + If provided, the last dimension must match n. device (str, int or torch.device, optional): device of the tensors. - dtype (str or torch.dtype, optional): dtype of the tensors. + dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long. + Examples: + >>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool) + >>> print(spec.zero()) """ shape: torch.Size @@ -776,12 +876,22 @@ class BinaryDiscreteTensorSpec(TensorSpec): def __init__( self, n: int, + shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, dtype: Union[str, torch.dtype] = torch.long, ): dtype, device = _default_dtype_and_device(dtype, device) - shape = torch.Size((n,)) box = BinaryBox(n) + if shape is None: + shape = torch.Size((n,)) + else: + shape = torch.Size(shape) + if shape[-1] != box.n: + raise ValueError( + f"The last value of the shape must match n for transform of type {self.__class__}. " + f"Got n={box.n} and shape={shape}." + ) + super().__init__(shape, box, device, dtype, domain="discrete") def rand(self, shape=None) -> torch.Tensor: @@ -803,20 +913,38 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten def is_in(self, val: torch.Tensor) -> bool: return ((val == 0) | (val == 1)).all() + def expand(self, *shape): + if len(shape) == 1 and isinstance(shape[0], tuple): + shape = shape[0] + if any(val < 0 for val in shape): + raise ValueError( + f"{self.__class__.__name__}.extend does not support negative shapes." + ) + if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): + raise ValueError( + f"The last {self.ndim} of the extended shape must match the" + f"shape of the CompositeSpec in CompositeSpec.extend." + ) + return self.__class__( + n=shape[-1], shape=shape, device=self.device, dtype=self.dtype + ) + @dataclass(repr=False) -class MultOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): +class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): """A concatenation of one-hot discrete tensor spec. Args: nvec (iterable of integers): cardinality of each of the elements of the tensor. + shape (torch.Size, optional): total shape of the sampled tensors. + If provided, the last dimension must match sum(nvec). device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. Examples: - >>> ts = MultOneHotDiscreteTensorSpec((3,2,3)) + >>> ts = MultiOneHotDiscreteTensorSpec((3,2,3)) >>> ts.is_in(torch.tensor([0,0,1, ... 0,1, ... 1,0,0])) @@ -831,12 +959,21 @@ class MultOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): def __init__( self, nvec: Sequence[int], + shape: Optional[torch.Size] = None, device=None, dtype=torch.long, use_register=False, ): dtype, device = _default_dtype_and_device(dtype, device) - shape = torch.Size((sum(nvec),)) + if shape is None: + shape = torch.Size((sum(nvec),)) + else: + shape = torch.Size(shape) + if shape[-1] != sum(nvec): + raise ValueError( + f"The last value of the shape must match sum(nvec) for transform of type {self.__class__}. " + f"Got sum(nvec)={sum(nvec)} and shape={shape}." + ) space = BoxList([DiscreteBox(n) for n in nvec]) self.use_register = use_register super(OneHotDiscreteTensorSpec, self).__init__( @@ -875,7 +1012,7 @@ def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: raise RuntimeError( f"value {v} is greater than the allowed max {space.n}" ) - x.append(super(MultOneHotDiscreteTensorSpec, self).encode(v, space)) + x.append(super(MultiOneHotDiscreteTensorSpec, self).encode(v, space)) return torch.cat(x, -1) def _split(self, val: torch.Tensor) -> Optional[torch.Tensor]: @@ -912,7 +1049,7 @@ def is_in(self, val: torch.Tensor) -> bool: if vals is None: return False return all( - super(MultOneHotDiscreteTensorSpec, self).is_in(_val) for _val in vals + super(MultiOneHotDiscreteTensorSpec, self).is_in(_val) for _val in vals ) def _project(self, val: torch.Tensor) -> torch.Tensor: @@ -924,6 +1061,23 @@ def to_categorical(self) -> MultiDiscreteTensorSpec: [_space.n for _space in self.space], self.device, self.dtype ) + def expand(self, *shape): + nvecs = [space.n for space in self.space] + if len(shape) == 1 and isinstance(shape[0], tuple): + shape = shape[0] + if any(val < 0 for val in shape): + raise ValueError( + f"{self.__class__.__name__}.extend does not support negative shapes." + ) + if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): + raise ValueError( + f"The last {self.ndim} of the extended shape must match the" + f"shape of the CompositeSpec in CompositeSpec.extend." + ) + return self.__class__( + nvec=nvecs, shape=shape, device=self.device, dtype=self.dtype + ) + class DiscreteTensorSpec(TensorSpec): """A discrete tensor spec. @@ -943,7 +1097,7 @@ class DiscreteTensorSpec(TensorSpec): Args: n (int): number of possible outcomes. - shape: (torch.Size, optional): shape of the variable, default is "(1,)". + shape: (torch.Size, optional): shape of the variable, default is "torch.Size([])". device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. @@ -1008,6 +1162,22 @@ def to_onehot(self) -> OneHotDiscreteTensorSpec: ) return OneHotDiscreteTensorSpec(self.space.n, self.device, self.dtype) + def expand(self, *shape): + if len(shape) == 1 and isinstance(shape[0], tuple): + shape = shape[0] + if any(val < 0 for val in shape): + raise ValueError( + f"{self.__class__.__name__}.extend does not support negative shapes." + ) + if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): + raise ValueError( + f"The last {self.ndim} of the extended shape must match the" + f"shape of the CompositeSpec in CompositeSpec.extend." + ) + return self.__class__( + n=self.space.n, shape=shape, device=self.device, dtype=self.dtype + ) + @dataclass(repr=False) class MultiDiscreteTensorSpec(DiscreteTensorSpec): @@ -1016,6 +1186,8 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): Args: nvec (iterable of integers or torch.Tensor): cardinality of each of the elements of the tensor. Can have several axes. + shape (torch.Size, optional): total shape of the sampled tensors. + If provided, the last dimension must match nvec.shape[-1]. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. @@ -1031,6 +1203,7 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): def __init__( self, nvec: Union[Sequence[int], torch.Tensor, int], + shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.long, ): @@ -1040,7 +1213,16 @@ def __init__( nvec = nvec.unsqueeze(0) self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) - shape = nvec.shape + if shape is None: + shape = nvec.shape + else: + shape = torch.Size(shape) + if shape[-1] != nvec.shape[-1]: + raise ValueError( + f"The last value of the shape must match nvec.shape[-1] for transform of type {self.__class__}. " + f"Got nvec.shape[-1]={sum(nvec)} and shape={shape}." + ) + space = BoxList.from_nvec(nvec) super(DiscreteTensorSpec, self).__init__( shape, space, device, dtype, domain="discrete" @@ -1095,7 +1277,7 @@ def is_in(self, val: torch.Tensor) -> bool: return ((val >= torch.zeros(self.nvec.size())) & (val < self.nvec)).all().item() - def to_onehot(self) -> MultOneHotDiscreteTensorSpec: + def to_onehot(self) -> MultiOneHotDiscreteTensorSpec: if len(self.shape) > 1: raise RuntimeError( f"DiscreteTensorSpec with shape that has several dimensions can't be converted to" @@ -1103,10 +1285,26 @@ def to_onehot(self) -> MultOneHotDiscreteTensorSpec: f"nestedtensors but it is not implemented yet. If you would like to see that feature, please submit " f"an issue of torchrl's github repo. " ) - return MultOneHotDiscreteTensorSpec( + return MultiOneHotDiscreteTensorSpec( [_space.n for _space in self.space], self.device, self.dtype ) + def expand(self, *shape): + if len(shape) == 1 and isinstance(shape[0], tuple): + shape = shape[0] + if any(val < 0 for val in shape): + raise ValueError( + f"{self.__class__.__name__}.extend does not support negative shapes." + ) + if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): + raise ValueError( + f"The last {self.ndim} of the extended shape must match the" + f"shape of the CompositeSpec in CompositeSpec.extend." + ) + return self.__class__( + nvec=self.nvec, shape=shape, device=self.device, dtype=self.dtype + ) + class CompositeSpec(TensorSpec): """A composition of TensorSpecs. @@ -1172,6 +1370,7 @@ class CompositeSpec(TensorSpec): """ + shape: torch.Size domain: str = "composite" @classmethod @@ -1179,8 +1378,51 @@ def __new__(cls, *args, **kwargs): cls._device = torch.device("cpu") return super().__new__(cls) - def __init__(self, *args, **kwargs): - self._specs = kwargs + @property + def shape(self): + return self._shape + + @shape.setter + def shape(self, value: torch.Size): + for key, spec in self.items(): + if spec.shape[: self.ndim] != self.shape: + raise ValueError( + f"The shape of the spec and the CompositeSpec mismatch during shape resetting: the " + f"{self.ndim} first dimensions should match but got self['{key}'].shape={spec.shape} and " + f"CompositeSpec.shape={self.shape}." + ) + self._shape = torch.Size(value) + + @property + def ndim(self): + return self.ndimension() + + def ndimension(self): + return len(self.shape) + + def set(self, name, spec): + if spec is not None: + shape = spec.shape + if shape[: self.ndim] != self.shape: + raise ValueError( + "The shape of the spec and the CompositeSpec mismatch: the first " + f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " + f"CompositeSpec.shape={self.shape}." + ) + self._specs[name] = spec + + def __init__(self, *args, shape=None, **kwargs): + if shape is None: + # Should we do this? Other specs have a default empty shape, maybe it would make sense to keep it + # optional for composite (for clarity and easiness of use). + # warnings.warn("shape=None for CompositeSpec will soon be deprecated. Make sure you set the " + # "batch size of your CompositeSpec as you would do for a tensordict.") + shape = [] + self._shape = torch.Size(shape) + self._specs = {} + for key, value in kwargs.items(): + self.set(key, value) + if len(kwargs): _device = None for key, item in self.items(): @@ -1262,7 +1504,7 @@ def __setitem__(self, key, value): f"Setting a new attribute ({key}) on another device ({value.device} against {self.device}). " f"All devices of CompositeSpec must match." ) - self._specs[key] = value + self.set(key, value) def __iter__(self): for k in self._specs: @@ -1408,6 +1650,25 @@ def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> N self[key] = item return self + def expand(self, *shape): + if len(shape) == 1 and isinstance(shape[0], tuple): + shape = shape[0] + if any(val < 0 for val in shape): + raise ValueError("CompositeSpec.extend does not support negative shapes.") + if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): + raise ValueError( + f"The last {self.ndim} of the extended shape must match the" + f"shape of the CompositeSpec in CompositeSpec.extend." + ) + out = CompositeSpec( + { + key: value.expand(*shape, *value.shape[self.ndim :]) + for key, value in tuple(self.items()) + }, + shape=shape, + ) + return out + def _keys_to_empty_composite_spec(keys): if not len(keys): diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 5ab927de561..82159db0e05 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -58,7 +58,7 @@ def __init__( def build_metadata_from_env(env) -> EnvMetaData: tensordict = env.fake_tensordict() specs = {key: getattr(env, key) for key in Specs._keys if key.endswith("_spec")} - specs = CompositeSpec(**specs) + specs = CompositeSpec(**specs, shape=env.batch_size) batch_size = env.batch_size env_str = str(env) device = env.device @@ -276,7 +276,7 @@ def action_spec(self) -> TensorSpec: @action_spec.setter def action_spec(self, value: TensorSpec) -> None: if self._input_spec is None: - self.input_spec = CompositeSpec(action=value) + self.input_spec = CompositeSpec(action=value, shape=self.batch_size) else: self.input_spec["action"] = value @@ -288,6 +288,8 @@ def input_spec(self) -> TensorSpec: def input_spec(self, value: TensorSpec) -> None: if not isinstance(value, CompositeSpec): raise TypeError("The type of an input_spec must be Composite.") + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError("The value of spec.shape must match the env batch size.") self.__dict__["_input_spec"] = value @property @@ -300,6 +302,8 @@ def reward_spec(self, value: TensorSpec) -> None: raise TypeError( f"reward_spec of type {type(value)} do not have a shape " f"attribute." ) + if value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError("The value of spec.shape must match the env batch size.") if len(value.shape) == 0: raise RuntimeError( "the reward_spec shape cannot be empty (this error" @@ -317,6 +321,8 @@ def observation_spec(self) -> TensorSpec: def observation_spec(self, value: TensorSpec) -> None: if not isinstance(value, CompositeSpec): raise TypeError("The type of an observation_spec must be Composite.") + elif value.shape[: len(self.batch_size)] != self.batch_size: + raise ValueError("The value of spec.shape must match the env batch size.") self.__dict__["_observation_spec"] = value def step(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -535,7 +541,7 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa tensordict = TensorDict( {}, device=self.device, batch_size=self.batch_size, _run_checks=False ) - action = self.action_spec.rand(self.batch_size) + action = self.action_spec.rand() tensordict.set("action", action) return self.step(tensordict) @@ -605,7 +611,7 @@ def rollout( if policy is None: def policy(td): - return td.set("action", self.action_spec.rand(self.batch_size)) + return td.set("action", self.action_spec.rand()) tensordicts = [] for i in range(max_steps): @@ -710,11 +716,11 @@ def to(self, device: DEVICE_TYPING) -> EnvBase: def fake_tensordict(self) -> TensorDictBase: """Returns a fake tensordict with key-value pairs that match in shape, device and dtype what can be expected during an environment rollout.""" input_spec = self.input_spec - fake_input = input_spec.zero(self.batch_size) + fake_input = input_spec.zero() observation_spec = self.observation_spec - fake_obs = observation_spec.zero(self.batch_size) + fake_obs = observation_spec.zero() reward_spec = self.reward_spec - fake_reward = reward_spec.zero(self.batch_size) + fake_reward = reward_spec.zero() fake_td = TensorDict( { **fake_obs, diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 4d1852a704c..fa787855e0e 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -121,7 +121,8 @@ def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 self.input_spec = CompositeSpec( action=BoundedTensorSpec( minimum=-1, maximum=1, shape=(env.action_size,), device=self.device - ) + ), + shape=env.batch_size, ) self.reward_spec = UnboundedContinuousTensorSpec( shape=[ @@ -132,7 +133,8 @@ def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 self.observation_spec = CompositeSpec( observation=UnboundedContinuousTensorSpec( shape=(env.observation_size,), device=self.device - ) + ), + shape=env.batch_size, ) # extract state spec from instance self.state_spec = self._make_state_spec(env) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 8d8a9af1bd7..4bd8487a650 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -14,7 +14,7 @@ CompositeSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, - MultOneHotDiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, @@ -84,7 +84,7 @@ def _gym_to_torchrl_spec_transform( return ( MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) if categorical_action_encoding - else MultOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + else MultiOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) ) elif isinstance(spec, gym.spaces.Box): shape = spec.shape diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index d7885268164..ab0ddeaa944 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -46,8 +46,8 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): ... batch_size=self.batch_size, ... device=self.device, ... ) - ... tensordict = tensordict.update(self.input_spec.rand(self.batch_size)) - ... tensordict = tensordict.update(self.observation_spec.rand(self.batch_size)) + ... tensordict = tensordict.update(self.input_spec.rand()) + ... tensordict = tensordict.update(self.observation_spec.rand()) ... return tensordict >>> # This environment is used as follows: >>> import torch.nn as nn diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index f1606a3c332..3c81ff1be83 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -51,6 +51,7 @@ def set_specs_from_env(self, env: EnvBase): state=self.observation_spec["state"], belief=self.observation_spec["belief"], action=self.action_spec.to(self.device), + shape=env.batch_size, ) def _reset(self, tensordict=None, **kwargs) -> TensorDict: diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index c21733f484f..7ac9ec3dc66 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -294,15 +294,29 @@ def _set_properties(self): self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size]) self._device = meta_data[0].device # TODO: check that all action_spec and reward spec match (issue #351) - self.reward_spec = meta_data[0].specs["reward_spec"] + + reward_spec = meta_data[0].specs["reward_spec"] + reward_spec = reward_spec.expand(self.num_workers, *reward_spec.shape) + self.reward_spec = reward_spec + _observation_spec = {} for md in meta_data: _observation_spec.update(dict(**md.specs["observation_spec"])) - self.observation_spec = CompositeSpec(**_observation_spec) + observation_spec = CompositeSpec( + **_observation_spec, shape=meta_data[0].batch_size + ) + observation_spec = observation_spec.expand( + self.num_workers, *observation_spec.shape + ) + self.observation_spec = observation_spec + _input_spec = {} for md in meta_data: _input_spec.update(dict(**md.specs["input_spec"])) - self.input_spec = CompositeSpec(**_input_spec) + input_spec = CompositeSpec(**_input_spec, shape=meta_data[0].batch_size) + input_spec = input_spec.expand(self.num_workers, *input_spec.shape) + self.input_spec = input_spec + self._dummy_env_str = str(meta_data[0]) self._env_tensordict = torch.stack( [meta_data.tensordict for meta_data in meta_data], 0 From df28e60d69a9a8b472973d71ccade7e202840b98 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 13 Jan 2023 22:44:38 +0000 Subject: [PATCH 02/30] amend --- test/test_specs.py | 35 +++++++++++++++++++++++++++++++++++ torchrl/data/tensor_specs.py | 27 ++++++++++++++++++++------- torchrl/envs/vec_env.py | 18 +++++++++++++++--- 3 files changed, 70 insertions(+), 10 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index 96e5494ecfa..5bf504c13d2 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -1111,10 +1111,14 @@ def test_binary(self, shape1, shape2): assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape spec2 = spec.expand(*shape2_real) assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape @pytest.mark.parametrize("shape2", [(), (5,)]) @pytest.mark.parametrize( @@ -1140,10 +1144,14 @@ def test_bounded(self, shape1, shape2, mini, maxi): assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape spec2 = spec.expand(*shape2_real) assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape def test_composite(self): batch_size = (5,) @@ -1230,10 +1238,14 @@ def test_discrete(self, shape1, shape2): assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape spec2 = spec.expand(*shape2_real) assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape @pytest.mark.parametrize( "shape1", @@ -1261,10 +1273,14 @@ def test_multidiscrete(self, shape1, shape2): assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape spec2 = spec.expand(*shape2_real) assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape @pytest.mark.parametrize( "shape1", @@ -1292,10 +1308,14 @@ def test_multionehot(self, shape1, shape2): assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape spec2 = spec.expand(*shape2_real) assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape @pytest.mark.parametrize( "shape1", @@ -1323,10 +1343,14 @@ def test_onehot(self, shape1, shape2): assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape spec2 = spec.expand(*shape2_real) assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape @pytest.mark.parametrize( "shape1", @@ -1354,10 +1378,14 @@ def test_unbounded(self, shape1, shape2): assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape spec2 = spec.expand(*shape2_real) assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape @pytest.mark.parametrize( "shape1", @@ -1387,3 +1415,10 @@ def test_unboundeddiscrete(self, shape1, shape2): assert spec2 is not spec assert spec2.dtype == spec.dtype assert (spec2.zero() == spec.zero()).all() + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape + spec2 = spec.expand(*shape2_real) + assert spec2 is not spec + assert spec2.dtype == spec.dtype + assert spec2.rand().shape == spec2.shape + assert spec2.zero().shape == spec2.shape diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 981464e2623..b054f7080d4 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -490,7 +490,9 @@ def expand(self, *shape): def rand(self, shape=None) -> torch.Tensor: if shape is None: - shape = torch.Size([]) + shape = self.shape[:-1] + else: + shape = torch.Size([*shape, *self.shape[:-1]]) return torch.nn.functional.gumbel_softmax( torch.rand(torch.Size([*shape, self.space.n]), device=self.device), hard=True, @@ -695,10 +697,16 @@ def rand(self, shape=None) -> torch.Tensor: out[out < a] = a.expand_as(out)[out < a] return out else: - interval = self.space.maximum - self.space.minimum - r = torch.rand( - torch.Size([*shape, *interval.shape]), device=interval.device - ) + if self.space.maximum.dtype == torch.bool: + maxi = self.space.maximum.int() + else: + maxi = self.space.maximum + if self.space.minimum.dtype == torch.bool: + mini = self.space.minimum.int() + else: + mini = self.space.minimum + interval = maxi - mini + r = torch.rand(torch.Size([*shape, *self.shape]), device=interval.device) r = interval * r r = self.space.minimum + r r = r.to(self.dtype).to(self.device) @@ -982,7 +990,10 @@ def __init__( def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: - shape = torch.Size([]) + shape = self.shape[:-1] + else: + shape = torch.Size([*shape, *self.shape[:-1]]) + x = torch.cat( [ torch.nn.functional.one_hot( @@ -1247,7 +1258,9 @@ def _rand(self, space: Box, shape: torch.Size, i: int): def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: - shape = torch.Size([]) + shape = self.shape[:-1] + else: + shape = torch.Size([*shape, *self.shape[:-1]]) x = self._rand(self.space, shape, self.nvec.ndim) if self.shape == torch.Size([1]): x = x.squeeze(-1) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 7ac9ec3dc66..c94ee872bdd 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -283,9 +283,21 @@ def _set_properties(self): meta_data = deepcopy(self.meta_data) if self._single_task: self._batch_size = meta_data.batch_size - self.observation_spec = meta_data.specs["observation_spec"] - self.reward_spec = meta_data.specs["reward_spec"] - self.input_spec = meta_data.specs["input_spec"] + observation_spec = meta_data.specs["observation_spec"] + + observation_spec = observation_spec.expand( + self.num_workers, *observation_spec.shape + ) + self.observation_spec = observation_spec + + reward_spec = meta_data.specs["reward_spec"] + reward_spec = reward_spec.expand(self.num_workers, *reward_spec.shape) + self.reward_spec = reward_spec + + input_spec = meta_data.specs["input_spec"] + input_spec = input_spec.expand(self.num_workers, *input_spec.shape) + self.input_spec = input_spec + self._dummy_env_str = meta_data.env_str self._device = meta_data.device self._env_tensordict = meta_data.tensordict From 48764960591aad84d135a300f95e4cebf8c0a497 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 13 Jan 2023 22:46:14 +0000 Subject: [PATCH 03/30] amend --- torchrl/envs/common.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 82159db0e05..f282691cf6d 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -353,9 +353,7 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = tensordict_out.get("reward") # unsqueeze rewards if needed - expected_reward_shape = torch.Size( - [*tensordict_out.batch_size, *self.reward_spec.shape] - ) + expected_reward_shape = self.reward_spec.shape n = len(expected_reward_shape) if len(reward.shape) >= n and reward.shape[-n:] != expected_reward_shape: reward = reward.view(*reward.shape[:n], *expected_reward_shape) From 60bccb37e5488037386c50fb625148323b9a64ae Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 13 Jan 2023 23:02:55 +0000 Subject: [PATCH 04/30] amend --- test/mocking_classes.py | 85 ++++++++++++++++++++++++----------------- test/test_env.py | 2 +- 2 files changed, 52 insertions(+), 35 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 62a6a5dcaac..8617b815adb 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -168,19 +168,22 @@ def __new__( reward_spec=None, **kwargs, ): + batch_size = kwargs.pop("batch_size", None) + if batch_size is None: + batch_size = torch.Size([]) if action_spec is None: - action_spec = UnboundedContinuousTensorSpec((1,)) + action_spec = UnboundedContinuousTensorSpec((*batch_size, 1,)) if input_spec is None: input_spec = CompositeSpec( action=action_spec, - observation=UnboundedContinuousTensorSpec((1,)), + observation=UnboundedContinuousTensorSpec((*batch_size, 1,)),shape=batch_size, ) if observation_spec is None: observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec((1,)) + observation=UnboundedContinuousTensorSpec((*batch_size, 1,)), shape=batch_size, ) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec((1,)) + reward_spec = UnboundedContinuousTensorSpec((*batch_size, 1,)) cls._reward_spec = reward_spec cls._observation_spec = observation_spec cls._input_spec = input_spec @@ -278,14 +281,17 @@ def __new__( categorical_action_encoding=False, **kwargs, ): + batch_size = kwargs.pop("batch_size", None) + if batch_size is None: + batch_size = torch.Size([]) size = cls.size = 7 if observation_spec is None: cls.out_key = "observation" observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec(shape=torch.Size([size])), + observation=UnboundedContinuousTensorSpec(shape=torch.Size([*batch_size, size])), observation_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([size]) - ), + shape=torch.Size([*batch_size, size]) + ),shape=batch_size, ) if action_spec is None: action_spec_cls = ( @@ -293,7 +299,7 @@ def __new__( if categorical_action_encoding else OneHotDiscreteTensorSpec ) - action_spec = action_spec_cls(7) + action_spec = action_spec_cls(*batch_size, 7) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec() @@ -365,19 +371,22 @@ def __new__( from_pixels=False, **kwargs, ): + batch_size = kwargs.pop("batch_size", None) + if batch_size is None: + batch_size = torch.Size([]) size = cls.size = 7 if observation_spec is None: cls.out_key = "observation" observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec(shape=torch.Size([size])), + observation=UnboundedContinuousTensorSpec(shape=torch.Size([*batch_size, size])), observation_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([size]) - ), + shape=torch.Size([*batch_size, size]) + ), shape=batch_size, ) if action_spec is None: - action_spec = BoundedTensorSpec(-1, 1, (7,)) + action_spec = BoundedTensorSpec(-1, 1, (*batch_size, 7,)) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec() + reward_spec = UnboundedContinuousTensorSpec(shape=batch_size) if input_spec is None: cls._out_key = "observation_orig" @@ -385,7 +394,7 @@ def __new__( **{ cls._out_key: observation_spec["observation"], "action": action_spec, - } + }, shape=batch_size ) cls._reward_spec = reward_spec cls._observation_spec = observation_spec @@ -406,7 +415,7 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: if tensordict is None: tensordict = TensorDict({}, self.batch_size, device=self.device) tensordict = tensordict.select() - tensordict.update(self.observation_spec.rand(self.batch_size)) + tensordict.update(self.observation_spec.rand()) # tensordict.set("next_" + self.out_key, self._get_out_obs(state)) # tensordict.set("next_" + self._out_key, self._get_out_obs(state)) tensordict.set("done", torch.zeros(*tensordict.shape, 1, dtype=torch.bool)) @@ -467,16 +476,20 @@ def __new__( from_pixels=True, **kwargs, ): + batch_size = kwargs.pop("batch_size", None) + if batch_size is None: + batch_size = torch.Size([]) if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])), - pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])), + pixels=UnboundedContinuousTensorSpec(shape=torch.Size([*batch_size, 1, 7, 7])), + pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([*batch_size, 1, 7, 7])), + shape=batch_size, ) if action_spec is None: - action_spec = OneHotDiscreteTensorSpec(7) + action_spec = OneHotDiscreteTensorSpec(7, shape=(*batch_size, 7)) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec() + reward_spec = UnboundedContinuousTensorSpec(shape=batch_size) if input_spec is None: cls._out_key = "pixels_orig" @@ -484,10 +497,11 @@ def __new__( **{ cls._out_key: observation_spec["pixels_orig"], "action": action_spec, - } + }, shape=batch_size ) return super().__new__( *args, + batch_size=batch_size, observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, @@ -517,11 +531,15 @@ def __new__( categorical_action_encoding=False, **kwargs, ): + batch_size = kwargs.pop("batch_size", None) + if batch_size is None: + batch_size = torch.Size([]) if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])), - pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])), + pixels=UnboundedContinuousTensorSpec(shape=torch.Size([*batch_size, 7, 7, 3])), + pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([*batch_size, 7, 7, 3])), + shape=batch_size, ) if action_spec is None: action_spec_cls = ( @@ -529,18 +547,19 @@ def __new__( if categorical_action_encoding else OneHotDiscreteTensorSpec ) - action_spec = action_spec_cls(7) + action_spec = action_spec_cls(7, shape=(*batch_size, 7)) if input_spec is None: cls._out_key = "pixels_orig" input_spec = CompositeSpec( **{ cls._out_key: observation_spec["pixels_orig"], "action": action_spec, - } + }, shape=batch_size, ) return super().__new__( *args, + batch_size=batch_size, observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, @@ -689,20 +708,18 @@ def __init__( batch_size=batch_size, ) self.observation_spec = CompositeSpec( - hidden_observation=UnboundedContinuousTensorSpec((4,)) + hidden_observation=UnboundedContinuousTensorSpec((*self.batch_size, 4,)), shape=self.batch_size ) self.input_spec = CompositeSpec( - hidden_observation=UnboundedContinuousTensorSpec((4,)), - action=UnboundedContinuousTensorSpec((1,)), + hidden_observation=UnboundedContinuousTensorSpec((*self.batch_size, 4,)), + action=UnboundedContinuousTensorSpec((*self.batch_size, 1,)), shape=self.batch_size ) - self.reward_spec = UnboundedContinuousTensorSpec((1,)) + self.reward_spec = UnboundedContinuousTensorSpec((*self.batch_size, 1,)) def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict: td = TensorDict( { - "hidden_observation": self.input_spec["hidden_observation"].rand( - self.batch_size - ), + "hidden_observation": self.input_spec["hidden_observation"].rand(), }, batch_size=self.batch_size, device=self.device, @@ -725,10 +742,10 @@ def __init__(self, max_steps: int = 5, **kwargs): self.max_steps = max_steps self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec((1,)) + observation=UnboundedContinuousTensorSpec((*self.batch_size, 1,)), shape=self.batch_size ) - self.reward_spec = UnboundedContinuousTensorSpec((1,)) - self.input_spec = CompositeSpec(action=BinaryDiscreteTensorSpec(1)) + self.reward_spec = UnboundedContinuousTensorSpec((*self.batch_size, 1,)) + self.input_spec = CompositeSpec(action=BinaryDiscreteTensorSpec(*self.batch_size, 1)) self.count = torch.zeros( (*self.batch_size, 1), device=self.device, dtype=torch.int diff --git a/test/test_env.py b/test/test_env.py index ffc4cecff38..23bd16d10d6 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -345,7 +345,7 @@ def test_mb_env_batch_lock(self, device, seed=0): with pytest.raises(RuntimeError, match="batch_locked is a read-only property"): mb_env.batch_locked = False td = mb_env.reset() - td["action"] = mb_env.action_spec.rand(mb_env.batch_size) + td["action"] = mb_env.action_spec.rand() td_expanded = td.unsqueeze(-1).expand(10, 2).reshape(-1).to_tensordict() mb_env.step(td) From ffcd7c0d0d46510350901b72f99e23b87c257025 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 13 Jan 2023 23:09:46 +0000 Subject: [PATCH 05/30] amend --- test/mocking_classes.py | 193 ++++++++++++++++++++++++++++++---------- 1 file changed, 145 insertions(+), 48 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 8617b815adb..da0ad42ddd3 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -107,14 +107,30 @@ def __new__( reward_spec=None, **kwargs, ): + batch_size = kwargs.get("batch_size", torch.Size([])) if action_spec is None: - action_spec = UnboundedContinuousTensorSpec((1,)) + action_spec = UnboundedContinuousTensorSpec( + ( + *batch_size, + 1, + ) + ) if observation_spec is None: observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec((1,)) + observation=UnboundedContinuousTensorSpec( + ( + *batch_size, + 1, + ) + ) ) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec((1,)) + reward_spec = UnboundedContinuousTensorSpec( + ( + *batch_size, + 1, + ) + ) if input_spec is None: input_spec = CompositeSpec(action=action_spec) cls._reward_spec = reward_spec @@ -168,22 +184,42 @@ def __new__( reward_spec=None, **kwargs, ): - batch_size = kwargs.pop("batch_size", None) - if batch_size is None: - batch_size = torch.Size([]) + batch_size = kwargs.get("batch_size", torch.Size([])) if action_spec is None: - action_spec = UnboundedContinuousTensorSpec((*batch_size, 1,)) + action_spec = UnboundedContinuousTensorSpec( + ( + *batch_size, + 1, + ) + ) if input_spec is None: input_spec = CompositeSpec( action=action_spec, - observation=UnboundedContinuousTensorSpec((*batch_size, 1,)),shape=batch_size, + observation=UnboundedContinuousTensorSpec( + ( + *batch_size, + 1, + ) + ), + shape=batch_size, ) if observation_spec is None: observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec((*batch_size, 1,)), shape=batch_size, + observation=UnboundedContinuousTensorSpec( + ( + *batch_size, + 1, + ) + ), + shape=batch_size, ) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec((*batch_size, 1,)) + reward_spec = UnboundedContinuousTensorSpec( + ( + *batch_size, + 1, + ) + ) cls._reward_spec = reward_spec cls._observation_spec = observation_spec cls._input_spec = input_spec @@ -281,17 +317,18 @@ def __new__( categorical_action_encoding=False, **kwargs, ): - batch_size = kwargs.pop("batch_size", None) - if batch_size is None: - batch_size = torch.Size([]) + batch_size = kwargs.get("batch_size", torch.Size([])) size = cls.size = 7 if observation_spec is None: cls.out_key = "observation" observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec(shape=torch.Size([*batch_size, size])), + observation=UnboundedContinuousTensorSpec( + shape=torch.Size([*batch_size, size]) + ), observation_orig=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, size]) - ),shape=batch_size, + ), + shape=batch_size, ) if action_spec is None: action_spec_cls = ( @@ -371,20 +408,28 @@ def __new__( from_pixels=False, **kwargs, ): - batch_size = kwargs.pop("batch_size", None) - if batch_size is None: - batch_size = torch.Size([]) + batch_size = kwargs.get("batch_size", torch.Size([])) size = cls.size = 7 if observation_spec is None: cls.out_key = "observation" observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec(shape=torch.Size([*batch_size, size])), + observation=UnboundedContinuousTensorSpec( + shape=torch.Size([*batch_size, size]) + ), observation_orig=UnboundedContinuousTensorSpec( shape=torch.Size([*batch_size, size]) - ), shape=batch_size, + ), + shape=batch_size, ) if action_spec is None: - action_spec = BoundedTensorSpec(-1, 1, (*batch_size, 7,)) + action_spec = BoundedTensorSpec( + -1, + 1, + ( + *batch_size, + 7, + ), + ) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec(shape=batch_size) @@ -394,7 +439,8 @@ def __new__( **{ cls._out_key: observation_spec["observation"], "action": action_spec, - }, shape=batch_size + }, + shape=batch_size, ) cls._reward_spec = reward_spec cls._observation_spec = observation_spec @@ -476,14 +522,16 @@ def __new__( from_pixels=True, **kwargs, ): - batch_size = kwargs.pop("batch_size", None) - if batch_size is None: - batch_size = torch.Size([]) + batch_size = kwargs.get("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec(shape=torch.Size([*batch_size, 1, 7, 7])), - pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([*batch_size, 1, 7, 7])), + pixels=UnboundedContinuousTensorSpec( + shape=torch.Size([*batch_size, 1, 7, 7]) + ), + pixels_orig=UnboundedContinuousTensorSpec( + shape=torch.Size([*batch_size, 1, 7, 7]) + ), shape=batch_size, ) if action_spec is None: @@ -497,7 +545,8 @@ def __new__( **{ cls._out_key: observation_spec["pixels_orig"], "action": action_spec, - }, shape=batch_size + }, + shape=batch_size, ) return super().__new__( *args, @@ -531,14 +580,16 @@ def __new__( categorical_action_encoding=False, **kwargs, ): - batch_size = kwargs.pop("batch_size", None) - if batch_size is None: - batch_size = torch.Size([]) + batch_size = kwargs.get("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec(shape=torch.Size([*batch_size, 7, 7, 3])), - pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([*batch_size, 7, 7, 3])), + pixels=UnboundedContinuousTensorSpec( + shape=torch.Size([*batch_size, 7, 7, 3]) + ), + pixels_orig=UnboundedContinuousTensorSpec( + shape=torch.Size([*batch_size, 7, 7, 3]) + ), shape=batch_size, ) if action_spec is None: @@ -554,7 +605,8 @@ def __new__( **{ cls._out_key: observation_spec["pixels_orig"], "action": action_spec, - }, shape=batch_size, + }, + shape=batch_size, ) return super().__new__( @@ -594,26 +646,31 @@ def __new__( pixel_shape=None, **kwargs, ): + batch_size = kwargs.get("batch_size", torch.Size([])) if pixel_shape is None: pixel_shape = [1, 7, 7] if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec(shape=torch.Size(pixel_shape)), + pixels=UnboundedContinuousTensorSpec( + shape=torch.Size([*batch_size, *pixel_shape]) + ), pixels_orig=UnboundedContinuousTensorSpec( - shape=torch.Size(pixel_shape) + shape=torch.Size([*batch_size, *pixel_shape]) ), + shape=batch_size, ) if action_spec is None: - action_spec = BoundedTensorSpec(-1, 1, pixel_shape[-1]) + action_spec = BoundedTensorSpec(-1, 1, [*batch_size, pixel_shape[-1]]) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec() + reward_spec = UnboundedContinuousTensorSpec(shape=batch_size) if input_spec is None: cls._out_key = "pixels_orig" input_spec = CompositeSpec( - **{cls._out_key: observation_spec["pixels"], "action": action_spec} + **{cls._out_key: observation_spec["pixels"], "action": action_spec}, + shape=batch_size, ) return super().__new__( *args, @@ -645,11 +702,16 @@ def __new__( from_pixels=True, **kwargs, ): + batch_size = kwargs.get("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])), - pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])), + pixels=UnboundedContinuousTensorSpec( + shape=torch.Size([*batch_size, 7, 7, 3]) + ), + pixels_orig=UnboundedContinuousTensorSpec( + shape=torch.Size([*batch_size, 7, 7, 3]) + ), ) return super().__new__( *args, @@ -708,13 +770,35 @@ def __init__( batch_size=batch_size, ) self.observation_spec = CompositeSpec( - hidden_observation=UnboundedContinuousTensorSpec((*self.batch_size, 4,)), shape=self.batch_size + hidden_observation=UnboundedContinuousTensorSpec( + ( + *self.batch_size, + 4, + ) + ), + shape=self.batch_size, ) self.input_spec = CompositeSpec( - hidden_observation=UnboundedContinuousTensorSpec((*self.batch_size, 4,)), - action=UnboundedContinuousTensorSpec((*self.batch_size, 1,)), shape=self.batch_size + hidden_observation=UnboundedContinuousTensorSpec( + ( + *self.batch_size, + 4, + ) + ), + action=UnboundedContinuousTensorSpec( + ( + *self.batch_size, + 1, + ) + ), + shape=self.batch_size, + ) + self.reward_spec = UnboundedContinuousTensorSpec( + ( + *self.batch_size, + 1, + ) ) - self.reward_spec = UnboundedContinuousTensorSpec((*self.batch_size, 1,)) def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict: td = TensorDict( @@ -742,10 +826,23 @@ def __init__(self, max_steps: int = 5, **kwargs): self.max_steps = max_steps self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec((*self.batch_size, 1,)), shape=self.batch_size + observation=UnboundedContinuousTensorSpec( + ( + *self.batch_size, + 1, + ) + ), + shape=self.batch_size, + ) + self.reward_spec = UnboundedContinuousTensorSpec( + ( + *self.batch_size, + 1, + ) + ) + self.input_spec = CompositeSpec( + action=BinaryDiscreteTensorSpec(*self.batch_size, 1) ) - self.reward_spec = UnboundedContinuousTensorSpec((*self.batch_size, 1,)) - self.input_spec = CompositeSpec(action=BinaryDiscreteTensorSpec(*self.batch_size, 1)) self.count = torch.zeros( (*self.batch_size, 1), device=self.device, dtype=torch.int From 134dddba4182d4961b67736304a0aec8d9717e2e Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 14 Jan 2023 17:30:08 +0000 Subject: [PATCH 06/30] amend --- test/mocking_classes.py | 5 +++-- torchrl/collectors/collectors.py | 6 ++---- torchrl/data/tensor_specs.py | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index da0ad42ddd3..94b28f1a9c1 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -122,7 +122,8 @@ def __new__( *batch_size, 1, ) - ) + ), + shape=batch_size, ) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec( @@ -132,7 +133,7 @@ def __new__( ) ) if input_spec is None: - input_spec = CompositeSpec(action=action_spec) + input_spec = CompositeSpec(action=action_spec, shape=batch_size) cls._reward_spec = reward_spec cls._observation_spec = observation_spec cls._input_spec = input_spec diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 193d60dba76..3cabff1e65c 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -439,10 +439,8 @@ def __init__( ): # if policy spec is non-empty, all the values are not None and the keys # match the out_keys we assume the user has given all relevant information - self._tensordict_out = ( - env.fake_tensordict().expand(env.batch_size).to_tensordict() - ) - self._tensordict_out.update(self.policy.spec.zero(env.batch_size)) + self._tensordict_out = env.fake_tensordict().to_tensordict() + self._tensordict_out.update(self.policy.spec.zero()) if env.device: self._tensordict_out = self._tensordict_out.to(env.device) self._tensordict_out = ( diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index b054f7080d4..0f1886adb38 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1642,7 +1642,7 @@ def zero(self, shape=None) -> TensorDictBase: for key in self.keys(True) if isinstance(key, str) and self[key] is not None }, - shape, + torch.Size([*shape, *self.shape]), device=self.device, ) From da5b6a14f49a3aef74d9084f26e7bcffe1aa6414 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 15 Jan 2023 20:56:18 +0000 Subject: [PATCH 07/30] amend --- test/mocking_classes.py | 8 +++--- torchrl/envs/common.py | 54 ++++++++++++++++++++++------------------- torchrl/envs/vec_env.py | 2 ++ 3 files changed, 35 insertions(+), 29 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 94b28f1a9c1..1fe44c17135 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -339,7 +339,7 @@ def __new__( ) action_spec = action_spec_cls(*batch_size, 7) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec() + reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) if input_spec is None: cls._out_key = "observation_orig" @@ -432,7 +432,7 @@ def __new__( ), ) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec(shape=batch_size) + reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1)) if input_spec is None: cls._out_key = "observation_orig" @@ -538,7 +538,7 @@ def __new__( if action_spec is None: action_spec = OneHotDiscreteTensorSpec(7, shape=(*batch_size, 7)) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec(shape=batch_size) + reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1)) if input_spec is None: cls._out_key = "pixels_orig" @@ -666,7 +666,7 @@ def __new__( action_spec = BoundedTensorSpec(-1, 1, [*batch_size, pixel_shape[-1]]) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec(shape=batch_size) + reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1)) if input_spec is None: cls._out_key = "pixels_orig" input_spec = CompositeSpec( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index f282691cf6d..fb47eb0efa7 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -353,24 +353,20 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = tensordict_out.get("reward") # unsqueeze rewards if needed - expected_reward_shape = self.reward_spec.shape - n = len(expected_reward_shape) - if len(reward.shape) >= n and reward.shape[-n:] != expected_reward_shape: - reward = reward.view(*reward.shape[:n], *expected_reward_shape) - tensordict_out.set("reward", reward) - elif len(reward.shape) < n: - reward = reward.view(expected_reward_shape) + batch_size = tensordict_out.shape + dims = len(batch_size) + expected_reward_shape = self.reward_spec.shape[dims:] + actual_reward_shape = reward.shape[dims:] + if actual_reward_shape != expected_reward_shape: + reward = reward.view([*batch_size, *expected_reward_shape]) tensordict_out.set("reward", reward) done = tensordict_out.get("done") # unsqueeze done if needed - expected_done_shape = torch.Size([*tensordict_out.batch_size, 1]) - n = len(expected_done_shape) - if len(done.shape) >= n and done.shape[-n:] != expected_done_shape: - done = done.view(*done.shape[:n], *expected_done_shape) - tensordict_out.set("done", done) - elif len(done.shape) < n: - done = done.view(expected_done_shape) + expected_done_shape = torch.Size([1]) + actual_done_shape = done.shape[dims:] + if actual_done_shape != expected_done_shape: + done = done.view([*batch_size, *expected_done_shape]) tensordict_out.set("done", done) if tensordict_out is tensordict: @@ -443,10 +439,23 @@ def reset( done = tensordict_reset.get("done", None) if done is not None: # unsqueeze done if needed - expected_done_shape = torch.Size([*tensordict_reset.batch_size, 1]) - if done.shape != expected_done_shape: - done = done.view(expected_done_shape) + batch_size = tensordict_reset.shape + dims = len(batch_size) + expected_done_shape = torch.Size([1]) + actual_done_shape = done.shape[dims:] + if actual_done_shape != expected_done_shape: + done = done.view([*batch_size, *expected_done_shape]) tensordict_reset.set("done", done) + else: + tensordict_reset.set( + "done", + torch.zeros( + *tensordict_reset.batch_size, + 1, + dtype=torch.bool, + device=self.device, + ), + ) if tensordict_reset.device != self.device: tensordict_reset = tensordict_reset.to(self.device) @@ -461,13 +470,6 @@ def reset( f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected." ) - tensordict_reset.set_default( - "done", - torch.zeros( - *tensordict_reset.batch_size, 1, dtype=torch.bool, device=self.device - ), - ) - if (_reset is None and tensordict_reset.get("done").any()) or ( _reset is not None and tensordict_reset.get("done")[_reset].any() ): @@ -725,7 +727,9 @@ def fake_tensordict(self) -> TensorDictBase: "next": fake_obs.clone(), **fake_input, "reward": fake_reward, - "done": fake_reward.to(torch.bool), + "done": torch.zeros( + (*self.batch_size, 1), dtype=torch.bool, device=self.device + ), }, batch_size=self.batch_size, device=self.device, diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index c94ee872bdd..a4903cba29a 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -1045,6 +1045,8 @@ def _run_worker_pipe_shared_mem( _td["done"] = done = torch.zeros( *_td.batch_size, 1, dtype=torch.bool, device=env.device ) + elif done is not None and done.shape != torch.Size([*_td.batch_size, 1]): + _td.set("done", done.unsqueeze(-1)) if reset_keys is None: reset_keys = set(_td.keys()) if pin_memory: From 61f004edcd3381dd716fadf18028e968eb091a1e Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 15 Jan 2023 21:27:19 +0000 Subject: [PATCH 08/30] amend --- torchrl/collectors/collectors.py | 2 +- torchrl/envs/common.py | 2 +- torchrl/envs/vec_env.py | 5 ----- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 3cabff1e65c..7789b177301 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -63,7 +63,7 @@ def __init__(self, action_spec: TensorSpec): self.action_spec = action_spec def __call__(self, td: TensorDictBase) -> TensorDictBase: - return td.set("action", self.action_spec.rand(td.batch_size)) + return td.set("action", self.action_spec.rand()) def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index fb47eb0efa7..f72a0c34e2f 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -70,7 +70,7 @@ def expand(self, *size: int) -> EnvMetaData: batch_size = torch.Size([*size]) return EnvMetaData( tensordict, - self.specs, + self.specs.expand(size), batch_size, self.env_str, self.device, diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index a4903cba29a..5140900c9ab 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -285,17 +285,12 @@ def _set_properties(self): self._batch_size = meta_data.batch_size observation_spec = meta_data.specs["observation_spec"] - observation_spec = observation_spec.expand( - self.num_workers, *observation_spec.shape - ) self.observation_spec = observation_spec reward_spec = meta_data.specs["reward_spec"] - reward_spec = reward_spec.expand(self.num_workers, *reward_spec.shape) self.reward_spec = reward_spec input_spec = meta_data.specs["input_spec"] - input_spec = input_spec.expand(self.num_workers, *input_spec.shape) self.input_spec = input_spec self._dummy_env_str = meta_data.env_str From 83df95264a89a159e68d9f582c322c3f6a5d31be Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 16 Jan 2023 16:59:05 +0000 Subject: [PATCH 09/30] amend --- test/mocking_classes.py | 55 ++++++++++++++++-------- test/test_specs.py | 62 +++++++++++++++++++-------- test/test_transforms.py | 15 ++++--- torchrl/data/tensor_specs.py | 50 ++++++++++++++------- torchrl/envs/common.py | 34 ++++++++++----- torchrl/envs/transforms/transforms.py | 19 +++++--- 6 files changed, 162 insertions(+), 73 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 1fe44c17135..44961b82da8 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -107,7 +107,7 @@ def __new__( reward_spec=None, **kwargs, ): - batch_size = kwargs.get("batch_size", torch.Size([])) + batch_size = kwargs.setdefault("batch_size", torch.Size([])) if action_spec is None: action_spec = UnboundedContinuousTensorSpec( ( @@ -185,7 +185,7 @@ def __new__( reward_spec=None, **kwargs, ): - batch_size = kwargs.get("batch_size", torch.Size([])) + batch_size = kwargs.setdefault("batch_size", torch.Size([])) if action_spec is None: action_spec = UnboundedContinuousTensorSpec( ( @@ -243,11 +243,19 @@ def _set_seed(self, seed: Optional[int]): self.max_val = max(self.counter + 100, self.counter * 2) def _step(self, tensordict): + if len(self.batch_size): + leading_batch_size = ( + tensordict.shape[: -len(self.batch_size)] + if tensordict is not None + else [] + ) + else: + leading_batch_size = tensordict.shape if tensordict is not None else [] self.counter += 1 # We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv n = ( torch.full( - (*tensordict.batch_size, *self.observation_spec["observation"].shape), + [*leading_batch_size, *self.observation_spec["observation"].shape], self.counter, ) .to(self.device) @@ -255,9 +263,11 @@ def _step(self, tensordict): ) done = self.counter >= self.max_val done = torch.full( - (*tensordict.batch_size, 1), done, dtype=torch.bool, device=self.device + (*leading_batch_size, *self.batch_size, 1), + done, + dtype=torch.bool, + device=self.device, ) - return TensorDict( {"reward": n, "done": done, "observation": n}, tensordict.batch_size, @@ -266,20 +276,31 @@ def _step(self, tensordict): def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self.max_val = max(self.counter + 100, self.counter * 2) - if tensordict is None: - batch_size = self.batch_size + batch_size = self.batch_size + if len(batch_size): + leading_batch_size = ( + tensordict.shape[: -len(self.batch_size)] + if tensordict is not None + else [] + ) else: - batch_size = tensordict.batch_size + leading_batch_size = tensordict.shape if tensordict is not None else [] n = ( torch.full( - (*batch_size, *self.observation_spec["observation"].shape), self.counter + [*leading_batch_size, *self.observation_spec["observation"].shape], + self.counter, ) .to(self.device) .to(torch.get_default_dtype()) ) done = self.counter >= self.max_val - done = torch.full((*batch_size, 1), done, dtype=torch.bool, device=self.device) + done = torch.full( + (*leading_batch_size, *batch_size, 1), + done, + dtype=torch.bool, + device=self.device, + ) return TensorDict( {"reward": n, "done": done, "observation": n}, @@ -318,7 +339,7 @@ def __new__( categorical_action_encoding=False, **kwargs, ): - batch_size = kwargs.get("batch_size", torch.Size([])) + batch_size = kwargs.setdefault("batch_size", torch.Size([])) size = cls.size = 7 if observation_spec is None: cls.out_key = "observation" @@ -409,7 +430,7 @@ def __new__( from_pixels=False, **kwargs, ): - batch_size = kwargs.get("batch_size", torch.Size([])) + batch_size = kwargs.setdefault("batch_size", torch.Size([])) size = cls.size = 7 if observation_spec is None: cls.out_key = "observation" @@ -523,7 +544,7 @@ def __new__( from_pixels=True, **kwargs, ): - batch_size = kwargs.get("batch_size", torch.Size([])) + batch_size = kwargs.setdefault("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( @@ -551,7 +572,6 @@ def __new__( ) return super().__new__( *args, - batch_size=batch_size, observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, @@ -581,7 +601,7 @@ def __new__( categorical_action_encoding=False, **kwargs, ): - batch_size = kwargs.get("batch_size", torch.Size([])) + batch_size = kwargs.setdefault("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( @@ -612,7 +632,6 @@ def __new__( return super().__new__( *args, - batch_size=batch_size, observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, @@ -647,7 +666,7 @@ def __new__( pixel_shape=None, **kwargs, ): - batch_size = kwargs.get("batch_size", torch.Size([])) + batch_size = kwargs.setdefault("batch_size", torch.Size([])) if pixel_shape is None: pixel_shape = [1, 7, 7] if observation_spec is None: @@ -703,7 +722,7 @@ def __new__( from_pixels=True, **kwargs, ): - batch_size = kwargs.get("batch_size", torch.Size([])) + batch_size = kwargs.setdefault("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" observation_spec = CompositeSpec( diff --git a/test/test_specs.py b/test/test_specs.py index 5bf504c13d2..c8df9286db9 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -270,8 +270,8 @@ def test_multi_discrete(shape, ns, dtype): *_real_shape, *nvec_shape, ] - ) - assert ts.is_in(r) + ), (r.shape, ns, shape, _real_shape, nvec_shape) + assert ts.is_in(r), (r, r.shape, ns) rand = torch.rand( torch.Size( [ @@ -301,9 +301,21 @@ def test_multi_discrete(shape, ns, dtype): ], ) @pytest.mark.parametrize("device", get_available_devices()) -def test_discrete_conversion(n, device): - categorical = DiscreteTensorSpec(n, device=device) - one_hot = OneHotDiscreteTensorSpec(n, device=device) +@pytest.mark.parametrize( + "shape", + [ + None, + [], + [ + 1, + ], + [1, 2], + ], +) +def test_discrete_conversion(n, device, shape): + categorical = DiscreteTensorSpec(n, device=device, shape=shape) + shape_one_hot = [n] if not shape else [*shape, n] + one_hot = OneHotDiscreteTensorSpec(n, device=device, shape=shape_one_hot) assert categorical != one_hot assert categorical.to_onehot() == one_hot @@ -642,21 +654,33 @@ def test_equality_onehot(self): dtype = torch.float16 use_register = False - ts = OneHotDiscreteTensorSpec(n, device, dtype, use_register) + ts = OneHotDiscreteTensorSpec( + n=n, device=device, dtype=dtype, use_register=use_register + ) - ts_same = OneHotDiscreteTensorSpec(n, device, dtype, use_register) + ts_same = OneHotDiscreteTensorSpec( + n=n, device=device, dtype=dtype, use_register=use_register + ) assert ts == ts_same - ts_other = OneHotDiscreteTensorSpec(n + 1, device, dtype, use_register) + ts_other = OneHotDiscreteTensorSpec( + n=n + 1, device=device, dtype=dtype, use_register=use_register + ) assert ts != ts_other - ts_other = OneHotDiscreteTensorSpec(n, "cpu:0", dtype, use_register) + ts_other = OneHotDiscreteTensorSpec( + n=n, device="cpu:0", dtype=dtype, use_register=use_register + ) assert ts != ts_other - ts_other = OneHotDiscreteTensorSpec(n, device, torch.float64, use_register) + ts_other = OneHotDiscreteTensorSpec( + n=n, device=device, dtype=torch.float64, use_register=use_register + ) assert ts != ts_other - ts_other = OneHotDiscreteTensorSpec(n, device, dtype, not use_register) + ts_other = OneHotDiscreteTensorSpec( + n=n, device=device, dtype=dtype, use_register=not use_register + ) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( @@ -730,21 +754,25 @@ def test_equality_discrete(self): device = "cpu" dtype = torch.float16 - ts = DiscreteTensorSpec(n, shape, device, dtype) + ts = DiscreteTensorSpec(n=n, shape=shape, device=device, dtype=dtype) - ts_same = DiscreteTensorSpec(n, shape, device, dtype) + ts_same = DiscreteTensorSpec(n=n, shape=shape, device=device, dtype=dtype) assert ts == ts_same - ts_other = DiscreteTensorSpec(n + 1, shape, device, dtype) + ts_other = DiscreteTensorSpec(n=n + 1, shape=shape, device=device, dtype=dtype) assert ts != ts_other - ts_other = DiscreteTensorSpec(n, shape, "cpu:0", dtype) + ts_other = DiscreteTensorSpec(n=n, shape=shape, device="cpu:0", dtype=dtype) assert ts != ts_other - ts_other = DiscreteTensorSpec(n, shape, device, torch.float64) + ts_other = DiscreteTensorSpec( + n=n, shape=shape, device=device, dtype=torch.float64 + ) assert ts != ts_other - ts_other = DiscreteTensorSpec(n, torch.Size([2]), device, torch.float64) + ts_other = DiscreteTensorSpec( + n=n, shape=torch.Size([2]), device=device, dtype=torch.float64 + ) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( diff --git a/test/test_transforms.py b/test/test_transforms.py index 1f46ef7160e..c30cac1b84c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1102,10 +1102,14 @@ def make_env(): t_env.transform.init_stats( num_iter=11, reduce_dim=reduce_dim, cat_dim=cat_dim ) - - assert t_env.transform.loc.shape == t_env.observation_spec["observation"].shape + batch_dims = len(t_env.batch_size) + assert ( + t_env.transform.loc.shape + == t_env.observation_spec["observation"].shape[batch_dims:] + ) assert ( - t_env.transform.scale.shape == t_env.observation_spec["observation"].shape + t_env.transform.scale.shape + == t_env.observation_spec["observation"].shape[batch_dims:] ) assert t_env.transform.loc.dtype == t_env.observation_spec["observation"].dtype assert ( @@ -2256,11 +2260,10 @@ def test_batch_unlocked_with_batch_size_transformed(device): with pytest.raises(RuntimeError, match="batch_locked is a read-only property"): env.batch_locked = False - td = env.reset() - td["action"] = env.action_spec.rand(env.batch_size) - td_expanded = td.expand(2, 2).reshape(-1).to_tensordict() + td["action"] = env.action_spec.rand() env.step(td) + td_expanded = td.expand(2, 2).reshape(-1).to_tensordict() with pytest.raises( RuntimeError, match="Expected a tensordict with shape==env.shape, " diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 0f1886adb38..1d2441dae2c 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -159,12 +159,15 @@ def __iter__(self): def __repr__(self): return f"{self.__class__.__name__}(boxes={self.boxes})" + def __len__(self): + return len(self.boxes) + @staticmethod def from_nvec(nvec: torch.Tensor): if nvec.ndim == 0: return DiscreteBox(nvec.item()) else: - return BoxList([BoxList.from_nvec(n) for n in nvec]) + return BoxList([BoxList.from_nvec(n) for n in nvec.unbind(-1)]) @dataclass(repr=False) @@ -465,7 +468,7 @@ def __init__( shape = torch.Size((space.n,)) else: shape = torch.Size(shape) - if shape[-1] != space.n: + if not len(shape) or shape[-1] != space.n: raise ValueError( f"The last value of the shape must match n for transform of type {self.__class__}. " f"Got n={space.n} and shape={shape}." @@ -566,7 +569,9 @@ def __eq__(self, other): ) def to_categorical(self) -> DiscreteTensorSpec: - return DiscreteTensorSpec(self.space.n, device=self.device, dtype=self.dtype) + return DiscreteTensorSpec( + self.space.n, device=self.device, dtype=self.dtype, shape=self.shape[:-1] + ) @dataclass(repr=False) @@ -1068,8 +1073,12 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return torch.cat([super()._project(_val) for _val in vals], -1) def to_categorical(self) -> MultiDiscreteTensorSpec: + return MultiDiscreteTensorSpec( - [_space.n for _space in self.space], self.device, self.dtype + [_space.n for _space in self.space], + device=self.device, + dtype=self.dtype, + shape=[*self.shape[:-1], len(self.space)], ) def expand(self, *shape): @@ -1166,12 +1175,15 @@ def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: return super().to_numpy(val, safe) def to_onehot(self) -> OneHotDiscreteTensorSpec: - if len(self.shape) > 1: - raise RuntimeError( - f"DiscreteTensorSpec with shape that has several dimensions can't be converted to " - f"OneHotDiscreteTensorSpec. Got shape={self.shape}." - ) - return OneHotDiscreteTensorSpec(self.space.n, self.device, self.dtype) + # if len(self.shape) > 1: + # raise RuntimeError( + # f"DiscreteTensorSpec with shape that has several dimensions can't be converted to " + # f"OneHotDiscreteTensorSpec. Got shape={self.shape}." + # ) + shape = [*self.shape, self.space.n] + return OneHotDiscreteTensorSpec( + n=self.space.n, shape=shape, device=self.device, dtype=self.dtype + ) def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], tuple): @@ -1233,6 +1245,7 @@ def __init__( f"The last value of the shape must match nvec.shape[-1] for transform of type {self.__class__}. " f"Got nvec.shape[-1]={sum(nvec)} and shape={shape}." ) + self.nvec = self.nvec.expand(shape) space = BoxList.from_nvec(nvec) super(DiscreteTensorSpec, self).__init__( @@ -1243,7 +1256,7 @@ def _rand(self, space: Box, shape: torch.Size, i: int): x = [] for _s in space: if isinstance(_s, BoxList): - x.append(self._rand(_s, shape, i - 1)) + x.append(self._rand(_s, shape[:-1], i - 1)) else: x.append( torch.randint( @@ -1254,14 +1267,17 @@ def _rand(self, space: Box, shape: torch.Size, i: int): dtype=self.dtype, ) ) - return torch.stack(x, -i) + return torch.stack(x, -1) def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] else: - shape = torch.Size([*shape, *self.shape[:-1]]) - x = self._rand(self.space, shape, self.nvec.ndim) + shape = ( + *shape, + *self.shape[:-1], + ) + x = self._rand(space=self.space, shape=shape, i=self.nvec.ndim) if self.shape == torch.Size([1]): x = x.squeeze(-1) return x @@ -1298,8 +1314,12 @@ def to_onehot(self) -> MultiOneHotDiscreteTensorSpec: f"nestedtensors but it is not implemented yet. If you would like to see that feature, please submit " f"an issue of torchrl's github repo. " ) + nvec = [_space.n for _space in self.space] return MultiOneHotDiscreteTensorSpec( - [_space.n for _space in self.space], self.device, self.dtype + nvec, + device=self.device, + dtype=self.dtype, + shape=[*self.shape[:-1], sum(nvec)], ) def expand(self, *shape): diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index f72a0c34e2f..96e21a9ce62 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -353,20 +353,27 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: reward = tensordict_out.get("reward") # unsqueeze rewards if needed - batch_size = tensordict_out.shape + # the input tensordict may have more leading dimensions than the batch_size + # e.g. in model-based contexts. + batch_size = self.batch_size dims = len(batch_size) - expected_reward_shape = self.reward_spec.shape[dims:] - actual_reward_shape = reward.shape[dims:] + leading_batch_size = ( + tensordict_out.batch_size[:-dims] if dims else tensordict_out.shape + ) + expected_reward_shape = torch.Size( + [*leading_batch_size, *self.reward_spec.shape] + ) + actual_reward_shape = reward.shape if actual_reward_shape != expected_reward_shape: - reward = reward.view([*batch_size, *expected_reward_shape]) + reward = reward.view(expected_reward_shape) tensordict_out.set("reward", reward) done = tensordict_out.get("done") # unsqueeze done if needed - expected_done_shape = torch.Size([1]) - actual_done_shape = done.shape[dims:] + expected_done_shape = torch.Size([*leading_batch_size, *batch_size, 1]) + actual_done_shape = done.shape if actual_done_shape != expected_done_shape: - done = done.view([*batch_size, *expected_done_shape]) + done = done.view(expected_done_shape) tensordict_out.set("done", done) if tensordict_out is tensordict: @@ -439,12 +446,17 @@ def reset( done = tensordict_reset.get("done", None) if done is not None: # unsqueeze done if needed - batch_size = tensordict_reset.shape + # the input tensordict may have more leading dimensions than the batch_size + # e.g. in model-based contexts. + batch_size = self.batch_size dims = len(batch_size) - expected_done_shape = torch.Size([1]) - actual_done_shape = done.shape[dims:] + leading_batch_size = ( + tensordict_reset.batch_size[:-dims] if dims else tensordict_reset.shape + ) + expected_done_shape = torch.Size([*leading_batch_size, *batch_size, 1]) + actual_done_shape = done if actual_done_shape != expected_done_shape: - done = done.view([*batch_size, *expected_done_shape]) + done = done.view(expected_done_shape) tensordict_reset.set("done", done) else: tensordict_reset.set( diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index fb0cd2ccd7b..a2739cde005 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -50,7 +50,7 @@ def new_fun(self, observation_spec): for in_key, out_key in zip(self.in_keys, self.out_keys): if in_key in observation_spec.keys(): d[out_key] = function(self, observation_spec[in_key]) - return CompositeSpec(d) + return CompositeSpec(d, shape=observation_spec.shape) else: return function(self, observation_spec) @@ -857,7 +857,7 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: return BoundedTensorSpec( self.clamp_min, self.clamp_max, - torch.Size((1,)), + shape=reward_spec.shape, device=reward_spec.device, dtype=reward_spec.dtype, ) @@ -898,7 +898,9 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor: return (reward > 0.0).to(torch.long) def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: - return BinaryDiscreteTensorSpec(n=1, device=reward_spec.device) + return BinaryDiscreteTensorSpec( + n=1, device=reward_spec.device, shape=reward_spec.shape + ) class Resize(ObservationTransform): @@ -1005,7 +1007,8 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec if key in self.in_keys else _obs_spec for key, _obs_spec in observation_spec._specs.items() - } + }, + shape=observation_spec.shape, ) space = observation_spec.space @@ -2593,7 +2596,9 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec # Update observation_spec with episode_specs if not isinstance(observation_spec, CompositeSpec): - observation_spec = CompositeSpec(observation=observation_spec) + observation_spec = CompositeSpec( + observation=observation_spec, shape=self.parent.batch_size + ) observation_spec.update(episode_specs) return observation_spec @@ -2667,7 +2672,9 @@ def transform_observation_spec( f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." ) observation_spec["step_count"] = UnboundedDiscreteTensorSpec( - shape=torch.Size([]), dtype=torch.int64, device=observation_spec.device + shape=self.parent.batch_size, + dtype=torch.int64, + device=observation_spec.device, ) observation_spec["step_count"].space.minimum = 0 return observation_spec From 7da045ed7a5fc13f2218ecc8c897ebcb27054963 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 16 Jan 2023 21:02:17 +0000 Subject: [PATCH 10/30] amend --- test/mocking_classes.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 44961b82da8..d0ad1215b7b 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -301,10 +301,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: dtype=torch.bool, device=self.device, ) - return TensorDict( {"reward": n, "done": done, "observation": n}, - batch_size, + [ + *leading_batch_size, + *batch_size, + ], device=self.device, ) @@ -504,8 +506,9 @@ def _step( tensordict.set(self._out_key, self._get_out_obs(obs)) done = torch.isclose(obs, torch.ones_like(obs) * (self.counter + 1)) - reward = done.any(-1).unsqueeze(-1) - done = done.all(-1).unsqueeze(-1) + while done.shape != tensordict.shape: + done = done.any(-1) + done = reward = done.unsqueeze(-1) tensordict.set("reward", reward.to(torch.get_default_dtype())) tensordict.set("done", done) return tensordict From 225c3da24e066913e36df07fb9e5b973a3a554b5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 17 Jan 2023 08:10:26 +0000 Subject: [PATCH 11/30] final amend? --- test/mocking_classes.py | 3 ++- test/test_env.py | 12 ++++++------ test/test_transforms.py | 4 ++-- torchrl/modules/planners/cem.py | 4 ++-- torchrl/modules/planners/mppi.py | 4 ++-- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index d0ad1215b7b..ae0512205b4 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -864,7 +864,8 @@ def __init__(self, max_steps: int = 5, **kwargs): ) ) self.input_spec = CompositeSpec( - action=BinaryDiscreteTensorSpec(*self.batch_size, 1) + action=BinaryDiscreteTensorSpec(n=1, shape=[*self.batch_size, 1]), + shape=self.batch_size, ) self.count = torch.zeros( diff --git a/test/test_env.py b/test/test_env.py index 23bd16d10d6..b2b28956e11 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -360,7 +360,7 @@ def test_mb_env_batch_lock(self, device, seed=0): with pytest.raises(RuntimeError, match="batch_locked is a read-only property"): mb_env.batch_locked = False td = mb_env.reset() - td["action"] = mb_env.action_spec.rand(mb_env.batch_size) + td["action"] = mb_env.action_spec.rand() td_expanded = td.expand(2) mb_env.step(td) # we should be able to do a step with a tensordict that has been expended @@ -909,7 +909,7 @@ def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): n_workers, lambda: CountingEnv(max_steps=max_steps, batch_size=batch_size) ) env.set_seed(1) - action = env.action_spec.rand(env.batch_size) + action = env.action_spec.rand() action[:] = 1 for i in range(max_steps): @@ -947,7 +947,7 @@ def test_env_base_reset_flag(batch_size, max_steps=3): env = CountingEnv(max_steps=max_steps, batch_size=batch_size) env.set_seed(1) - action = env.action_spec.rand(env.batch_size) + action = env.action_spec.rand() action[:] = 1 for i in range(max_steps): @@ -1071,7 +1071,7 @@ def test_batch_locked(device): with pytest.raises(RuntimeError, match="batch_locked is a read-only property"): env.batch_locked = False td = env.reset() - td["action"] = env.action_spec.rand(env.batch_size) + td["action"] = env.action_spec.rand() td_expanded = td.expand(2).clone() td = env.step(td) @@ -1089,7 +1089,7 @@ def test_batch_unlocked(device): with pytest.raises(RuntimeError, match="batch_locked is a read-only property"): env.batch_locked = False td = env.reset() - td["action"] = env.action_spec.rand(env.batch_size) + td["action"] = env.action_spec.rand() td_expanded = td.expand(2).clone() td = env.step(td) @@ -1105,7 +1105,7 @@ def test_batch_unlocked_with_batch_size(device): env.batch_locked = False td = env.reset() - td["action"] = env.action_spec.rand(env.batch_size) + td["action"] = env.action_spec.rand() td_expanded = td.expand(2, 2).reshape(-1).to_tensordict() td = env.step(td) diff --git a/test/test_transforms.py b/test/test_transforms.py index c30cac1b84c..1221ca1c561 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2217,7 +2217,7 @@ def test_batch_locked_transformed(device): with pytest.raises(RuntimeError, match="batch_locked is a read-only property"): env.batch_locked = False td = env.reset() - td["action"] = env.action_spec.rand(env.batch_size) + td["action"] = env.action_spec.rand() td_expanded = td.expand(2).clone() env.step(td) @@ -2241,7 +2241,7 @@ def test_batch_unlocked_transformed(device): with pytest.raises(RuntimeError, match="batch_locked is a read-only property"): env.batch_locked = False td = env.reset() - td["action"] = env.action_spec.rand(env.batch_size) + td["action"] = env.action_spec.rand() td_expanded = td.expand(2).clone() env.step(td) env.step(td_expanded) diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index 491f02f3391..f442c50529a 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -67,9 +67,9 @@ class CEMPlanner(MPCPlannerBase): ... device=self.device, ... ) ... tensordict = tensordict.update( - ... self.input_spec.rand(self.batch_size)) + ... self.input_spec.rand()) ... tensordict = tensordict.update( - ... self.observation_spec.rand(self.batch_size)) + ... self.observation_spec.rand()) ... return tensordict ... >>> from torchrl.modules import MLP, WorldModelWrapper diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index f1a5fe9b255..21fb53fae00 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -66,9 +66,9 @@ class MPPIPlanner(MPCPlannerBase): ... device=self.device, ... ) ... tensordict = tensordict.update( - ... self.input_spec.rand(self.batch_size)) + ... self.input_spec.rand()) ... tensordict = tensordict.update( - ... self.observation_spec.rand(self.batch_size)) + ... self.observation_spec.rand()) ... return tensordict >>> from torchrl.modules import MLP, WorldModelWrapper >>> import torch.nn as nn From ed49a64eb9a1ee18292daca18d4e9fcdcdd1c3d5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 17 Jan 2023 10:51:59 +0000 Subject: [PATCH 12/30] hopefully another almost final! --- torchrl/data/tensor_specs.py | 45 +++++++++++++++++++++--------------- torchrl/envs/common.py | 8 +++++-- torchrl/envs/libs/jumanji.py | 14 ++++++----- torchrl/envs/libs/vmas.py | 11 ++++++--- torchrl/envs/utils.py | 2 +- 5 files changed, 50 insertions(+), 30 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 1d2441dae2c..139d9a6212e 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import warnings from copy import deepcopy from dataclasses import dataclass from textwrap import indent @@ -236,7 +237,10 @@ def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: val = torch.tensor(val, dtype=self.dtype, device=self.device) if val.shape[-len(self.shape) :] != self.shape: # option 1: add a singleton dim at the end - if self.shape == torch.Size([1]): + if ( + val.shape[-len(self.shape) :] == self.shape[:-1] + and self.shape[-1] == 1 + ): val = val.unsqueeze(-1) else: raise RuntimeError( @@ -476,7 +480,7 @@ def __init__( super().__init__(shape, space, device, dtype, "discrete") def expand(self, *shape): - if len(shape) == 1 and isinstance(shape[0], tuple): + if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( @@ -666,7 +670,7 @@ def __init__( ) def expand(self, *shape): - if len(shape) == 1 and isinstance(shape[0], tuple): + if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( @@ -735,9 +739,15 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: return val def is_in(self, val: torch.Tensor) -> bool: - return (val >= self.space.minimum.to(val.device)).all() and ( - val <= self.space.maximum.to(val.device) - ).all() + try: + return (val >= self.space.minimum.to(val.device)).all() and ( + val <= self.space.maximum.to(val.device) + ).all() + except RuntimeError as err: + if "The size of tensor a" in str(err): + warnings.warn(f"Got a shape mismatch: {str(err)}") + return False + raise err @dataclass(repr=False) @@ -783,7 +793,7 @@ def is_in(self, val: torch.Tensor) -> bool: return True def expand(self, *shape): - if len(shape) == 1 and isinstance(shape[0], tuple): + if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( @@ -850,7 +860,7 @@ def is_in(self, val: torch.Tensor) -> bool: return True def expand(self, *shape): - if len(shape) == 1 and isinstance(shape[0], tuple): + if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( @@ -927,7 +937,7 @@ def is_in(self, val: torch.Tensor) -> bool: return ((val == 0) | (val == 1)).all() def expand(self, *shape): - if len(shape) == 1 and isinstance(shape[0], tuple): + if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( @@ -1083,7 +1093,7 @@ def to_categorical(self) -> MultiDiscreteTensorSpec: def expand(self, *shape): nvecs = [space.n for space in self.space] - if len(shape) == 1 and isinstance(shape[0], tuple): + if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( @@ -1186,7 +1196,7 @@ def to_onehot(self) -> OneHotDiscreteTensorSpec: ) def expand(self, *shape): - if len(shape) == 1 and isinstance(shape[0], tuple): + if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( @@ -1323,7 +1333,7 @@ def to_onehot(self) -> MultiOneHotDiscreteTensorSpec: ) def expand(self, *shape): - if len(shape) == 1 and isinstance(shape[0], tuple): + if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError( @@ -1444,7 +1454,7 @@ def set(self, name, spec): ) self._specs[name] = spec - def __init__(self, *args, shape=None, **kwargs): + def __init__(self, *args, shape=None, device=None, **kwargs): if shape is None: # Should we do this? Other specs have a default empty shape, maybe it would make sense to keep it # optional for composite (for clarity and easiness of use). @@ -1456,8 +1466,8 @@ def __init__(self, *args, shape=None, **kwargs): for key, value in kwargs.items(): self.set(key, value) + _device = device if len(kwargs): - _device = None for key, item in self.items(): if item is None: continue @@ -1468,10 +1478,8 @@ def __init__(self, *args, shape=None, **kwargs): f"Setting a new attribute ({key}) on another device ({item.device} against {self.device}). " f"All devices of CompositeSpec must match." ) - self._device = _device + self._device = _device if len(args): - if not len(kwargs): - self._device = None if len(args) > 1: raise RuntimeError( "Got multiple arguments, when at most one is expected for CompositeSpec." @@ -1684,7 +1692,7 @@ def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> N return self def expand(self, *shape): - if len(shape) == 1 and isinstance(shape[0], tuple): + if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] if any(val < 0 for val in shape): raise ValueError("CompositeSpec.extend does not support negative shapes.") @@ -1699,6 +1707,7 @@ def expand(self, *shape): for key, value in tuple(self.items()) }, shape=shape, + device=self.device, ) return out diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 96e21a9ce62..f18d49ac774 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -289,7 +289,9 @@ def input_spec(self, value: TensorSpec) -> None: if not isinstance(value, CompositeSpec): raise TypeError("The type of an input_spec must be Composite.") if value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError("The value of spec.shape must match the env batch size.") + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." + ) self.__dict__["_input_spec"] = value @property @@ -322,7 +324,9 @@ def observation_spec(self, value: TensorSpec) -> None: if not isinstance(value, CompositeSpec): raise TypeError("The type of an observation_spec must be Composite.") elif value.shape[: len(self.batch_size)] != self.batch_size: - raise ValueError("The value of spec.shape must match the env batch size.") + raise ValueError( + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." + ) self.__dict__["_observation_spec"] = value def step(self, tensordict: TensorDictBase) -> TensorDictBase: diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index a4ae6200e94..463ba27fc5a 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -184,15 +184,17 @@ def _make_input_spec(self, env) -> TensorSpec: action=_jumanji_to_torchrl_spec_transform( env.action_spec(), device=self.device ), - ) + ).expand(self.batch_size) def _make_observation_spec(self, env) -> TensorSpec: spec = env.observation_spec() new_spec = _jumanji_to_torchrl_spec_transform(spec, device=self.device) if isinstance(spec, jumanji.specs.Array): - return CompositeSpec(observation=new_spec) + return CompositeSpec(observation=new_spec).expand(self.batch_size) elif isinstance(spec, jumanji.specs.Spec): - return CompositeSpec(**{k: v for k, v in new_spec.items()}) + return CompositeSpec(**{k: v for k, v in new_spec.items()}).expand( + self.batch_size + ) else: raise TypeError(f"Unsupported spec type {type(spec)}") @@ -202,7 +204,7 @@ def _make_reward_spec(self, env) -> TensorSpec: ) if not len(reward_spec.shape): reward_spec.shape = torch.Size([1]) - return reward_spec + return reward_spec.expand([*self.batch_size, *reward_spec.shape]) def _make_specs(self, env: "jumanji.env.Environment") -> None: # noqa: F821 @@ -212,7 +214,7 @@ def _make_specs(self, env: "jumanji.env.Environment") -> None: # noqa: F821 self.reward_spec = self._make_reward_spec(env) # extract state spec from instance - self.state_spec = self._make_state_spec(env) + self.state_spec = self._make_state_spec(env).expand(self.batch_size) self.input_spec["state"] = self.state_spec # build state example for data conversion @@ -249,7 +251,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # prepare inputs state = _tensordict_to_object(tensordict.get("state"), self._state_example) action = self.read_action(tensordict.get("action")) - reward = self.reward_spec.zero(self.batch_size) + reward = self.reward_spec.zero() # flatten batch size into vector state = _tree_flatten(state, self.batch_size) diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 0ec49794756..e012514870a 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -122,6 +122,11 @@ def __init__( "Batch size used in constructor is not compatible with vmas." ) self.batch_size = torch.Size([self.n_agents, *self.batch_size]) + self.input_spec = self.input_spec.expand(self.batch_size) + self.observation_spec = self.observation_spec.expand(self.batch_size) + self.reward_spec = self.reward_spec.expand( + [*self.batch_size, *self.reward_spec.shape] + ) @property def lib(self): @@ -158,12 +163,12 @@ def _make_specs( device=self.device, ) ) - ) + ).expand(self.batch_size) self.reward_spec = UnboundedContinuousTensorSpec( shape=torch.Size((1,)), device=self.device, - ) + ).expand([*self.batch_size, 1]) self.observation_spec = CompositeSpec( observation=( @@ -184,7 +189,7 @@ def _make_specs( for key, value in self.scenario.info(agent0).items() }, ).to(self.device), - ) + ).expand(self.batch_size) def _check_kwargs(self, kwargs: Dict): if "env" not in kwargs: diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 4e1a91bf187..2d4adeab4c4 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -209,7 +209,7 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True): # test dtypes real_tensordict = env.rollout(3) # keep empty structures, for example dict() - for key, value in real_tensordict.items(): + for key, value in real_tensordict[..., -1].items(): _check_isin(key, value, env.observation_spec, env.input_spec) From f5e09fea4352d7df34b3b86e8ffd2dd9861f208b Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 17 Jan 2023 15:57:41 +0000 Subject: [PATCH 13/30] amend --- torchrl/data/tensor_specs.py | 61 ++++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 13 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 139d9a6212e..c643a3cdc5f 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -39,6 +39,8 @@ _DEFAULT_SHAPE = torch.Size((1,)) +DEVICE_ERR_MSG = "device of empty CompositeSpec is not defined." + def _default_dtype_and_device( dtype: Union[None, torch.dtype], @@ -1475,7 +1477,7 @@ def __init__(self, *args, shape=None, device=None, **kwargs): _device = item.device elif item.device != _device: raise RuntimeError( - f"Setting a new attribute ({key}) on another device ({item.device} against {self.device}). " + f"Setting a new attribute ({key}) on another device ({item.device} against {_device}). " f"All devices of CompositeSpec must match." ) self._device = _device @@ -1507,15 +1509,16 @@ def device(self) -> DEVICE_TYPING: if _device is None: raise RuntimeError( "device of empty CompositeSpec is not defined. " - "You can set it directly by calling" + "You can set it directly by calling " "`spec.device = device`." ) self._device = _device return self._device @device.setter - def device(self, value: DEVICE_TYPING): - self._device = value + def device(self, device: DEVICE_TYPING): + device = torch.device(device) + self.to(device) def __getitem__(self, item): if isinstance(item, tuple) and len(item) > 1: @@ -1540,11 +1543,29 @@ def __setitem__(self, key, value): raise TypeError(f"Got key of type {type(key)} when a string was expected.") if key in {"shape", "device", "dtype", "space"}: raise AttributeError(f"CompositeSpec[{key}] cannot be set") - if value is not None and value.device != self.device: - raise RuntimeError( - f"Setting a new attribute ({key}) on another device ({value.device} against {self.device}). " - f"All devices of CompositeSpec must match." - ) + try: + if value is not None and value.device != self.device: + raise RuntimeError( + f"Setting a new attribute ({key}) on another device ({value.device} against {self.device}). " + f"All devices of CompositeSpec must match." + ) + except RuntimeError as err: + cond1 = DEVICE_ERR_MSG in str(err) + cond2 = self._device is None + if cond1 and cond2: + try: + device_val = value.device + self.to(device_val) + except RuntimeError as suberr: + if DEVICE_ERR_MSG in str(suberr): + pass + else: + raise suberr + elif cond1: + pass + else: + raise err + self.set(key, value) def __iter__(self): @@ -1650,8 +1671,9 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: raise ValueError( "Only device casting is allowed with specs of type CompositeSpec." ) - - self.device = torch.device(dest) + if self._device and self._device == torch.device(dest): + return self + self._device = torch.device(dest) for key, value in list(self.items()): if value is None: continue @@ -1686,8 +1708,21 @@ def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> N if key in self.keys(True) and isinstance(self[key], CompositeSpec): self[key].update(item) continue - if isinstance(item, TensorSpec) and item.device != self.device: - item = deepcopy(item).to(self.device) + try: + if isinstance(item, TensorSpec) and item.device != self.device: + item = deepcopy(item).to(self.device) + except RuntimeError as err: + if DEVICE_ERR_MSG in str(err): + try: + item_device = item.device + self.device = item_device + except RuntimeError as suberr: + if DEVICE_ERR_MSG in str(suberr): + pass + else: + raise suberr + else: + raise err self[key] = item return self From 66ef4e63637652d8c16265da0858965a8781bb62 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 19 Jan 2023 18:21:11 +0000 Subject: [PATCH 14/30] amend --- test/test_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 230b5e453a1..9d47deb3242 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1165,10 +1165,10 @@ def make_env(): ) assert t_env.transform.loc.shape == torch.Size( - [t_env.observation_spec["pixels"].shape[0], 1, 1] + [t_env.observation_spec["pixels"].shape[-3], 1, 1] ) assert t_env.transform.scale.shape == torch.Size( - [t_env.observation_spec["pixels"].shape[0], 1, 1] + [t_env.observation_spec["pixels"].shape[-3], 1, 1] ) def test_observationnorm_stats_already_initialized_error(self): From 4094d4e6e217195ce3a42d176ea72acb9a2bb37c Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 20 Jan 2023 12:57:17 +0000 Subject: [PATCH 15/30] amend --- knowledge_base/PRO-TIPS.md | 10 ++++++++++ test/test_shared.py | 1 - torchrl/data/tensor_specs.py | 16 ++++++++-------- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/knowledge_base/PRO-TIPS.md b/knowledge_base/PRO-TIPS.md index 4ed24925774..3749a11b38e 100644 --- a/knowledge_base/PRO-TIPS.md +++ b/knowledge_base/PRO-TIPS.md @@ -99,3 +99,13 @@ Errors to look for that may be related to this misconception are the following: are being used, using vectorized maps and functional programming (through functorch) instead of looping over the model configurations can provide a significant speedup. + +## Common bugs +- For bugs related to mujoco (incl. DeepMind Control suite and other libraries), + refer to the [MUJOCO_INSTALLATION](MUJOCO_INSTALLATION.md) file. +- `ValueError: bad value(s) in fds_to_keep`: this can have multiple reasons. One that is common in torchrl + is that you are trying to send a tensor across processes that is a view of another tensor. + For instance, when sending the tensor `b = tensor.expand(new_shape)` across processes, the reference to the original + content will be lost (as the `expand` operation keeps the reference to the original tensor). + To debug this, look for such operations (`view`, `permute`, `expand`, etc.) and call `clone()` or `contiguous()` after + the call to the function. diff --git a/test/test_shared.py b/test/test_shared.py index 1329b0ef12d..c4790597359 100644 --- a/test/test_shared.py +++ b/test/test_shared.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse -import sys import time import warnings diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 67f4f55580e..e9167f3de39 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -635,22 +635,22 @@ def __init__( if shape is not None and shape != maximum.shape: raise RuntimeError(err_msg) shape = maximum.shape - minimum = minimum.expand(*shape) + minimum = minimum.expand(*shape).contiguous() elif minimum.ndimension(): if shape is not None and shape != minimum.shape: raise RuntimeError(err_msg) shape = minimum.shape - maximum = maximum.expand(*shape) + maximum = maximum.expand(*shape).contiguous() elif shape is None: raise RuntimeError(err_msg) else: - minimum = minimum.expand(*shape) - maximum = maximum.expand(*shape) + minimum = minimum.expand(*shape).contiguous() + maximum = maximum.expand(*shape).contiguous() if minimum.numel() > maximum.numel(): - maximum = maximum.expand_as(minimum) + maximum = maximum.expand_as(minimum).contiguous() elif maximum.numel() > minimum.numel(): - minimum = minimum.expand_as(maximum) + minimum = minimum.expand_as(maximum).contiguous() if shape is None: shape = minimum.shape else: @@ -684,8 +684,8 @@ def expand(self, *shape): f"shape of the CompositeSpec in CompositeSpec.extend." ) return self.__class__( - minimum=self.space.minimum.expand(shape), - maximum=self.space.maximum.expand(shape), + minimum=self.space.minimum.expand(shape).contiguous(), + maximum=self.space.maximum.expand(shape).contiguous(), shape=shape, device=self.device, dtype=self.dtype, From bdb28d605349eee558024471886d1a873868684c Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 20 Jan 2023 13:58:07 +0000 Subject: [PATCH 16/30] ammend --- test/test_libs.py | 10 ++++---- torchrl/data/tensor_specs.py | 18 +++++++++++--- torchrl/envs/common.py | 16 +++++++++--- torchrl/envs/libs/brax.py | 47 ++++++++++++++++++------------------ 4 files changed, 54 insertions(+), 37 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index e6891148c2b..fcebd4e4c28 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -342,7 +342,7 @@ def test_collector_run(self, env_lib, env_args, env_kwargs, device): class TestHabitat: def test_habitat(self, envname): env = HabitatEnv(envname) - rollout = env.rollout(3) + _ = env.rollout(3) check_env_specs(env) @pytest.mark.parametrize("from_pixels", [True, False]) @@ -518,7 +518,7 @@ def test_brax_grad(self, envname, batch_size): env = BraxEnv(envname, batch_size=batch_size, requires_grad=True) env.set_seed(0) td1 = env.reset() - action = torch.randn(batch_size + env.action_spec.shape) + action = torch.randn(env.action_spec.shape) action.requires_grad_(True) td1["action"] = action td2 = env.step(td1) @@ -572,7 +572,7 @@ def test_vmas_batch_size_error(self, scenario_name, batch_size): TypeError, match="Batch size used in constructor is not compatible with vmas.", ): - env = VmasEnv( + _ = VmasEnv( scenario_name=scenario_name, num_envs=num_envs, n_agents=n_agents, @@ -583,14 +583,14 @@ def test_vmas_batch_size_error(self, scenario_name, batch_size): TypeError, match="Batch size used in constructor does not match vmas batch size.", ): - env = VmasEnv( + _ = VmasEnv( scenario_name=scenario_name, num_envs=num_envs, n_agents=n_agents, batch_size=batch_size, ) else: - env = VmasEnv( + _ = VmasEnv( scenario_name=scenario_name, num_envs=num_envs, n_agents=n_agents, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index e9167f3de39..237413f89d3 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1471,11 +1471,20 @@ def __init__(self, *args, shape=None, device=None, **kwargs): _device = device if len(kwargs): for key, item in self.items(): + try: + item_device = item.device + except RuntimeError as err: + cond1 = DEVICE_ERR_MSG in str(err) + if cond1: + item_device = _device + else: + raise err + if item is None: continue if _device is None: - _device = item.device - elif item.device != _device: + _device = item_device + elif item_device != _device: raise RuntimeError( f"Setting a new attribute ({key}) on another device ({item.device} against {_device}). " f"All devices of CompositeSpec must match." @@ -1644,6 +1653,7 @@ def rand(self, shape=None) -> TensorDictBase: return TensorDict( _dict, batch_size=shape, + device=self._device, ) def keys(self, yield_nesting_keys: bool = False) -> KeysView: @@ -1693,7 +1703,7 @@ def zero(self, shape=None) -> TensorDictBase: if isinstance(key, str) and self[key] is not None }, torch.Size([*shape, *self.shape]), - device=self.device, + device=self._device, ) def __eq__(self, other): @@ -1742,7 +1752,7 @@ def expand(self, *shape): for key, value in tuple(self.items()) }, shape=shape, - device=self.device, + device=self._device, ) return out diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 11d25537637..661618c5ef8 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -223,10 +223,6 @@ def __init__( # we want an error to be raised if we pass batch_size but # it's already been set self.batch_size = torch.Size(batch_size) - elif ("batch_size" not in self.__dir__()) and ( - "batch_size" not in self.__class__.__dict__ - ): - self.batch_size = torch.Size([]) self._run_type_checks = run_type_checks @classmethod @@ -269,6 +265,18 @@ def run_type_checks(self) -> bool: def run_type_checks(self, run_type_checks: bool) -> None: self._run_type_checks = run_type_checks + @property + def batch_size(self) -> TensorSpec: + if ("_batch_size" not in self.__dir__()) and ( + "_batch_size" not in self.__class__.__dict__ + ): + self._batch_size = torch.Size([]) + return self._batch_size + + @batch_size.setter + def batch_size(self, value: torch.Size) -> None: + self._batch_size = torch.Size(value) + @property def action_spec(self) -> TensorSpec: return self.input_spec["action"] diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index fa787855e0e..7fe03741fb8 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -114,27 +114,38 @@ def _make_state_spec(self, env: "brax.envs.env.Env"): key = jax.random.PRNGKey(0) state = env.reset(key) state_dict = _object_to_tensordict(state, self.device, batch_size=()) - state_spec = _extract_spec(state_dict) + state_spec = _extract_spec(state_dict).expand(self.batch_size) return state_spec def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 self.input_spec = CompositeSpec( action=BoundedTensorSpec( - minimum=-1, maximum=1, shape=(env.action_size,), device=self.device + minimum=-1, + maximum=1, + shape=( + *self.batch_size, + env.action_size, + ), + device=self.device, ), - shape=env.batch_size, + shape=self.batch_size, ) self.reward_spec = UnboundedContinuousTensorSpec( shape=[ + *self.batch_size, 1, ], device=self.device, ) self.observation_spec = CompositeSpec( observation=UnboundedContinuousTensorSpec( - shape=(env.observation_size,), device=self.device + shape=( + *self.batch_size, + env.observation_size, + ), + device=self.device, ), - shape=env.batch_size, + shape=self.batch_size, ) # extract state spec from instance self.state_spec = self._make_state_spec(env) @@ -171,8 +182,8 @@ def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: state = _object_to_tensordict(state, self.device, self.batch_size) # build result - reward = state.get("reward").view(*self.batch_size, *self.reward_spec.shape) - done = state.get("done").bool().view(*self.batch_size, *self.reward_spec.shape) + reward = state.get("reward").view(*self.reward_spec.shape) + done = state.get("done").bool().view(*self.reward_spec.shape) tensordict_out = TensorDict( source={ "observation": state.get("obs"), @@ -204,14 +215,8 @@ def _step_without_grad(self, tensordict: TensorDictBase): next_state = _object_to_tensordict(next_state, self.device, self.batch_size) # build result - reward = next_state.get("reward").view( - *self.batch_size, *self.reward_spec.shape - ) - done = ( - next_state.get("done") - .bool() - .view(*self.batch_size, *self.reward_spec.shape) - ) + reward = next_state.get("reward").view(self.reward_spec.shape) + done = next_state.get("done").bool().view(self.reward_spec.shape) tensordict_out = TensorDict( source={ "observation": next_state.get("obs"), @@ -238,19 +243,13 @@ def _step_with_grad(self, tensordict: TensorDictBase): self, state, action, *qp_values ) - # extract done values - next_done = ( - next_state_nograd.get("done") - .bool() - .view(*self.batch_size, *self.reward_spec.shape) - ) + # extract done values: we assume a shape identical to reward + next_done = next_state_nograd.get("done").bool().view(*self.reward_spec.shape) # merge with tensors with grad function next_state = next_state_nograd next_state["obs"] = next_obs - next_state["reward"] = next_reward.view( - *self.batch_size, *self.reward_spec.shape - ) + next_state["reward"] = next_reward.view(*self.reward_spec.shape) next_state["qp"].update(dict(zip(qp_keys, next_qp_values))) # build result From 65de17483c0bc92b302699481bb485ac15c2b3bf Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 20 Jan 2023 16:13:11 +0000 Subject: [PATCH 17/30] ammend --- test/test_libs.py | 8 +- test/test_specs.py | 219 ++++++++++++++++++++++++++++++++- torchrl/data/tensor_specs.py | 230 ++++++++++++++++++++++++++++++----- torchrl/envs/common.py | 18 ++- 4 files changed, 432 insertions(+), 43 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index fcebd4e4c28..5f6180058c8 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -293,7 +293,6 @@ def test_td_creation_from_spec(env_lib, env_args, env_kwargs): @pytest.mark.skipif(IS_OSX, reason="rendering unstable on osx, skipping") -@pytest.mark.skipif(not (_has_dmc and _has_gym), reason="gym or dm_control not present") @pytest.mark.parametrize( "env_lib,env_args,env_kwargs", [ @@ -307,6 +306,11 @@ def test_td_creation_from_spec(env_lib, env_args, env_kwargs): @pytest.mark.parametrize("device", get_available_devices()) class TestCollectorLib: def test_collector_run(self, env_lib, env_args, env_kwargs, device): + if not _has_dmc and env_lib is DMControlEnv: + raise pytest.skip("no dmc") + if not _has_gym and env_lib is GymEnv: + raise pytest.skip("no gym") + from_pixels = env_kwargs.get("from_pixels", False) if from_pixels and (not torch.has_cuda or not torch.cuda.device_count()): raise pytest.skip("no cuda device") @@ -315,7 +319,7 @@ def test_collector_run(self, env_lib, env_args, env_kwargs, device): env = ParallelEnv(3, env_fn) collector = MultiaSyncDataCollector( create_env_fn=[env, env], - policy=RandomPolicy(env.action_spec), + policy=RandomPolicy(action_spec=env.action_spec), total_frames=-1, max_frames_per_traj=100, frames_per_batch=21, diff --git a/test/test_specs.py b/test/test_specs.py index c8df9286db9..ee2a8752549 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -469,10 +469,10 @@ def test_device_cast(self, is_complete, device, dtype, dest): # Note: trivial test in case there is only one device available. ts = self._composite_spec(is_complete, device, dtype) ts.rand() - ts.to(dest) - cast_r = ts.rand() + td_to = ts.to(dest) + cast_r = td_to.rand() - assert ts.device == dest + assert td_to.device == dest assert cast_r["obs"].device == dest if is_complete: assert cast_r["act"].device == dest @@ -1450,3 +1450,216 @@ def test_unboundeddiscrete(self, shape1, shape2): assert spec2.dtype == spec.dtype assert spec2.rand().shape == spec2.shape assert spec2.zero().shape == spec2.shape + + +class TestClone: + @pytest.mark.parametrize( + "shape1", + [ + None, + (4,), + (5, 4), + ], + ) + def test_binary(self, shape1): + spec = BinaryDiscreteTensorSpec( + n=4, shape=shape1, device="cpu", dtype=torch.bool + ) + assert spec == spec.clone() + assert spec is not spec.clone() + + @pytest.mark.parametrize( + "shape1,mini,maxi", + [ + [(10,), -torch.ones([]), torch.ones([])], + [None, -torch.ones([10]), torch.ones([])], + [None, -torch.ones([]), torch.ones([10])], + [(10,), -torch.ones([]), torch.ones([10])], + [(10,), -torch.ones([10]), torch.ones([])], + [(10,), -torch.ones([10]), torch.ones([10])], + ], + ) + def test_bounded(self, shape1, mini, maxi): + spec = BoundedTensorSpec( + mini, maxi, shape=shape1, device="cpu", dtype=torch.bool + ) + assert spec == spec.clone() + assert spec is not spec.clone() + + def test_composite(self): + batch_size = (5,) + spec1 = BoundedTensorSpec( + -torch.ones([*batch_size, 10]), + torch.ones([*batch_size, 10]), + shape=( + *batch_size, + 10, + ), + device="cpu", + dtype=torch.bool, + ) + spec2 = BinaryDiscreteTensorSpec( + n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool + ) + spec3 = DiscreteTensorSpec( + n=4, shape=batch_size, device="cpu", dtype=torch.long + ) + spec4 = MultiDiscreteTensorSpec( + nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long + ) + spec5 = MultiOneHotDiscreteTensorSpec( + nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long + ) + spec6 = OneHotDiscreteTensorSpec( + n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long + ) + spec7 = UnboundedContinuousTensorSpec( + shape=(*batch_size, 9), + device="cpu", + dtype=torch.float64, + ) + spec8 = UnboundedDiscreteTensorSpec( + shape=(*batch_size, 9), + device="cpu", + dtype=torch.long, + ) + spec = CompositeSpec( + spec1=spec1, + spec2=spec2, + spec3=spec3, + spec4=spec4, + spec5=spec5, + spec6=spec6, + spec7=spec7, + spec8=spec8, + shape=batch_size, + ) + assert spec is not spec.clone() + spec_clone = spec.clone() + for key, item in spec.items(): + assert item == spec_clone[key], key + assert spec == spec.clone() + + @pytest.mark.parametrize( + "shape1", + [ + None, + (), + (5,), + ], + ) + def test_discrete( + self, + shape1, + ): + spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) + assert spec == spec.clone() + assert spec is not spec.clone() + + @pytest.mark.parametrize( + "shape1", + [ + None, + (), + (5,), + ], + ) + def test_multidiscrete( + self, + shape1, + ): + if shape1 is None: + shape1 = (3,) + else: + shape1 = (*shape1, 3) + spec = MultiDiscreteTensorSpec( + nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long + ) + assert spec == spec.clone() + assert spec is not spec.clone() + + @pytest.mark.parametrize( + "shape1", + [ + None, + (), + (5,), + ], + ) + def test_multionehot( + self, + shape1, + ): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = MultiOneHotDiscreteTensorSpec( + nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long + ) + assert spec == spec.clone() + assert spec is not spec.clone() + + @pytest.mark.parametrize( + "shape1", + [ + None, + (), + (5,), + ], + ) + def test_onehot( + self, + shape1, + ): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = OneHotDiscreteTensorSpec( + n=15, shape=shape1, device="cpu", dtype=torch.long + ) + assert spec == spec.clone() + assert spec is not spec.clone() + + @pytest.mark.parametrize( + "shape1", + [ + None, + (), + (5,), + ], + ) + def test_unbounded( + self, + shape1, + ): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = UnboundedContinuousTensorSpec( + shape=shape1, device="cpu", dtype=torch.float64 + ) + assert spec == spec.clone() + assert spec is not spec.clone() + + @pytest.mark.parametrize( + "shape1", + [ + None, + (), + (5,), + ], + ) + def test_unboundeddiscrete( + self, + shape1, + ): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = UnboundedDiscreteTensorSpec(shape=shape1, device="cpu", dtype=torch.long) + assert spec == spec.clone() + assert spec is not spec.clone() diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 237413f89d3..71c8ff7515d 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -100,6 +100,9 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox: def __repr__(self): return f"{self.__class__.__name__}()" + def clone(self) -> DiscreteBox: + return deepcopy(self) + @dataclass(repr=False) class ContinuousBox(Box): @@ -108,18 +111,23 @@ class ContinuousBox(Box): minimum: torch.Tensor maximum: torch.Tensor + def __post_init__(self): + self.minimum = self.minimum.clone() + self.maximum = self.maximum.clone() + def __iter__(self): yield self.minimum yield self.maximum def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox: - self.minimum = self.minimum.to(dest) - self.maximum = self.maximum.to(dest) - return self + return self.__class__(self.minimum.to(dest), self.maximum.to(dest)) + + def clone(self) -> ContinuousBox: + return self.__class__(self.minimum.clone(), self.maximum.clone()) def __repr__(self): - min_str = f"minimum={self.minimum}" - max_str = f"maximum={self.maximum}" + min_str = f"minimum=Tensor(shape={self.minimum.shape}, device={self.minimum.device}, dtype={self.minimum.dtype}, contiguous={self.maximum.is_contiguous()})" + max_str = f"maximum=Tensor(shape={self.maximum.shape}, device={self.maximum.device}, dtype={self.maximum.dtype}, contiguous={self.maximum.is_contiguous()})" return f"{self.__class__.__name__}({min_str}, {max_str})" def __eq__(self, other): @@ -140,7 +148,7 @@ class DiscreteBox(Box): register = invertible_dict() def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> DiscreteBox: - return self + return deepcopy(self) def __repr__(self): return f"{self.__class__.__name__}(n={self.n})" @@ -180,7 +188,7 @@ class BinaryBox(Box): n: int def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox: - return self + return deepcopy(self) def __repr__(self): return f"{self.__class__.__name__}(n={self.n})" @@ -395,14 +403,13 @@ def zero(self, shape=None) -> torch.Tensor: shape = torch.Size([]) return torch.zeros((*shape, *self.shape), dtype=self.dtype, device=self.device) + @abc.abstractmethod def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec": - if self.space is not None: - self.space.to(dest) - if isinstance(dest, (torch.device, str, int)): - self.device = torch.device(dest) - else: - self.dtype = dest - return self + raise NotImplementedError + + @abc.abstractmethod + def clone(self) -> "TensorSpec": + raise NotImplementedError def __repr__(self): shape_str = "shape=" + str(self.shape) @@ -481,6 +488,30 @@ def __init__( ) super().__init__(shape, space, device, dtype, "discrete") + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + if isinstance(dest, torch.dtype): + dest_dtype = dest + dest_device = self.device + else: + dest_dtype = self.dtype + dest_device = torch.device(dest) + return self.__class__( + n=self.space.n, + shape=self.shape, + device=dest_device, + dtype=dest_dtype, + use_register=self.use_register, + ) + + def clone(self) -> CompositeSpec: + return self.__class__( + n=self.space.n, + shape=self.shape, + device=self.device, + dtype=self.dtype, + use_register=self.use_register, + ) + def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] @@ -635,22 +666,22 @@ def __init__( if shape is not None and shape != maximum.shape: raise RuntimeError(err_msg) shape = maximum.shape - minimum = minimum.expand(*shape).contiguous() + minimum = minimum.expand(*shape).clone() elif minimum.ndimension(): if shape is not None and shape != minimum.shape: raise RuntimeError(err_msg) shape = minimum.shape - maximum = maximum.expand(*shape).contiguous() + maximum = maximum.expand(*shape).clone() elif shape is None: raise RuntimeError(err_msg) else: - minimum = minimum.expand(*shape).contiguous() - maximum = maximum.expand(*shape).contiguous() + minimum = minimum.expand(*shape).clone() + maximum = maximum.expand(*shape).clone() if minimum.numel() > maximum.numel(): - maximum = maximum.expand_as(minimum).contiguous() + maximum = maximum.expand_as(minimum).clone() elif maximum.numel() > minimum.numel(): - minimum = minimum.expand_as(maximum).contiguous() + minimum = minimum.expand_as(maximum).clone() if shape is None: shape = minimum.shape else: @@ -684,8 +715,8 @@ def expand(self, *shape): f"shape of the CompositeSpec in CompositeSpec.extend." ) return self.__class__( - minimum=self.space.minimum.expand(shape).contiguous(), - maximum=self.space.maximum.expand(shape).contiguous(), + minimum=self.space.minimum.expand(shape).clone(), + maximum=self.space.maximum.expand(shape).clone(), shape=shape, device=self.device, dtype=self.dtype, @@ -751,6 +782,30 @@ def is_in(self, val: torch.Tensor) -> bool: return False raise err + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + if isinstance(dest, torch.dtype): + dest_dtype = dest + dest_device = self.device + else: + dest_dtype = self.dtype + dest_device = torch.device(dest) + return self.__class__( + minimum=self.space.minimum.to(dest), + maximum=self.space.maximum.to(dest), + shape=self.shape, + device=dest_device, + dtype=dest_dtype, + ) + + def clone(self) -> CompositeSpec: + return self.__class__( + minimum=self.space.minimum.clone(), + maximum=self.space.maximum.clone(), + shape=self.shape, + device=self.device, + dtype=self.dtype, + ) + @dataclass(repr=False) class UnboundedContinuousTensorSpec(TensorSpec): @@ -785,6 +840,18 @@ def __init__( domain="continuous", ) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + if isinstance(dest, torch.dtype): + dest_dtype = dest + dest_device = self.device + else: + dest_dtype = self.dtype + dest_device = torch.device(dest) + return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype) + + def clone(self) -> CompositeSpec: + return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) + def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) @@ -848,6 +915,18 @@ def __init__( domain="continuous", ) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + if isinstance(dest, torch.dtype): + dest_dtype = dest + dest_device = self.device + else: + dest_dtype = self.dtype + dest_device = torch.device(dest) + return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype) + + def clone(self) -> CompositeSpec: + return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) + def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) @@ -954,6 +1033,22 @@ def expand(self, *shape): n=shape[-1], shape=shape, device=self.device, dtype=self.dtype ) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + if isinstance(dest, torch.dtype): + dest_dtype = dest + dest_device = self.device + else: + dest_dtype = self.dtype + dest_device = torch.device(dest) + return self.__class__( + n=self.space.n, shape=self.shape, device=dest_device, dtype=dest_dtype + ) + + def clone(self) -> CompositeSpec: + return self.__class__( + n=self.space.n, shape=self.shape, device=self.device, dtype=self.dtype + ) + @dataclass(repr=False) class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): @@ -989,6 +1084,7 @@ def __init__( dtype=torch.long, use_register=False, ): + self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) if shape is None: shape = torch.Size((sum(nvec),)) @@ -1005,6 +1101,28 @@ def __init__( shape, space, device, dtype, domain="discrete" ) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + if isinstance(dest, torch.dtype): + dest_dtype = dest + dest_device = self.device + else: + dest_dtype = self.dtype + dest_device = torch.device(dest) + return self.__class__( + nvec=deepcopy(self.nvec), + shape=self.shape, + device=dest_device, + dtype=dest_dtype, + ) + + def clone(self) -> CompositeSpec: + return self.__class__( + nvec=deepcopy(self.nvec), + shape=self.shape, + device=self.device, + dtype=self.dtype, + ) + def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] @@ -1213,6 +1331,25 @@ def expand(self, *shape): n=self.space.n, shape=shape, device=self.device, dtype=self.dtype ) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + if isinstance(dest, torch.dtype): + dest_dtype = dest + dest_device = self.device + else: + dest_dtype = self.dtype + dest_device = torch.device(dest) + return self.__class__( + n=self.space.n, shape=self.shape, device=dest_device, dtype=dest_dtype + ) + + def clone(self) -> CompositeSpec: + return self.__class__( + n=self.space.n, + shape=self.shape, + device=self.device, + dtype=self.dtype, + ) + @dataclass(repr=False) class MultiDiscreteTensorSpec(DiscreteTensorSpec): @@ -1259,11 +1396,30 @@ def __init__( ) self.nvec = self.nvec.expand(shape) - space = BoxList.from_nvec(nvec) + space = BoxList.from_nvec(self.nvec) super(DiscreteTensorSpec, self).__init__( shape, space, device, dtype, domain="discrete" ) + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + if isinstance(dest, torch.dtype): + dest_dtype = dest + dest_device = self.device + else: + dest_dtype = self.dtype + dest_device = torch.device(dest) + return self.__class__( + n=self.nvec.to(dest), shape=None, device=dest_device, dtype=dest_dtype + ) + + def clone(self) -> CompositeSpec: + return self.__class__( + nvec=self.nvec.clone(), + shape=None, + device=self.device, + dtype=self.dtype, + ) + def _rand(self, space: Box, shape: torch.Size, i: int): x = [] for _s in space: @@ -1471,6 +1627,9 @@ def __init__(self, *args, shape=None, device=None, **kwargs): _device = device if len(kwargs): for key, item in self.items(): + if item is None: + continue + try: item_device = item.device except RuntimeError as err: @@ -1480,8 +1639,6 @@ def __init__(self, *args, shape=None, device=None, **kwargs): else: raise err - if item is None: - continue if _device is None: _device = item_device elif item_device != _device: @@ -1607,7 +1764,7 @@ def __repr__(self) -> str: indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items() ] sub_str = ",\n".join(sub_str) - return f"CompositeSpec(\n{sub_str})" + return f"CompositeSpec(\n{sub_str}, device={self._device})" def type_check( self, @@ -1682,13 +1839,22 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: "Only device casting is allowed with specs of type CompositeSpec." ) if self._device and self._device == torch.device(dest): - return self - self._device = torch.device(dest) - for key, value in list(self.items()): + return self.__class__(**self._specs, device=self._device) + + _device = torch.device(dest) + items = list(self.items()) + kwargs = {} + for key, value in items: if value is None: + kwargs[key] = value continue - self[key] = value.to(dest) - return self + kwargs[key] = value.to(dest) + return self.__class__(**kwargs, device=_device) + + def clone(self) -> CompositeSpec: + return self.__class__( + **{key: item.clone() for key, item in self.items()}, device=self._device + ) def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: return {key: self[key]._to_numpy(val) for key, val in val.items()} @@ -1708,7 +1874,7 @@ def zero(self, shape=None) -> TensorDictBase: def __eq__(self, other): return ( - type(self) == type(other) + type(self) is type(other) and self._device == other._device and self._specs == other._specs ) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 661618c5ef8..3f49828c6a5 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -56,12 +56,18 @@ def __init__( @staticmethod def build_metadata_from_env(env) -> EnvMetaData: - tensordict = env.fake_tensordict() - specs = {key: getattr(env, key) for key in Specs._keys if key.endswith("_spec")} + tensordict = env.fake_tensordict().clone() + specs = { + "input_spec": env.input_spec, + "observation_spec": env.observation_spec, + "reward_spec": env.reward_spec, + } specs = CompositeSpec(**specs, shape=env.batch_size) + batch_size = env.batch_size env_str = str(env) device = env.device + specs.to("cpu").clone().to(device).clone() batch_locked = env.batch_locked return EnvMetaData(tensordict, specs, batch_size, env_str, device, batch_locked) @@ -78,7 +84,7 @@ def expand(self, *size: int) -> EnvMetaData: ) def to(self, device: DEVICE_TYPING) -> EnvMetaData: - tensordict = self.tensordict.to(device) + tensordict = self.tensordict.contiguous().to(device) specs = self.specs.to(device) return EnvMetaData( tensordict, specs, self.batch_size, self.env_str, device, self.batch_locked @@ -86,13 +92,13 @@ def to(self, device: DEVICE_TYPING) -> EnvMetaData: def __setstate__(self, state): state["tensordict"] = state["tensordict"].to_tensordict().to(state["device"]) - state["specs"] = deepcopy(state["specs"]).to(state["device"]) + state["specs"] = state["specs"].clone().to(state["device"]) self.__dict__.update(state) def __getstate__(self): state = self.__dict__.copy() - state["tensordict"] = state["tensordict"].to("cpu") - state["specs"] = state["specs"].to("cpu") + state["tensordict"] = state["tensordict"].to_tensordict().to("cpu") + state["specs"] = state["specs"].clone().to("cpu") return state From 8b50297dcfebe22e15f655cf091c4f5946a65fa7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 20 Jan 2023 18:20:11 +0000 Subject: [PATCH 18/30] ammend --- torchrl/envs/common.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 3f49828c6a5..ec9377022c4 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -38,6 +38,11 @@ def _tensor_to_np(t): class EnvMetaData: """A class for environment meta-data storage and passing in multiprocessed settings.""" + def __new__(cls, *args, **kwargs): + self._spec = None + self._tensordict = None + return cls + def __init__( self, tensordict: TensorDictBase, @@ -54,6 +59,22 @@ def __init__( self.device = device self.batch_locked = batch_locked + @property + def tensordict(self): + return self._tensordict.to(self.device) + + @property + def spec(self): + return self._specs.to(self.device) + + @tensordict.setter + def tensordict(self, value: TensorDictBase): + self._tensordict = value.to("cpu") + + @spec.setter + def spec(self, value: CompositeSpec): + self._spec = value.to("cpu") + @staticmethod def build_metadata_from_env(env) -> EnvMetaData: tensordict = env.fake_tensordict().clone() @@ -90,17 +111,6 @@ def to(self, device: DEVICE_TYPING) -> EnvMetaData: tensordict, specs, self.batch_size, self.env_str, device, self.batch_locked ) - def __setstate__(self, state): - state["tensordict"] = state["tensordict"].to_tensordict().to(state["device"]) - state["specs"] = state["specs"].clone().to(state["device"]) - self.__dict__.update(state) - - def __getstate__(self): - state = self.__dict__.copy() - state["tensordict"] = state["tensordict"].to_tensordict().to("cpu") - state["specs"] = state["specs"].clone().to("cpu") - return state - class Specs: """Container for action, observation and reward specs. From 3c2840aa629dcee678afe3a6c68f32dfc801cd66 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 20 Jan 2023 18:21:00 +0000 Subject: [PATCH 19/30] ammend --- torchrl/envs/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index ec9377022c4..0aab545d145 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -39,8 +39,8 @@ class EnvMetaData: """A class for environment meta-data storage and passing in multiprocessed settings.""" def __new__(cls, *args, **kwargs): - self._spec = None - self._tensordict = None + cls._spec = None + cls._tensordict = None return cls def __init__( From d38805f09a87180e6c262ab1848dc0bd21c2bab7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 20 Jan 2023 18:22:06 +0000 Subject: [PATCH 20/30] ammend --- torchrl/envs/common.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 0aab545d145..6e331ce8d7c 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -38,11 +38,6 @@ def _tensor_to_np(t): class EnvMetaData: """A class for environment meta-data storage and passing in multiprocessed settings.""" - def __new__(cls, *args, **kwargs): - cls._spec = None - cls._tensordict = None - return cls - def __init__( self, tensordict: TensorDictBase, From e9cd58b9c0de5d8de60f0881fd321c481d12df06 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 20 Jan 2023 18:22:17 +0000 Subject: [PATCH 21/30] ammend --- torchrl/envs/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 6e331ce8d7c..8286e9f3852 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -47,11 +47,11 @@ def __init__( device: torch.device, batch_locked: bool = True, ): + self.device = device self.tensordict = tensordict self.specs = specs self.batch_size = batch_size self.env_str = env_str - self.device = device self.batch_locked = batch_locked @property From 6b33cc2138d31e0822933e64b33d3220682d5a45 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 20 Jan 2023 18:26:27 +0000 Subject: [PATCH 22/30] ammend --- torchrl/envs/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 8286e9f3852..037617860f7 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -83,7 +83,7 @@ def build_metadata_from_env(env) -> EnvMetaData: batch_size = env.batch_size env_str = str(env) device = env.device - specs.to("cpu").clone().to(device).clone() + specs = specs.to("cpu").clone().to(device).clone() batch_locked = env.batch_locked return EnvMetaData(tensordict, specs, batch_size, env_str, device, batch_locked) From 8805a05adce9506b17e484bc67f30187434235cf Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 20 Jan 2023 18:26:51 +0000 Subject: [PATCH 23/30] ammend --- torchrl/envs/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 037617860f7..7925f244c7b 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -78,12 +78,12 @@ def build_metadata_from_env(env) -> EnvMetaData: "observation_spec": env.observation_spec, "reward_spec": env.reward_spec, } - specs = CompositeSpec(**specs, shape=env.batch_size) + specs = CompositeSpec(**specs, shape=env.batch_size).to("cpu") batch_size = env.batch_size env_str = str(env) device = env.device - specs = specs.to("cpu").clone().to(device).clone() + specs = specs.to("cpu") batch_locked = env.batch_locked return EnvMetaData(tensordict, specs, batch_size, env_str, device, batch_locked) From c85fc43d713241296894a1ae67a4045635faaf3c Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 20 Jan 2023 21:28:27 +0000 Subject: [PATCH 24/30] amend --- test/test_env.py | 1 - torchrl/data/tensor_specs.py | 10 ++++++---- torchrl/envs/common.py | 5 ++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index b2b28956e11..5d3d46f81da 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -911,7 +911,6 @@ def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): env.set_seed(1) action = env.action_spec.rand() action[:] = 1 - for i in range(max_steps): td = env.step( TensorDict( diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 71c8ff7515d..069dfc5f3e8 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1764,7 +1764,7 @@ def __repr__(self) -> str: indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items() ] sub_str = ",\n".join(sub_str) - return f"CompositeSpec(\n{sub_str}, device={self._device})" + return f"CompositeSpec(\n{sub_str}, device={self._device}, shape={self.shape})" def type_check( self, @@ -1839,7 +1839,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: "Only device casting is allowed with specs of type CompositeSpec." ) if self._device and self._device == torch.device(dest): - return self.__class__(**self._specs, device=self._device) + return self.__class__(**self._specs, device=self._device, shape=self.shape) _device = torch.device(dest) items = list(self.items()) @@ -1849,11 +1849,13 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: kwargs[key] = value continue kwargs[key] = value.to(dest) - return self.__class__(**kwargs, device=_device) + return self.__class__(**kwargs, device=_device, shape=self.shape) def clone(self) -> CompositeSpec: return self.__class__( - **{key: item.clone() for key, item in self.items()}, device=self._device + **{key: item.clone() for key, item in self.items()}, + device=self._device, + shape=self.shape, ) def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 7925f244c7b..3d49a644d63 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -6,7 +6,6 @@ from __future__ import annotations import abc -from copy import deepcopy from numbers import Number from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Union @@ -89,10 +88,10 @@ def build_metadata_from_env(env) -> EnvMetaData: def expand(self, *size: int) -> EnvMetaData: tensordict = self.tensordict.expand(*size).to_tensordict() - batch_size = torch.Size([*size]) + batch_size = torch.Size(list(size)) return EnvMetaData( tensordict, - self.specs.expand(size), + self.specs.expand(*size), batch_size, self.env_str, self.device, From 0c8a01dba15a8b0689e608918086a580744ea3d9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 22 Jan 2023 09:56:49 +0000 Subject: [PATCH 25/30] amend --- torchrl/envs/common.py | 11 ++++++++ torchrl/envs/model_based/common.py | 7 ++--- torchrl/envs/transforms/transforms.py | 8 +++--- torchrl/envs/vec_env.py | 40 +-------------------------- 4 files changed, 19 insertions(+), 47 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 3d49a644d63..04e3a44142b 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +from copy import deepcopy from numbers import Number from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Union @@ -98,6 +99,16 @@ def expand(self, *size: int) -> EnvMetaData: self.batch_locked, ) + def clone(self): + return EnvMetaData( + self.tensordict.clone(), + self.specs.clone(), + torch.Size([*self.batch_size]), + deepcopy(self.env_str), + self.device, + self.batch_locked, + ) + def to(self, device: DEVICE_TYPING) -> EnvMetaData: tensordict = self.tensordict.contiguous().to(device) specs = self.specs.to(device) diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index ab0ddeaa944..328569cc65b 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import abc -from copy import deepcopy from typing import List, Optional, Union import numpy as np @@ -139,9 +138,9 @@ def __new__(cls, *args, **kwargs): def set_specs_from_env(self, env: EnvBase): """Sets the specs of the environment from the specs of the given environment.""" - self.observation_spec = deepcopy(env.observation_spec).to(self.device) - self.reward_spec = deepcopy(env.reward_spec).to(self.device) - self.input_spec = deepcopy(env.input_spec).to(self.device) + self.observation_spec = env.observation_spec.clone().to(self.device) + self.reward_spec = env.reward_spec.clone().to(self.device) + self.input_spec = env.input_spec.clone().to(self.device) def _step( self, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 646a8a11504..ec8f95187b2 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -7,7 +7,7 @@ import collections import multiprocessing as mp -from copy import copy, deepcopy +from copy import copy from textwrap import indent from typing import Any, List, Optional, OrderedDict, Sequence, Tuple, Union @@ -405,7 +405,7 @@ def observation_spec(self) -> TensorSpec: """Observation spec of the transformed environment.""" if self._observation_spec is None or not self.cache_specs: observation_spec = self.transform.transform_observation_spec( - deepcopy(self.base_env.observation_spec) + self.base_env.observation_spec.clone() ) if self.cache_specs: self.__dict__["_observation_spec"] = observation_spec @@ -423,7 +423,7 @@ def input_spec(self) -> TensorSpec: """Action spec of the transformed environment.""" if self._input_spec is None or not self.cache_specs: input_spec = self.transform.transform_input_spec( - deepcopy(self.base_env.input_spec) + self.base_env.input_spec.clone() ) if self.cache_specs: self.__dict__["_input_spec"] = input_spec @@ -436,7 +436,7 @@ def reward_spec(self) -> TensorSpec: """Reward spec of the transformed environment.""" if self._reward_spec is None or not self.cache_specs: reward_spec = self.transform.transform_reward_spec( - deepcopy(self.base_env.reward_spec) + self.base_env.reward_spec.clone() ) if self.cache_specs: self.__dict__["_reward_spec"] = reward_spec diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 5140900c9ab..9c2db338775 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -227,44 +227,6 @@ def _get_metadata( ) self._set_properties() - # def _prepare_dummy_env( - # self, create_env_fn: List[Callable], create_env_kwargs: List[Dict] - # ): - # self._dummy_env_instance = None - # if self._single_task: - # # if EnvCreator, the metadata are already there - # if isinstance(create_env_fn[0], EnvCreator): - # self._dummy_env_fun = create_env_fn[0] - # self._dummy_env_fun.create_env_kwargs.update(create_env_kwargs[0]) - # # get the metadata - # - # try: - # self._dummy_env_fun = CloudpickleWrapper( - # create_env_fn[0], **create_env_kwargs[0] - # ) - # except RuntimeError as err: - # if isinstance(create_env_fn[0], EnvCreator): - # self._dummy_env_fun = create_env_fn[0] - # self._dummy_env_fun.create_env_kwargs.update(create_env_kwargs[0]) - # else: - # raise err - # else: - # n_tasks = len(create_env_fn) - # self._dummy_env_fun = [] - # for i in range(n_tasks): - # try: - # self._dummy_env_fun.append( - # CloudpickleWrapper(create_env_fn[i], **create_env_kwargs[i]) - # ) - # except RuntimeError as err: - # if isinstance(create_env_fn[i], EnvCreator): - # self._dummy_env_fun.append(create_env_fn[i]) - # self._dummy_env_fun[i].create_env_kwargs.update( - # create_env_kwargs[i] - # ) - # else: - # raise err - def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None: """Updates the kwargs of each environment given a dictionary or a list of dictionaries. @@ -280,7 +242,7 @@ def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None: _kwargs.update(_new_kwargs) def _set_properties(self): - meta_data = deepcopy(self.meta_data) + meta_data = self.meta_data.clone() if self._single_task: self._batch_size = meta_data.batch_size observation_spec = meta_data.specs["observation_spec"] From 9c61a375546a2b89e1cdc1ca1e9b4f192b81b704 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 22 Jan 2023 10:13:03 +0000 Subject: [PATCH 26/30] amend --- torchrl/envs/common.py | 6 +++--- torchrl/envs/env_creator.py | 2 +- torchrl/envs/vec_env.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 04e3a44142b..62bdd7a2ac7 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -59,7 +59,7 @@ def tensordict(self): return self._tensordict.to(self.device) @property - def spec(self): + def specs(self): return self._specs.to(self.device) @tensordict.setter @@ -67,8 +67,8 @@ def tensordict(self, value: TensorDictBase): self._tensordict = value.to("cpu") @spec.setter - def spec(self, value: CompositeSpec): - self._spec = value.to("cpu") + def specs(self, value: CompositeSpec): + self._specs = value.to("cpu") @staticmethod def build_metadata_from_env(env) -> EnvMetaData: diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index c9121c76dd1..a2cea9ba872 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -188,7 +188,7 @@ def get_env_metadata( f"got EnvCreator.create_env_kwargs={env_or_creator.create_env_kwargs} and " f"kwargs = {kwargs}" ) - return env_or_creator.meta_data + return env_or_creator.meta_data.clone() else: raise NotImplementedError( f"env of type {type(env_or_creator)} is not supported by get_env_metadata." diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 9c2db338775..6f59160db50 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -215,6 +215,7 @@ def _get_metadata( if self._single_task: # if EnvCreator, the metadata are already there meta_data = get_env_metadata(create_env_fn[0], create_env_kwargs[0]) + print("device of ", type(self), ":", meta_data.device) self.meta_data = meta_data.expand( *(self.num_workers, *meta_data.batch_size) ) @@ -242,7 +243,7 @@ def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None: _kwargs.update(_new_kwargs) def _set_properties(self): - meta_data = self.meta_data.clone() + meta_data = self.meta_data if self._single_task: self._batch_size = meta_data.batch_size observation_spec = meta_data.specs["observation_spec"] From ea53803a84f1ed203e0d8b038bb65b3e7e6c3127 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 22 Jan 2023 10:13:24 +0000 Subject: [PATCH 27/30] remove print --- torchrl/envs/vec_env.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 6f59160db50..ada35610294 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -215,7 +215,6 @@ def _get_metadata( if self._single_task: # if EnvCreator, the metadata are already there meta_data = get_env_metadata(create_env_fn[0], create_env_kwargs[0]) - print("device of ", type(self), ":", meta_data.device) self.meta_data = meta_data.expand( *(self.num_workers, *meta_data.batch_size) ) From 36a5b88c5a52dbe183afb5af437723cf4dd8fef1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 22 Jan 2023 10:14:12 +0000 Subject: [PATCH 28/30] bf --- torchrl/envs/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 62bdd7a2ac7..c8e26a7ce3e 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -66,7 +66,7 @@ def specs(self): def tensordict(self, value: TensorDictBase): self._tensordict = value.to("cpu") - @spec.setter + @specs.setter def specs(self, value: CompositeSpec): self._specs = value.to("cpu") From 9797dab1994ce45ebcf75c04063bb90fff52917f Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 22 Jan 2023 11:28:44 +0000 Subject: [PATCH 29/30] bf --- torchrl/data/tensor_specs.py | 8 ++++++-- torchrl/modules/distributions/continuous.py | 5 ++++- torchrl/modules/distributions/truncated_normal.py | 1 - 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 069dfc5f3e8..611d75346a2 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -900,8 +900,12 @@ def __init__( min_value = False max_value = True else: - min_value = torch.iinfo(dtype).min - max_value = torch.iinfo(dtype).max + if dtype.is_floating_point: + min_value = torch.finfo(dtype).min + max_value = torch.finfo(dtype).max + else: + min_value = torch.iinfo(dtype).min + max_value = torch.iinfo(dtype).max space = ContinuousBox( torch.full(shape, min_value, device=device), torch.full(shape, max_value, device=device), diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index f0644174be0..65100e5bb8c 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -94,7 +94,10 @@ def _call(self, x: torch.Tensor) -> torch.Tensor: return y def _inverse(self, y: torch.Tensor) -> torch.Tensor: - eps = torch.finfo(y.dtype).eps + if y.dtype.is_floating_point: + eps = torch.finfo(y.dtype).eps + else: + raise NotImplementedError("No inverse tanh for integer inputs.") y = y.clamp(-1 + eps, 1 - eps) x = super()._inverse(y) return x diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py index f733dcac5f7..1dfde393709 100644 --- a/torchrl/modules/distributions/truncated_normal.py +++ b/torchrl/modules/distributions/truncated_normal.py @@ -52,7 +52,6 @@ def __init__(self, a, b, validate_args=None): .tolist() ): raise ValueError("Incorrect truncation range") - # eps = torch.finfo(self.a.dtype).eps * 10 eps = self.eps self._dtype_min_gt_0 = eps self._dtype_max_lt_1 = 1 - eps From 470de2c19584d523bb945626bae1b0e5da89f7e2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 22 Jan 2023 11:41:26 +0000 Subject: [PATCH 30/30] init --- torchrl/envs/transforms/transforms.py | 49 +++++++++++++-------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 1322f627ea5..99741a8add8 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1547,6 +1547,15 @@ def __init__( if cat_dim > 0: raise ValueError(self._CAT_DIM_ERR) self.cat_dim = cat_dim + for in_key in self.in_keys: + buffer_name = f"_cat_buffers_{in_key}" + setattr( + self, + buffer_name, + torch.nn.parameter.UninitializedBuffer( + device=torch.device("cpu"), dtype=torch.get_default_dtype() + ), + ) def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Resets _buffers.""" @@ -1554,12 +1563,10 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: if len(tensordict.batch_size) < 1 or tensordict.batch_size[0] == 1: for in_key in self.in_keys: buffer_name = f"_cat_buffers_{in_key}" - try: - buffer = getattr(self, buffer_name) - buffer.fill_(0.0) - except AttributeError: - # we'll instantiate later, when needed - pass + buffer = getattr(self, buffer_name) + if isinstance(buffer, torch.nn.parameter.UninitializedBuffer): + continue + buffer.fill_(0.0) # Batched environments else: @@ -1573,12 +1580,10 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: ) for in_key in self.in_keys: buffer_name = f"_cat_buffers_{in_key}" - try: - buffer = getattr(self, buffer_name) - buffer[_reset] = 0.0 - except AttributeError: - # we'll instantiate later, when needed - pass + buffer = getattr(self, buffer_name) + if isinstance(buffer, torch.nn.parameter.UninitializedBuffer): + continue + buffer[_reset] = 0.0 return tensordict @@ -1587,15 +1592,9 @@ def _make_missing_buffer(self, data, buffer_name): d = shape[self.cat_dim] shape[self.cat_dim] = d * self.N shape = torch.Size(shape) - self.register_buffer( - buffer_name, - torch.zeros( - shape, - dtype=data.dtype, - device=data.device, - ), - ) - buffer = getattr(self, buffer_name) + getattr(self, buffer_name).materialize(shape) + buffer = getattr(self, buffer_name).to(data.dtype).to(data.device).zero_() + setattr(self, buffer_name, buffer) return buffer def _call(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -1605,12 +1604,12 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: buffer_name = f"_cat_buffers_{in_key}" data = tensordict[in_key] d = data.size(self.cat_dim) - try: - buffer = getattr(self, buffer_name) + buffer = getattr(self, buffer_name) + if isinstance(buffer, torch.nn.parameter.UninitializedBuffer): + buffer = self._make_missing_buffer(data, buffer_name) + else: # shift obs 1 position to the right buffer.copy_(torch.roll(buffer, shifts=-d, dims=self.cat_dim)) - except AttributeError: - buffer = self._make_missing_buffer(data, buffer_name) # add new obs idx = self.cat_dim if idx < 0: