diff --git a/python/gym_jiminy/rllib/gym_jiminy/rllib/ppo.py b/python/gym_jiminy/rllib/gym_jiminy/rllib/ppo.py index 2e967ceb0..477086e87 100644 --- a/python/gym_jiminy/rllib/gym_jiminy/rllib/ppo.py +++ b/python/gym_jiminy/rllib/gym_jiminy/rllib/ppo.py @@ -155,6 +155,10 @@ def _compute_mirrored_value(value: torch.Tensor, """Compute mirrored value from observation space based on provided mirroring transformation. """ + # Make sure value and mirror_mat are on the same device. + # This is needed for multi-GPU training. + mirror_mat = mirror_mat.to(value.device) + def _update_flattened_slice(data: torch.Tensor, shape: Tuple[int, ...], mirror_mat: torch.Tensor) -> torch.Tensor: