diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 04ca3ddf07e..079c8b71e12 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -484,6 +484,10 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: # Adds the pixel observation spec by calling render on the parent env + switch = False + if not self.enabled: + switch = True + self.switch() parent = self.parent td_in = TensorDict({}, batch_size=parent.batch_size, device=parent.device) self._call(td_in) @@ -495,6 +499,8 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec device=obs.device, dtype=obs.dtype, shape=obs.shape ) observation_spec[self.out_keys[0]] = spec + if switch: + self.switch() return observation_spec def switch(self, mode: str | bool = None):