-
Notifications
You must be signed in to change notification settings - Fork 335
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/main' into refactor-losses-funct…
…ional
- Loading branch information
Showing
21 changed files
with
2,141 additions
and
188 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
## Reproducing Importance Weighted Actor-Learner Architecture (IMPALA) Algorithm Results | ||
|
||
This repository contains scripts that enable training agents using the IMPALA Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Espeholt et al. 2018. | ||
|
||
## Examples Structure | ||
|
||
Please note that we provide 2 examples, one for single node training and one for distributed training. Both examples rely on the same utils file, but besides that are independent. Each example contains the following files: | ||
|
||
1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. impala_single_node_ray.py). | ||
|
||
2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils.py). | ||
|
||
3. **Configuration File:** This file includes default hyperparameters specified in the original paper. For the multi-node case, the file also includes the configuration file of the Ray cluster. Users can modify these hyperparameters to customize their experiments (e.g. config_single_node.yaml). | ||
|
||
|
||
## Running the Examples | ||
|
||
You can execute the single node IMPALA algorithm on Atari environments by running the following command: | ||
|
||
```bash | ||
python impala_single_node.py | ||
``` | ||
|
||
You can execute the multi-node IMPALA algorithm on Atari environments by running the following command: | ||
|
||
```bash | ||
python impala_single_node_ray.py | ||
``` | ||
or | ||
|
||
```bash | ||
python impala_single_node_submitit.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Environment | ||
env: | ||
env_name: PongNoFrameskip-v4 | ||
|
||
# Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html | ||
ray_init_config: | ||
address: null | ||
num_cpus: null | ||
num_gpus: null | ||
resources: null | ||
object_store_memory: null | ||
local_mode: False | ||
ignore_reinit_error: False | ||
include_dashboard: null | ||
dashboard_host: 127.0.0.1 | ||
dashboard_port: null | ||
job_config: null | ||
configure_logging: True | ||
logging_level: info | ||
logging_format: null | ||
log_to_driver: True | ||
namespace: null | ||
runtime_env: null | ||
storage: null | ||
|
||
# Device for the forward and backward passes | ||
local_device: "cuda:0" | ||
|
||
# Resources assigned to each IMPALA rollout collection worker | ||
remote_worker_resources: | ||
num_cpus: 1 | ||
num_gpus: 0.25 | ||
memory: 1073741824 # 1*1024**3 - 1GB | ||
|
||
# collector | ||
collector: | ||
frames_per_batch: 80 | ||
total_frames: 200_000_000 | ||
num_workers: 12 | ||
|
||
# logger | ||
logger: | ||
backend: wandb | ||
exp_name: Atari_IMPALA | ||
test_interval: 200_000_000 | ||
num_test_episodes: 3 | ||
|
||
# Optim | ||
optim: | ||
lr: 0.0006 | ||
eps: 1e-8 | ||
weight_decay: 0.0 | ||
momentum: 0.0 | ||
alpha: 0.99 | ||
max_grad_norm: 40.0 | ||
anneal_lr: True | ||
|
||
# loss | ||
loss: | ||
gamma: 0.99 | ||
batch_size: 32 | ||
sgd_updates: 1 | ||
critic_coef: 0.5 | ||
entropy_coef: 0.01 | ||
loss_critic_type: l2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Environment | ||
env: | ||
env_name: PongNoFrameskip-v4 | ||
|
||
# Device for the forward and backward passes | ||
local_device: "cuda:0" | ||
|
||
# SLURM config | ||
slurm_config: | ||
timeout_min: 10 | ||
slurm_partition: train | ||
slurm_cpus_per_task: 1 | ||
slurm_gpus_per_node: 1 | ||
|
||
# collector | ||
collector: | ||
backend: gloo | ||
frames_per_batch: 80 | ||
total_frames: 200_000_000 | ||
num_workers: 1 | ||
|
||
# logger | ||
logger: | ||
backend: wandb | ||
exp_name: Atari_IMPALA | ||
test_interval: 200_000_000 | ||
num_test_episodes: 3 | ||
|
||
# Optim | ||
optim: | ||
lr: 0.0006 | ||
eps: 1e-8 | ||
weight_decay: 0.0 | ||
momentum: 0.0 | ||
alpha: 0.99 | ||
max_grad_norm: 40.0 | ||
anneal_lr: True | ||
|
||
# loss | ||
loss: | ||
gamma: 0.99 | ||
batch_size: 32 | ||
sgd_updates: 1 | ||
critic_coef: 0.5 | ||
entropy_coef: 0.01 | ||
loss_critic_type: l2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Environment | ||
env: | ||
env_name: PongNoFrameskip-v4 | ||
|
||
# Device for the forward and backward passes | ||
device: "cuda:0" | ||
|
||
# collector | ||
collector: | ||
frames_per_batch: 80 | ||
total_frames: 200_000_000 | ||
num_workers: 12 | ||
|
||
# logger | ||
logger: | ||
backend: wandb | ||
exp_name: Atari_IMPALA | ||
test_interval: 200_000_000 | ||
num_test_episodes: 3 | ||
|
||
# Optim | ||
optim: | ||
lr: 0.0006 | ||
eps: 1e-8 | ||
weight_decay: 0.0 | ||
momentum: 0.0 | ||
alpha: 0.99 | ||
max_grad_norm: 40.0 | ||
anneal_lr: True | ||
|
||
# loss | ||
loss: | ||
gamma: 0.99 | ||
batch_size: 32 | ||
sgd_updates: 1 | ||
critic_coef: 0.5 | ||
entropy_coef: 0.01 | ||
loss_critic_type: l2 |
Oops, something went wrong.