Skip to content

Commit

Permalink
use set_default instead
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jan 3, 2023
1 parent ec9da99 commit 7602e09
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion torchrl/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,13 +1002,17 @@ def _run_worker_pipe_shared_mem(
raise RuntimeError("call 'init' before resetting")
# _td = tensordict.select("observation").to(env.device).clone()
_td = env._reset(**reset_kwargs)
_td.set_default(
"done",
torch.zeros(*_td.batch_size, 1, dtype=torch.bool, device=_td.device),
)
if reset_keys is None:
reset_keys = set(_td.keys())
if pin_memory:
_td.pin_memory()
tensordict.update_(_td)
child_pipe.send(("reset_obs", reset_keys))
if _td.get("done", torch.zeros([], dtype=torch.bool)).any():
if _td.get("done").any():
raise RuntimeError(f"{env.__class__.__name__} is done after reset")

elif cmd == "step":
Expand Down

0 comments on commit 7602e09

Please sign in to comment.