diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 153ba14d9a4..07cd1f9c6f8 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -182,7 +182,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create a single batch of trajectories stacked_data = torch.stack(accumulator, dim=0) - stacked_data = stacked_data.to(device) + stacked_data = stacked_data.to(device, non_blocking=True) # Compute advantage stacked_data = adv_module(stacked_data)