Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 23, 2024
1 parent 0889a5b commit 5f9828a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,14 +967,14 @@ def zero(self, shape=None) -> TensorDictBase:
dim = self.dim + len(shape)
else:
dim = self.dim
return torch.stack([spec.zero(shape) for spec in self._specs], dim)
return LazyStackedTensorDict.maybe_dense_stack([spec.zero(shape) for spec in self._specs], dim)

def rand(self, shape=None) -> TensorDictBase:
if shape is not None:
dim = self.dim + len(shape)
else:
dim = self.dim
return torch.stack([spec.rand(shape) for spec in self._specs], dim)
return LazyStackedTensorDict.maybe_dense_stack([spec.rand(shape) for spec in self._specs], dim)

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> T:
if dest is None:
Expand Down Expand Up @@ -4229,7 +4229,7 @@ def project(self, val: TensorDictBase) -> TensorDictBase:
vals.append(spec.project(subval))
else:
vals.append(subval)
res = torch.stack(vals, dim=self.dim)
res = LazyStackedTensorDict.maybe_dense_stack(vals, dim=self.dim)
if not isinstance(val, LazyStackedTensorDict):
res = res.to_tensordict()
return res
Expand Down

0 comments on commit 5f9828a

Please sign in to comment.