Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 7, 2025
1 parent bcf0afb commit 901c999
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2632,6 +2632,7 @@ def encode(
) -> torch.Tensor | TensorDictBase:
return val


class _UnboundedMeta(abc.ABCMeta):
def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
Expand Down Expand Up @@ -4930,7 +4931,7 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase:
_dict[key] = item.rand(shape)
# No need to run checks since we know Composite is compliant with
# TensorDict requirements
return TensorDict._new_unsafe(
return TensorDict(
_dict,
batch_size=_size([*shape, *self.shape]),
device=self._device,
Expand Down
5 changes: 3 additions & 2 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
for key, val in TensorDict(obs_dict, []).items(True, True)
)
else:
tensordict_out = TensorDict._new_unsafe(
tensordict_out = TensorDict(
obs_dict,
batch_size=tensordict.batch_size,
)
Expand Down Expand Up @@ -376,7 +376,8 @@ def _reset(

source = self.read_obs(obs)

tensordict_out = TensorDict._new_unsafe(
# _new_unsafe cannot be used because it won't wrap non-tensor correctly
tensordict_out = TensorDict(
source=source,
batch_size=self.batch_size,
)
Expand Down

0 comments on commit 901c999

Please sign in to comment.