Skip to content

Commit

Permalink
BugFix: skip split_trajectories condition (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 25, 2022
1 parent 1e9941b commit f7177cc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchrl/collectors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def split_trajectories(rollout_tensordict: _TensorDict) -> _TensorDict:
splits = traj_ids.view(-1)
splits = [(splits == i).sum().item() for i in splits.unique_consecutive()]
# if all splits are identical then we can skip this function
if len(set(splits)) == 1:
if len(set(splits)) == 1 and splits[0] == traj_ids.shape[1]:
rollout_tensordict.set(
"mask",
torch.ones(
Expand Down

0 comments on commit f7177cc

Please sign in to comment.