Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Nov 22, 2023
1 parent 638c0d6 commit 2f8b545
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 3 deletions.
3 changes: 3 additions & 0 deletions examples/impala/config_multi_node_ray.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ ray_init_config:
runtime_env: null
storage: null

# Device for the forward and backward passes
device: "cuda:0"

# Resources assigned to each IMPALA rollout collection worker
remote_worker_resources:
num_cpus: 1
Expand Down
3 changes: 3 additions & 0 deletions examples/impala/config_multi_node_submitit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
env:
env_name: PongNoFrameskip-v4

# Device for the forward and backward passes
local_device: "cuda:0"

# SLURM config
slurm_config:
timeout_min: 10
Expand Down
3 changes: 3 additions & 0 deletions examples/impala/config_single_node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
env:
env_name: PongNoFrameskip-v4

# Device for the forward and backward passes
local_device: "cuda:0"

# collector
collector:
frames_per_batch: 80
Expand Down
2 changes: 1 addition & 1 deletion examples/impala/impala_multi_node_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import eval_model, make_env, make_ppo_models

device = "cpu" if not torch.cuda.device_count() else "cuda"
device = torch.device(cfg.local_device)

# Correct for frame_skip
frame_skip = 4
Expand Down
2 changes: 1 addition & 1 deletion examples/impala/impala_multi_node_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import eval_model, make_env, make_ppo_models

device = "cpu" if not torch.cuda.device_count() else "cuda"
device = torch.device(cfg.local_device)

# Correct for frame_skip
frame_skip = 4
Expand Down
2 changes: 1 addition & 1 deletion examples/impala/impala_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import eval_model, make_env, make_ppo_models

device = "cpu" if not torch.cuda.device_count() else "cuda"
device = torch.device(cfg.device)

# Correct for frame_skip
frame_skip = 4
Expand Down

0 comments on commit 2f8b545

Please sign in to comment.