Skip to content

Commit

Permalink
[Minor] Add env.shape attribute (#1938)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 20, 2024
1 parent c45ee1f commit 799f939
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,13 @@ def run_type_checks(self, run_type_checks: bool) -> None:

@property
def batch_size(self) -> torch.Size:
"""Number of envs batched in this environment instance organised in a `torch.Size()` object.
Environment may be similar or different but it is assumed that they have little if
not no interactions between them (e.g., multi-task or batched execution
in parallel).
"""
_batch_size = self.__dict__["_batch_size"]
if _batch_size is None:
_batch_size = self._batch_size = torch.Size([])
Expand All @@ -439,6 +446,11 @@ def batch_size(self, value: torch.Size) -> None:
self.input_spec.shape = value
self.input_spec.lock_()

@property
def shape(self):
"""Equivalent to :attr:`~.batch_size`."""
return self.batch_size

@property
def device(self) -> torch.device:
device = self.__dict__.get("_device", None)
Expand Down Expand Up @@ -2162,7 +2174,7 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
self.batch_locked or self.batch_size != ()
) and tensordict.batch_size != self.batch_size:
raise RuntimeError(
f"Expected a tensordict with shape==env.shape, "
f"Expected a tensordict with shape==env.batch_size, "
f"got {tensordict.batch_size} and {self.batch_size}"
)

Expand Down

0 comments on commit 799f939

Please sign in to comment.