From 2f8b545926daf52f4ba33086be711e60e7dc006d Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 22 Nov 2023 16:09:21 +0100 Subject: [PATCH] fixes --- examples/impala/config_multi_node_ray.yaml | 3 +++ examples/impala/config_multi_node_submitit.yaml | 3 +++ examples/impala/config_single_node.yaml | 3 +++ examples/impala/impala_multi_node_ray.py | 2 +- examples/impala/impala_multi_node_submitit.py | 2 +- examples/impala/impala_single_node.py | 2 +- 6 files changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/impala/config_multi_node_ray.yaml b/examples/impala/config_multi_node_ray.yaml index 7117578ded1..925a655e9c2 100644 --- a/examples/impala/config_multi_node_ray.yaml +++ b/examples/impala/config_multi_node_ray.yaml @@ -23,6 +23,9 @@ ray_init_config: runtime_env: null storage: null +# Device for the forward and backward passes +device: "cuda:0" + # Resources assigned to each IMPALA rollout collection worker remote_worker_resources: num_cpus: 1 diff --git a/examples/impala/config_multi_node_submitit.yaml b/examples/impala/config_multi_node_submitit.yaml index 8ad08292c7e..f924e34fc27 100644 --- a/examples/impala/config_multi_node_submitit.yaml +++ b/examples/impala/config_multi_node_submitit.yaml @@ -2,6 +2,9 @@ env: env_name: PongNoFrameskip-v4 +# Device for the forward and backward passes +local_device: "cuda:0" + # SLURM config slurm_config: timeout_min: 10 diff --git a/examples/impala/config_single_node.yaml b/examples/impala/config_single_node.yaml index 86a11d6b40c..de6fc718552 100644 --- a/examples/impala/config_single_node.yaml +++ b/examples/impala/config_single_node.yaml @@ -2,6 +2,9 @@ env: env_name: PongNoFrameskip-v4 +# Device for the forward and backward passes +local_device: "cuda:0" + # collector collector: frames_per_batch: 80 diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index be7a2ea81ec..592bd839821 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -29,7 +29,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + device = torch.device(cfg.local_device) # Correct for frame_skip frame_skip = 4 diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index 118913699f9..8d80e200030 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -31,7 +31,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + device = torch.device(cfg.local_device) # Correct for frame_skip frame_skip = 4 diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 2cd1043f46f..8d587064f26 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -28,7 +28,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + device = torch.device(cfg.device) # Correct for frame_skip frame_skip = 4