diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index c00167b5b29..49ebdc57f1c 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -2632,6 +2632,7 @@ def encode( ) -> torch.Tensor | TensorDictBase: return val + class _UnboundedMeta(abc.ABCMeta): def __call__(cls, *args, **kwargs): instance = super().__call__(*args, **kwargs) @@ -4930,7 +4931,7 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase: _dict[key] = item.rand(shape) # No need to run checks since we know Composite is compliant with # TensorDict requirements - return TensorDict._new_unsafe( + return TensorDict( _dict, batch_size=_size([*shape, *self.shape]), device=self._device, diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index aa1883958ca..bb5a4ddea43 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -342,7 +342,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: for key, val in TensorDict(obs_dict, []).items(True, True) ) else: - tensordict_out = TensorDict._new_unsafe( + tensordict_out = TensorDict( obs_dict, batch_size=tensordict.batch_size, ) @@ -376,7 +376,8 @@ def _reset( source = self.read_obs(obs) - tensordict_out = TensorDict._new_unsafe( + # _new_unsafe cannot be used because it won't wrap non-tensor correctly + tensordict_out = TensorDict( source=source, batch_size=self.batch_size, )