From 5ec14bb7279608dd8a34770ffc3e7a1dd33e269c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Dec 2024 10:26:44 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f7a25c1bd5c..5b6763f6910 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -730,19 +730,20 @@ def _create_td(self) -> None: ) ) env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys) - env_obs_keys = [ - key for key in env_obs_keys if key not in self._non_tensor_keys - ] - env_input_keys = [ - key for key in env_input_keys if key not in self._non_tensor_keys - ] - env_output_keys = [ - key for key in env_output_keys if key not in self._non_tensor_keys - ] self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys) self._env_input_keys = sorted(env_input_keys, key=_sort_keys) self._env_output_keys = sorted(env_output_keys, key=_sort_keys) + self._env_obs_keys = [ + key for key in self._env_obs_keys if key not in self._non_tensor_keys + ] + self._env_input_keys = [ + key for key in self._env_input_keys if key not in self._non_tensor_keys + ] + self._env_output_keys = [ + key for key in self._env_output_keys if key not in self._non_tensor_keys + ] + reset_keys = self.reset_keys self._selected_keys = ( set(self._env_output_keys)