diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 54179f76941..f1592c11e70 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -292,7 +292,7 @@ def _set_properties(self): devices = set() for _meta_data in meta_data: device = _meta_data.device - devices.append(device) + devices.add(device) if self._device is None: if len(devices) > 1: raise ValueError( @@ -620,19 +620,26 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: _td.select(*self._selected_reset_keys_filt, strict=False) ) selected_output_keys = self._selected_reset_keys_filt + device = self.device if self._single_task: # select + clone creates 2 tds, but we can create one only out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in selected_output_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) - return out + _set_single_key( + self.shared_tensordict_parent, out, key, clone=True, device=device + ) else: - return self.shared_tensordict_parent.select( + out = self.shared_tensordict_parent.select( *selected_output_keys, strict=False, - ).clone() + ) + if out.device == device: + out = out.clone() + else: + out = out.to(self.device, non_blocking=True) + return out def _reset_proc_data(self, tensordict, tensordict_reset): # since we call `reset` directly, all the postproc has been completed @@ -640,31 +647,6 @@ def _reset_proc_data(self, tensordict, tensordict_reset): return _update_during_reset(tensordict_reset, tensordict, self.reset_keys) return tensordict_reset - # @_check_start - # def _step( - # self, - # tensordict: TensorDict, - # ) -> TensorDict: - # tensordict_in = tensordict.clone(False) - # next_td = self.shared_tensordict_parent.get("next") - # for i in range(self.num_workers): - # # shared_tensordicts are locked, and we need to select the keys since we update in-place. - # # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. - # out_td = self._envs[i]._step(tensordict_in[i]) - # next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) - # # We must pass a clone of the tensordict, as the values of this tensordict - # # will be modified in-place at further steps - # if self._single_task: - # out = TensorDict( - # {}, batch_size=self.shared_tensordict_parent.shape, device=self.device - # ) - # for key in self._selected_step_keys: - # _set_single_key(next_td, out, key, clone=True) - # else: - # # strict=False ensures that non-homogeneous keys are still there - # out = next_td.select(*self._selected_step_keys, strict=False).clone() - # return out - @_check_start def _step( self, @@ -684,15 +666,20 @@ def _step( next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps + device = self.device if self._single_task: out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True, device=device) else: # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False).clone() + out = next_td.select(*self._selected_step_keys, strict=False) + if out.device == device: + out = out.clone() + else: + out = out.to(self.device, non_blocking=True) return out def __getattr__(self, attr: str) -> Any: @@ -890,56 +877,18 @@ def step_and_maybe_reset( # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - tensordict.set("next", self.shared_tensordict_parent.get("next").clone()) - tensordict_ = self.shared_tensordict_parent.exclude( - "next", *self.reset_keys - ).clone() + next_td = self.shared_tensordict_parent.get("next") + tensordict_ = self.shared_tensordict_parent.exclude("next", *self.reset_keys) + device = self.device + if self.shared_tensordict_parent.device == device: + next_td = next_td.clone() + tensordict_ = tensordict_.clone() + else: + next_td = next_td.to(device, non_blocking=True) + tensordict_ = tensordict_.to(device, non_blocking=True) + tensordict.set("next", next_td) return tensordict, tensordict_ - # @_check_start - # def step_and_maybe_reset( - # self, tensordict: TensorDictBase - # ) -> Tuple[TensorDictBase, TensorDictBase]: - # if self._single_task and not self.has_lazy_inputs: - # # We must use the in_keys and nothing else for the following reasons: - # # - efficiency: copying all the keys will in practice mean doing a lot - # # of writing operations since the input tensordict may (and often will) - # # contain all the previous output data. - # # - value mismatch: if the batched env is placed within a transform - # # and this transform overrides an observation key (eg, CatFrames) - # # the shape, dtype or device may not necessarily match and writing - # # the value in-place will fail. - # for key in tensordict.keys(True, True): - # # we copy the input keys as well as the keys in the 'next' td, if any - # # as this mechanism can be used by a policy to set anticipatively the - # # keys of the next call (eg, with recurrent nets) - # if key in self._env_input_keys or ( - # isinstance(key, tuple) - # and key[0] == "next" - # and key in self.shared_tensordict_parent.keys(True, True) - # ): - # val = tensordict.get(key) - # self.shared_tensordict_parent.set_(key, val) - # else: - # self.shared_tensordict_parent.update_( - # tensordict.select(*self._env_input_keys, "next", strict=False) - # ) - # for i in range(self.num_workers): - # self.parent_channels[i].send(("step_and_maybe_reset", None)) - # - # for i in range(self.num_workers): - # event = self._events[i] - # event.wait() - # event.clear() - # - # # We must pass a clone of the tensordict, as the values of this tensordict - # # will be modified in-place at further steps - # tensordict.set("next", self.shared_tensordict_parent.get("next").clone()) - # tensordict_ = self.shared_tensordict_parent.exclude( - # "next", *self.reset_keys - # ).clone() - # return tensordict, tensordict_ - @_check_start def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self._single_task and not self.has_lazy_inputs: @@ -984,15 +933,20 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps next_td = self.shared_tensordict_parent.get("next") + device = self.device if self._single_task: out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True, device=device) else: # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False).clone() + out = next_td.select(*self._selected_step_keys, strict=False) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) return out @_check_start @@ -1055,19 +1009,26 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self._cuda_stream.wait_event(event) selected_output_keys = self._selected_reset_keys_filt + device = self.device if self._single_task: # select + clone creates 2 tds, but we can create one only out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in selected_output_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) - return out + _set_single_key( + self.shared_tensordict_parent, out, key, clone=True, device=device + ) else: - return self.shared_tensordict_parent.select( + out = self.shared_tensordict_parent.select( *selected_output_keys, strict=False, - ).clone() + ) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) + return out @_check_start def _shutdown_workers(self) -> None: diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 06eec73be97..9a2a71f24bd 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -237,7 +237,11 @@ def step_mdp( def _set_single_key( - source: TensorDictBase, dest: TensorDictBase, key: str | tuple, clone: bool = False + source: TensorDictBase, + dest: TensorDictBase, + key: str | tuple, + clone: bool = False, + device=None, ): # key should be already unraveled if isinstance(key, str): @@ -253,7 +257,9 @@ def _set_single_key( source = val dest = new_val else: - if clone: + if device is not None and val.device != device: + val = val.to(device, non_blocking=True) + elif clone: val = val.clone() dest._set_str(k, val, inplace=False, validated=True) # This is a temporary solution to understand if a key is heterogeneous @@ -262,7 +268,7 @@ def _set_single_key( if re.match(r"Found more than one unique shape in the tensors", str(err)): # this is a het key for s_td, d_td in zip(source.tensordicts, dest.tensordicts): - _set_single_key(s_td, d_td, k, clone) + _set_single_key(s_td, d_td, k, clone=clone, device=device) break else: raise err