From f7177cc8cfd3e19bb2bbe4d029c7da5ffa828591 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 25 Jun 2022 22:33:21 +0100 Subject: [PATCH] BugFix: skip split_trajectories condition (#226) --- torchrl/collectors/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index d2b086f5018..560bfedf17c 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -36,7 +36,7 @@ def split_trajectories(rollout_tensordict: _TensorDict) -> _TensorDict: splits = traj_ids.view(-1) splits = [(splits == i).sum().item() for i in splits.unique_consecutive()] # if all splits are identical then we can skip this function - if len(set(splits)) == 1: + if len(set(splits)) == 1 and splits[0] == traj_ids.shape[1]: rollout_tensordict.set( "mask", torch.ones(