Skip to content

Commit

Permalink
🔨 enable multi-GPU support using custom PPO
Browse files Browse the repository at this point in the history
  • Loading branch information
mwulfman committed Apr 27, 2024
1 parent daeb266 commit 3cdc6a6
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def _compute_mirrored_value(value: torch.Tensor,
"""Compute mirrored value from observation space based on provided
mirroring transformation.
"""
# Make sure value and mirror_mat are on the same device.
# This is needed for multi-GPU training.
mirror_mat = mirror_mat.to(value.device)

def _update_flattened_slice(data: torch.Tensor,
shape: Tuple[int, ...],
mirror_mat: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit 3cdc6a6

Please sign in to comment.