diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml new file mode 100644 index 00000000000..98af413285f --- /dev/null +++ b/.github/pytorch-probot.yml @@ -0,0 +1,5 @@ +# List of workflows that will be re-run in case of failures +# https://github.com/pytorch/test-infra/blob/main/torchci/lib/bot/retryBot.ts +retryable_workflows: +- Build M1 +- Wheels diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index a69e79c69a3..92475420867 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -58,6 +58,12 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_offlin # ==================================================================================== # # ================================ Gymnasium ========================================= # +python .github/unittest/helpers/coverage_run_parallel.py examples/impala/impala_single_node.py \ + collector.total_frames=80 \ + collector.frames_per_batch=20 \ + collector.num_workers=1 \ + logger.backend= \ + logger.test_interval=10 python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco.py \ env.env_name=HalfCheetah-v4 \ collector.total_frames=40 \ diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 1a2384a1df1..01d880708f4 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -34,7 +34,8 @@ jobs: python -m pip install git+https://github.com/pytorch/tensordict python setup.py develop python -m pip install pytest pytest-benchmark - python -m pip install dm_control + python3 -m pip install "gym[accept-rom-license,atari]" + python3 -m pip install dm_control - name: Run benchmarks run: | cd benchmarks/ @@ -57,62 +58,65 @@ jobs: benchmark_gpu: name: GPU Pytest benchmark - runs-on: ubuntu-20.04 - strategy: - matrix: - include: - - os: linux.4xlarge.nvidia.gpu - python-version: 3.8 + runs-on: linux.g5.4xlarge.nvidia.gpu defaults: run: shell: bash -l {0} - container: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 + container: + image: nvidia/cuda:12.3.0-base-ubuntu22.04 + options: --gpus all steps: - - name: Install deps - run: | - export TZ=Europe/London - export DEBIAN_FRONTEND=noninteractive # tzdata bug - apt-get update -y - apt-get install software-properties-common -y - add-apt-repository ppa:git-core/candidate -y - apt-get update -y - apt-get upgrade -y - apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev - - name: Check ldd --version - run: ldd --version - - name: Checkout - uses: actions/checkout@v3 - - name: Update pip - run: | - apt-get install python3.8 python3-pip -y - pip3 install --upgrade pip - - name: Setup git - run: git config --global --add safe.directory /__w/rl/rl - - name: setup Path - run: | - echo /usr/local/bin >> $GITHUB_PATH - - name: Setup Environment - run: | - python3 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 - python3 -m pip install git+https://github.com/pytorch/tensordict - python3 setup.py develop - python3 -m pip install pytest pytest-benchmark - python3 -m pip install dm_control - - name: Run benchmarks - run: | - cd benchmarks/ - python3 -m pytest --benchmark-json output.json - - name: Store benchmark results - uses: benchmark-action/github-action-benchmark@v1 - if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} - with: - name: GPU Benchmark Results - tool: 'pytest' - output-file-path: benchmarks/output.json - fail-on-alert: true - alert-threshold: '200%' - alert-comment-cc-users: '@vmoens' - comment-on-alert: true - github-token: ${{ secrets.GITHUB_TOKEN }} - gh-pages-branch: gh-pages - auto-push: true + - name: Install deps + run: | + export TZ=Europe/London + export DEBIAN_FRONTEND=noninteractive # tzdata bug + apt-get update -y + apt-get install software-properties-common -y + add-apt-repository ppa:git-core/candidate -y + apt-get update -y + apt-get upgrade -y + apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev + - name: Check ldd --version + run: ldd --version + - name: Checkout + uses: actions/checkout@v3 + - name: Python Setup + uses: actions/setup-python@v4 + with: + python-version: 3.8 + - name: Setup git + run: git config --global --add safe.directory /__w/rl/rl + - name: setup Path + run: | + echo /usr/local/bin >> $GITHUB_PATH + - name: Setup Environment + run: | + python3 -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 + python3 -m pip install git+https://github.com/pytorch/tensordict + python3 setup.py develop + python3 -m pip install pytest pytest-benchmark + python3 -m pip install "gym[accept-rom-license,atari]" + python3 -m pip install dm_control + - name: check GPU presence + run: | + python -c """import torch + assert torch.cuda.device_count() + """ + - name: Run benchmarks + run: | + cd benchmarks/ + python3 -m pytest --benchmark-json output.json + - name: Store benchmark results + uses: benchmark-action/github-action-benchmark@v1 + if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} + with: + name: GPU Benchmark Results + tool: 'pytest' + output-file-path: benchmarks/output.json + fail-on-alert: true + alert-threshold: '200%' + alert-comment-cc-users: '@vmoens' + comment-on-alert: true + github-token: ${{ secrets.GITHUB_TOKEN }} + gh-pages-branch: gh-pages + auto-push: true diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index e44c683a6d6..0f0ad3e5723 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -33,7 +33,8 @@ jobs: python -m pip install git+https://github.com/pytorch/tensordict python setup.py develop python -m pip install pytest pytest-benchmark - python -m pip install dm_control + python3 -m pip install "gym[accept-rom-license,atari]" + python3 -m pip install dm_control - name: Setup benchmarks run: | echo "BASE_SHA=$(echo ${{ github.event.pull_request.base.sha }} | cut -c1-8)" >> $GITHUB_ENV @@ -63,75 +64,78 @@ jobs: benchmark_gpu: name: GPU Pytest benchmark - runs-on: ubuntu-20.04 - strategy: - matrix: - include: - - os: linux.4xlarge.nvidia.gpu - python-version: 3.8 + runs-on: linux.g5.4xlarge.nvidia.gpu defaults: run: shell: bash -l {0} - container: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 + container: + image: nvidia/cuda:12.3.0-base-ubuntu22.04 + options: --gpus all steps: - - name: Who triggered this? - run: | - echo "Action triggered by ${{ github.event.pull_request.html_url }}" - - name: Install deps - run: | - export TZ=Europe/London - export DEBIAN_FRONTEND=noninteractive # tzdata bug - apt-get update -y - apt-get install software-properties-common -y - add-apt-repository ppa:git-core/candidate -y - apt-get update -y - apt-get upgrade -y - apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev - - name: Check ldd --version - run: ldd --version - - name: Checkout - uses: actions/checkout@v3 - with: - fetch-depth: 50 # this is to make sure we obtain the target base commit - - name: Update pip - run: | - apt-get install python3.8 python3-pip -y - pip3 install --upgrade pip - - name: Setup git - run: git config --global --add safe.directory /__w/rl/rl - - name: setup Path - run: | - echo /usr/local/bin >> $GITHUB_PATH - - name: Setup Environment - run: | - python3 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 - python3 -m pip install git+https://github.com/pytorch/tensordict - python3 setup.py develop - python3 -m pip install pytest pytest-benchmark - python3 -m pip install dm_control - - name: Setup benchmarks - run: | - echo "BASE_SHA=$(echo ${{ github.event.pull_request.base.sha }} | cut -c1-8)" >> $GITHUB_ENV - echo "HEAD_SHA=$(echo ${{ github.event.pull_request.head.sha }} | cut -c1-8)" >> $GITHUB_ENV - echo "BASELINE_JSON=$(mktemp)" >> $GITHUB_ENV - echo "CONTENDER_JSON=$(mktemp)" >> $GITHUB_ENV - echo "PR_COMMENT=$(mktemp)" >> $GITHUB_ENV - - name: Run benchmarks - run: | - cd benchmarks/ - RUN_BENCHMARK="pytest --rank 0 --benchmark-json " - git checkout ${{ github.event.pull_request.base.sha }} - $RUN_BENCHMARK ${{ env.BASELINE_JSON }} - git checkout ${{ github.event.pull_request.head.sha }} - $RUN_BENCHMARK ${{ env.CONTENDER_JSON }} - - name: Publish results - uses: apbard/pytest-benchmark-commenter@v3 - with: - token: ${{ secrets.GITHUB_TOKEN }} - benchmark-file: ${{ env.CONTENDER_JSON }} - comparison-benchmark-file: ${{ env.BASELINE_JSON }} - benchmark-metrics: 'name,max,mean,ops' - comparison-benchmark-metric: 'ops' - comparison-higher-is-better: true - comparison-threshold: 5 - benchmark-title: 'Result of GPU Benchmark Tests' + - name: Who triggered this? + run: | + echo "Action triggered by ${{ github.event.pull_request.html_url }}" + - name: Install deps + run: | + export TZ=Europe/London + export DEBIAN_FRONTEND=noninteractive # tzdata bug + apt-get update -y + apt-get install software-properties-common -y + add-apt-repository ppa:git-core/candidate -y + apt-get update -y + apt-get upgrade -y + apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev + - name: Check ldd --version + run: ldd --version + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 50 # this is to make sure we obtain the target base commit + - name: Python Setup + uses: actions/setup-python@v4 + with: + python-version: 3.8 + - name: Setup git + run: git config --global --add safe.directory /__w/rl/rl + - name: setup Path + run: | + echo /usr/local/bin >> $GITHUB_PATH + - name: Setup Environment + run: | + python3 -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 + python3 -m pip install git+https://github.com/pytorch/tensordict + python3 setup.py develop + python3 -m pip install pytest pytest-benchmark + python3 -m pip install "gym[accept-rom-license,atari]" + python3 -m pip install dm_control + - name: check GPU presence + run: | + python -c """import torch + assert torch.cuda.device_count() + """ + - name: Setup benchmarks + run: | + echo "BASE_SHA=$(echo ${{ github.event.pull_request.base.sha }} | cut -c1-8)" >> $GITHUB_ENV + echo "HEAD_SHA=$(echo ${{ github.event.pull_request.head.sha }} | cut -c1-8)" >> $GITHUB_ENV + echo "BASELINE_JSON=$(mktemp)" >> $GITHUB_ENV + echo "CONTENDER_JSON=$(mktemp)" >> $GITHUB_ENV + echo "PR_COMMENT=$(mktemp)" >> $GITHUB_ENV + - name: Run benchmarks + run: | + cd benchmarks/ + RUN_BENCHMARK="pytest --rank 0 --benchmark-json " + git checkout ${{ github.event.pull_request.base.sha }} + $RUN_BENCHMARK ${{ env.BASELINE_JSON }} + git checkout ${{ github.event.pull_request.head.sha }} + $RUN_BENCHMARK ${{ env.CONTENDER_JSON }} + - name: Publish results + uses: apbard/pytest-benchmark-commenter@v3 + with: + token: ${{ secrets.GITHUB_TOKEN }} + benchmark-file: ${{ env.CONTENDER_JSON }} + comparison-benchmark-file: ${{ env.BASELINE_JSON }} + benchmark-metrics: 'name,max,mean,ops' + comparison-benchmark-metric: 'ops' + comparison-higher-is-better: true + comparison-threshold: 5 + benchmark-title: 'Result of GPU Benchmark Tests' diff --git a/README.md b/README.md index 05c2e9843c2..905e8d28a4c 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ pypi nightly version [![Downloads](https://static.pepy.tech/personalized-badge/torchrl?period=total&units=international_system&left_color=blue&right_color=orange&left_text=Downloads)](https://pepy.tech/project/torchrl) [![Downloads](https://static.pepy.tech/personalized-badge/torchrl-nightly?period=total&units=international_system&left_color=blue&right_color=orange&left_text=Downloads%20(nightly))](https://pepy.tech/project/torchrl-nightly) -[![Discord Shield](https://dcbadge.vercel.app/api/server/xSURYdvu)](https://discord.gg/xSURYdvu) +[![Discord Shield](https://dcbadge.vercel.app/api/server/cZs26Qq3Dd)](https://discord.gg/cZs26Qq3Dd) # TorchRL diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 71b7a481ce0..246c5ee15f0 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -76,12 +76,12 @@ def make(envname=envname, gym_backend=gym_backend): # regular parallel env for device in avail_devices: - def make(envname=envname, gym_backend=gym_backend, device=device): + def make(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") # env_make = EnvCreator(make) - penv = ParallelEnv(num_workers, EnvCreator(make)) + penv = ParallelEnv(num_workers, EnvCreator(make), device=device) with torch.inference_mode(): # warmup penv.rollout(2) @@ -103,13 +103,13 @@ def make(envname=envname, gym_backend=gym_backend, device=device): for device in avail_devices: - def make(envname=envname, gym_backend=gym_backend, device=device): + def make(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") env_make = EnvCreator(make) # penv = SerialEnv(num_workers, env_make) - penv = ParallelEnv(num_workers, env_make) + penv = ParallelEnv(num_workers, env_make, device=device) collector = SyncDataCollector( penv, RandomPolicy(penv.action_spec), @@ -164,14 +164,14 @@ def make_env( for device in avail_devices: # async collector # + torchrl parallel env - def make_env( - envname=envname, gym_backend=gym_backend, device=device - ): + def make_env(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") penv = ParallelEnv( - num_workers // num_collectors, EnvCreator(make_env) + num_workers // num_collectors, + EnvCreator(make_env), + device=device, ) collector = MultiaSyncDataCollector( [penv] * num_collectors, @@ -206,10 +206,9 @@ def make_env( envname=envname, num_workers=num_workers, gym_backend=gym_backend, - device=device, ): with set_gym_backend(gym_backend): - penv = GymEnv(envname, num_envs=num_workers, device=device) + penv = GymEnv(envname, num_envs=num_workers, device="cpu") return penv penv = EnvCreator( @@ -247,14 +246,14 @@ def make_env( for device in avail_devices: # sync collector # + torchrl parallel env - def make_env( - envname=envname, gym_backend=gym_backend, device=device - ): + def make_env(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") penv = ParallelEnv( - num_workers // num_collectors, EnvCreator(make_env) + num_workers // num_collectors, + EnvCreator(make_env), + device=device, ) collector = MultiSyncDataCollector( [penv] * num_collectors, @@ -289,10 +288,9 @@ def make_env( envname=envname, num_workers=num_workers, gym_backend=gym_backend, - device=device, ): with set_gym_backend(gym_backend): - penv = GymEnv(envname, num_envs=num_workers, device=device) + penv = GymEnv(envname, num_envs=num_workers, device="cpu") return penv penv = EnvCreator( diff --git a/benchmarks/test_collectors_benchmark.py b/benchmarks/test_collectors_benchmark.py index 9f6c4599587..1e9634f643f 100644 --- a/benchmarks/test_collectors_benchmark.py +++ b/benchmarks/test_collectors_benchmark.py @@ -13,7 +13,7 @@ MultiSyncDataCollector, RandomPolicy, ) -from torchrl.envs import EnvCreator, StepCounter, TransformedEnv +from torchrl.envs import EnvCreator, GymEnv, StepCounter, TransformedEnv from torchrl.envs.libs.dm_control import DMControlEnv @@ -78,9 +78,10 @@ def async_collector_setup(): def single_collector_setup_pixels(): device = "cuda:0" if torch.cuda.device_count() else "cpu" - env = TransformedEnv( - DMControlEnv("cheetah", "run", device=device, from_pixels=True), StepCounter(50) - ) + # env = TransformedEnv( + # DMControlEnv("cheetah", "run", device=device, from_pixels=True), StepCounter(50) + # ) + env = TransformedEnv(GymEnv("ALE/Pong-v5"), StepCounter(50)) c = SyncDataCollector( env, RandomPolicy(env.action_spec), @@ -99,7 +100,8 @@ def sync_collector_setup_pixels(): device = "cuda:0" if torch.cuda.device_count() else "cpu" env = EnvCreator( lambda: TransformedEnv( - DMControlEnv("cheetah", "run", device=device, from_pixels=True), + # DMControlEnv("cheetah", "run", device=device, from_pixels=True), + GymEnv("ALE/Pong-v5"), StepCounter(50), ) ) @@ -121,7 +123,8 @@ def async_collector_setup_pixels(): device = "cuda:0" if torch.cuda.device_count() else "cpu" env = EnvCreator( lambda: TransformedEnv( - DMControlEnv("cheetah", "run", device=device, from_pixels=True), + # DMControlEnv("cheetah", "run", device=device, from_pixels=True), + GymEnv("ALE/Pong-v5"), StepCounter(50), ) ) diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index ca5b7eb82ed..d07e8f5da90 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -123,7 +123,7 @@ def test_gae_speed(benchmark, gae_fn, gamma_tensor, batches, timesteps): gamma = 0.99 if gamma_tensor: - gamma = torch.full(size, gamma) + gamma = torch.full(size, gamma, device=device) lmbda = 0.95 benchmark( diff --git a/examples/a2c/a2c_atari.py b/examples/a2c/a2c_atari.py index 44a37cb3ce6..4598c11844b 100644 --- a/examples/a2c/a2c_atari.py +++ b/examples/a2c/a2c_atari.py @@ -117,9 +117,9 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(data.numel()) # Get training rewards and lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "done"]] + episode_length = data["next", "step_count"][data["next", "terminated"]] log_info.update( { "train/reward": episode_rewards.mean().item(), diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index a5265f442b7..360c6daac28 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -117,7 +117,7 @@ "object_store_memory": 1024**3, } collector = RayCollector( - env_makers=[env] * num_collectors, + create_env_fn=[env] * num_collectors, policy=policy_module, collector_class=SyncDataCollector, collector_kwargs={ diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index fba4247e2a7..385e4a53aab 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -147,6 +147,7 @@ def transformed_env_constructor( state_dim_gsde: Optional[int] = None, batch_dims: Optional[int] = 0, obs_norm_state_dict: Optional[dict] = None, + ignore_device: bool = False, ) -> Union[Callable, EnvCreator]: """ Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. @@ -179,6 +180,7 @@ def transformed_env_constructor( it should be set to 1 (or the number of dims of the batch). obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded into the environment + ignore_device (bool, optional): if True, the device is ignored. """ def make_transformed_env(**kwargs) -> TransformedEnv: @@ -189,14 +191,17 @@ def make_transformed_env(**kwargs) -> TransformedEnv: from_pixels = cfg.from_pixels if custom_env is None and custom_env_maker is None: - if isinstance(cfg.collector_device, str): - device = cfg.collector_device - elif isinstance(cfg.collector_device, Sequence): - device = cfg.collector_device[0] + if not ignore_device: + if isinstance(cfg.collector_device, str): + device = cfg.collector_device + elif isinstance(cfg.collector_device, Sequence): + device = cfg.collector_device[0] + else: + raise ValueError( + "collector_device must be either a string or a sequence of strings" + ) else: - raise ValueError( - "collector_device must be either a string or a sequence of strings" - ) + device = None env_kwargs = { "env_name": env_name, "device": device, @@ -252,19 +257,19 @@ def parallel_env_constructor( kwargs: keyword arguments for the `transformed_env_constructor` method. """ batch_transform = cfg.batch_transform + kwargs.update({"cfg": cfg, "use_env_creator": True}) if cfg.env_per_collector == 1: - kwargs.update({"cfg": cfg, "use_env_creator": True}) make_transformed_env = transformed_env_constructor(**kwargs) return make_transformed_env - kwargs.update({"cfg": cfg, "use_env_creator": True}) make_transformed_env = transformed_env_constructor( - return_transformed_envs=not batch_transform, **kwargs + return_transformed_envs=not batch_transform, ignore_device=True, **kwargs ) parallel_env = ParallelEnv( num_workers=cfg.env_per_collector, create_env_fn=make_transformed_env, create_env_kwargs=None, pin_memory=cfg.pin_memory, + device=cfg.collector_device, ) if batch_transform: kwargs.update( diff --git a/examples/impala/README.md b/examples/impala/README.md new file mode 100644 index 00000000000..00e0d010b82 --- /dev/null +++ b/examples/impala/README.md @@ -0,0 +1,33 @@ +## Reproducing Importance Weighted Actor-Learner Architecture (IMPALA) Algorithm Results + +This repository contains scripts that enable training agents using the IMPALA Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Espeholt et al. 2018. + +## Examples Structure + +Please note that we provide 2 examples, one for single node training and one for distributed training. Both examples rely on the same utils file, but besides that are independent. Each example contains the following files: + +1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. impala_single_node_ray.py). + +2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils.py). + +3. **Configuration File:** This file includes default hyperparameters specified in the original paper. For the multi-node case, the file also includes the configuration file of the Ray cluster. Users can modify these hyperparameters to customize their experiments (e.g. config_single_node.yaml). + + +## Running the Examples + +You can execute the single node IMPALA algorithm on Atari environments by running the following command: + +```bash +python impala_single_node.py +``` + +You can execute the multi-node IMPALA algorithm on Atari environments by running the following command: + +```bash +python impala_single_node_ray.py +``` +or + +```bash +python impala_single_node_submitit.py +``` diff --git a/examples/impala/config_multi_node_ray.yaml b/examples/impala/config_multi_node_ray.yaml new file mode 100644 index 00000000000..e312b336651 --- /dev/null +++ b/examples/impala/config_multi_node_ray.yaml @@ -0,0 +1,65 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + +# Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html +ray_init_config: + address: null + num_cpus: null + num_gpus: null + resources: null + object_store_memory: null + local_mode: False + ignore_reinit_error: False + include_dashboard: null + dashboard_host: 127.0.0.1 + dashboard_port: null + job_config: null + configure_logging: True + logging_level: info + logging_format: null + log_to_driver: True + namespace: null + runtime_env: null + storage: null + +# Device for the forward and backward passes +local_device: "cuda:0" + +# Resources assigned to each IMPALA rollout collection worker +remote_worker_resources: + num_cpus: 1 + num_gpus: 0.25 + memory: 1073741824 # 1*1024**3 - 1GB + +# collector +collector: + frames_per_batch: 80 + total_frames: 200_000_000 + num_workers: 12 + +# logger +logger: + backend: wandb + exp_name: Atari_IMPALA + test_interval: 200_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 0.0006 + eps: 1e-8 + weight_decay: 0.0 + momentum: 0.0 + alpha: 0.99 + max_grad_norm: 40.0 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + batch_size: 32 + sgd_updates: 1 + critic_coef: 0.5 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/impala/config_multi_node_submitit.yaml b/examples/impala/config_multi_node_submitit.yaml new file mode 100644 index 00000000000..f632ba15dc2 --- /dev/null +++ b/examples/impala/config_multi_node_submitit.yaml @@ -0,0 +1,46 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + +# Device for the forward and backward passes +local_device: "cuda:0" + +# SLURM config +slurm_config: + timeout_min: 10 + slurm_partition: train + slurm_cpus_per_task: 1 + slurm_gpus_per_node: 1 + +# collector +collector: + backend: gloo + frames_per_batch: 80 + total_frames: 200_000_000 + num_workers: 1 + +# logger +logger: + backend: wandb + exp_name: Atari_IMPALA + test_interval: 200_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 0.0006 + eps: 1e-8 + weight_decay: 0.0 + momentum: 0.0 + alpha: 0.99 + max_grad_norm: 40.0 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + batch_size: 32 + sgd_updates: 1 + critic_coef: 0.5 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/impala/config_single_node.yaml b/examples/impala/config_single_node.yaml new file mode 100644 index 00000000000..d39407c1a69 --- /dev/null +++ b/examples/impala/config_single_node.yaml @@ -0,0 +1,38 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + +# Device for the forward and backward passes +device: "cuda:0" + +# collector +collector: + frames_per_batch: 80 + total_frames: 200_000_000 + num_workers: 12 + +# logger +logger: + backend: wandb + exp_name: Atari_IMPALA + test_interval: 200_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 0.0006 + eps: 1e-8 + weight_decay: 0.0 + momentum: 0.0 + alpha: 0.99 + max_grad_norm: 40.0 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + batch_size: 32 + sgd_updates: 1 + critic_coef: 0.5 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py new file mode 100644 index 00000000000..a0d2d88c5a2 --- /dev/null +++ b/examples/impala/impala_multi_node_ray.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the IMPALA Algorithm +results from Espeholt et al. 2018 for the on Atari Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config_multi_node_ray", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.collectors.distributed import RayCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import A2CLoss + from torchrl.objectives.value import VTrace + from torchrl.record.loggers import generate_exp_name, get_logger + from utils import eval_model, make_env, make_ppo_models + + device = torch.device(cfg.local_device) + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Extract other config parameters + batch_size = cfg.loss.batch_size # Number of rollouts per batch + num_workers = ( + cfg.collector.num_workers + ) # Number of parallel workers collecting rollouts + lr = cfg.optim.lr + anneal_lr = cfg.optim.anneal_lr + sgd_updates = cfg.loss.sgd_updates + max_grad_norm = cfg.optim.max_grad_norm + num_test_episodes = cfg.logger.num_test_episodes + total_network_updates = ( + total_frames // (frames_per_batch * batch_size) + ) * cfg.loss.sgd_updates + + # Create models (check utils.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + ray_init_config = { + "address": cfg.ray_init_config.address, + "num_cpus": cfg.ray_init_config.num_cpus, + "num_gpus": cfg.ray_init_config.num_gpus, + "resources": cfg.ray_init_config.resources, + "object_store_memory": cfg.ray_init_config.object_store_memory, + "local_mode": cfg.ray_init_config.local_mode, + "ignore_reinit_error": cfg.ray_init_config.ignore_reinit_error, + "include_dashboard": cfg.ray_init_config.include_dashboard, + "dashboard_host": cfg.ray_init_config.dashboard_host, + "dashboard_port": cfg.ray_init_config.dashboard_port, + "job_config": cfg.ray_init_config.job_config, + "configure_logging": cfg.ray_init_config.configure_logging, + "logging_level": cfg.ray_init_config.logging_level, + "logging_format": cfg.ray_init_config.logging_format, + "log_to_driver": cfg.ray_init_config.log_to_driver, + "namespace": cfg.ray_init_config.namespace, + "runtime_env": cfg.ray_init_config.runtime_env, + "storage": cfg.ray_init_config.storage, + } + remote_config = { + "num_cpus": cfg.remote_worker_resources.num_cpus, + "num_gpus": cfg.remote_worker_resources.num_gpus + if torch.cuda.device_count() + else 0, + "memory": cfg.remote_worker_resources.memory, + } + collector = RayCollector( + create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, + policy=actor, + collector_class=SyncDataCollector, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + max_frames_per_traj=-1, + ray_init_config=ray_init_config, + remote_configs=remote_config, + sync=False, + update_after_each_batch=True, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch * batch_size), + sampler=sampler, + batch_size=frames_per_batch * batch_size, + ) + + # Create loss and adv modules + adv_module = VTrace( + gamma=cfg.loss.gamma, + value_network=critic, + actor_network=actor, + average_adv=False, + ) + loss_module = A2CLoss( + actor=actor, + critic=critic, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + ) + loss_module.set_keys(done="eol", terminated="eol") + + # Create optimizer + optim = torch.optim.RMSprop( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + alpha=cfg.optim.alpha, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name( + "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" + ) + logger = get_logger( + cfg.logger.backend, + logger_name="impala", + experiment_name=exp_name, + project="impala", + ) + + # Create test environment + test_env = make_env(cfg.env.env_name, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + pbar = tqdm.tqdm(total=total_frames) + accumulator = [] + start_time = sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "terminated"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + if len(accumulator) < batch_size: + accumulator.append(data) + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + continue + + losses = TensorDict({}, batch_size=[sgd_updates]) + training_start = time.time() + for j in range(sgd_updates): + + # Create a single batch of trajectories + stacked_data = torch.stack(accumulator, dim=0).contiguous() + stacked_data = stacked_data.to(device, non_blocking=True) + + # Compute advantage + with torch.no_grad(): + stacked_data = adv_module(stacked_data) + + # Add to replay buffer + for stacked_d in stacked_data: + stacked_data_reshape = stacked_d.reshape(-1) + data_buffer.extend(stacked_data_reshape) + + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = lr * alpha + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device, non_blocking=True) + + # Forward pass loss + loss = loss_module(batch) + losses[j] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_reward = eval_model( + actor, test_env, num_episodes=num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_reward, + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + accumulator = [] + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py new file mode 100644 index 00000000000..3355febbfaf --- /dev/null +++ b/examples/impala/impala_multi_node_submitit.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the IMPALA Algorithm +results from Espeholt et al. 2018 for the on Atari Environments. +""" +import hydra + + +@hydra.main( + config_path=".", config_name="config_multi_node_submitit", version_base="1.1" +) +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.collectors.distributed import DistributedDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import A2CLoss + from torchrl.objectives.value import VTrace + from torchrl.record.loggers import generate_exp_name, get_logger + from utils import eval_model, make_env, make_ppo_models + + device = torch.device(cfg.local_device) + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Extract other config parameters + batch_size = cfg.loss.batch_size # Number of rollouts per batch + num_workers = ( + cfg.collector.num_workers + ) # Number of parallel workers collecting rollouts + lr = cfg.optim.lr + anneal_lr = cfg.optim.anneal_lr + sgd_updates = cfg.loss.sgd_updates + max_grad_norm = cfg.optim.max_grad_norm + num_test_episodes = cfg.logger.num_test_episodes + total_network_updates = ( + total_frames // (frames_per_batch * batch_size) + ) * cfg.loss.sgd_updates + + # Create models (check utils.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + slurm_kwargs = { + "timeout_min": cfg.slurm_config.timeout_min, + "slurm_partition": cfg.slurm_config.slurm_partition, + "slurm_cpus_per_task": cfg.slurm_config.slurm_cpus_per_task, + "slurm_gpus_per_node": cfg.slurm_config.slurm_gpus_per_node, + } + # Create collector + device_str = "device" if num_workers <= 1 else "devices" + if cfg.collector.backend == "nccl": + collector_kwargs = {device_str: "cuda:0", f"storing_{device_str}": "cuda:0"} + elif cfg.collector.backend == "gloo": + collector_kwargs = {device_str: "cpu", f"storing_{device_str}": "cpu"} + else: + raise NotImplementedError( + f"device assignment not implemented for backend {cfg.collector.backend}" + ) + collector = DistributedDataCollector( + create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, + policy=actor, + num_workers_per_collector=1, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + collector_class=SyncDataCollector, + collector_kwargs=collector_kwargs, + slurm_kwargs=slurm_kwargs, + storing_device="cuda:0" if cfg.collector.backend == "nccl" else "cpu", + launcher="submitit", + # update_after_each_batch=True, + backend=cfg.collector.backend, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch * batch_size), + sampler=sampler, + batch_size=frames_per_batch * batch_size, + ) + + # Create loss and adv modules + adv_module = VTrace( + gamma=cfg.loss.gamma, + value_network=critic, + actor_network=actor, + average_adv=False, + ) + loss_module = A2CLoss( + actor=actor, + critic=critic, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + ) + loss_module.set_keys(done="eol", terminated="eol") + + # Create optimizer + optim = torch.optim.RMSprop( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + alpha=cfg.optim.alpha, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name( + "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" + ) + logger = get_logger( + cfg.logger.backend, + logger_name="impala", + experiment_name=exp_name, + project="impala", + ) + + # Create test environment + test_env = make_env(cfg.env.env_name, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + pbar = tqdm.tqdm(total=total_frames) + accumulator = [] + start_time = sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + if len(accumulator) < batch_size: + accumulator.append(data) + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + continue + + losses = TensorDict({}, batch_size=[sgd_updates]) + training_start = time.time() + for j in range(sgd_updates): + + # Create a single batch of trajectories + stacked_data = torch.stack(accumulator, dim=0).contiguous() + stacked_data = stacked_data.to(device, non_blocking=True) + + # Compute advantage + with torch.no_grad(): + stacked_data = adv_module(stacked_data) + + # Add to replay buffer + for stacked_d in stacked_data: + stacked_data_reshape = stacked_d.reshape(-1) + data_buffer.extend(stacked_data_reshape) + + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = lr * alpha + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device) + + # Forward pass loss + loss = loss_module(batch) + losses[j] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_reward = eval_model( + actor, test_env, num_episodes=num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_reward, + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + accumulator = [] + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py new file mode 100644 index 00000000000..cd270f4c9e9 --- /dev/null +++ b/examples/impala/impala_single_node.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the IMPALA Algorithm +results from Espeholt et al. 2018 for the on Atari Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config_single_node", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import MultiaSyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import A2CLoss + from torchrl.objectives.value import VTrace + from torchrl.record.loggers import generate_exp_name, get_logger + from utils import eval_model, make_env, make_ppo_models + + device = torch.device(cfg.device) + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Extract other config parameters + batch_size = cfg.loss.batch_size # Number of rollouts per batch + num_workers = ( + cfg.collector.num_workers + ) # Number of parallel workers collecting rollouts + lr = cfg.optim.lr + anneal_lr = cfg.optim.anneal_lr + sgd_updates = cfg.loss.sgd_updates + max_grad_norm = cfg.optim.max_grad_norm + num_test_episodes = cfg.logger.num_test_episodes + total_network_updates = ( + total_frames // (frames_per_batch * batch_size) + ) * cfg.loss.sgd_updates + + # Create models (check utils.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + collector = MultiaSyncDataCollector( + create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + update_at_each_batch=True, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch * batch_size), + sampler=sampler, + batch_size=frames_per_batch * batch_size, + ) + + # Create loss and adv modules + adv_module = VTrace( + gamma=cfg.loss.gamma, + value_network=critic, + actor_network=actor, + average_adv=False, + ) + loss_module = A2CLoss( + actor=actor, + critic=critic, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + ) + loss_module.set_keys(done="eol", terminated="eol") + + # Create optimizer + optim = torch.optim.RMSprop( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + alpha=cfg.optim.alpha, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name( + "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" + ) + logger = get_logger( + cfg.logger.backend, + logger_name="impala", + experiment_name=exp_name, + project="impala", + ) + + # Create test environment + test_env = make_env(cfg.env.env_name, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + pbar = tqdm.tqdm(total=total_frames) + accumulator = [] + start_time = sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "terminated"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + if len(accumulator) < batch_size: + accumulator.append(data) + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + continue + + losses = TensorDict({}, batch_size=[sgd_updates]) + training_start = time.time() + for j in range(sgd_updates): + + # Create a single batch of trajectories + stacked_data = torch.stack(accumulator, dim=0).contiguous() + stacked_data = stacked_data.to(device, non_blocking=True) + + # Compute advantage + with torch.no_grad(): + stacked_data = adv_module(stacked_data) + + # Add to replay buffer + for stacked_d in stacked_data: + stacked_data_reshape = stacked_d.reshape(-1) + data_buffer.extend(stacked_data_reshape) + + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = lr * alpha + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device, non_blocking=True) + + # Forward pass loss + loss = loss_module(batch) + losses[j] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_reward = eval_model( + actor, test_env, num_episodes=num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_reward, + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + accumulator = [] + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/impala/utils.py b/examples/impala/utils.py new file mode 100644 index 00000000000..2983f8a0193 --- /dev/null +++ b/examples/impala/utils.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.envs import ( + CatFrames, + DoubleToFloat, + EndOfLifeTransform, + ExplorationType, + GrayScale, + GymEnv, + NoopResetEnv, + Resize, + RewardClipping, + RewardSum, + StepCounter, + ToTensorImage, + TransformedEnv, + VecNorm, +) +from torchrl.modules import ( + ActorValueOperator, + ConvNet, + MLP, + OneHotCategorical, + ProbabilisticActor, + ValueOperator, +) + + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_env(env_name, device, is_test=False): + env = GymEnv( + env_name, frame_skip=4, from_pixels=True, pixels_only=False, device=device + ) + env = TransformedEnv(env) + env.append_transform(NoopResetEnv(noops=30, random=True)) + if not is_test: + env.append_transform(EndOfLifeTransform()) + env.append_transform(RewardClipping(-1, 1)) + env.append_transform(ToTensorImage(from_int=False)) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter(max_steps=4500)) + env.append_transform(DoubleToFloat()) + env.append_transform(VecNorm(in_keys=["pixels"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_modules_pixels(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["pixels"].shape + + # Define distribution class and kwargs + num_outputs = proof_environment.action_spec.space.n + distribution_class = OneHotCategorical + distribution_kwargs = {} + + # Define input keys + in_keys = ["pixels"] + + # Define a shared Module and TensorDictModule (CNN + MLP) + common_cnn = ConvNet( + activation_class=torch.nn.ReLU, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + common_cnn_output = common_cnn(torch.ones(input_shape)) + common_mlp = MLP( + in_features=common_cnn_output.shape[-1], + activation_class=torch.nn.ReLU, + activate_last_layer=True, + out_features=512, + num_cells=[], + ) + common_mlp_output = common_mlp(common_cnn_output) + + # Define shared net as TensorDictModule + common_module = TensorDictModule( + module=torch.nn.Sequential(common_cnn, common_mlp), + in_keys=in_keys, + out_keys=["common_features"], + ) + + # Define on head for the policy + policy_net = MLP( + in_features=common_mlp_output.shape[-1], + out_features=num_outputs, + activation_class=torch.nn.ReLU, + num_cells=[], + ) + policy_module = TensorDictModule( + module=policy_net, + in_keys=["common_features"], + out_keys=["logits"], + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + policy_module, + in_keys=["logits"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define another head for the value + value_net = MLP( + activation_class=torch.nn.ReLU, + in_features=common_mlp_output.shape[-1], + out_features=1, + num_cells=[], + ) + value_module = ValueOperator( + value_net, + in_keys=["common_features"], + ) + + return common_module, policy_module, value_module + + +def make_ppo_models(env_name): + + proof_environment = make_env(env_name, device="cpu") + common_module, policy_module, value_module = make_ppo_modules_pixels( + proof_environment + ) + + # Wrap modules in a single ActorCritic operator + actor_critic = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_module, + value_operator=value_module, + ) + + actor = actor_critic.get_policy_operator() + critic = actor_critic.get_value_operator() + + del proof_environment + + return actor, critic + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = torch.zeros(num_episodes, dtype=torch.float32) + for i in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards[i] = reward.sum() + del td_test + return test_rewards.mean() diff --git a/examples/ppo/config_mujoco.yaml b/examples/ppo/config_mujoco.yaml index 1272c1f4bff..0322526e7b1 100644 --- a/examples/ppo/config_mujoco.yaml +++ b/examples/ppo/config_mujoco.yaml @@ -18,7 +18,7 @@ logger: optim: lr: 3e-4 weight_decay: 0.0 - anneal_lr: False + anneal_lr: True # loss loss: diff --git a/examples/ppo/ppo_atari.py b/examples/ppo/ppo_atari.py index eb2ce15ec5a..1bfbccdeba4 100644 --- a/examples/ppo/ppo_atari.py +++ b/examples/ppo/ppo_atari.py @@ -134,9 +134,9 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(data.numel()) # Get training rewards and episode lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "stop"]] + episode_length = data["next", "step_count"][data["next", "terminated"]] log_info.update( { "train/reward": episode_rewards.mean().item(), diff --git a/examples/ppo/ppo_mujoco.py b/examples/ppo/ppo_mujoco.py index ff6aeda51d2..52b12f688e1 100644 --- a/examples/ppo/ppo_mujoco.py +++ b/examples/ppo/ppo_mujoco.py @@ -28,7 +28,6 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils_mujoco import eval_model, make_env, make_ppo_models - # Define paper hyperparameters device = "cpu" if not torch.cuda.device_count() else "cuda" num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size total_network_updates = ( @@ -67,6 +66,7 @@ def main(cfg: "DictConfig"): # noqa: F821 value_network=critic, average_gae=False, ) + loss_module = ClipPPOLoss( actor=actor, critic=critic, @@ -78,8 +78,8 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) - critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr) + actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr, eps=1e-5) + critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr, eps=1e-5) # Create logger logger = None @@ -187,7 +187,9 @@ def main(cfg: "DictConfig"): # noqa: F821 "train/lr": alpha * cfg_optim_lr, "train/sampling_time": sampling_time, "train/training_time": training_time, - "train/clip_epsilon": alpha * cfg_loss_clip_epsilon, + "train/clip_epsilon": alpha * cfg_loss_clip_epsilon + if cfg_loss_anneal_clip_eps + else cfg_loss_clip_epsilon, } ) diff --git a/examples/ppo/utils_mujoco.py b/examples/ppo/utils_mujoco.py index 8fa2a53fd92..7be234b322d 100644 --- a/examples/ppo/utils_mujoco.py +++ b/examples/ppo/utils_mujoco.py @@ -28,10 +28,10 @@ def make_env(env_name="HalfCheetah-v4", device="cpu"): env = GymEnv(env_name, device=device) env = TransformedEnv(env) + env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2)) + env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) env.append_transform(RewardSum()) env.append_transform(StepCounter()) - env.append_transform(VecNorm(in_keys=["observation"])) - env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) env.append_transform(DoubleToFloat(in_keys=["observation"])) return env @@ -72,7 +72,9 @@ def make_ppo_models_state(proof_environment): # Add state-independent normal scale policy_mlp = torch.nn.Sequential( policy_mlp, - AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]), + AddStateIndependentNormalScale( + proof_environment.action_spec.shape[-1], scale_lb=1e-8 + ), ) # Add probabilistic sampling of the actions diff --git a/test/assets/generate.py b/test/assets/generate.py index 75a87bb71b5..deb47f95999 100644 --- a/test/assets/generate.py +++ b/test/assets/generate.py @@ -4,6 +4,12 @@ # LICENSE file in the root directory of this source tree. """Script used to generate the mini datasets.""" +import multiprocessing as mp + +try: + mp.set_start_method("spawn") +except Exception: + pass from tempfile import TemporaryDirectory from datasets import Dataset, DatasetDict, load_dataset @@ -36,8 +42,9 @@ def get_minibatch(): batch_size=16, block_size=33, tensorclass_type=PromptData, - dataset_name="test/datasets_mini/openai_summarize_tldr", + dataset_name="../datasets_mini/openai_summarize_tldr", device="cpu", + num_workers=2, infinite=False, prefetch=0, split="train", @@ -47,3 +54,8 @@ def get_minibatch(): for data in dl: data = data.clone().memmap_("test/datasets_mini/tldr_batch/") break + print("done") + + +if __name__ == "__main__": + get_minibatch() diff --git a/test/assets/tldr_batch.zip b/test/assets/tldr_batch.zip index 6293560f580..252b9e3c999 100644 Binary files a/test/assets/tldr_batch.zip and b/test/assets/tldr_batch.zip differ diff --git a/test/test_cost.py b/test/test_cost.py index ba8d5b43b7f..a4b87e9f746 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -47,7 +47,7 @@ get_default_devices, ) from mocking_classes import ContinuousActionConvMockEnv -from tensordict.nn import get_functional, NormalParamExtractor, TensorDictModule +from tensordict.nn import NormalParamExtractor, TensorDictModule from tensordict.nn.utils import Buffer # from torchrl.data.postprocs.utils import expand_as_right @@ -130,6 +130,7 @@ GAE, TD1Estimator, TDLambdaEstimator, + VTrace, ) from torchrl.objectives.value.functional import ( _transpose_time, @@ -140,6 +141,7 @@ vec_generalized_advantage_estimate, vec_td1_advantage_estimate, vec_td_lambda_advantage_estimate, + vtrace_advantage_estimate, ) from torchrl.objectives.value.utils import ( _custom_conv1d, @@ -437,7 +439,7 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = DQNLoss(actor, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -915,7 +917,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -1400,7 +1402,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): delay_actor=delay_actor, delay_value=delay_value, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -2009,7 +2011,7 @@ def test_td3( delay_actor=delay_actor, delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -2696,7 +2698,7 @@ def test_sac( **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -3481,7 +3483,7 @@ def test_discrete_sac( loss_function="l2", **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4091,7 +4093,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4458,7 +4460,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4475,7 +4477,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_deprec.make_value_estimator(td_est) return @@ -4888,7 +4890,7 @@ def test_cql( delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4984,6 +4986,18 @@ def test_cql( else: raise NotImplementedError(k) loss_fn.zero_grad() + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) sum([item for _, item in loss.items()]).backward() named_parameters = list(loss_fn.named_parameters()) @@ -5324,7 +5338,7 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -5555,6 +5569,7 @@ def _create_mock_actor( action_dim=4, device="cpu", observation_key="observation", + sample_log_prob_key="sample_log_prob", ): # Actor action_spec = BoundedTensorSpec( @@ -5569,6 +5584,8 @@ def _create_mock_actor( distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, + return_log_prob=True, + log_prob_key=sample_log_prob_key, ) return actor.to(device) @@ -5606,6 +5623,7 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, + return_log_prob=True, ) module = nn.Sequential(base_layer, nn.Linear(5, 1)) value = ValueOperator( @@ -5632,6 +5650,7 @@ def _create_mock_actor_value_shared( distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, + return_log_prob=True, ) module = nn.Linear(5, 1) value_head = ValueOperator( @@ -5739,7 +5758,7 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): @@ -5752,6 +5771,13 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5818,7 +5844,7 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode): loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_ppo_shared(self, loss_class, device, advantage): torch.manual_seed(self.seed) @@ -5831,6 +5857,12 @@ def test_ppo_shared(self, loss_class, device, advantage): lmbda=0.9, value_network=value, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5892,6 +5924,7 @@ def test_ppo_shared(self, loss_class, device, advantage): "advantage", ( "gae", + "vtrace", "td", "td_lambda", ), @@ -5911,6 +5944,12 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses): lmbda=0.9, value_network=value, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5962,11 +6001,9 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses): ) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): - if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): - raise pytest.skip("make_functional_with_buffers needs to be changed") torch.manual_seed(self.seed) td = self._create_seq_mock_data_ppo(device=device) @@ -5976,6 +6013,13 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5991,21 +6035,31 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): loss_fn = loss_class(actor, value, gamma=0.9, loss_critic_type="l2") - floss_fn, params, buffers = make_functional_with_buffers(loss_fn) + params = TensorDict.from_module(loss_fn, as_module=True) + # fill params with zero - for p in params: - p.data.zero_() + def zero_param(p): + if isinstance(p, nn.Parameter): + p.data.zero_() + + params.apply(zero_param) + # assert len(list(floss_fn.parameters())) == 0 - if advantage is not None: - advantage(td) - loss = floss_fn(params, buffers, td) + with params.to_module(loss_fn): + if advantage is not None: + advantage(td) + loss = loss_fn(td) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) loss_critic.backward(retain_graph=True) # check that grads are independent and non null named_parameters = loss_fn.named_parameters() - for (name, _), p in zip(named_parameters, params): + for name, p in params.items(True, True): + if isinstance(name, tuple): + name = "-".join(name) + if not isinstance(p, nn.Parameter): + continue if p.grad is not None and p.grad.norm() > 0.0: assert "actor" not in name assert "critic" in name @@ -6013,12 +6067,12 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): assert "actor" in name assert "critic" not in name - for param in params: - param.grad = None + for p in params.values(True, True): + p.grad = None loss_objective.backward() named_parameters = loss_fn.named_parameters() - - for (name, other_p), p in zip(named_parameters, params): + for (name, other_p) in named_parameters: + p = params.get(tuple(name.split("."))) assert other_p.shape == p.shape assert other_p.dtype == p.dtype assert other_p.device == p.device @@ -6028,7 +6082,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): if p.grad is None: assert "actor" not in name assert "critic" in name - for param in params: + for param in params.values(True, True): param.grad = None @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @@ -6038,6 +6092,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, + ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -6079,7 +6134,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): """Test PPO loss module with non-default tensordict keys.""" @@ -6097,7 +6152,9 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): sample_log_prob_key=tensor_keys["sample_log_prob"], action_key=tensor_keys["action"], ) - actor = self._create_mock_actor() + actor = self._create_mock_actor( + sample_log_prob_key=tensor_keys["sample_log_prob"] + ) value = self._create_mock_value(out_keys=[tensor_keys["value"]]) if advantage == "gae": @@ -6107,6 +6164,13 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): value_network=value, differentiable=gradient_mode, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -6200,7 +6264,9 @@ def test_ppo_notensordict( terminated_key=terminated_key, ) - actor = self._create_mock_actor(observation_key=observation_key) + actor = self._create_mock_actor( + observation_key=observation_key, sample_log_prob_key=sample_log_prob_key + ) value = self._create_mock_value(observation_key=observation_key) loss = loss_class(actor=actor, critic=value) @@ -6259,6 +6325,7 @@ def _create_mock_actor( action_dim=4, device="cpu", observation_key="observation", + sample_log_prob_key="sample_log_prob", ): # Actor action_spec = BoundedTensorSpec( @@ -6273,6 +6340,8 @@ def _create_mock_actor( in_keys=["loc", "scale"], spec=action_spec, distribution_class=TanhNormal, + return_log_prob=True, + log_prob_key=sample_log_prob_key, ) return actor.to(device) @@ -6363,6 +6432,7 @@ def _create_seq_mock_data_a2c( reward_key="reward", done_key="done", terminated_key="terminated", + sample_log_prob_key="sample_log_prob", ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -6392,7 +6462,7 @@ def _create_seq_mock_data_a2c( }, "collector": {"mask": mask}, action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0), - "sample_log_prob": torch.randn_like(action[..., 1]).masked_fill_( + sample_log_prob_key: torch.randn_like(action[..., 1]).masked_fill_( ~mask, 0.0 ) / 10, @@ -6405,7 +6475,7 @@ def _create_seq_mock_data_a2c( return td @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_a2c(self, device, gradient_mode, advantage, td_est): @@ -6418,6 +6488,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -6462,6 +6539,8 @@ def test_a2c(self, device, gradient_mode, advantage, td_est): assert ("critic" not in name) or ("target_" in name) value.zero_grad() + for n, p in loss_fn.named_parameters(): + assert p.grad is None or p.grad.norm() == 0, n loss_objective.backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: @@ -6542,7 +6621,7 @@ def test_a2c_separate_losses(self, separate_losses): not _has_functorch, reason=f"functorch not found, {FUNCTORCH_ERR}" ) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_a2c_diff(self, device, gradient_mode, advantage): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): @@ -6560,6 +6639,13 @@ def test_a2c_diff(self, device, gradient_mode, advantage): advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td_lambda": advantage = TDLambdaEstimator( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode @@ -6609,6 +6695,7 @@ def test_a2c_diff(self, device, gradient_mode, advantage): ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, + ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -6626,6 +6713,7 @@ def test_a2c_tensordict_keys(self, td_est): "reward": "reward", "done": "done", "terminated": "terminated", + "sample_log_prob": "sample_log_prob", } self.tensordict_keys_test( @@ -6648,8 +6736,16 @@ def test_a2c_tensordict_keys(self, td_est): } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + @pytest.mark.parametrize( + "td_est", + [ + ValueEstimators.GAE, + ValueEstimators.VTrace, + ], + ) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_tensordict_keys_run(self, device): + def test_a2c_tensordict_keys_run(self, device, advantage, td_est): """Test A2C loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -6658,6 +6754,7 @@ def test_a2c_tensordict_keys_run(self, device): value_key = "state_value_test" action_key = "action_test" reward_key = "reward_test" + sample_log_prob_key = "sample_log_prob_test" done_key = ("done", "test") terminated_key = ("terminated", "test") @@ -6667,24 +6764,29 @@ def test_a2c_tensordict_keys_run(self, device): reward_key=reward_key, done_key=done_key, terminated_key=terminated_key, + sample_log_prob_key=sample_log_prob_key, ) - actor = self._create_mock_actor(device=device) - value = self._create_mock_value(device=device, out_keys=[value_key]) - advantage = GAE( - gamma=0.9, - lmbda=0.9, - value_network=value, - differentiable=gradient_mode, - ) - advantage.set_keys( - advantage=advantage_key, - value_target=value_target_key, - value=value_key, - reward=reward_key, - done=done_key, - terminated=terminated_key, + actor = self._create_mock_actor( + device=device, sample_log_prob_key=sample_log_prob_key ) + value = self._create_mock_value(device=device, out_keys=[value_key]) + if advantage == "gae": + advantage = GAE( + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode + ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) + elif advantage is None: + pass + else: + raise NotImplementedError + loss_fn = A2CLoss(actor, value, loss_critic_type="l2") loss_fn.set_keys( advantage=advantage_key, @@ -6694,9 +6796,23 @@ def test_a2c_tensordict_keys_run(self, device): reward=reward_key, done=done_key, terminated=done_key, + sample_log_prob=sample_log_prob_key, ) - advantage(td) + if advantage is not None: + advantage.set_keys( + advantage=advantage_key, + value_target=value_target_key, + value=value_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + sample_log_prob=sample_log_prob_key, + ) + advantage(td) + else: + if td_est is not None: + loss_fn.make_value_estimator(td_est) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -6794,7 +6910,16 @@ class TestReinforce(LossModuleTestBase): @pytest.mark.parametrize("delay_value", [True, False]) @pytest.mark.parametrize("gradient_mode", [True, False]) @pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None]) - @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + @pytest.mark.parametrize( + "td_est", + [ + ValueEstimators.TD1, + ValueEstimators.TD0, + ValueEstimators.GAE, + ValueEstimators.TDLambda, + None, + ], + ) def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est): n_obs = 3 n_act = 5 @@ -6816,20 +6941,20 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est advantage = GAE( gamma=gamma, lmbda=0.9, - value_network=get_functional(value_net), + value_network=value_net, differentiable=gradient_mode, ) elif advantage == "td": advantage = TD1Estimator( gamma=gamma, - value_network=get_functional(value_net), + value_network=value_net, differentiable=gradient_mode, ) elif advantage == "td_lambda": advantage = TDLambdaEstimator( gamma=0.9, lmbda=0.9, - value_network=get_functional(value_net), + value_network=value_net, differentiable=gradient_mode, ) elif advantage is None: @@ -7512,7 +7637,7 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est) imagination_horizon=imagination_horizon, discount_loss=discount_loss, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_module.make_value_estimator(td_est) return @@ -8254,7 +8379,7 @@ def test_iql( expectile=expectile, loss_function="l2", ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -9615,6 +9740,113 @@ def test_gae_multidim( torch.testing.assert_close(r1, r3, rtol=1e-4, atol=1e-4) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1]) + @pytest.mark.parametrize("N", [(1,), (3,), (7, 3)]) + @pytest.mark.parametrize("T", [200, 5, 3]) + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("has_done", [False, True]) + def test_vtrace(self, device, gamma, N, T, dtype, has_done): + torch.manual_seed(0) + + done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone() + if has_done: + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated + reward = torch.randn(*N, T, 1, device=device, dtype=dtype) + state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) + next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) + log_pi = torch.log(torch.rand(*N, T, 1, device=device, dtype=dtype)) + log_mu = torch.log(torch.rand(*N, T, 1, device=device, dtype=dtype)) + + _, value_target = vtrace_advantage_estimate( + gamma, + log_pi, + log_mu, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + ) + + assert not torch.isnan(value_target).any() + assert not torch.isinf(value_target).any() + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1]) + @pytest.mark.parametrize("N", [(3,), (7, 3)]) + @pytest.mark.parametrize("T", [100, 3]) + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("feature_dim", [[5], [2, 5]]) + @pytest.mark.parametrize("has_done", [True, False]) + def test_vtrace_multidim(self, device, gamma, N, T, dtype, has_done, feature_dim): + D = feature_dim + time_dim = -1 - len(D) + + torch.manual_seed(0) + + done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + terminated = done.clone() + if has_done: + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated + reward = torch.randn(*N, T, *D, device=device, dtype=dtype) + state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) + next_state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) + log_pi = torch.log(torch.rand(*N, T, *D, device=device, dtype=dtype)) + log_mu = torch.log(torch.rand(*N, T, *D, device=device, dtype=dtype)) + + r1 = vtrace_advantage_estimate( + gamma, + log_pi, + log_mu, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + time_dim=time_dim, + ) + if len(D) == 2: + r2 = [ + vtrace_advantage_estimate( + gamma, + log_pi[..., i : i + 1, j], + log_mu[..., i : i + 1, j], + state_value[..., i : i + 1, j], + next_state_value[..., i : i + 1, j], + reward[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + done=done[..., i : i + 1, j], + time_dim=-2, + ) + for i in range(D[0]) + for j in range(D[1]) + ] + else: + r2 = [ + vtrace_advantage_estimate( + gamma, + log_pi[..., i : i + 1], + log_mu[..., i : i + 1], + state_value[..., i : i + 1], + next_state_value[..., i : i + 1], + reward[..., i : i + 1], + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + time_dim=-2, + ) + for i in range(D[0]) + ] + + list2 = list(zip(*r2)) + r2 = [torch.cat(list2[0], -1), torch.cat(list2[1], -1)] + if len(D) == 2: + r2 = [r2[0].unflatten(-1, D), r2[1].unflatten(-1, D)] + torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("gamma", [0.5, 0.99, 0.1]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) @@ -9638,9 +9870,6 @@ def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done): next_state_value = torch.randn(*N, T, 1, device=device) gamma_tensor = torch.full((*N, T, 1), gamma, device=device) - # if len(N) == 2: - # print(terminated[4, 0, -10:]) - # print(done[4, 0, -10:]) v1 = vec_td_lambda_advantage_estimate( gamma, lmbda, @@ -10549,6 +10778,7 @@ class TestAdv: [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) def test_dispatch( @@ -10559,18 +10789,46 @@ def test_dispatch( value_net = TensorDictModule( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) - module = adv( - gamma=0.98, - value_network=value_net, - differentiable=False, - **kwargs, - ) - kwargs = { - "obs": torch.randn(1, 10, 3), - "next_reward": torch.randn(1, 10, 1, requires_grad=True), - "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), - "next_obs": torch.randn(1, 10, 3), - } + if adv is VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + differentiable=False, + **kwargs, + ) + kwargs = { + "obs": torch.randn(1, 10, 3), + "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "next_reward": torch.randn(1, 10, 1, requires_grad=True), + "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_obs": torch.randn(1, 10, 3), + } + else: + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=False, + **kwargs, + ) + kwargs = { + "obs": torch.randn(1, 10, 3), + "next_reward": torch.randn(1, 10, 1, requires_grad=True), + "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_obs": torch.randn(1, 10, 3), + } advantage, value_target = module(**kwargs) assert advantage.shape == torch.Size([1, 10, 1]) assert value_target.shape == torch.Size([1, 10, 1]) @@ -10581,6 +10839,7 @@ def test_dispatch( [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) def test_diff_reward( @@ -10591,23 +10850,55 @@ def test_diff_reward( value_net = TensorDictModule( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) - module = adv( - gamma=0.98, - value_network=value_net, - differentiable=True, - **kwargs, - ) - td = TensorDict( - { - "obs": torch.randn(1, 10, 3), - "next": { + if adv is VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + differentiable=True, + **kwargs, + ) + td = TensorDict( + { "obs": torch.randn(1, 10, 3), - "reward": torch.randn(1, 10, 1, requires_grad=True), - "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "next": { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + }, }, - }, - [1, 10], - ) + [1, 10], + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=True, + **kwargs, + ) + td = TensorDict( + { + "obs": torch.randn(1, 10, 3), + "next": { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + }, + }, + [1, 10], + ) td = module(td.clone(False)) # check that the advantage can't backprop to the value params td["advantage"].sum().backward() @@ -10622,6 +10913,7 @@ def test_diff_reward( [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) @pytest.mark.parametrize("shifted", [True, False]) @@ -10629,25 +10921,60 @@ def test_non_differentiable(self, adv, shifted, kwargs): value_net = TensorDictModule( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) - module = adv( - gamma=0.98, - value_network=value_net, - differentiable=False, - shifted=shifted, - **kwargs, - ) - td = TensorDict( - { - "obs": torch.randn(1, 10, 3), - "next": { + + if adv is VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + differentiable=False, + shifted=shifted, + **kwargs, + ) + td = TensorDict( + { "obs": torch.randn(1, 10, 3), - "reward": torch.randn(1, 10, 1, requires_grad=True), - "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "next": { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + }, }, - }, - [1, 10], - names=[None, "time"], - ) + [1, 10], + names=[None, "time"], + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=False, + shifted=shifted, + **kwargs, + ) + td = TensorDict( + { + "obs": torch.randn(1, 10, 3), + "next": { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + }, + }, + [1, 10], + names=[None, "time"], + ) td = module(td.clone(False)) assert td["advantage"].is_leaf @@ -10657,6 +10984,7 @@ def test_non_differentiable(self, adv, shifted, kwargs): [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) @pytest.mark.parametrize("has_value_net", [True, False]) @@ -10679,28 +11007,65 @@ def test_skip_existing( else: value_net = None - module = adv( - gamma=0.98, - value_network=value_net, - differentiable=True, - shifted=shifted, - skip_existing=skip_existing, - **kwargs, - ) - td = TensorDict( - { - "obs": torch.randn(1, 10, 3), - "state_value": torch.ones(1, 10, 1), - "next": { + if adv is VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + differentiable=True, + shifted=shifted, + skip_existing=skip_existing, + **kwargs, + ) + td = TensorDict( + { "obs": torch.randn(1, 10, 3), + "sample_log_prob": torch.log(torch.rand(1, 10, 1)), "state_value": torch.ones(1, 10, 1), - "reward": torch.randn(1, 10, 1, requires_grad=True), - "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(1, 10, 3), + "state_value": torch.ones(1, 10, 1), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + }, }, - }, - [1, 10], - names=[None, "time"], - ) + [1, 10], + names=[None, "time"], + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=True, + shifted=shifted, + skip_existing=skip_existing, + **kwargs, + ) + td = TensorDict( + { + "obs": torch.randn(1, 10, 3), + "state_value": torch.ones(1, 10, 1), + "next": { + "obs": torch.randn(1, 10, 3), + "state_value": torch.ones(1, 10, 1), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + }, + }, + [1, 10], + names=[None, "time"], + ) td = module(td.clone(False)) if has_value_net and not skip_existing: exp_val = 0 @@ -10718,15 +11083,34 @@ def test_skip_existing( [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) def test_set_keys(self, value, adv, kwargs): value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=[value]) - module = adv( - gamma=0.98, - value_network=value_net, - **kwargs, - ) + if adv is VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + **kwargs, + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + **kwargs, + ) module.set_keys(value=value) assert module.tensor_keys.value == value @@ -10740,6 +11124,7 @@ def test_set_keys(self, value, adv, kwargs): [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) def test_set_deprecated_keys(self, adv, kwargs): @@ -10748,14 +11133,36 @@ def test_set_deprecated_keys(self, adv, kwargs): ) with pytest.warns(DeprecationWarning): - module = adv( - gamma=0.98, - value_network=value_net, - value_key="test_value", - advantage_key="advantage_test", - value_target_key="value_target_test", - **kwargs, - ) + + if adv is VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + value_key="test_value", + advantage_key="advantage_test", + value_target_key="value_target_test", + **kwargs, + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + value_key="test_value", + advantage_key="advantage_test", + value_target_key="value_target_test", + **kwargs, + ) assert module.tensor_keys.value == "test_value" assert module.tensor_keys.advantage == "advantage_test" assert module.tensor_keys.value_target == "value_target_test" diff --git a/test/test_env.py b/test/test_env.py index 6cee7f545d7..aed4e07b0b7 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -354,6 +354,48 @@ def test_mb_env_batch_lock(self, device, seed=0): class TestParallel: + @pytest.mark.skipif( + not torch.cuda.device_count(), reason="No cuda device detected." + ) + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("hetero", [True, False]) + @pytest.mark.parametrize("pdevice", [None, "cpu", "cuda"]) + @pytest.mark.parametrize("edevice", ["cpu", "cuda"]) + @pytest.mark.parametrize("bwad", [True, False]) + def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad): + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv + if not hetero: + env = cls( + 2, lambda: ContinuousActionVecMockEnv(device=edevice), device=pdevice + ) + else: + env1 = lambda: ContinuousActionVecMockEnv(device=edevice) + env2 = lambda: TransformedEnv(ContinuousActionVecMockEnv(device=edevice)) + env = cls(2, [env1, env2], device=pdevice) + + r = env.rollout(2, break_when_any_done=bwad) + if pdevice is not None: + assert env.device.type == torch.device(pdevice).type + assert r.device.type == torch.device(pdevice).type + assert all( + item.device.type == torch.device(pdevice).type + for item in r.values(True, True) + ) + else: + assert env.device.type == torch.device(edevice).type + assert r.device.type == torch.device(edevice).type + assert all( + item.device.type == torch.device(edevice).type + for item in r.values(True, True) + ) + if parallel: + assert ( + env.shared_tensordict_parent.device.type == torch.device(edevice).type + ) + @pytest.mark.parametrize("num_parallel_env", [1, 10]) @pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)]) def test_env_with_batch_size(self, num_parallel_env, env_batch_size): diff --git a/test/test_exploration.py b/test/test_exploration.py index 8a374cd9009..24bb8c246d0 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -51,7 +51,7 @@ class TestEGreedy: - @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0]) + @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1]) @pytest.mark.parametrize("module", [True, False]) def test_egreedy(self, eps_init, module): torch.manual_seed(0) @@ -78,7 +78,7 @@ def test_egreedy(self, eps_init, module): assert (action == 0).any() assert ((action == 1) | (action == 0)).all() - @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0]) + @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1]) @pytest.mark.parametrize("module", [True, False]) @pytest.mark.parametrize("spec_class", ["discrete", "one_hot"]) def test_egreedy_masked(self, module, eps_init, spec_class): diff --git a/test/test_libs.py b/test/test_libs.py index f1715a550f4..c3379021510 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -400,7 +400,7 @@ def test_vecenvs_wrapper(self, envname): ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"] + (["FetchReach-v2"] if _has_gym_robotics else []), ) - @pytest.mark.flaky(reruns=3, reruns_delay=1) + @pytest.mark.flaky(reruns=8, reruns_delay=1) def test_vecenvs_env(self, envname): from _utils_internal import rollout_consistency_assertion @@ -1896,12 +1896,8 @@ def test_direct_download(self, task): keys = keys.intersection(data_d4rl._storage._storage.keys(True, True)) assert len(keys) assert_allclose_td( - data_direct._storage._storage.select(*keys).apply( - lambda t: t.as_tensor().float() - ), - data_d4rl._storage._storage.select(*keys).apply( - lambda t: t.as_tensor().float() - ), + data_direct._storage._storage.select(*keys).apply(lambda t: t.float()), + data_d4rl._storage._storage.select(*keys).apply(lambda t: t.float()), ) @pytest.mark.parametrize( diff --git a/test/test_rb_distributed.py b/test/test_rb_distributed.py index 8a46b1a006d..548f04dc41d 100644 --- a/test/test_rb_distributed.py +++ b/test/test_rb_distributed.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse import os + import sys import time @@ -22,10 +23,10 @@ class ReplayBufferNode(RemoteTensorDictReplayBuffer): - def __init__(self, capacity: int): + def __init__(self, capacity: int, scratch_dir=None): super().__init__( storage=LazyMemmapStorage( - max_size=capacity, scratch_dir="/tmp/", device=torch.device("cpu") + max_size=capacity, scratch_dir=scratch_dir, device=torch.device("cpu") ), sampler=RandomSampler(), writer=RoundRobinWriter(), diff --git a/test/test_rlhf.py b/test/test_rlhf.py index 2abb9a6d386..31ef96681df 100644 --- a/test/test_rlhf.py +++ b/test/test_rlhf.py @@ -14,7 +14,12 @@ import torch.nn.functional as F from _utils_internal import get_default_devices -from tensordict import is_tensor_collection, MemmapTensor, TensorDict, TensorDictBase +from tensordict import ( + is_tensor_collection, + MemoryMappedTensor, + TensorDict, + TensorDictBase, +) from tensordict.nn import TensorDictModule from torchrl.data.rlhf import TensorDictTokenizer from torchrl.data.rlhf.dataset import ( @@ -188,8 +193,8 @@ def test_dataset_to_tensordict(tmpdir, suffix): else: assert ("c", "d", "a") in td.keys(True) assert ("c", "d", "b") in td.keys(True) - assert isinstance(td.get((suffix, "a")), MemmapTensor) - assert isinstance(td.get((suffix, "b")), MemmapTensor) + assert isinstance(td.get((suffix, "a")), MemoryMappedTensor) + assert isinstance(td.get((suffix, "b")), MemoryMappedTensor) @pytest.mark.skipif( diff --git a/test/test_shared.py b/test/test_shared.py index c4790597359..186c8ae9525 100644 --- a/test/test_shared.py +++ b/test/test_shared.py @@ -144,24 +144,7 @@ def test_shared(self, shared): ) -# @pytest.mark.skipif( -# sys.platform == "win32", -# reason="RuntimeError from Torch serialization.py when creating td_saved on Windows", -# ) -@pytest.mark.parametrize( - "idx", - [ - torch.tensor( - [ - 3, - 5, - 7, - 8, - ] - ), - slice(200), - ], -) +@pytest.mark.parametrize("idx", [0, slice(200)]) @pytest.mark.parametrize("dtype", [torch.float, torch.bool]) def test_memmap(idx, dtype, large_scale=False): N = 5000 if large_scale else 10 diff --git a/test/test_transforms.py b/test/test_transforms.py index da8bc12c126..cff1d33b34a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -9,6 +9,7 @@ import itertools import pickle +import re import sys from copy import copy from functools import partial @@ -4878,6 +4879,39 @@ def test_sum_reward(self, keys, device): def test_transform_inverse(self): raise pytest.skip("No inverse for RewardSum") + @pytest.mark.parametrize("in_keys", [["reward"], ["reward_1", "reward_2"]]) + @pytest.mark.parametrize( + "out_keys", [["episode_reward"], ["episode_reward_1", "episode_reward_2"]] + ) + @pytest.mark.parametrize("reset_keys", [["_reset"], ["_reset1", "_reset2"]]) + def test_keys_length_errors(self, in_keys, reset_keys, out_keys, batch=10): + reset_dict = { + reset_key: torch.zeros(batch, dtype=torch.bool) for reset_key in reset_keys + } + reward_sum_dict = {out_key: torch.randn(batch) for out_key in out_keys} + reset_dict.update(reward_sum_dict) + td = TensorDict(reset_dict, []) + + if len(in_keys) != len(out_keys): + with pytest.raises( + ValueError, + match="RewardSum expects the same number of input and output keys", + ): + RewardSum(in_keys=in_keys, reset_keys=reset_keys, out_keys=out_keys) + else: + t = RewardSum(in_keys=in_keys, reset_keys=reset_keys, out_keys=out_keys) + + if len(in_keys) != len(reset_keys): + with pytest.raises( + ValueError, + match=re.escape( + f"Could not match the env reset_keys {reset_keys} with the in_keys {in_keys}" + ), + ): + t.reset(td) + else: + t.reset(td) + class TestReward2Go(TransformBase): @pytest.mark.parametrize("device", get_default_devices()) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index ef790b6f9f6..9c8417b9c97 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -12,7 +12,7 @@ import torch from tensordict import is_tensorclass -from tensordict.memmap import MemmapTensor +from tensordict.memmap import MemmapTensor, MemoryMappedTensor from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.utils import expand_right @@ -482,7 +482,7 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: if self.device == "auto": self.device = data.device if isinstance(data, torch.Tensor): - # if Tensor, we just create a MemmapTensor of the desired shape, device and dtype + # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype out = torch.empty( self.max_size, *data.shape, @@ -531,12 +531,12 @@ class LazyMemmapStorage(LazyTensorStorage): >>> storage.get(0) TensorDict( fields={ - some data: MemmapTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), + some data: MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), some: TensorDict( fields={ nested: TensorDict( fields={ - data: MemmapTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)}, + data: MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([11]), device=cpu, is_shared=False)}, @@ -560,8 +560,8 @@ class LazyMemmapStorage(LazyTensorStorage): >>> storage.set(range(10), data) >>> storage.get(0) MyClass( - bar=MemmapTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False), - foo=MemmapTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), + bar=MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False), + foo=MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), batch_size=torch.Size([11]), device=cpu, is_shared=False) @@ -603,7 +603,12 @@ def load_state_dict(self, state_dict): if isinstance(self._storage, torch.Tensor): _mem_map_tensor_as_tensor(self._storage).copy_(_storage) elif self._storage is None: - self._storage = MemmapTensor(_storage) + self._storage = _make_memmap( + _storage, + path=self.scratch_dir + "/tensor.memmap" + if self.scratch_dir is not None + else None, + ) else: raise RuntimeError( f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}" @@ -638,7 +643,8 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: self.device = data.device if self.device.type != "cpu": warnings.warn( - "Support for Memmap device other than CPU will be deprecated in v0.4.0.", + "Support for Memmap device other than CPU will be deprecated in v0.4.0. " + "Using a 'cuda' device may be suboptimal.", category=DeprecationWarning, ) if is_tensor_collection(data): @@ -656,9 +662,13 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: ) else: # If not a tensorclass/tensordict, it must be a tensor(-like) - # if Tensor, we just create a MemmapTensor of the desired shape, device and dtype - out = MemmapTensor( - self.max_size, *data.shape, device=self.device, dtype=data.dtype + # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype + out = _make_empty_memmap( + (self.max_size, *data.shape), + dtype=data.dtype, + path=self.scratch_dir + "/tensor.memmap" + if self.scratch_dir is not None + else None, ) if VERBOSE: filesize = os.path.getsize(out.filename) / 1024 / 1024 @@ -668,6 +678,13 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: self._storage = out self.initialized = True + def get(self, index: Union[int, Sequence[int], slice]) -> Any: + result = super().get(index) + # to be deprecated in v0.4 + if result.device != self.device: + return result.to(self.device, non_blocking=True) + return result + # Utils def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor: @@ -677,6 +694,7 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor: f"Supported backends are {_CKPT_BACKEND.backends}" ) if isinstance(mem_map_tensor, torch.Tensor): + # This will account for MemoryMappedTensors return mem_map_tensor if _CKPT_BACKEND == "torchsnapshot": # TorchSnapshot doesn't know how to stream MemmapTensor, so we view MemmapTensor @@ -737,25 +755,27 @@ def _collate_list_tensordict(x): return out -def _collate_contiguous(x): +def _collate_id(x): return x -def _collate_as_tensor(x): - return x.as_tensor() - - def _get_default_collate(storage, _is_tensordict=False): if isinstance(storage, ListStorage): if _is_tensordict: return _collate_list_tensordict else: return torch.utils.data._utils.collate.default_collate - elif isinstance(storage, LazyMemmapStorage): - return _collate_as_tensor - elif isinstance(storage, (TensorStorage,)): - return _collate_contiguous + elif isinstance(storage, TensorStorage): + return _collate_id else: raise NotImplementedError( f"Could not find a default collate_fn for storage {type(storage)}." ) + + +def _make_memmap(tensor, path): + return MemoryMappedTensor.from_tensor(tensor, filename=path) + + +def _make_empty_memmap(shape, dtype, path): + return MemoryMappedTensor.empty(shape=shape, dtype=dtype, filename=path) diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index db2b6a418d6..09086bfad65 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -77,8 +77,8 @@ class TokenizedDatasetLoader: >>> print(dataset) TensorDict( fields={ - attention_mask: MemmapTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False), - input_ids: MemmapTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)}, + attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False), + input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([185068]), device=None, is_shared=False) @@ -137,6 +137,7 @@ def load(self): data_dir = root_dir / str(Path(self.dataset_name).name).split("-")[0] data_dir_total = data_dir / split / str(max_length) # search for data + print("Looking for data in", data_dir_total) if os.path.exists(data_dir_total): dataset = TensorDict.load_memmap(data_dir_total) return dataset @@ -270,8 +271,8 @@ def dataset_to_tensordict( fields={ prefix: TensorDict( fields={ - labels: MemmapTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False), - tokens: MemmapTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)}, + labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False), + tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False)}, diff --git a/torchrl/data/rlhf/prompt.py b/torchrl/data/rlhf/prompt.py index d534a95379e..d50653c9967 100644 --- a/torchrl/data/rlhf/prompt.py +++ b/torchrl/data/rlhf/prompt.py @@ -74,10 +74,10 @@ def from_dataset( >>> data = PromptData.from_dataset("train") >>> print(data) PromptDataTLDR( - attention_mask=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), - input_ids=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), - prompt_rindex=MemmapTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False), - labels=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), + attention_mask=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), + input_ids=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), + prompt_rindex=MemoryMappedTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False), + labels=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False), logits=None, loss=None, batch_size=torch.Size([116722]), diff --git a/torchrl/data/rlhf/reward.py b/torchrl/data/rlhf/reward.py index e7843e02f46..20f379ef659 100644 --- a/torchrl/data/rlhf/reward.py +++ b/torchrl/data/rlhf/reward.py @@ -41,16 +41,16 @@ class PairwiseDataset: >>> print(data) PairwiseDataset( chosen_data=RewardData( - attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), - input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([92534]), device=None, is_shared=False), rejected_data=RewardData( - attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), - input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([92534]), @@ -97,16 +97,16 @@ def from_dataset( >>> print(data) PairwiseDataset( chosen_data=RewardData( - attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), - input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([92534]), device=None, is_shared=False), rejected_data=RewardData( - attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), - input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), + input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False), rewards=None, end_scores=None, batch_size=torch.Size([92534]), diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f0e132eb092..ac0a136c7f9 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -122,11 +122,16 @@ class _BatchedEnv(EnvBase): memmap (bool): whether or not the returned tensordict will be placed in memory map. policy_proof (callable, optional): if provided, it'll be used to get the list of tensors to return through the :obj:`step()` and :obj:`reset()` methods, such as :obj:`"hidden"` etc. - device (str, int, torch.device): for consistency, this argument is kept. However this - argument should not be passed, as the device will be inferred from the environments. - It is assumed that all environments will run on the same device as a common shared - tensordict will be used to pass data from process to process. The device can be - changed after instantiation using :obj:`env.to(device)`. + device (str, int, torch.device): The device of the batched environment can be passed. + If not, it is inferred from the env. In this case, it is assumed that + the device of all environments match. If it is provided, it can differ + from the sub-environment device(s). In that case, the data will be + automatically cast to the appropriate device during collection. + This can be used to speed up collection in case casting to device + introduces an overhead (eg, numpy-based environents etc.): by using + a ``"cuda"`` device for the batched environment but a ``"cpu"`` + device for the nested environments, one can keep the overhead to a + minimum. num_threads (int, optional): number of threads for this process. Defaults to the number of workers. This parameter has no effect for the :class:`~SerialEnv` class. @@ -162,14 +167,7 @@ def __init__( num_threads: int = None, num_sub_threads: int = 1, ): - if device is not None: - raise ValueError( - "Device setting for batched environment can't be done at initialization. " - "The device will be inferred from the constructed environment. " - "It can be set through the `to(device)` method." - ) - - super().__init__(device=None) + super().__init__(device=device) self.is_closed = True if num_threads is None: num_threads = num_workers + 1 # 1 more thread for this proc @@ -218,7 +216,7 @@ def __init__( "memmap and shared memory are mutually exclusive features." ) self._batch_size = None - self._device = None + self._device = torch.device(device) if device is not None else device self._dummy_env_str = None self._seeds = None self.__dict__["_input_spec"] = None @@ -273,7 +271,9 @@ def _set_properties(self): self._properties_set = True if self._single_task: self._batch_size = meta_data.batch_size - device = self._device = meta_data.device + device = meta_data.device + if self._device is None: + self._device = device input_spec = meta_data.specs["input_spec"].to(device) output_spec = meta_data.specs["output_spec"].to(device) @@ -289,8 +289,18 @@ def _set_properties(self): self._batch_locked = meta_data.batch_locked else: self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size]) - device = self._device = meta_data[0].device - # TODO: check that all action_spec and reward spec match (issue #351) + devices = set() + for _meta_data in meta_data: + device = _meta_data.device + devices.add(device) + if self._device is None: + if len(devices) > 1: + raise ValueError( + f"The device wasn't passed to {type(self)}, but more than one device was found in the sub-environments. " + f"Please indicate a device to be used for collection." + ) + device = list(devices)[0] + self._device = device input_spec = [] for md in meta_data: @@ -413,7 +423,7 @@ def _create_td(self) -> None: *(unravel_key(("next", key)) for key in self._env_output_keys), strict=False, ) - self.shared_tensordict_parent = shared_tensordict_parent.to(self.device) + self.shared_tensordict_parent = shared_tensordict_parent else: # Multi-task: we share tensordict that *may* have different keys shared_tensordict_parent = [ @@ -421,7 +431,7 @@ def _create_td(self) -> None: *self._selected_keys, *(unravel_key(("next", key)) for key in self._env_output_keys), strict=False, - ).to(self.device) + ) for tensordict in shared_tensordict_parent ] shared_tensordict_parent = torch.stack( @@ -440,13 +450,11 @@ def _create_td(self) -> None: # Multi-task: we share tensordict that *may* have different keys # LazyStacked already stores this so we don't need to do anything self.shared_tensordicts = self.shared_tensordict_parent - if self.device.type == "cpu": + if self.shared_tensordict_parent.device.type == "cpu": if self._share_memory: - for td in self.shared_tensordicts: - td.share_memory_() + self.shared_tensordict_parent.share_memory_() elif self._memmap: - for td in self.shared_tensordicts: - td.memmap_() + self.shared_tensordict_parent.memmap_() else: if self._share_memory: self.shared_tensordict_parent.share_memory_() @@ -483,7 +491,6 @@ def close(self) -> None: self.__dict__["_input_spec"] = None self.__dict__["_output_spec"] = None self._properties_set = False - self.event = None self._shutdown_workers() self.is_closed = True @@ -507,11 +514,6 @@ def to(self, device: DEVICE_TYPING): if device == self.device: return self self._device = device - self.meta_data = ( - self.meta_data.to(device) - if self._single_task - else [meta_data.to(device) for meta_data in self.meta_data] - ) if not self.is_closed: warn( "Casting an open environment to another device requires closing and re-opening it. " @@ -543,7 +545,7 @@ def _start_workers(self) -> None: for idx in range(_num_workers): env = self.create_env_fn[idx](**self.create_env_kwargs[idx]) - self._envs.append(env.to(self.device)) + self._envs.append(env) self.is_closed = False @_check_start @@ -603,29 +605,39 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: if tensordict_.is_empty(): tensordict_ = None else: - # reset will do modifications in-place. We want the original - # tensorict to be unchaned, so we clone it - tensordict_ = tensordict_.clone(False) + env_device = _env.device + if env_device != self.device: + tensordict_ = tensordict_.to(env_device) + else: + tensordict_ = tensordict_.clone(False) else: tensordict_ = None + _td = _env.reset(tensordict=tensordict_, **kwargs) self.shared_tensordicts[i].update_( _td.select(*self._selected_reset_keys_filt, strict=False) ) selected_output_keys = self._selected_reset_keys_filt + device = self.device if self._single_task: # select + clone creates 2 tds, but we can create one only out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in selected_output_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) - return out + _set_single_key( + self.shared_tensordict_parent, out, key, clone=True, device=device + ) else: - return self.shared_tensordict_parent.select( + out = self.shared_tensordict_parent.select( *selected_output_keys, strict=False, - ).clone() + ) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) + return out def _reset_proc_data(self, tensordict, tensordict_reset): # since we call `reset` directly, all the postproc has been completed @@ -643,19 +655,29 @@ def _step( for i in range(self.num_workers): # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. - out_td = self._envs[i]._step(tensordict_in[i]) + env_device = self._envs[i].device + if env_device != self.device: + data_in = tensordict_in[i].to(env_device, non_blocking=True) + else: + data_in = tensordict_in[i] + out_td = self._envs[i]._step(data_in) next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps + device = self.device if self._single_task: out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True, device=device) else: # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False).clone() + out = next_td.select(*self._selected_step_keys, strict=False) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) return out def __getattr__(self, attr: str) -> Any: @@ -710,6 +732,32 @@ class ParallelEnv(_BatchedEnv): """ __doc__ += _BatchedEnv.__doc__ + __doc__ += """ + + .. note:: + The choice of the devices where ParallelEnv needs to be executed can + drastically influence its performance. The rule of thumbs is: + + - If the base environment (backend, e.g., Gym) is executed on CPU, the + sub-environments should be executed on CPU and the data should be + passed via shared physical memory. + - If the base environment is (or can be) executed on CUDA, the sub-environments + should be placed on CUDA too. + - If a CUDA device is available and the policy is to be executed on CUDA, + the ParallelEnv device should be set to CUDA. + + Therefore, supposing a CUDA device is available, we have the following scenarios: + + >>> # The sub-envs are executed on CPU, but the policy is on GPU + >>> env = ParallelEnv(N, MyEnv(..., device="cpu"), device="cuda") + >>> # The sub-envs are executed on CUDA + >>> env = ParallelEnv(N, MyEnv(..., device="cuda"), device="cuda") + >>> # this will create the exact same environment + >>> env = ParallelEnv(N, MyEnv(..., device="cuda")) + >>> # If no cuda device is available + >>> env = ParallelEnv(N, MyEnv(..., device="cpu")) + + """ def _start_workers(self) -> None: from torchrl.envs.env_creator import EnvCreator @@ -722,39 +770,39 @@ def _start_workers(self) -> None: self.parent_channels = [] self._workers = [] - self._events = [] - if self.device.type == "cuda": + func = _run_worker_pipe_shared_mem + if self.shared_tensordict_parent.device.type == "cuda": self.event = torch.cuda.Event() else: self.event = None + self._events = [ctx.Event() for _ in range(_num_workers)] + kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)] with clear_mpi_env_vars(): for idx in range(_num_workers): if self._verbose: print(f"initiating worker {idx}") # No certainty which module multiprocessing_context is parent_pipe, child_pipe = ctx.Pipe() - event = ctx.Event() - self._events.append(event) env_fun = self.create_env_fn[idx] if not isinstance(env_fun, EnvCreator): env_fun = CloudpickleWrapper(env_fun) - + kwargs[idx].update( + { + "parent_pipe": parent_pipe, + "child_pipe": child_pipe, + "env_fun": env_fun, + "env_fun_kwargs": self.create_env_kwargs[idx], + "shared_tensordict": self.shared_tensordicts[idx], + "_selected_input_keys": self._selected_input_keys, + "_selected_reset_keys": self._selected_reset_keys, + "_selected_step_keys": self._selected_step_keys, + "has_lazy_inputs": self.has_lazy_inputs, + } + ) process = _ProcessNoWarn( - target=_run_worker_pipe_shared_mem, + target=func, num_threads=self.num_sub_threads, - args=( - parent_pipe, - child_pipe, - env_fun, - self.create_env_kwargs[idx], - self.device, - event, - self.shared_tensordicts[idx], - self._selected_input_keys, - self._selected_reset_keys, - self._selected_step_keys, - self.has_lazy_inputs, - ), + kwargs=kwargs[idx], ) process.daemon = True process.start() @@ -834,10 +882,16 @@ def step_and_maybe_reset( # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - tensordict.set("next", self.shared_tensordict_parent.get("next").clone()) - tensordict_ = self.shared_tensordict_parent.exclude( - "next", *self.reset_keys - ).clone() + next_td = self.shared_tensordict_parent.get("next") + tensordict_ = self.shared_tensordict_parent.exclude("next", *self.reset_keys) + device = self.device + if self.shared_tensordict_parent.device == device: + next_td = next_td.clone() + tensordict_ = tensordict_.clone() + else: + next_td = next_td.to(device, non_blocking=True) + tensordict_ = tensordict_.to(device, non_blocking=True) + tensordict.set("next", next_td) return tensordict, tensordict_ @_check_start @@ -880,15 +934,20 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps next_td = self.shared_tensordict_parent.get("next") + device = self.device if self._single_task: out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True, device=device) else: # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False).clone() + out = next_td.select(*self._selected_step_keys, strict=False) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) return out @_check_start @@ -944,19 +1003,26 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: event.clear() selected_output_keys = self._selected_reset_keys_filt + device = self.device if self._single_task: # select + clone creates 2 tds, but we can create one only out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in selected_output_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) - return out + _set_single_key( + self.shared_tensordict_parent, out, key, clone=True, device=device + ) else: - return self.shared_tensordict_parent.select( + out = self.shared_tensordict_parent.select( *selected_output_keys, strict=False, - ).clone() + ) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) + return out @_check_start def _shutdown_workers(self) -> None: @@ -981,6 +1047,7 @@ def _shutdown_workers(self) -> None: del self.parent_channels self._cuda_events = None self._events = None + self.event = None @_check_start def set_seed( @@ -1063,7 +1130,6 @@ def _run_worker_pipe_shared_mem( child_pipe: connection.Connection, env_fun: Union[EnvBase, Callable], env_fun_kwargs: Dict[str, Any], - device: DEVICE_TYPING = None, mp_event: mp.Event = None, shared_tensordict: TensorDictBase = None, _selected_input_keys=None, @@ -1072,13 +1138,11 @@ def _run_worker_pipe_shared_mem( has_lazy_inputs: bool = False, verbose: bool = False, ) -> None: - if device is None: - device = torch.device("cpu") + device = shared_tensordict.device if device.type == "cuda": event = torch.cuda.Event() else: event = None - parent_pipe.close() pid = os.getpid() if not isinstance(env_fun, EnvBase): @@ -1089,7 +1153,6 @@ def _run_worker_pipe_shared_mem( "env_fun_kwargs must be empty if an environment is passed to a process." ) env = env_fun - env = env.to(device) del env_fun i = -1 @@ -1144,7 +1207,8 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - next_td = env._step(shared_tensordict) + env_input = shared_tensordict + next_td = env._step(env_input) next_shared_tensordict.update_(next_td) if event is not None: event.record() @@ -1155,7 +1219,8 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) + env_input = shared_tensordict + td, root_next_td = env.step_and_maybe_reset(env_input) next_shared_tensordict.update_(td.get("next")) root_shared_tensordict.update_(root_next_td) if event is not None: @@ -1208,3 +1273,10 @@ def _run_worker_pipe_shared_mem( else: # don't send env through pipe child_pipe.send(("_".join([cmd, "done"]), None)) + + +def _update_cuda(t_dest, t_source): + if t_source is None: + return + t_dest.copy_(t_source.pin_memory(), non_blocking=True) + return diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py index f3a9f2aa469..5645785117d 100644 --- a/torchrl/envs/transforms/gym_transforms.py +++ b/torchrl/envs/transforms/gym_transforms.py @@ -148,7 +148,7 @@ def _step(self, tensordict, next_tensordict): lives = self._get_lives() end_of_life = torch.tensor( - tensordict.get(self.lives_key) < lives, device=self.parent.device + tensordict.get(self.lives_key) > lives, device=self.parent.device ) try: done = next_tensordict.get(self.done_key) diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 240c1029486..79ee94318cb 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -5,18 +5,13 @@ from copy import copy, deepcopy import torch -from tensordict import TensorDictBase, unravel_key -from tensordict.nn import ( - make_functional, - ProbabilisticTensorDictModule, - repopulate_module, - TensorDictParams, -) +from tensordict import TensorDict, TensorDictBase, unravel_key +from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams from tensordict.utils import is_seq_of_nested_key from torch import nn from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec from torchrl.envs.transforms.transforms import Transform -from torchrl.envs.transforms.utils import _set_missing_tolerance +from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param class KLRewardTransform(Transform): @@ -116,11 +111,10 @@ def __init__( self.in_keys = self.in_keys + actor.in_keys # check that the model has parameters - params = make_functional( - actor, keep_params=False, funs_to_decorate=["forward", "get_dist"] - ) - self.functional_actor = deepcopy(actor) - repopulate_module(actor, params) + params = TensorDict.from_module(actor) + with params.apply(_stateless_param, device="meta").to_module(actor): + # copy a stateless actor + self.__dict__["functional_actor"] = deepcopy(actor) # we need to register these params as buffer to have `to` and similar # methods work properly @@ -170,9 +164,8 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: if self.out_keys[0] != ("reward",) and self.parent is not None: tensordict.set(self.out_keys[0], self.parent.reward_spec.zero()) return tensordict - dist = self.functional_actor.get_dist( - tensordict.clone(False), params=self.frozen_params - ) + with self.frozen_params.to_module(self.functional_actor): + dist = self.functional_actor.get_dist(tensordict.clone(False)) # get the log_prob given the original model log_prob = dist.log_prob(action) reward_key = self.in_keys[0] diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 3e6d597dffd..de8baf2e403 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4692,6 +4692,7 @@ def __init__( """Initialises the transform. Filters out non-reward input keys and defines output keys.""" super().__init__(in_keys=in_keys, out_keys=out_keys) self._reset_keys = reset_keys + self._keys_checked = False @property def in_keys(self): @@ -4770,9 +4771,7 @@ def _check_match(reset_keys, in_keys): return False return True - if len(reset_keys) != len(self.in_keys) or not _check_match( - reset_keys, self.in_keys - ): + if not _check_match(reset_keys, self.in_keys): raise ValueError( f"Could not match the env reset_keys {reset_keys} with the {type(self)} in_keys {self.in_keys}. " f"Please provide the reset_keys manually. Reset entries can be " @@ -4781,6 +4780,14 @@ def _check_match(reset_keys, in_keys): ) reset_keys = copy(reset_keys) self._reset_keys = reset_keys + + if not self._keys_checked and len(reset_keys) != len(self.in_keys): + raise ValueError( + f"Could not match the env reset_keys {reset_keys} with the in_keys {self.in_keys}. " + "Please make sure that these have the same length." + ) + self._keys_checked = True + return reset_keys @reset_keys.setter diff --git a/torchrl/envs/transforms/utils.py b/torchrl/envs/transforms/utils.py index a99c22a87da..a1b30cb1aca 100644 --- a/torchrl/envs/transforms/utils.py +++ b/torchrl/envs/transforms/utils.py @@ -5,6 +5,7 @@ import torch +from torch import nn def check_finite(tensor: torch.Tensor): @@ -59,3 +60,11 @@ def _get_reset(reset_key, tensordict): if _reset.ndim > parent_td.ndim: _reset = _reset.flatten(parent_td.ndim, -1).any(-1) return _reset + + +def _stateless_param(param): + is_param = isinstance(param, nn.Parameter) + param = param.data.to("meta") + if is_param: + return nn.Parameter(param, requires_grad=False) + return param diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 06eec73be97..9a2a71f24bd 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -237,7 +237,11 @@ def step_mdp( def _set_single_key( - source: TensorDictBase, dest: TensorDictBase, key: str | tuple, clone: bool = False + source: TensorDictBase, + dest: TensorDictBase, + key: str | tuple, + clone: bool = False, + device=None, ): # key should be already unraveled if isinstance(key, str): @@ -253,7 +257,9 @@ def _set_single_key( source = val dest = new_val else: - if clone: + if device is not None and val.device != device: + val = val.to(device, non_blocking=True) + elif clone: val = val.clone() dest._set_str(k, val, inplace=False, validated=True) # This is a temporary solution to understand if a key is heterogeneous @@ -262,7 +268,7 @@ def _set_single_key( if re.match(r"Found more than one unique shape in the tensors", str(err)): # this is a het key for s_td, d_td in zip(source.tensordicts, dest.tensordicts): - _set_single_key(s_td, d_td, k, clone) + _set_single_key(s_td, d_td, k, clone=clone, device=device) break else: raise err diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 1e5a557546a..bf81cfd5dfd 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -183,7 +183,7 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from tensordict.nn import TensorDictModule, make_functional + >>> from tensordict.nn import TensorDictModule >>> from torchrl.data import BoundedTensorSpec >>> from torchrl.modules import ProbabilisticActor, NormalParamWrapper, TanhNormal >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) @@ -197,8 +197,9 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): ... in_keys=["loc", "scale"], ... distribution_class=TanhNormal, ... ) - >>> params = make_functional(td_module) - >>> td = td_module(td, params=params) + >>> params = TensorDict.from_module(td_module) + >>> with params.to_module(td_module): + ... td = td_module(td) >>> td TensorDict( fields={ @@ -319,7 +320,6 @@ class ValueOperator(TensorDictModule): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from tensordict.nn import make_functional >>> from torch import nn >>> from torchrl.data import UnboundedContinuousTensorSpec >>> from torchrl.modules import ValueOperator @@ -334,8 +334,9 @@ class ValueOperator(TensorDictModule): >>> td_module = ValueOperator( ... in_keys=["observation", "action"], module=module ... ) - >>> params = make_functional(td_module) - >>> td = td_module(td, params=params) + >>> params = TensorDict.from_module(td_module) + >>> with params.to_module(td_module): + ... td = td_module(td) >>> print(td) TensorDict( fields={ @@ -792,7 +793,6 @@ class QValueHook: Examples: >>> import torch >>> from tensordict import TensorDict - >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import OneHotDiscreteTensorSpec >>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor @@ -878,7 +878,6 @@ class DistributionalQValueHook(QValueHook): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import OneHotDiscreteTensorSpec >>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor @@ -893,12 +892,13 @@ class DistributionalQValueHook(QValueHook): ... return self.linear(x).view(-1, nbins, 4).log_softmax(-2) ... >>> module = CustomDistributionalQval() - >>> params = make_functional(module) + >>> params = TensorDict.from_module(module) >>> action_spec = OneHotDiscreteTensorSpec(4) >>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins)) >>> module.register_forward_hook(hook) >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"]) - >>> qvalue_actor(td, params=params) + >>> with params.to_module(module): + ... qvalue_actor(td) >>> print(td) TensorDict( fields={ @@ -992,7 +992,6 @@ class QValueActor(SafeSequential): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import OneHotDiscreteTensorSpec >>> from torchrl.modules.tensordict_module.actors import QValueActor diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index c5f34a7774d..22786519681 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -138,7 +138,6 @@ class SafeModule(TensorDictModule): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from tensordict.nn.functional_modules import make_functional >>> from torchrl.data import UnboundedContinuousTensorSpec >>> from torchrl.modules import TensorDictModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) @@ -150,8 +149,9 @@ class SafeModule(TensorDictModule): ... in_keys=["input", "hidden"], ... out_keys=["output"], ... ) - >>> params = make_functional(td_fmodule) - >>> td_functional = td_fmodule(td.clone(), params=params) + >>> params = TensorDict.from_module(td_fmodule) + >>> with params.to_module(td_module): + ... td_functional = td_fmodule(td.clone()) >>> print(td_functional) TensorDict( fields={ diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 46f71e2b3d6..5c8ae799061 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -110,7 +110,7 @@ def __init__( self.register_buffer("eps_init", torch.tensor([eps_init])) self.register_buffer("eps_end", torch.tensor([eps_end])) self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.tensor([eps_init])) + self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32)) if spec is not None: if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: @@ -259,7 +259,7 @@ def __init__( if self.eps_end > self.eps_init: raise RuntimeError("eps should decrease over time or be constant") self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.tensor([eps_init])) + self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32)) self.action_key = action_key self.action_mask_key = action_mask_key if spec is not None: @@ -405,7 +405,7 @@ def __init__( self.annealing_num_steps = annealing_num_steps self.register_buffer("mean", torch.tensor([mean])) self.register_buffer("std", torch.tensor([std])) - self.register_buffer("sigma", torch.tensor([sigma_init])) + self.register_buffer("sigma", torch.tensor([sigma_init], dtype=torch.float32)) self.action_key = action_key self.out_keys = list(self.td_module.out_keys) if action_key not in self.out_keys: @@ -613,7 +613,7 @@ def __init__( f"got eps_init={eps_init} and eps_end={eps_end}" ) self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.tensor([eps_init])) + self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32)) self.out_keys = list(self.td_module.out_keys) + self.ou.out_keys self.is_init_key = is_init_key noise_key = self.ou.noise_key diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 71167c5106f..28f721ba6a1 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -33,7 +33,6 @@ class SafeSequential(TensorDictSequential, SafeModule): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from tensordict.nn.functional_modules import make_functional >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec >>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamWrapper >>> from torchrl.modules.tensordict_module import SafeProbabilisticModule @@ -58,8 +57,9 @@ class SafeSequential(TensorDictSequential, SafeModule): ... out_keys=["output"], ... ) >>> td_module = SafeSequential(td_module1, td_module2) - >>> params = make_functional(td_module) - >>> td_module(td, params=params) + >>> params = TensorDict.from_module(td_module) + >>> with params.to_module(td_module): + ... td_module(td) >>> print(td) TensorDict( fields={ diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index bb7b9014f0d..4384ccef282 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -3,11 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import warnings +from copy import deepcopy from dataclasses import dataclass from typing import Tuple import torch -from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import ( + dispatch, + ProbabilisticTensorDictSequential, + repopulate_module, + TensorDictModule, +) from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -20,7 +26,13 @@ distance_loss, ValueEstimators, ) -from torchrl.objectives.value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator +from torchrl.objectives.value import ( + GAE, + TD0Estimator, + TD1Estimator, + TDLambdaEstimator, + VTrace, +) class A2CLoss(LossModule): @@ -202,6 +214,7 @@ class _AcceptedKeys: reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + sample_log_prob: NestedKey = "sample_log_prob" default_keys = _AcceptedKeys() default_value_estimator: ValueEstimators = ValueEstimators.GAE @@ -314,8 +327,8 @@ def _log_probs( f"tensordict stored {self.tensor_keys.action} require grad." ) tensordict_clone = tensordict.select(*self.actor.in_keys).clone() - - dist = self.actor.get_dist(tensordict_clone, params=self.actor_params) + with self.actor_params.to_module(self.actor): + dist = self.actor.get_dist(tensordict_clone) log_prob = dist.log_prob(action) log_prob = log_prob.unsqueeze(-1) return log_prob, dist @@ -326,10 +339,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: # overhead that we could easily reduce. target_return = tensordict.get(self.tensor_keys.value_target) tensordict_select = tensordict.select(*self.critic.in_keys) - state_value = self.critic( - tensordict_select, - params=self.critic_params, - ).get(self.tensor_keys.value) + with self.critic_params.to_module(self.critic): + state_value = self.critic( + tensordict_select, + ).get(self.tensor_keys.value) loss_value = distance_loss( target_return, state_value, @@ -361,6 +374,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: target_params=self.target_critic_params, ) advantage = tensordict.get(self.tensor_keys.advantage) + assert not advantage.requires_grad log_probs, dist = self._log_probs(tensordict) loss = -(log_probs * advantage) td_out = TensorDict({"loss_objective": loss.mean()}, []) @@ -379,6 +393,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self.value_type = value_type hp = dict(default_value_kwargs(value_type)) hp.update(hyperparams) + if hasattr(self, "gamma"): hp["gamma"] = self.gamma if value_type == ValueEstimators.TD1: @@ -389,6 +404,14 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = GAE(value_network=self.critic, **hp) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + elif value_type == ValueEstimators.VTrace: + # VTrace currently does not support functional call on the actor + actor_with_params = repopulate_module( + deepcopy(self.actor), self.actor_params + ) + self._value_estimator = VTrace( + value_network=self.critic, actor_network=actor_with_params, **hp + ) else: raise NotImplementedError(f"Unknown value type {value_type}") @@ -399,5 +422,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, "terminated": self.tensor_keys.terminated, + "sample_log_prob": self.tensor_keys.sample_log_prob, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index bdccbda3808..00ba8cf456a 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -10,15 +10,10 @@ from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple -from tensordict import TensorDictBase - -from tensordict.nn import ( - make_functional, - repopulate_module, - TensorDictModule, - TensorDictModuleBase, - TensorDictParams, -) +import torch +from tensordict import TensorDict, TensorDictBase + +from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams from torch import nn from torch.nn import Parameter @@ -87,7 +82,7 @@ class _AcceptedKeys: pass default_value_estimator: ValueEstimators = None - SEP = "_sep_" + SEP = "." TARGET_NET_WARNING = ( "No target network updater has been associated " "with this loss module, but target parameters have been found. " @@ -138,7 +133,7 @@ def set_keys(self, **kwargs) -> None: """ for key, value in kwargs.items(): if key not in self._AcceptedKeys.__dict__: - raise ValueError(f"{key} it not an accepted tensordict key") + raise ValueError(f"{key} is not an accepted tensordict key") if value is not None: setattr(self.tensor_keys, key, value) else: @@ -178,21 +173,15 @@ def convert_to_functional( expand_dim: Optional[int] = None, create_target_params: bool = False, compare_against: Optional[List[Parameter]] = None, - funs_to_decorate=None, + **kwargs, ) -> None: """Converts a module to functional to be used in the loss. Args: module (TensorDictModule or compatible): a stateful tensordict module. - This module will be made functional, yet still stateful, meaning - that it will be callable with the following alternative signatures: - - >>> module(tensordict) - >>> module(tensordict, params=params) - - ``params`` is a :class:`tensordict.TensorDict` instance with parameters - stuctured as the output of :func:`tensordict.nn.make_functional` - is. + Parameters from this module will be isolated in the `_params` + attribute and a stateless version of the module will be registed + under the `module_name` attribute. module_name (str): name where the module will be found. The parameters of the module will be found under ``loss_module._params`` whereas the module will be found under ``loss_module.``. @@ -223,45 +212,27 @@ def convert_to_functional( the resulting parameters will be a detached version of the original parameters. If ``None``, the resulting parameters will carry gradients as expected. - funs_to_decorate (list of str, optional): if provided, the list of - methods of ``module`` to make functional, ie the list of - methods that will accept the ``params`` keyword argument. """ - if funs_to_decorate is None: - funs_to_decorate = ["forward"] + if kwargs.pop("funs_to_decorate", None) is not None: + warnings.warn( + "funs_to_decorate is without effect with the new objective API.", + category=DeprecationWarning, + ) + if kwargs: + raise TypeError(f"Unrecognised keyword arguments {list(kwargs.keys())}") # To make it robust to device casting, we must register list of # tensors as lazy calls to `getattr(self, name_of_tensor)`. # Otherwise, casting the module to a device will keep old references # to uncast tensors sep = self.SEP - params = make_functional(module, funs_to_decorate=funs_to_decorate) - # buffer_names = next(itertools.islice(zip(*module.named_buffers()), 1)) - buffer_names = [] - for key, value in params.items(True, True): - # we just consider all that is not param as a buffer, but if the module has been made - # functional and the params have been replaced this may break - if not isinstance(value, nn.Parameter): - key = sep.join(key) if not isinstance(key, str) else key - buffer_names.append(key) - functional_module = deepcopy(module) - repopulate_module(module, params) - - params_and_buffers = params - # we transform the buffers in params to make sure they follow the device - # as tensor = nn.Parameter(tensor) keeps its identity when moved to another device - - # separate params and buffers - params_and_buffers = TensorDictParams(params_and_buffers, no_convert=True) - # sanity check - for key in params_and_buffers.keys(True): + params = TensorDict.from_module(module, as_module=True) + + for key in params.keys(True): if sep in key: raise KeyError( f"The key {key} contains the '_sep_' pattern which is prohibited. Consider renaming the parameter / buffer." ) - params_and_buffers_flat = params_and_buffers.flatten_keys(sep) - buffers = params_and_buffers_flat.select(*buffer_names) - params = params_and_buffers_flat.exclude(*buffer_names) if compare_against is not None: compare_against = set(compare_against) else: @@ -273,6 +244,9 @@ def convert_to_functional( # For buffers, a cloned expansion (or equivalently a repeat) is returned. def _compare_and_expand(param): + if not isinstance(param, nn.Parameter): + buffer = param.expand(expand_dim, *param.shape).clone() + return buffer if param in compare_against: expanded_param = param.data.expand(expand_dim, *param.shape) # the expanded parameter must be sent to device when to() @@ -287,45 +261,40 @@ def _compare_and_expand(param): ) return p_out - params = params.apply( - _compare_and_expand, batch_size=[expand_dim, *params.shape] - ) - - buffers = buffers.apply( - lambda buffer: buffer.expand(expand_dim, *buffer.shape).clone(), - batch_size=[expand_dim, *buffers.shape], + params = TensorDictParams( + params.apply( + _compare_and_expand, batch_size=[expand_dim, *params.shape] + ), + no_convert=True, ) - params_and_buffers.update(params.unflatten_keys(sep)) - params_and_buffers.update(buffers.unflatten_keys(sep)) - params_and_buffers.batch_size = params.batch_size - - # self.params_to_map = params_to_map - param_name = module_name + "_params" prev_set_params = set(self.parameters()) # register parameters and buffers - for key, parameter in list(params_and_buffers.items(True, True)): + for key, parameter in list(params.items(True, True)): if parameter not in prev_set_params: pass elif compare_against is not None and parameter in compare_against: - params_and_buffers.set(key, parameter.data) + params.set(key, parameter.data) - setattr(self, param_name, params_and_buffers) + setattr(self, param_name, params) - # set the functional module - setattr(self, module_name, functional_module) + # set the functional module: we need to convert the params to non-differentiable params + # otherwise they will appear twice in parameters + with params.apply( + self._make_meta_params, device=torch.device("meta") + ).to_module(module): + # avoid buffers and params being exposed + self.__dict__[module_name] = deepcopy(module) name_params_target = "target_" + module_name if create_target_params: # if create_target_params: # we create a TensorDictParams to keep the target params as Buffer instances target_params = TensorDictParams( - params_and_buffers.apply( - _make_target_param(clone=create_target_params) - ), + params.apply(_make_target_param(clone=create_target_params)), no_convert=True, ) setattr(self, name_params_target + "_params", target_params) @@ -447,6 +416,10 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." ) + elif value_type == ValueEstimators.VTrace: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) elif value_type == ValueEstimators.TDLambda: raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." @@ -454,86 +427,18 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams else: raise NotImplementedError(f"Unknown value type {value_type}") - # def _apply(self, fn, recurse=True): - # """Modifies torch.nn.Module._apply to work with Buffer class.""" - # if recurse: - # for module in self.children(): - # module._apply(fn) - # - # def compute_should_use_set_data(tensor, tensor_applied): - # if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): - # # If the new tensor has compatible tensor type as the existing tensor, - # # the current behavior is to change the tensor in-place using `.data =`, - # # and the future behavior is to overwrite the existing tensor. However, - # # changing the current behavior is a BC-breaking change, and we want it - # # to happen in future releases. So for now we introduce the - # # `torch.__future__.get_overwrite_module_params_on_conversion()` - # # global flag to let the user control whether they want the future - # # behavior of overwriting the existing tensor or not. - # return not torch.__future__.get_overwrite_module_params_on_conversion() - # else: - # return False - # - # for key, param in self._parameters.items(): - # if param is None: - # continue - # # Tensors stored in modules are graph leaves, and we don't want to - # # track autograd history of `param_applied`, so we have to use - # # `with torch.no_grad():` - # with torch.no_grad(): - # param_applied = fn(param) - # should_use_set_data = compute_should_use_set_data(param, param_applied) - # if should_use_set_data: - # param.data = param_applied - # out_param = param - # else: - # assert isinstance(param, Parameter) - # assert param.is_leaf - # out_param = Parameter(param_applied, param.requires_grad) - # self._parameters[key] = out_param - # - # if param.grad is not None: - # with torch.no_grad(): - # grad_applied = fn(param.grad) - # should_use_set_data = compute_should_use_set_data(param.grad, grad_applied) - # if should_use_set_data: - # assert out_param.grad is not None - # out_param.grad.data = grad_applied - # else: - # assert param.grad.is_leaf - # out_param.grad = grad_applied.requires_grad_(param.grad.requires_grad) - # - # for key, buffer in self._buffers.items(): - # if buffer is None: - # continue - # # Tensors stored in modules are graph leaves, and we don't want to - # # track autograd history of `buffer_applied`, so we have to use - # # `with torch.no_grad():` - # with torch.no_grad(): - # buffer_applied = fn(buffer) - # should_use_set_data = compute_should_use_set_data(buffer, buffer_applied) - # if should_use_set_data: - # buffer.data = buffer_applied - # out_buffer = buffer - # else: - # assert isinstance(buffer, Buffer) - # assert buffer.is_leaf - # out_buffer = Buffer(buffer_applied, buffer.requires_grad) - # self._buffers[key] = out_buffer - # - # if buffer.grad is not None: - # with torch.no_grad(): - # grad_applied = fn(buffer.grad) - # should_use_set_data = compute_should_use_set_data(buffer.grad, grad_applied) - # if should_use_set_data: - # assert out_buffer.grad is not None - # out_buffer.grad.data = grad_applied - # else: - # assert buffer.grad.is_leaf - # out_buffer.grad = grad_applied.requires_grad_(buffer.grad.requires_grad) - return self + @staticmethod + def _make_meta_params(param): + is_param = isinstance(param, nn.Parameter) + + pd = param.detach().to("meta") + + if is_param: + pd = nn.Parameter(pd, requires_grad=False) + return pd + class _make_target_param: def __init__(self, clone): diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 86d5461b15f..9b2c5eace7a 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import math import warnings +from copy import deepcopy from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -26,6 +27,7 @@ from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_WARNING, + _vmap_func, default_value_kwargs, distance_loss, ValueEstimators, @@ -33,18 +35,6 @@ from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -try: - try: - from torch import vmap - except ImportError: - from functorch import vmap - - _has_functorch = True - err = "" -except ImportError as err: - _has_functorch = False - FUNCTORCH_ERROR = err - class CQLLoss(LossModule): """TorchRL implementation of the continuous CQL loss. @@ -281,8 +271,6 @@ def __init__( lagrange_thresh: float = 0.0, ) -> None: self._out_keys = None - if not _has_functorch: - raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR super().__init__() # Actor @@ -291,7 +279,6 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist"], ) # Q value @@ -362,8 +349,8 @@ def __init__( torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)), ) - self._vmap_qvalue_networkN0 = vmap(self.qvalue_network, (None, 0)) - self._vmap_qvalue_network00 = vmap(self.qvalue_network) + self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) + self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network) @property def target_entropy(self): @@ -590,8 +577,10 @@ def _get_policy_actions(self, data, actor_params, num_actions=10): batch_size=batch_size, ) with torch.no_grad(): - with set_exploration_type(ExplorationType.RANDOM): - dist = self.actor_network.get_dist(tensordict, params=actor_params) + with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module( + self.actor_network + ): + dist = self.actor_network.get_dist(tensordict) action = dist.rsample() tensordict.set(self.tensor_keys.action, action) sample_log_prob = dist.log_prob(action) @@ -607,11 +596,11 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params): tensordict = tensordict.clone(False) # get actions and log-probs with torch.no_grad(): - with set_exploration_type(ExplorationType.RANDOM): + with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module( + self.actor_network + ): next_tensordict = tensordict.get("next").clone(False) - next_dist = self.actor_network.get_dist( - next_tensordict, params=actor_params - ) + next_dist = self.actor_network.get_dist(next_tensordict) next_action = next_dist.rsample() next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = next_dist.log_prob(next_action) @@ -1066,7 +1055,8 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self.value_type = value_type # we will take care of computing the next value inside this module - value_net = self.value_network + value_net = deepcopy(self.value_network) + self.value_network_params.to_module(value_net, return_swap=False) hp = dict(default_value_kwargs(value_type)) hp.update(hyperparams) @@ -1117,10 +1107,8 @@ def value_loss( tensordict: TensorDictBase, ) -> Tuple[torch.Tensor, dict]: td_copy = tensordict.clone(False) - self.value_network( - td_copy, - params=self.value_network_params, - ) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) action = tensordict.get(self.tensor_keys.action) pred_val = td_copy.get(self.tensor_keys.action_value) @@ -1137,8 +1125,7 @@ def value_loss( # calculate target value with torch.no_grad(): target_value = self.value_estimator.value_estimate( - td_copy, - target_params=self._cached_detached_target_value_params, + td_copy, params=self._cached_detached_target_value_params ).squeeze(-1) with torch.no_grad(): diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 1795f785716..3b4debe6259 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -11,7 +11,7 @@ from typing import Tuple import torch -from tensordict.nn import dispatch, make_functional, repopulate_module, TensorDictModule +from tensordict.nn import dispatch, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey, unravel_key @@ -197,10 +197,10 @@ def __init__( self.delay_value = delay_value actor_critic = ActorCriticWrapper(actor_network, value_network) - params = make_functional(actor_critic) - self.actor_critic = deepcopy(actor_critic) - repopulate_module(actor_network, params["module", "0"]) - repopulate_module(value_network, params["module", "1"]) + params = TensorDict.from_module(actor_critic) + params_meta = params.apply(self._make_meta_params, device=torch.device("meta")) + with params_meta.to_module(actor_critic): + self.__dict__["actor_critic"] = deepcopy(actor_critic) self.convert_to_functional( actor_network, @@ -295,14 +295,10 @@ def loss_actor( td_copy = tensordict.select( *self.actor_in_keys, *self.value_exclusive_keys ).detach() - td_copy = self.actor_network( - td_copy, - params=self.actor_network_params, - ) - td_copy = self.value_network( - td_copy, - params=self._cached_detached_value_params, - ) + with self.actor_network_params.to_module(self.actor_network): + td_copy = self.actor_network(td_copy) + with self._cached_detached_value_params.to_module(self.value_network): + td_copy = self.value_network(td_copy) loss_actor = -td_copy.get(self.tensor_keys.state_action_value) metadata = {} return loss_actor.mean(), metadata @@ -313,10 +309,8 @@ def loss_value( ) -> Tuple[torch.Tensor, dict]: # value loss td_copy = tensordict.select(*self.value_network.in_keys).detach() - self.value_network( - td_copy, - params=self.value_network_params, - ) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) pred_val = td_copy.get(self.tensor_keys.state_action_value).squeeze(-1) target_value = self.value_estimator.value_estimate( diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index db3cf633aef..52339d583dd 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -88,7 +88,6 @@ def __init__( actor_network, "actor_network", create_target_params=False, - funs_to_decorate=["forward", "get_dist"], ) try: device = next(self.parameters()).device @@ -208,9 +207,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if target_actions.requires_grad: raise RuntimeError("target action cannot be part of a graph.") - action_dist = self.actor_network.get_dist( - tensordict, params=self.actor_network_params - ) + with self.actor_network_params.to_module(self.actor_network): + action_dist = self.actor_network.get_dist(tensordict) log_likelihood = action_dist.log_prob(target_actions).mean() entropy = self.get_entropy_bonus(action_dist).mean() @@ -319,9 +317,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.clone(False) target_actions = tensordict.get(self.tensor_keys.action_target).detach() - pred_actions = self.actor_network( - tensordict, params=self.actor_network_params - ).get(self.tensor_keys.action_pred) + with self.actor_network_params.to_module(self.actor_network): + pred_actions = self.actor_network(tensordict).get( + self.tensor_keys.action_pred + ) loss = distance_loss( pred_actions, target_actions, diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 696efbdc650..947a7574967 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -21,21 +21,13 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators from torchrl.objectives.common import LossModule -from torchrl.objectives.utils import _cache_values, _GAMMA_LMBDA_DEPREC_WARNING +from torchrl.objectives.utils import ( + _cache_values, + _GAMMA_LMBDA_DEPREC_WARNING, + _vmap_func, +) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -try: - try: - from torch import vmap - except ImportError: - from functorch import vmap - - FUNCTORCH_ERR = "" - _has_functorch = True -except ImportError as err: - FUNCTORCH_ERR = str(err) - _has_functorch = False - class REDQLoss_deprecated(LossModule): """REDQ Loss module. @@ -149,8 +141,6 @@ def __init__( ): self._in_keys = None self._out_keys = None - if not _has_functorch: - raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) @@ -208,7 +198,7 @@ def __init__( self.target_entropy_buffer = None self.gSDE = gSDE - self._vmap_qvalue_networkN0 = vmap(self.qvalue_network, (None, 0)) + self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) @@ -328,11 +318,10 @@ def _cached_detach_qvalue_network_params(self): def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: obs_keys = self.actor_network.in_keys tensordict_clone = tensordict.select(*obs_keys) - with set_exploration_type(ExplorationType.RANDOM): - self.actor_network( - tensordict_clone, - params=self.actor_network_params, - ) + with set_exploration_type( + ExplorationType.RANDOM + ), self.actor_network_params.to_module(self.actor_network): + self.actor_network(tensordict_clone) tensordict_expand = self._vmap_qvalue_networkN0( tensordict_clone.select(*self.qvalue_network.in_keys), @@ -364,11 +353,10 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: ) # next_observation -> # observation # select pseudo-action - with set_exploration_type(ExplorationType.RANDOM): - self.actor_network( - next_td, - params=self.target_actor_network_params, - ) + with set_exploration_type( + ExplorationType.RANDOM + ), self.target_actor_network_params.to_module(self.actor_network): + self.actor_network(next_td) sample_log_prob = next_td.get("sample_log_prob") # get q-values next_td = self._vmap_qvalue_networkN0( diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index a9dd4314a35..6af8c165c51 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -289,10 +289,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: """ td_copy = tensordict.clone(False) - self.value_network( - td_copy, - params=self.value_network_params, - ) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) action = tensordict.get(self.tensor_keys.action) pred_val = td_copy.get(self.tensor_keys.action_value) @@ -462,10 +460,10 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: # Calculate current state probabilities (online network noise already # sampled) td_clone = tensordict.clone() - self.value_network( - td_clone, - params=self.value_network_params, - ) # Log probabilities log p(s_t, ·; θonline) + with self.value_network_params.to_module(self.value_network): + self.value_network( + td_clone, + ) # Log probabilities log p(s_t, ·; θonline) action_log_softmax = td_clone.get(self.tensor_keys.action_value) if self.action_space == "categorical": @@ -475,24 +473,18 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: action, action_log_softmax, batch_size, atoms ) - with torch.no_grad(): + with torch.no_grad(), self.value_network_params.to_module(self.value_network): # Calculate nth next state probabilities next_td = step_mdp(tensordict) - self.value_network( - next_td, - params=self.value_network_params, - ) # Probabilities p(s_t+n, ·; θonline) + self.value_network(next_td) # Probabilities p(s_t+n, ·; θonline) next_td_action = next_td.get(self.tensor_keys.action) if self.action_space == "categorical": argmax_indices_ns = next_td_action.squeeze(-1) else: argmax_indices_ns = next_td_action.argmax(-1) # one-hot encoding - - self.value_network( - next_td, - params=self.target_value_network_params, - ) # Probabilities p(s_t+n, ·; θtarget) + with self.target_value_network_params.to_module(self.value_network): + self.value_network(next_td) # Probabilities p(s_t+n, ·; θtarget) pns = next_td.get(self.tensor_keys.action_value).exp() # Double-Q probabilities # p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 8d7fa0b53c1..a741d83ba13 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -14,26 +14,16 @@ from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule + from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_WARNING, + _vmap_func, default_value_kwargs, distance_loss, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -try: - try: - from torch import vmap - except ImportError: - from functorch import vmap - - _has_functorch = True - err = "" -except ImportError as err: - _has_functorch = False - FUNCTORCH_ERROR = err - class IQLLoss(LossModule): r"""TorchRL implementation of the IQL loss. @@ -248,8 +238,6 @@ def __init__( ) -> None: self._in_keys = None self._out_keys = None - if not _has_functorch: - raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR super().__init__() self._set_deprecated_ctor_keys(priority=priority_key) @@ -262,7 +250,6 @@ def __init__( actor_network, "actor_network", create_target_params=False, - funs_to_decorate=["forward", "get_dist"], ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -299,7 +286,7 @@ def __init__( if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma - self._vmap_qvalue_networkN0 = vmap(self.qvalue_network, (None, 0)) + self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) @property def device(self) -> torch.device: @@ -393,10 +380,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def actor_loss(self, tensordict: TensorDictBase) -> Tensor: # KL loss - dist = self.actor_network.get_dist( - tensordict, - params=self.actor_network_params, - ) + with self.actor_network_params.to_module(self.actor_network): + dist = self.actor_network.get_dist(tensordict) log_prob = dist.log_prob(tensordict[self.tensor_keys.action]) @@ -412,10 +397,8 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: # state value with torch.no_grad(): td_copy = tensordict.select(*self.value_network.in_keys).detach() - self.value_network( - td_copy, - params=self.value_network_params, - ) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) value = td_copy.get(self.tensor_keys.value).squeeze( -1 ) # assert has no gradient @@ -434,10 +417,8 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) # state value td_copy = tensordict.select(*self.value_network.in_keys) - self.value_network( - td_copy, - params=self.value_network_params, - ) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) value = td_copy.get(self.tensor_keys.value).squeeze(-1) value_loss = self.loss_value_diff(min_q - value, self.expectile).mean() return value_loss, {} diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index 00106571744..23947696c9f 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -12,7 +12,7 @@ import torch from tensordict import TensorDict, TensorDictBase -from tensordict.nn import dispatch, make_functional, repopulate_module, TensorDictModule +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import nn @@ -212,10 +212,11 @@ def __init__( ) global_value_network = SafeSequential(local_value_network, mixer_network) - params = make_functional(global_value_network) - self.global_value_network = deepcopy(global_value_network) - repopulate_module(local_value_network, params["module", "0"]) - repopulate_module(mixer_network, params["module", "1"]) + params = TensorDict.from_module(global_value_network) + with params.apply( + self._make_meta_params, device=torch.device("meta") + ).to_module(global_value_network): + self.__dict__["global_value_network"] = deepcopy(global_value_network) self.convert_to_functional( local_value_network, @@ -326,10 +327,10 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDict: td_copy = tensordict.clone(False) - self.local_value_network( - td_copy, - params=self.local_value_network_params, - ) + with self.local_value_network_params.to_module(self.local_value_network): + self.local_value_network( + td_copy, + ) action = tensordict.get(self.tensor_keys.action) pred_val = td_copy.get( @@ -346,7 +347,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: pred_val_index = (pred_val * action).sum(-1, keepdim=True) td_copy.set(self.tensor_keys.local_value, pred_val_index) # [*B, n_agents, 1] - self.mixer_network(td_copy, params=self.mixer_network_params) + with self.mixer_network_params.to_module(self.mixer_network): + self.mixer_network(td_copy) pred_val_index = td_copy.get(self.tensor_keys.global_value).squeeze(-1) # [*B] this is global and shared among the agents as will be the target diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index e576ca33c1c..11b5fef2ae7 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -4,11 +4,17 @@ # LICENSE file in the root directory of this source tree. import math import warnings +from copy import deepcopy from dataclasses import dataclass from typing import Tuple import torch -from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import ( + dispatch, + ProbabilisticTensorDictSequential, + repopulate_module, + TensorDictModule, +) from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -22,7 +28,7 @@ ) from .common import LossModule -from .value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator +from .value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator, VTrace class PPOLoss(LossModule): @@ -271,9 +277,7 @@ def __init__( self._in_keys = None self._out_keys = None super().__init__() - self.convert_to_functional( - actor, "actor", funs_to_decorate=["forward", "get_dist"] - ) + self.convert_to_functional(actor, "actor") if separate_losses: # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared @@ -374,7 +378,8 @@ def _log_weight( f"tensordict stored {self.tensor_keys.action} requires grad." ) - dist = self.actor.get_dist(tensordict, params=self.actor_params) + with self.actor_params.to_module(self.actor): + dist = self.actor.get_dist(tensordict) log_prob = dist.log_prob(action) prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob) @@ -400,10 +405,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: f"can be used for the value loss." ) - state_value_td = self.critic( - tensordict, - params=self.critic_params, - ) + with self.critic_params.to_module(self.critic): + state_value_td = self.critic(tensordict) try: state_value = state_value_td.get(self.tensor_keys.value) @@ -469,6 +472,14 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = GAE(value_network=self.critic, **hp) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + elif value_type == ValueEstimators.VTrace: + # VTrace currently does not support functional call on the actor + actor_with_params = repopulate_module( + deepcopy(self.actor), self.actor_params + ) + self._value_estimator = VTrace( + value_network=self.critic, actor_network=actor_with_params, **hp + ) else: raise NotImplementedError(f"Unknown value type {value_type}") @@ -479,6 +490,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, "terminated": self.tensor_keys.terminated, + "sample_log_prob": self.tensor_keys.sample_log_prob, } self._value_estimator.set_keys(**tensor_keys) @@ -848,7 +860,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: neg_loss = log_weight.exp() * advantage previous_dist = self.actor.build_dist_from_params(tensordict) - current_dist = self.actor.get_dist(tensordict, params=self.actor_params) + with self.actor_params.to_module(self.actor): + current_dist = self.actor.get_dist(tensordict) try: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index dd64a4bc033..347becc24ae 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -18,27 +18,17 @@ from torchrl.data import CompositeSpec from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule + from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_WARNING, + _vmap_func, default_value_kwargs, distance_loss, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -try: - try: - from torch import vmap - except ImportError: - from functorch import vmap - - FUNCTORCH_ERR = "" - _has_functorch = True -except ImportError as err: - FUNCTORCH_ERR = str(err) - _has_functorch = False - class REDQLoss(LossModule): """REDQ Loss module. @@ -265,8 +255,6 @@ def __init__( priority_key: str = None, separate_losses: bool = False, ): - if not _has_functorch: - raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR super().__init__() self._in_keys = None @@ -276,7 +264,6 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist_params"], ) # let's make sure that actor_network has `return_log_prob` set to True @@ -331,8 +318,8 @@ def __init__( warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma - self._vmap_qvalue_network00 = vmap(self.qvalue_network) - self._vmap_getdist = vmap(self.actor_network.get_dist_params) + self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network) + self._vmap_getdist = _vmap_func(self.actor_network, func="get_dist_params") @property def target_entropy(self): diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 93910f1eebf..832af829c64 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -3,12 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import warnings +from copy import deepcopy from dataclasses import dataclass from typing import Optional import torch -from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import ( + dispatch, + ProbabilisticTensorDictSequential, + repopulate_module, + TensorDictModule, +) from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torchrl.objectives.common import LossModule @@ -18,7 +24,13 @@ distance_loss, ValueEstimators, ) -from torchrl.objectives.value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator +from torchrl.objectives.value import ( + GAE, + TD0Estimator, + TD1Estimator, + TDLambdaEstimator, + VTrace, +) class ReinforceLoss(LossModule): @@ -285,10 +297,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: advantage = tensordict.get(self.tensor_keys.advantage) # compute log-prob - tensordict = self.actor_network( - tensordict, - params=self.actor_network_params, - ) + with self.actor_network_params.to_module(self.actor_network): + tensordict = self.actor_network(tensordict) log_prob = tensordict.get(self.tensor_keys.sample_log_prob) if log_prob.shape == advantage.shape[:-1]: @@ -305,10 +315,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: try: target_return = tensordict.get(self.tensor_keys.value_target) tensordict_select = tensordict.select(*self.critic.in_keys) - state_value = self.critic( - tensordict_select, - params=self.critic_params, - ).get(self.tensor_keys.value) + with self.critic_params.to_module(self.critic): + state_value = self.critic(tensordict_select).get(self.tensor_keys.value) loss_value = distance_loss( target_return, state_value, @@ -340,6 +348,14 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = GAE(value_network=self.critic, **hp) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + elif value_type == ValueEstimators.VTrace: + # VTrace currently does not support functional call on the actor + actor_with_params = repopulate_module( + deepcopy(self.actor), self.actor_params + ) + self._value_estimator = VTrace( + value_network=self.critic, actor_network=actor_with_params, **hp + ) else: raise NotImplementedError(f"Unknown value type {value_type}") @@ -350,5 +366,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, "terminated": self.tensor_keys.terminated, + "sample_log_prob": self.tensor_keys.sample_log_prob, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 076df1c54a4..d0617dedc74 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -12,7 +12,7 @@ import numpy as np import torch -from tensordict.nn import dispatch, make_functional, TensorDictModule +from tensordict.nn import dispatch, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import Tensor @@ -22,27 +22,17 @@ from torchrl.modules import ProbabilisticActor from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.common import LossModule + from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_WARNING, + _vmap_func, default_value_kwargs, distance_loss, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -try: - try: - from torch import vmap - except ImportError: - from functorch import vmap - - _has_functorch = True - err = "" -except ImportError as err: - _has_functorch = False - FUNCTORCH_ERROR = err - def _delezify(func): @wraps(func) @@ -293,8 +283,6 @@ def __init__( ) -> None: self._in_keys = None self._out_keys = None - if not _has_functorch: - raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) @@ -382,16 +370,15 @@ def __init__( self._target_entropy = target_entropy self._action_spec = action_spec if self._version == 1: - self.actor_critic = ActorCriticWrapper( + self.__dict__["actor_critic"] = ActorCriticWrapper( self.actor_network, self.value_network ) - make_functional(self.actor_critic) if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma - self._vmap_qnetworkN0 = vmap(self.qvalue_network, (None, 0)) + self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0)) if self._version == 1: - self._vmap_qnetwork00 = vmap(qvalue_network) + self._vmap_qnetwork00 = _vmap_func(qvalue_network) @property def target_entropy_buffer(self): @@ -589,11 +576,10 @@ def _cached_detached_qvalue_params(self): def _actor_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: - with set_exploration_type(ExplorationType.RANDOM): - dist = self.actor_network.get_dist( - tensordict, - params=self.actor_network_params, - ) + with set_exploration_type( + ExplorationType.RANDOM + ), self.actor_network_params.to_module(self.actor_network): + dist = self.actor_network.get_dist(tensordict) a_reparm = dist.rsample() log_prob = dist.log_prob(a_reparm) @@ -680,11 +666,11 @@ def _compute_target_v2(self, tensordict) -> Tensor: tensordict = tensordict.clone(False) # get actions and log-probs with torch.no_grad(): - with set_exploration_type(ExplorationType.RANDOM): + with set_exploration_type( + ExplorationType.RANDOM + ), self.actor_network_params.to_module(self.actor_network): next_tensordict = tensordict.get("next").clone(False) - next_dist = self.actor_network.get_dist( - next_tensordict, params=self.actor_network_params - ) + next_dist = self.actor_network.get_dist(next_tensordict) next_action = next_dist.rsample() next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = next_dist.log_prob(next_action) @@ -736,16 +722,11 @@ def _value_loss( ) -> Tuple[Tensor, Dict[str, Tensor]]: # value loss td_copy = tensordict.select(*self.value_network.in_keys).detach() - self.value_network( - td_copy, - params=self.value_network_params, - ) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) pred_val = td_copy.get(self.tensor_keys.value).squeeze(-1) - - action_dist = self.actor_network.get_dist( - td_copy, - params=self.target_actor_network_params, - ) # resample an action + with self.target_actor_network_params.to_module(self.actor_network): + action_dist = self.actor_network.get_dist(td_copy) # resample an action action = action_dist.rsample() td_copy.set(self.tensor_keys.action, action, inplace=False) @@ -991,8 +972,6 @@ def __init__( separate_losses: bool = False, ): self._in_keys = None - if not _has_functorch: - raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR super().__init__() self._set_deprecated_ctor_keys(priority_key=priority_key) @@ -1070,7 +1049,7 @@ def __init__( self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) - self._vmap_qnetworkN0 = vmap(self.qvalue_network, (None, 0)) + self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0)) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -1154,9 +1133,8 @@ def _compute_target(self, tensordict) -> Tensor: next_tensordict = tensordict.get("next").clone(False) # get probs and log probs for actions computed from "next" - next_dist = self.actor_network.get_dist( - next_tensordict, params=self.actor_network_params - ) + with self.actor_network_params.to_module(self.actor_network): + next_dist = self.actor_network.get_dist(next_tensordict) next_prob = next_dist.probs next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob)) @@ -1221,10 +1199,8 @@ def _actor_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: # get probs and log probs for actions - dist = self.actor_network.get_dist( - tensordict, - params=self.actor_network_params, - ) + with self.actor_network_params.to_module(self.actor_network): + dist = self.actor_network.get_dist(tensordict) prob = dist.probs log_prob = torch.log(torch.where(prob == 0, 1e-8, prob)) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 9912c143ae6..082873a2358 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -15,27 +15,17 @@ from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule + from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_WARNING, + _vmap_func, default_value_kwargs, distance_loss, ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -try: - try: - from torch import vmap - except ImportError: - from functorch import vmap - - FUNCTORCH_ERR = "" - _has_functorch = True -except ImportError as err: - FUNCTORCH_ERR = str(err) - _has_functorch = False - class TD3Loss(LossModule): """TD3 Loss module. @@ -229,10 +219,6 @@ def __init__( priority_key: str = None, separate_losses: bool = False, ) -> None: - if not _has_functorch: - raise ImportError( - f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}" - ) super().__init__() self._in_keys = None @@ -310,8 +296,8 @@ def __init__( if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma - self._vmap_qvalue_network00 = vmap(self.qvalue_network) - self._vmap_actor_network00 = vmap(self.actor_network) + self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network) + self._vmap_actor_network00 = _vmap_func(self.actor_network) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -359,9 +345,8 @@ def _cached_stack_actor_params(self): def actor_loss(self, tensordict): tensordict_actor_grad = tensordict.select(*self.actor_network.in_keys) - tensordict_actor_grad = self.actor_network( - tensordict_actor_grad, self.actor_network_params - ) + with self.actor_network_params.to_module(self.actor_network): + tensordict_actor_grad = self.actor_network(tensordict_actor_grad) actor_loss_td = tensordict_actor_grad.select( *self.qvalue_network.in_keys ).expand( @@ -395,9 +380,8 @@ def value_loss(self, tensordict): next_td_actor = step_mdp(tensordict).select( *self.actor_network.in_keys ) # next_observation -> - next_td_actor = self.actor_network( - next_td_actor, self.target_actor_network_params - ) + with self.target_actor_network_params.to_module(self.actor_network): + next_td_actor = self.actor_network(next_td_actor) next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp( self.min_action, self.max_action ) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index bc678ed0154..c3e7dbc68ce 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -10,10 +10,17 @@ import torch from tensordict.nn import TensorDictModule -from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase +from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn, Tensor from torch.nn import functional as F +try: + from torch import vmap +except ImportError as err: + try: + from functorch import vmap + except ImportError as err_ft: + raise err_ft from err from torchrl.envs.utils import step_mdp _GAMMA_LMBDA_DEPREC_WARNING = ( @@ -39,6 +46,7 @@ class ValueEstimators(Enum): TD1 = "TD(1) (infinity-step return)" TDLambda = "TD(lambda)" GAE = "Generalized advantage estimate" + VTrace = "V-trace" def default_value_kwargs(value_type: ValueEstimators): @@ -61,6 +69,8 @@ def default_value_kwargs(value_type: ValueEstimators): return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True} elif value_type == ValueEstimators.TDLambda: return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True} + elif value_type == ValueEstimators.VTrace: + return {"gamma": 0.99, "differentiable": True} else: raise NotImplementedError(f"Unknown value type {value_type}.") @@ -353,18 +363,19 @@ class hold_out_net(_context_manager): def __init__(self, network: nn.Module) -> None: self.network = network - try: - self.p_example = next(network.parameters()) - except (AttributeError, StopIteration): - self.p_example = torch.tensor([]) - self._prev_state = [] + for p in network.parameters(): + self.mode = p.requires_grad + break + else: + self.mode = True def __enter__(self) -> None: - self._prev_state.append(self.p_example.requires_grad) - self.network.requires_grad_(False) + if self.mode: + self.network.requires_grad_(False) def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self.network.requires_grad_(self._prev_state.pop()) + if self.mode: + self.network.requires_grad_() class hold_out_params(_context_manager): @@ -457,9 +468,23 @@ def new_fun(self, netname=None): out = fun(self, netname) else: out = fun(self) - if is_tensor_collection(out): - out.lock_() + # TODO: decide what to do with locked tds in functional calls + # if is_tensor_collection(out): + # out.lock_() _cache[attr_name] = out return out return new_fun + + +def _vmap_func(module, *args, func=None, **kwargs): + def decorated_module(*module_args_params): + params = module_args_params[-1] + module_args = module_args_params[:-1] + with params.to_module(module): + if func is None: + return module(*module_args) + else: + return getattr(module, func)(*module_args) + + return vmap(decorated_module, *args, **kwargs) # noqa: TOR101 diff --git a/torchrl/objectives/value/__init__.py b/torchrl/objectives/value/__init__.py index 11ae2e6d9e2..51496986153 100644 --- a/torchrl/objectives/value/__init__.py +++ b/torchrl/objectives/value/__init__.py @@ -12,4 +12,5 @@ TDLambdaEstimate, TDLambdaEstimator, ValueEstimatorBase, + VTrace, ) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 4d3a25279a1..f3aff0da1d2 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -2,9 +2,12 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + + import abc import functools import warnings +from contextlib import nullcontext from dataclasses import asdict, dataclass from functools import wraps from typing import Callable, List, Optional, Union @@ -24,7 +27,7 @@ from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp -from torchrl.objectives.utils import hold_out_net +from torchrl.objectives.utils import _vmap_func, hold_out_net from torchrl.objectives.value.functional import ( generalized_advantage_estimate, td0_return_estimate, @@ -32,8 +35,10 @@ vec_generalized_advantage_estimate, vec_td1_return_estimate, vec_td_lambda_return_estimate, + vtrace_advantage_estimate, ) + try: from torch import vmap except ImportError as err: @@ -117,7 +122,8 @@ def _call_value_nets( "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." ) if params is not None: - value_est = value_net(data_in, params).get(value_key) + with params.to_module(value_net): + value_est = value_net(data_in).get(value_key) else: value_est = value_net(data_in).get(value_key) value, value_ = value_est[idx], value_est[idx_] @@ -134,8 +140,8 @@ def _call_value_nets( "params and next_params must be either both provided or not." ) elif params is not None: - params_stack = torch.stack([params, next_params], 0) - data_out = vmap(value_net, (0, 0))(data_in, params_stack) + params_stack = torch.stack([params, next_params], 0).contiguous() + data_out = _vmap_func(value_net, (0, 0))(data_in, params_stack) else: data_out = vmap(value_net, (0,))(data_in) value_est = data_out.get(value_key) @@ -147,6 +153,17 @@ def _call_value_nets( return value, value_ +def _call_actor_net( + actor_net: TensorDictModuleBase, + data: TensorDictBase, + params: TensorDictBase, + log_prob_key: NestedKey, +): + # TODO: extend to handle time dimension (and vmap?) + log_pi = actor_net(data.select(actor_net.in_keys)).get(log_prob_key) + return log_pi + + class ValueEstimatorBase(TensorDictModuleBase): """An abstract parent class for value function modules. @@ -179,9 +196,11 @@ class _AcceptedKeys: whether a trajectory is done. Defaults to ``"done"``. terminated (NestedKey): The key in the input TensorDict that indicates whether a trajectory is terminated. Defaults to ``"terminated"``. - steps_to_next_obs_key (NestedKey): The key in the input tensordict + steps_to_next_obs (NestedKey): The key in the input tensordict that indicates the number of steps to the next observation. Defaults to ``"steps_to_next_obs"``. + sample_log_prob (NestedKey): The key in the input tensordict that + indicates the log probability of the sampled action. Defaults to ``"sample_log_prob"``. """ advantage: NestedKey = "advantage" @@ -191,6 +210,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" steps_to_next_obs: NestedKey = "steps_to_next_obs" + sample_log_prob: NestedKey = "sample_log_prob" default_keys = _AcceptedKeys() value_network: Union[TensorDictModule, Callable] @@ -223,6 +243,10 @@ def terminated_key(self): def steps_to_next_obs_key(self): return self.tensor_keys.steps_to_next_obs + @property + def sample_log_prob_key(self): + return self.tensor_keys.sample_log_prob + @abc.abstractmethod def forward( self, @@ -270,7 +294,7 @@ def __init__( self._tensor_keys = None self.differentiable = differentiable self.skip_existing = skip_existing - self.value_network = value_network + self.__dict__["value_network"] = value_network self.dep_keys = {} self.shifted = shifted @@ -341,7 +365,7 @@ def set_keys(self, **kwargs) -> None: raise ValueError("tensordict keys cannot be None") if key not in self._AcceptedKeys.__dict__: raise KeyError( - f"{key} it not an accepted tensordict key for advantages" + f"{key} is not an accepted tensordict key for advantages" ) if ( key == "value" @@ -403,10 +427,10 @@ def is_stateless(self): def _next_value(self, tensordict, target_params, kwargs): step_td = step_mdp(tensordict, keep_other=False) if self.value_network is not None: - if target_params is not None: - kwargs["params"] = target_params - with hold_out_net(self.value_network): - self.value_network(step_td, **kwargs) + with hold_out_net( + self.value_network + ) if target_params is None else target_params.to_module(self.value_network): + self.value_network(step_td) next_value = step_td.get(self.tensor_keys.value) return next_value @@ -447,6 +471,7 @@ class TD0Estimator(ValueEstimatorBase): of the advantage entry. Defaults to ``"value_target"``. value_key (str or tuple of str, optional): [Deprecated] the value key to read from the input tensordict. Defaults to ``"state_value"``. + device (torch.device, optional): device of the module. """ @@ -462,6 +487,7 @@ def __init__( value_target_key: NestedKey = None, value_key: NestedKey = None, skip_existing: Optional[bool] = None, + device: Optional[torch.device] = None, ): super().__init__( value_network=value_network, @@ -472,10 +498,6 @@ def __init__( value_key=value_key, skip_existing=skip_existing, ) - try: - device = next(value_network.parameters()).device - except (AttributeError, StopIteration): - device = torch.device("cpu") self.register_buffer("gamma", torch.tensor(gamma, device=device)) self.average_rewards = average_rewards @@ -560,7 +582,9 @@ def forward( params = params.detach() if target_params is None: target_params = params.clone(False) - with hold_out_net(self.value_network): + with hold_out_net(self.value_network) if ( + params is None and target_params is None + ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params value, next_value = _call_value_nets( @@ -597,7 +621,7 @@ def value_estimate( if self.average_rewards: reward = reward - reward.mean() - reward = reward / reward.std().clamp_min(1e-4) + reward = reward / reward.std().clamp_min(1e-5) tensordict.set( ("next", self.tensor_keys.reward), reward ) # we must update the rewards if they are used later in the code @@ -649,6 +673,7 @@ class TD1Estimator(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. + device (torch.device, optional): device of the module. """ @@ -664,6 +689,7 @@ def __init__( value_target_key: NestedKey = None, value_key: NestedKey = None, shifted: bool = False, + device: Optional[torch.device] = None, ): super().__init__( value_network=value_network, @@ -674,10 +700,6 @@ def __init__( shifted=shifted, skip_existing=skip_existing, ) - try: - device = next(value_network.parameters()).device - except (AttributeError, StopIteration): - device = torch.device("cpu") self.register_buffer("gamma", torch.tensor(gamma, device=device)) self.average_rewards = average_rewards @@ -761,7 +783,9 @@ def forward( params = params.detach() if target_params is None: target_params = params.clone(False) - with hold_out_net(self.value_network): + with hold_out_net(self.value_network) if ( + params is None and target_params is None + ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params value, next_value = _call_value_nets( @@ -799,7 +823,7 @@ def value_estimate( if self.average_rewards: reward = reward - reward.mean() - reward = reward / reward.std().clamp_min(1e-4) + reward = reward / reward.std().clamp_min(1e-5) tensordict.set( ("next", self.tensor_keys.reward), reward ) # we must update the rewards if they are used later in the code @@ -855,6 +879,7 @@ class TDLambdaEstimator(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. + device (torch.device, optional): device of the module. """ @@ -872,6 +897,7 @@ def __init__( value_target_key: NestedKey = None, value_key: NestedKey = None, shifted: bool = False, + device: Optional[torch.device] = None, ): super().__init__( value_network=value_network, @@ -882,10 +908,6 @@ def __init__( skip_existing=skip_existing, shifted=shifted, ) - try: - device = next(value_network.parameters()).device - except (AttributeError, StopIteration): - device = torch.device("cpu") self.register_buffer("gamma", torch.tensor(gamma, device=device)) self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) self.average_rewards = average_rewards @@ -972,7 +994,9 @@ def forward( params = params.detach() if target_params is None: target_params = params.clone(False) - with hold_out_net(self.value_network): + with hold_out_net(self.value_network) if ( + params is None and target_params is None + ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params value, next_value = _call_value_nets( @@ -1083,6 +1107,7 @@ class GAE(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. + device (torch.device, optional): device of the module. GAE will return an :obj:`"advantage"` entry containing the advange value. It will also return a :obj:`"value_target"` entry with the return value that is to be used @@ -1112,6 +1137,7 @@ def __init__( value_target_key: NestedKey = None, value_key: NestedKey = None, shifted: bool = False, + device: Optional[torch.device] = None, ): super().__init__( shifted=shifted, @@ -1122,10 +1148,6 @@ def __init__( value_key=value_key, skip_existing=skip_existing, ) - try: - device = next(value_network.parameters()).device - except (AttributeError, StopIteration): - device = torch.device("cpu") self.register_buffer("gamma", torch.tensor(gamma, device=device)) self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) self.average_gae = average_gae @@ -1137,7 +1159,7 @@ def __init__( def forward( self, tensordict: TensorDictBase, - *unused_args, + *, params: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, ) -> TensorDictBase: @@ -1218,7 +1240,9 @@ def forward( params = params.detach() if target_params is None: target_params = params.clone(False) - with hold_out_net(self.value_network): + with hold_out_net(self.value_network) if ( + params is None and target_params is None + ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params value, next_value = _call_value_nets( @@ -1298,7 +1322,9 @@ def value_estimate( params = params.detach() if target_params is None: target_params = params.clone(False) - with hold_out_net(self.value_network): + with hold_out_net(self.value_network) if ( + params is None and target_params is None + ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params value, next_value = _call_value_nets( @@ -1328,6 +1354,284 @@ def value_estimate( return value_target +class VTrace(ValueEstimatorBase): + """A class wrapper around V-Trace estimate functional. + + Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" + :ref:`here `_ for more context. + + Args: + gamma (scalar): exponential mean discount. + value_network (TensorDictModule): value operator used to retrieve the value estimates. + actor_network (TensorDictModule): actor operator used to retrieve the log prob. + rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights. + Defaults to ``1.0``. + c_thresh (Union[float, Tensor]): c clipping parameter for importance weights. + Defaults to ``1.0``. + average_adv (bool): if ``True``, the resulting advantage values will be standardized. + Default is ``False``. + differentiable (bool, optional): if ``True``, gradients are propagated through + the computation of the value function. Default is ``False``. + + .. note:: + The proper way to make the function call non-differentiable is to + decorate it in a `torch.no_grad()` context manager/decorator or + pass detached parameters for functional modules. + skip_existing (bool, optional): if ``True``, the value network will skip + modules which outputs are already present in the tensordict. + Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + is not affected. + Defaults to "state_value". + advantage_key (str or tuple of str, optional): [Deprecated] the key of + the advantage entry. Defaults to ``"advantage"``. + value_target_key (str or tuple of str, optional): [Deprecated] the key + of the advantage entry. Defaults to ``"value_target"``. + value_key (str or tuple of str, optional): [Deprecated] the value key to + read from the input tensordict. Defaults to ``"state_value"``. + shifted (bool, optional): if ``True``, the value and next value are + estimated with a single call to the value network. This is faster + but is only valid whenever (1) the ``"next"`` value is shifted by + only one time step (which is not the case with multi-step value + estimation, for instance) and (2) when the parameters used at time + ``t`` and ``t+1`` are identical (which is not the case when target + parameters are to be used). Defaults to ``False``. + device (torch.device, optional): device of the module. + + VTrace will return an :obj:`"advantage"` entry containing the advantage value. It will also + return a :obj:`"value_target"` entry with the V-Trace target value. + + .. note:: + As other advantage functions do, if the ``value_key`` is already present + in the input tensordict, the VTrace module will ignore the calls to the value + network (if any) and use the provided value instead. + + """ + + def __init__( + self, + *, + gamma: Union[float, torch.Tensor], + actor_network: TensorDictModule, + value_network: TensorDictModule, + rho_thresh: Union[float, torch.Tensor] = 1.0, + c_thresh: Union[float, torch.Tensor] = 1.0, + average_adv: bool = False, + differentiable: bool = False, + skip_existing: Optional[bool] = None, + advantage_key: Optional[NestedKey] = None, + value_target_key: Optional[NestedKey] = None, + value_key: Optional[NestedKey] = None, + shifted: bool = False, + device: Optional[torch.device] = None, + ): + super().__init__( + shifted=shifted, + value_network=value_network, + differentiable=differentiable, + advantage_key=advantage_key, + value_target_key=value_target_key, + value_key=value_key, + skip_existing=skip_existing, + ) + if not isinstance(gamma, torch.Tensor): + gamma = torch.tensor(gamma, device=device) + if not isinstance(rho_thresh, torch.Tensor): + rho_thresh = torch.tensor(rho_thresh, device=device) + if not isinstance(c_thresh, torch.Tensor): + c_thresh = torch.tensor(c_thresh, device=device) + + self.register_buffer("gamma", gamma) + self.register_buffer("rho_thresh", rho_thresh) + self.register_buffer("c_thresh", c_thresh) + self.average_adv = average_adv + self.actor_network = actor_network + + if isinstance(gamma, torch.Tensor) and gamma.shape != (): + raise NotImplementedError( + "Per-value gamma is not supported yet. Gamma must be a scalar." + ) + + @property + def in_keys(self): + parent_in_keys = super().in_keys + extended_in_keys = parent_in_keys + [self.tensor_keys.sample_log_prob] + return extended_in_keys + + @_self_set_skip_existing + @_self_set_grad_enabled + @dispatch + def forward( + self, + tensordict: TensorDictBase, + *, + params: Optional[List[Tensor]] = None, + target_params: Optional[List[Tensor]] = None, + ) -> TensorDictBase: + """Computes the V-Trace correction given the data in tensordict. + + If a functional module is provided, a nested TensorDict containing the parameters + (and if relevant the target parameters) can be passed to the module. + + Args: + tensordict (TensorDictBase): A TensorDict containing the data + (an observation key, "action", "reward", "done" and "next" tensordict state + as returned by the environment) necessary to compute the value estimates and the GAE. + The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + params (TensorDictBase, optional): A nested TensorDict containing the params + to be passed to the functional value network module. + target_params (TensorDictBase, optional): A nested TensorDict containing the + target params to be passed to the functional value network module. + + Returns: + An updated TensorDict with an advantage and a value_error keys as defined in the constructor. + + Examples: + >>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]) + >>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]) + >>> actor_net = ProbabilisticActor( + ... module=actor_net, + ... in_keys=["logits"], + ... out_keys=["action"], + ... distribution_class=OneHotCategorical, + ... return_log_prob=True, + ... ) + >>> module = VTrace( + ... gamma=0.98, + ... value_network=value_net, + ... actor_network=actor_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> sample_log_prob = torch.randn(1, 10, 1) + >>> tensordict = TensorDict({ + ... "obs": obs, + ... "done": done, + ... "terminated": terminated, + ... "sample_log_prob": sample_log_prob, + ... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated}, + ... }, batch_size=[1, 10]) + >>> _ = module(tensordict) + >>> assert "advantage" in tensordict.keys() + + The module supports non-tensordict (i.e. unpacked tensordict) inputs too: + + Examples: + >>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]) + >>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]) + >>> actor_net = ProbabilisticActor( + ... module=actor_net, + ... in_keys=["logits"], + ... out_keys=["action"], + ... distribution_class=OneHotCategorical, + ... return_log_prob=True, + ... ) + >>> module = VTrace( + ... gamma=0.98, + ... value_network=value_net, + ... actor_network=actor_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> sample_log_prob = torch.randn(1, 10, 1) + >>> tensordict = TensorDict({ + ... "obs": obs, + ... "done": done, + ... "terminated": terminated, + ... "sample_log_prob": sample_log_prob, + ... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated}, + ... }, batch_size=[1, 10]) + >>> advantage, value_target = module( + ... obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated, sample_log_prob=sample_log_prob + ... ) + + """ + if tensordict.batch_dims < 1: + raise RuntimeError( + "Expected input tensordict to have at least one dimensions, got " + f"tensordict.batch_size = {tensordict.batch_size}" + ) + reward = tensordict.get(("next", self.tensor_keys.reward)) + device = reward.device + gamma = self.gamma.to(device) + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) + if steps_to_next_obs is not None: + gamma = gamma ** steps_to_next_obs.view_as(reward) + + # Make sure we have the value and next value + if self.value_network is not None: + if params is not None: + params = params.detach() + if target_params is None: + target_params = params.clone(False) + with hold_out_net(self.value_network): + # we may still need to pass gradient, but we don't want to assign grads to + # value net params + value, next_value = _call_value_nets( + value_net=self.value_network, + data=tensordict, + params=params, + next_params=target_params, + single_call=self.shifted, + value_key=self.tensor_keys.value, + detach_next=True, + ) + else: + value = tensordict.get(self.tensor_keys.value) + next_value = tensordict.get(("next", self.tensor_keys.value)) + + # Make sure we have the log prob computed at collection time + if self.tensor_keys.sample_log_prob not in tensordict.keys(): + raise ValueError( + f"Expected {self.tensor_keys.sample_log_prob} to be in tensordict" + ) + log_mu = tensordict.get(self.tensor_keys.sample_log_prob).view_as(value) + + # Compute log prob with current policy + with hold_out_net(self.actor_network): + log_pi = _call_actor_net( + actor_net=self.actor_network, + data=tensordict, + params=None, + log_prob_key=self.tensor_keys.sample_log_prob, + ).view_as(value) + + # Compute the V-Trace correction + done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.terminated)) + + adv, value_target = vtrace_advantage_estimate( + gamma, + log_pi, + log_mu, + value, + next_value, + reward, + done, + terminated, + rho_thresh=self.rho_thresh, + c_thresh=self.c_thresh, + time_dim=tensordict.ndim - 1, + ) + + if self.average_adv: + loc = adv.mean() + scale = adv.std().clamp_min(1e-5) + adv = adv - loc + adv = adv / scale + + tensordict.set(self.tensor_keys.advantage, adv) + tensordict.set(self.tensor_keys.value_target, value_target) + + return tensordict + + def _deprecate_class(cls, new_cls): @wraps(cls.__init__) def new_init(self, *args, **kwargs): diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 7c33895e965..6c43af02aeb 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -27,6 +27,7 @@ "vec_td_lambda_return_estimate", "td_lambda_advantage_estimate", "vec_td_lambda_advantage_estimate", + "vtrace_advantage_estimate", ] from torchrl.objectives.value.utils import ( @@ -1212,6 +1213,93 @@ def vec_td_lambda_advantage_estimate( ) +######################################################################## +# V-Trace +# ----- + + +@_transpose_time +def vtrace_advantage_estimate( + gamma: float, + log_pi: torch.Tensor, + log_mu: torch.Tensor, + state_value: torch.Tensor, + next_state_value: torch.Tensor, + reward: torch.Tensor, + done: torch.Tensor, + terminated: torch.Tensor | None = None, + rho_thresh: Union[float, torch.Tensor] = 1.0, + c_thresh: Union[float, torch.Tensor] = 1.0, + time_dim: int = -2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes V-Trace off-policy actor critic targets. + + Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" + https://arxiv.org/abs/1802.01561 for more context. + + Args: + gamma (scalar): exponential mean discount. + log_pi (Tensor): collection actor log probability of taking actions in the environment. + log_mu (Tensor): current actor log probability of taking actions in the environment. + state_value (Tensor): value function result with state input. + next_state_value (Tensor): value function result with next_state input. + reward (Tensor): reward of taking actions in the environment. + done (Tensor): boolean flag for end of episode. + terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states. + rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights. + c_thresh (Union[float, Tensor]): c clipping parameter for importance weights. + time_dim (int): dimension where the time is unrolled. Defaults to -2. + + All tensors (values, reward and done) must have shape + ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. + """ + if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + raise RuntimeError(SHAPE_ERR) + + device = state_value.device + + if not isinstance(rho_thresh, torch.Tensor): + rho_thresh = torch.tensor(rho_thresh, device=device) + if not isinstance(c_thresh, torch.Tensor): + c_thresh = torch.tensor(c_thresh, device=device) + + c_thresh = c_thresh.to(device) + rho_thresh = rho_thresh.to(device) + + not_done = (~done).int() + not_terminated = not_done if terminated is None else (~terminated).int() + *batch_size, time_steps, lastdim = not_done.shape + done_discounts = gamma * not_done + terminated_discounts = gamma * not_terminated + + rho = (log_pi - log_mu).exp() + clipped_rho = rho.clamp_max(rho_thresh) + deltas = clipped_rho * ( + reward + terminated_discounts * next_state_value - state_value + ) + clipped_c = rho.clamp_max(c_thresh) + + vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])] + for i in reversed(range(time_steps)): + discount_t, c_t, delta_t = ( + done_discounts[..., i, :], + clipped_c[..., i, :], + deltas[..., i, :], + ) + vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1]) + vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:], dim=time_dim) + vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[time_dim]) + vs = vs_minus_v_xs + state_value + vs_t_plus_1 = torch.cat( + [vs[..., 1:, :], next_state_value[..., -1:, :]], dim=time_dim + ) + advantages = clipped_rho * ( + reward + terminated_discounts * vs_t_plus_1 - state_value + ) + + return advantages, vs + + ######################################################################## # Reward to go # ------------ diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py deleted file mode 100644 index 43f5246502f..00000000000 --- a/torchrl/objectives/value/vtrace.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import math -from typing import Tuple, Union - -import torch - - -def _c_val( - log_pi: torch.Tensor, - log_mu: torch.Tensor, - c: Union[float, torch.Tensor] = 1, -) -> torch.Tensor: - return (log_pi - log_mu).clamp_max(math.log(c)).exp() - - -def _dv_val( - rewards: torch.Tensor, - vals: torch.Tensor, - gamma: Union[float, torch.Tensor], - rho_bar: Union[float, torch.Tensor], - log_pi: torch.Tensor, - log_mu: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - rho = _c_val(log_pi, log_mu, rho_bar) - next_vals = torch.cat([vals[:, 1:], torch.zeros_like(vals[:, :1])], 1) - dv = rho * (rewards + gamma * next_vals - vals) - return dv, rho - - -def _vtrace( - rewards: torch.Tensor, - vals: torch.Tensor, - log_pi: torch.Tensor, - log_mu: torch.Tensor, - gamma: Union[torch.Tensor, float], - rho_bar: Union[float, torch.Tensor] = 1.0, - c_bar: Union[float, torch.Tensor] = 1.0, -) -> Tuple[torch.Tensor, torch.Tensor]: - T = vals.shape[1] - if not isinstance(gamma, torch.Tensor): - gamma = torch.full_like(vals, gamma) - - dv, rho = _dv_val(rewards, vals, gamma, rho_bar, log_pi, log_mu) - c = _c_val(log_pi, log_mu, c_bar) - - v_out = [] - v_out.append(vals[:, -1] + dv[:, -1]) - for t in range(T - 2, -1, -1): - _v_out = ( - vals[:, t] + dv[:, t] + gamma[:, t] * c[:, t] * (v_out[-1] - vals[:, t + 1]) - ) - v_out.append(_v_out) - v_out = torch.stack(list(reversed(v_out)), 1) - return v_out, rho diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index ee343aa438e..3782de64fa2 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -657,6 +657,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): out_keys=[action_key], default_interaction_type=InteractionType.RANDOM, distribution_class=TanhNormal, + distribution_kwargs={"tanh_loc": True}, spec=CompositeSpec(**{action_key: proof_environment.action_spec}), ), ) @@ -703,8 +704,9 @@ def _dreamer_make_actor_real( SafeProbabilisticModule( in_keys=["loc", "scale"], out_keys=[action_key], - default_interaction_type=InteractionType.RANDOM, + default_interaction_type=InteractionType.MODE, distribution_class=TanhNormal, + distribution_kwargs={"tanh_loc": True}, spec=CompositeSpec( **{action_key: proof_environment.action_spec.to("cpu")} ), diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 5237b344e56..be6e607c1b5 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -46,6 +46,7 @@ # replay buffer is a straightforward process, as shown in the following # example: # +import tempfile from torchrl.data import ReplayBuffer @@ -175,9 +176,8 @@ ###################################################################### # We can also customize the storage location on disk: # -buffer_lazymemmap = ReplayBuffer( - storage=LazyMemmapStorage(size, scratch_dir="/tmp/memmap/") -) +tempdir = tempfile.TemporaryDirectory() +buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size, scratch_dir=tempdir)) buffer_lazymemmap.extend(data) print(f"The buffer has {len(buffer_lazymemmap)} elements") print("the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename) @@ -207,8 +207,9 @@ from torchrl.data import TensorDictReplayBuffer +tempdir = tempfile.TemporaryDirectory() buffer_lazymemmap = TensorDictReplayBuffer( - storage=LazyMemmapStorage(size, scratch_dir="/tmp/memmap/"), batch_size=12 + storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12 ) buffer_lazymemmap.extend(data) print(f"The buffer has {len(buffer_lazymemmap)} elements") @@ -248,8 +249,9 @@ class MyData: batch_size=[1000], ) +tempdir = tempfile.TemporaryDirectory() buffer_lazymemmap = TensorDictReplayBuffer( - storage=LazyMemmapStorage(size, scratch_dir="/tmp/memmap/"), batch_size=12 + storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12 ) buffer_lazymemmap.extend(data) print(f"The buffer has {len(buffer_lazymemmap)} elements") diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index ccb1c9d4ea7..ef995030c9d 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -25,6 +25,7 @@ # will pass the arguments and keyword arguments to the root library builder. # # With gym, it means that building an environment is as easy as: + # sphinx_gallery_start_ignore import warnings