Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 27, 2023
1 parent ea88bb4 commit 7c04f62
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 93 deletions.
141 changes: 51 additions & 90 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def _set_properties(self):
devices = set()
for _meta_data in meta_data:
device = _meta_data.device
devices.append(device)
devices.add(device)
if self._device is None:
if len(devices) > 1:
raise ValueError(
Expand Down Expand Up @@ -620,51 +620,33 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
_td.select(*self._selected_reset_keys_filt, strict=False)
)
selected_output_keys = self._selected_reset_keys_filt
device = self.device
if self._single_task:
# select + clone creates 2 tds, but we can create one only
out = TensorDict(
{}, batch_size=self.shared_tensordict_parent.shape, device=self.device
{}, batch_size=self.shared_tensordict_parent.shape, device=device
)
for key in selected_output_keys:
_set_single_key(self.shared_tensordict_parent, out, key, clone=True)
return out
_set_single_key(
self.shared_tensordict_parent, out, key, clone=True, device=device
)
else:
return self.shared_tensordict_parent.select(
out = self.shared_tensordict_parent.select(
*selected_output_keys,
strict=False,
).clone()
)
if out.device == device:
out = out.clone()
else:
out = out.to(self.device, non_blocking=True)
return out

def _reset_proc_data(self, tensordict, tensordict_reset):
# since we call `reset` directly, all the postproc has been completed
if tensordict is not None:
return _update_during_reset(tensordict_reset, tensordict, self.reset_keys)
return tensordict_reset

# @_check_start
# def _step(
# self,
# tensordict: TensorDict,
# ) -> TensorDict:
# tensordict_in = tensordict.clone(False)
# next_td = self.shared_tensordict_parent.get("next")
# for i in range(self.num_workers):
# # shared_tensordicts are locked, and we need to select the keys since we update in-place.
# # There may be unexpected keys, such as "_reset", that we should comfortably ignore here.
# out_td = self._envs[i]._step(tensordict_in[i])
# next_td[i].update_(out_td.select(*self._env_output_keys, strict=False))
# # We must pass a clone of the tensordict, as the values of this tensordict
# # will be modified in-place at further steps
# if self._single_task:
# out = TensorDict(
# {}, batch_size=self.shared_tensordict_parent.shape, device=self.device
# )
# for key in self._selected_step_keys:
# _set_single_key(next_td, out, key, clone=True)
# else:
# # strict=False ensures that non-homogeneous keys are still there
# out = next_td.select(*self._selected_step_keys, strict=False).clone()
# return out

@_check_start
def _step(
self,
Expand All @@ -684,15 +666,20 @@ def _step(
next_td[i].update_(out_td.select(*self._env_output_keys, strict=False))
# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
device = self.device
if self._single_task:
out = TensorDict(
{}, batch_size=self.shared_tensordict_parent.shape, device=self.device
{}, batch_size=self.shared_tensordict_parent.shape, device=device
)
for key in self._selected_step_keys:
_set_single_key(next_td, out, key, clone=True)
_set_single_key(next_td, out, key, clone=True, device=device)
else:
# strict=False ensures that non-homogeneous keys are still there
out = next_td.select(*self._selected_step_keys, strict=False).clone()
out = next_td.select(*self._selected_step_keys, strict=False)
if out.device == device:
out = out.clone()
else:
out = out.to(self.device, non_blocking=True)
return out

def __getattr__(self, attr: str) -> Any:
Expand Down Expand Up @@ -890,56 +877,18 @@ def step_and_maybe_reset(

# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
tensordict.set("next", self.shared_tensordict_parent.get("next").clone())
tensordict_ = self.shared_tensordict_parent.exclude(
"next", *self.reset_keys
).clone()
next_td = self.shared_tensordict_parent.get("next")
tensordict_ = self.shared_tensordict_parent.exclude("next", *self.reset_keys)
device = self.device
if self.shared_tensordict_parent.device == device:
next_td = next_td.clone()
tensordict_ = tensordict_.clone()
else:
next_td = next_td.to(device, non_blocking=True)
tensordict_ = tensordict_.to(device, non_blocking=True)
tensordict.set("next", next_td)
return tensordict, tensordict_

# @_check_start
# def step_and_maybe_reset(
# self, tensordict: TensorDictBase
# ) -> Tuple[TensorDictBase, TensorDictBase]:
# if self._single_task and not self.has_lazy_inputs:
# # We must use the in_keys and nothing else for the following reasons:
# # - efficiency: copying all the keys will in practice mean doing a lot
# # of writing operations since the input tensordict may (and often will)
# # contain all the previous output data.
# # - value mismatch: if the batched env is placed within a transform
# # and this transform overrides an observation key (eg, CatFrames)
# # the shape, dtype or device may not necessarily match and writing
# # the value in-place will fail.
# for key in tensordict.keys(True, True):
# # we copy the input keys as well as the keys in the 'next' td, if any
# # as this mechanism can be used by a policy to set anticipatively the
# # keys of the next call (eg, with recurrent nets)
# if key in self._env_input_keys or (
# isinstance(key, tuple)
# and key[0] == "next"
# and key in self.shared_tensordict_parent.keys(True, True)
# ):
# val = tensordict.get(key)
# self.shared_tensordict_parent.set_(key, val)
# else:
# self.shared_tensordict_parent.update_(
# tensordict.select(*self._env_input_keys, "next", strict=False)
# )
# for i in range(self.num_workers):
# self.parent_channels[i].send(("step_and_maybe_reset", None))
#
# for i in range(self.num_workers):
# event = self._events[i]
# event.wait()
# event.clear()
#
# # We must pass a clone of the tensordict, as the values of this tensordict
# # will be modified in-place at further steps
# tensordict.set("next", self.shared_tensordict_parent.get("next").clone())
# tensordict_ = self.shared_tensordict_parent.exclude(
# "next", *self.reset_keys
# ).clone()
# return tensordict, tensordict_

@_check_start
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
if self._single_task and not self.has_lazy_inputs:
Expand Down Expand Up @@ -984,15 +933,20 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
next_td = self.shared_tensordict_parent.get("next")
device = self.device
if self._single_task:
out = TensorDict(
{}, batch_size=self.shared_tensordict_parent.shape, device=self.device
{}, batch_size=self.shared_tensordict_parent.shape, device=device
)
for key in self._selected_step_keys:
_set_single_key(next_td, out, key, clone=True)
_set_single_key(next_td, out, key, clone=True, device=device)
else:
# strict=False ensures that non-homogeneous keys are still there
out = next_td.select(*self._selected_step_keys, strict=False).clone()
out = next_td.select(*self._selected_step_keys, strict=False)
if out.device == device:
out = out.clone()
else:
out = out.to(device, non_blocking=True)
return out

@_check_start
Expand Down Expand Up @@ -1055,19 +1009,26 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
self._cuda_stream.wait_event(event)

selected_output_keys = self._selected_reset_keys_filt
device = self.device
if self._single_task:
# select + clone creates 2 tds, but we can create one only
out = TensorDict(
{}, batch_size=self.shared_tensordict_parent.shape, device=self.device
{}, batch_size=self.shared_tensordict_parent.shape, device=device
)
for key in selected_output_keys:
_set_single_key(self.shared_tensordict_parent, out, key, clone=True)
return out
_set_single_key(
self.shared_tensordict_parent, out, key, clone=True, device=device
)
else:
return self.shared_tensordict_parent.select(
out = self.shared_tensordict_parent.select(
*selected_output_keys,
strict=False,
).clone()
)
if out.device == device:
out = out.clone()
else:
out = out.to(device, non_blocking=True)
return out

@_check_start
def _shutdown_workers(self) -> None:
Expand Down
12 changes: 9 additions & 3 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,11 @@ def step_mdp(


def _set_single_key(
source: TensorDictBase, dest: TensorDictBase, key: str | tuple, clone: bool = False
source: TensorDictBase,
dest: TensorDictBase,
key: str | tuple,
clone: bool = False,
device=None,
):
# key should be already unraveled
if isinstance(key, str):
Expand All @@ -253,7 +257,9 @@ def _set_single_key(
source = val
dest = new_val
else:
if clone:
if device is not None and val.device != device:
val = val.to(device, non_blocking=True)
elif clone:
val = val.clone()
dest._set_str(k, val, inplace=False, validated=True)
# This is a temporary solution to understand if a key is heterogeneous
Expand All @@ -262,7 +268,7 @@ def _set_single_key(
if re.match(r"Found more than one unique shape in the tensors", str(err)):
# this is a het key
for s_td, d_td in zip(source.tensordicts, dest.tensordicts):
_set_single_key(s_td, d_td, k, clone)
_set_single_key(s_td, d_td, k, clone=clone, device=device)
break
else:
raise err
Expand Down

0 comments on commit 7c04f62

Please sign in to comment.