diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml
new file mode 100644
index 00000000000..98af413285f
--- /dev/null
+++ b/.github/pytorch-probot.yml
@@ -0,0 +1,5 @@
+# List of workflows that will be re-run in case of failures
+# https://github.com/pytorch/test-infra/blob/main/torchci/lib/bot/retryBot.ts
+retryable_workflows:
+- Build M1
+- Wheels
diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh
index a69e79c69a3..92475420867 100755
--- a/.github/unittest/linux_examples/scripts/run_test.sh
+++ b/.github/unittest/linux_examples/scripts/run_test.sh
@@ -58,6 +58,12 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/cql/cql_offlin
# ==================================================================================== #
# ================================ Gymnasium ========================================= #
+python .github/unittest/helpers/coverage_run_parallel.py examples/impala/impala_single_node.py \
+ collector.total_frames=80 \
+ collector.frames_per_batch=20 \
+ collector.num_workers=1 \
+ logger.backend= \
+ logger.test_interval=10
python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco.py \
env.env_name=HalfCheetah-v4 \
collector.total_frames=40 \
diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml
index 1a2384a1df1..01d880708f4 100644
--- a/.github/workflows/benchmarks.yml
+++ b/.github/workflows/benchmarks.yml
@@ -34,7 +34,8 @@ jobs:
python -m pip install git+https://github.com/pytorch/tensordict
python setup.py develop
python -m pip install pytest pytest-benchmark
- python -m pip install dm_control
+ python3 -m pip install "gym[accept-rom-license,atari]"
+ python3 -m pip install dm_control
- name: Run benchmarks
run: |
cd benchmarks/
@@ -57,62 +58,65 @@ jobs:
benchmark_gpu:
name: GPU Pytest benchmark
- runs-on: ubuntu-20.04
- strategy:
- matrix:
- include:
- - os: linux.4xlarge.nvidia.gpu
- python-version: 3.8
+ runs-on: linux.g5.4xlarge.nvidia.gpu
defaults:
run:
shell: bash -l {0}
- container: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
+ container:
+ image: nvidia/cuda:12.3.0-base-ubuntu22.04
+ options: --gpus all
steps:
- - name: Install deps
- run: |
- export TZ=Europe/London
- export DEBIAN_FRONTEND=noninteractive # tzdata bug
- apt-get update -y
- apt-get install software-properties-common -y
- add-apt-repository ppa:git-core/candidate -y
- apt-get update -y
- apt-get upgrade -y
- apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev
- - name: Check ldd --version
- run: ldd --version
- - name: Checkout
- uses: actions/checkout@v3
- - name: Update pip
- run: |
- apt-get install python3.8 python3-pip -y
- pip3 install --upgrade pip
- - name: Setup git
- run: git config --global --add safe.directory /__w/rl/rl
- - name: setup Path
- run: |
- echo /usr/local/bin >> $GITHUB_PATH
- - name: Setup Environment
- run: |
- python3 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
- python3 -m pip install git+https://github.com/pytorch/tensordict
- python3 setup.py develop
- python3 -m pip install pytest pytest-benchmark
- python3 -m pip install dm_control
- - name: Run benchmarks
- run: |
- cd benchmarks/
- python3 -m pytest --benchmark-json output.json
- - name: Store benchmark results
- uses: benchmark-action/github-action-benchmark@v1
- if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }}
- with:
- name: GPU Benchmark Results
- tool: 'pytest'
- output-file-path: benchmarks/output.json
- fail-on-alert: true
- alert-threshold: '200%'
- alert-comment-cc-users: '@vmoens'
- comment-on-alert: true
- github-token: ${{ secrets.GITHUB_TOKEN }}
- gh-pages-branch: gh-pages
- auto-push: true
+ - name: Install deps
+ run: |
+ export TZ=Europe/London
+ export DEBIAN_FRONTEND=noninteractive # tzdata bug
+ apt-get update -y
+ apt-get install software-properties-common -y
+ add-apt-repository ppa:git-core/candidate -y
+ apt-get update -y
+ apt-get upgrade -y
+ apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev
+ - name: Check ldd --version
+ run: ldd --version
+ - name: Checkout
+ uses: actions/checkout@v3
+ - name: Python Setup
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.8
+ - name: Setup git
+ run: git config --global --add safe.directory /__w/rl/rl
+ - name: setup Path
+ run: |
+ echo /usr/local/bin >> $GITHUB_PATH
+ - name: Setup Environment
+ run: |
+ python3 -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
+ python3 -m pip install git+https://github.com/pytorch/tensordict
+ python3 setup.py develop
+ python3 -m pip install pytest pytest-benchmark
+ python3 -m pip install "gym[accept-rom-license,atari]"
+ python3 -m pip install dm_control
+ - name: check GPU presence
+ run: |
+ python -c """import torch
+ assert torch.cuda.device_count()
+ """
+ - name: Run benchmarks
+ run: |
+ cd benchmarks/
+ python3 -m pytest --benchmark-json output.json
+ - name: Store benchmark results
+ uses: benchmark-action/github-action-benchmark@v1
+ if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }}
+ with:
+ name: GPU Benchmark Results
+ tool: 'pytest'
+ output-file-path: benchmarks/output.json
+ fail-on-alert: true
+ alert-threshold: '200%'
+ alert-comment-cc-users: '@vmoens'
+ comment-on-alert: true
+ github-token: ${{ secrets.GITHUB_TOKEN }}
+ gh-pages-branch: gh-pages
+ auto-push: true
diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml
index e44c683a6d6..0f0ad3e5723 100644
--- a/.github/workflows/benchmarks_pr.yml
+++ b/.github/workflows/benchmarks_pr.yml
@@ -33,7 +33,8 @@ jobs:
python -m pip install git+https://github.com/pytorch/tensordict
python setup.py develop
python -m pip install pytest pytest-benchmark
- python -m pip install dm_control
+ python3 -m pip install "gym[accept-rom-license,atari]"
+ python3 -m pip install dm_control
- name: Setup benchmarks
run: |
echo "BASE_SHA=$(echo ${{ github.event.pull_request.base.sha }} | cut -c1-8)" >> $GITHUB_ENV
@@ -63,75 +64,78 @@ jobs:
benchmark_gpu:
name: GPU Pytest benchmark
- runs-on: ubuntu-20.04
- strategy:
- matrix:
- include:
- - os: linux.4xlarge.nvidia.gpu
- python-version: 3.8
+ runs-on: linux.g5.4xlarge.nvidia.gpu
defaults:
run:
shell: bash -l {0}
- container: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
+ container:
+ image: nvidia/cuda:12.3.0-base-ubuntu22.04
+ options: --gpus all
steps:
- - name: Who triggered this?
- run: |
- echo "Action triggered by ${{ github.event.pull_request.html_url }}"
- - name: Install deps
- run: |
- export TZ=Europe/London
- export DEBIAN_FRONTEND=noninteractive # tzdata bug
- apt-get update -y
- apt-get install software-properties-common -y
- add-apt-repository ppa:git-core/candidate -y
- apt-get update -y
- apt-get upgrade -y
- apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev
- - name: Check ldd --version
- run: ldd --version
- - name: Checkout
- uses: actions/checkout@v3
- with:
- fetch-depth: 50 # this is to make sure we obtain the target base commit
- - name: Update pip
- run: |
- apt-get install python3.8 python3-pip -y
- pip3 install --upgrade pip
- - name: Setup git
- run: git config --global --add safe.directory /__w/rl/rl
- - name: setup Path
- run: |
- echo /usr/local/bin >> $GITHUB_PATH
- - name: Setup Environment
- run: |
- python3 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
- python3 -m pip install git+https://github.com/pytorch/tensordict
- python3 setup.py develop
- python3 -m pip install pytest pytest-benchmark
- python3 -m pip install dm_control
- - name: Setup benchmarks
- run: |
- echo "BASE_SHA=$(echo ${{ github.event.pull_request.base.sha }} | cut -c1-8)" >> $GITHUB_ENV
- echo "HEAD_SHA=$(echo ${{ github.event.pull_request.head.sha }} | cut -c1-8)" >> $GITHUB_ENV
- echo "BASELINE_JSON=$(mktemp)" >> $GITHUB_ENV
- echo "CONTENDER_JSON=$(mktemp)" >> $GITHUB_ENV
- echo "PR_COMMENT=$(mktemp)" >> $GITHUB_ENV
- - name: Run benchmarks
- run: |
- cd benchmarks/
- RUN_BENCHMARK="pytest --rank 0 --benchmark-json "
- git checkout ${{ github.event.pull_request.base.sha }}
- $RUN_BENCHMARK ${{ env.BASELINE_JSON }}
- git checkout ${{ github.event.pull_request.head.sha }}
- $RUN_BENCHMARK ${{ env.CONTENDER_JSON }}
- - name: Publish results
- uses: apbard/pytest-benchmark-commenter@v3
- with:
- token: ${{ secrets.GITHUB_TOKEN }}
- benchmark-file: ${{ env.CONTENDER_JSON }}
- comparison-benchmark-file: ${{ env.BASELINE_JSON }}
- benchmark-metrics: 'name,max,mean,ops'
- comparison-benchmark-metric: 'ops'
- comparison-higher-is-better: true
- comparison-threshold: 5
- benchmark-title: 'Result of GPU Benchmark Tests'
+ - name: Who triggered this?
+ run: |
+ echo "Action triggered by ${{ github.event.pull_request.html_url }}"
+ - name: Install deps
+ run: |
+ export TZ=Europe/London
+ export DEBIAN_FRONTEND=noninteractive # tzdata bug
+ apt-get update -y
+ apt-get install software-properties-common -y
+ add-apt-repository ppa:git-core/candidate -y
+ apt-get update -y
+ apt-get upgrade -y
+ apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev
+ - name: Check ldd --version
+ run: ldd --version
+ - name: Checkout
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 50 # this is to make sure we obtain the target base commit
+ - name: Python Setup
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.8
+ - name: Setup git
+ run: git config --global --add safe.directory /__w/rl/rl
+ - name: setup Path
+ run: |
+ echo /usr/local/bin >> $GITHUB_PATH
+ - name: Setup Environment
+ run: |
+ python3 -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
+ python3 -m pip install git+https://github.com/pytorch/tensordict
+ python3 setup.py develop
+ python3 -m pip install pytest pytest-benchmark
+ python3 -m pip install "gym[accept-rom-license,atari]"
+ python3 -m pip install dm_control
+ - name: check GPU presence
+ run: |
+ python -c """import torch
+ assert torch.cuda.device_count()
+ """
+ - name: Setup benchmarks
+ run: |
+ echo "BASE_SHA=$(echo ${{ github.event.pull_request.base.sha }} | cut -c1-8)" >> $GITHUB_ENV
+ echo "HEAD_SHA=$(echo ${{ github.event.pull_request.head.sha }} | cut -c1-8)" >> $GITHUB_ENV
+ echo "BASELINE_JSON=$(mktemp)" >> $GITHUB_ENV
+ echo "CONTENDER_JSON=$(mktemp)" >> $GITHUB_ENV
+ echo "PR_COMMENT=$(mktemp)" >> $GITHUB_ENV
+ - name: Run benchmarks
+ run: |
+ cd benchmarks/
+ RUN_BENCHMARK="pytest --rank 0 --benchmark-json "
+ git checkout ${{ github.event.pull_request.base.sha }}
+ $RUN_BENCHMARK ${{ env.BASELINE_JSON }}
+ git checkout ${{ github.event.pull_request.head.sha }}
+ $RUN_BENCHMARK ${{ env.CONTENDER_JSON }}
+ - name: Publish results
+ uses: apbard/pytest-benchmark-commenter@v3
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }}
+ benchmark-file: ${{ env.CONTENDER_JSON }}
+ comparison-benchmark-file: ${{ env.BASELINE_JSON }}
+ benchmark-metrics: 'name,max,mean,ops'
+ comparison-benchmark-metric: 'ops'
+ comparison-higher-is-better: true
+ comparison-threshold: 5
+ benchmark-title: 'Result of GPU Benchmark Tests'
diff --git a/README.md b/README.md
index 05c2e9843c2..905e8d28a4c 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@
[![Downloads](https://static.pepy.tech/personalized-badge/torchrl?period=total&units=international_system&left_color=blue&right_color=orange&left_text=Downloads)](https://pepy.tech/project/torchrl)
[![Downloads](https://static.pepy.tech/personalized-badge/torchrl-nightly?period=total&units=international_system&left_color=blue&right_color=orange&left_text=Downloads%20(nightly))](https://pepy.tech/project/torchrl-nightly)
-[![Discord Shield](https://dcbadge.vercel.app/api/server/xSURYdvu)](https://discord.gg/xSURYdvu)
+[![Discord Shield](https://dcbadge.vercel.app/api/server/cZs26Qq3Dd)](https://discord.gg/cZs26Qq3Dd)
# TorchRL
diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py
index 71b7a481ce0..246c5ee15f0 100644
--- a/benchmarks/ecosystem/gym_env_throughput.py
+++ b/benchmarks/ecosystem/gym_env_throughput.py
@@ -76,12 +76,12 @@ def make(envname=envname, gym_backend=gym_backend):
# regular parallel env
for device in avail_devices:
- def make(envname=envname, gym_backend=gym_backend, device=device):
+ def make(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
- return GymEnv(envname, device=device)
+ return GymEnv(envname, device="cpu")
# env_make = EnvCreator(make)
- penv = ParallelEnv(num_workers, EnvCreator(make))
+ penv = ParallelEnv(num_workers, EnvCreator(make), device=device)
with torch.inference_mode():
# warmup
penv.rollout(2)
@@ -103,13 +103,13 @@ def make(envname=envname, gym_backend=gym_backend, device=device):
for device in avail_devices:
- def make(envname=envname, gym_backend=gym_backend, device=device):
+ def make(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
- return GymEnv(envname, device=device)
+ return GymEnv(envname, device="cpu")
env_make = EnvCreator(make)
# penv = SerialEnv(num_workers, env_make)
- penv = ParallelEnv(num_workers, env_make)
+ penv = ParallelEnv(num_workers, env_make, device=device)
collector = SyncDataCollector(
penv,
RandomPolicy(penv.action_spec),
@@ -164,14 +164,14 @@ def make_env(
for device in avail_devices:
# async collector
# + torchrl parallel env
- def make_env(
- envname=envname, gym_backend=gym_backend, device=device
- ):
+ def make_env(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
- return GymEnv(envname, device=device)
+ return GymEnv(envname, device="cpu")
penv = ParallelEnv(
- num_workers // num_collectors, EnvCreator(make_env)
+ num_workers // num_collectors,
+ EnvCreator(make_env),
+ device=device,
)
collector = MultiaSyncDataCollector(
[penv] * num_collectors,
@@ -206,10 +206,9 @@ def make_env(
envname=envname,
num_workers=num_workers,
gym_backend=gym_backend,
- device=device,
):
with set_gym_backend(gym_backend):
- penv = GymEnv(envname, num_envs=num_workers, device=device)
+ penv = GymEnv(envname, num_envs=num_workers, device="cpu")
return penv
penv = EnvCreator(
@@ -247,14 +246,14 @@ def make_env(
for device in avail_devices:
# sync collector
# + torchrl parallel env
- def make_env(
- envname=envname, gym_backend=gym_backend, device=device
- ):
+ def make_env(envname=envname, gym_backend=gym_backend):
with set_gym_backend(gym_backend):
- return GymEnv(envname, device=device)
+ return GymEnv(envname, device="cpu")
penv = ParallelEnv(
- num_workers // num_collectors, EnvCreator(make_env)
+ num_workers // num_collectors,
+ EnvCreator(make_env),
+ device=device,
)
collector = MultiSyncDataCollector(
[penv] * num_collectors,
@@ -289,10 +288,9 @@ def make_env(
envname=envname,
num_workers=num_workers,
gym_backend=gym_backend,
- device=device,
):
with set_gym_backend(gym_backend):
- penv = GymEnv(envname, num_envs=num_workers, device=device)
+ penv = GymEnv(envname, num_envs=num_workers, device="cpu")
return penv
penv = EnvCreator(
diff --git a/benchmarks/test_collectors_benchmark.py b/benchmarks/test_collectors_benchmark.py
index 9f6c4599587..1e9634f643f 100644
--- a/benchmarks/test_collectors_benchmark.py
+++ b/benchmarks/test_collectors_benchmark.py
@@ -13,7 +13,7 @@
MultiSyncDataCollector,
RandomPolicy,
)
-from torchrl.envs import EnvCreator, StepCounter, TransformedEnv
+from torchrl.envs import EnvCreator, GymEnv, StepCounter, TransformedEnv
from torchrl.envs.libs.dm_control import DMControlEnv
@@ -78,9 +78,10 @@ def async_collector_setup():
def single_collector_setup_pixels():
device = "cuda:0" if torch.cuda.device_count() else "cpu"
- env = TransformedEnv(
- DMControlEnv("cheetah", "run", device=device, from_pixels=True), StepCounter(50)
- )
+ # env = TransformedEnv(
+ # DMControlEnv("cheetah", "run", device=device, from_pixels=True), StepCounter(50)
+ # )
+ env = TransformedEnv(GymEnv("ALE/Pong-v5"), StepCounter(50))
c = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
@@ -99,7 +100,8 @@ def sync_collector_setup_pixels():
device = "cuda:0" if torch.cuda.device_count() else "cpu"
env = EnvCreator(
lambda: TransformedEnv(
- DMControlEnv("cheetah", "run", device=device, from_pixels=True),
+ # DMControlEnv("cheetah", "run", device=device, from_pixels=True),
+ GymEnv("ALE/Pong-v5"),
StepCounter(50),
)
)
@@ -121,7 +123,8 @@ def async_collector_setup_pixels():
device = "cuda:0" if torch.cuda.device_count() else "cpu"
env = EnvCreator(
lambda: TransformedEnv(
- DMControlEnv("cheetah", "run", device=device, from_pixels=True),
+ # DMControlEnv("cheetah", "run", device=device, from_pixels=True),
+ GymEnv("ALE/Pong-v5"),
StepCounter(50),
)
)
diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py
index ca5b7eb82ed..d07e8f5da90 100644
--- a/benchmarks/test_objectives_benchmarks.py
+++ b/benchmarks/test_objectives_benchmarks.py
@@ -123,7 +123,7 @@ def test_gae_speed(benchmark, gae_fn, gamma_tensor, batches, timesteps):
gamma = 0.99
if gamma_tensor:
- gamma = torch.full(size, gamma)
+ gamma = torch.full(size, gamma, device=device)
lmbda = 0.95
benchmark(
diff --git a/examples/a2c/a2c_atari.py b/examples/a2c/a2c_atari.py
index 44a37cb3ce6..4598c11844b 100644
--- a/examples/a2c/a2c_atari.py
+++ b/examples/a2c/a2c_atari.py
@@ -117,9 +117,9 @@ def main(cfg: "DictConfig"): # noqa: F821
pbar.update(data.numel())
# Get training rewards and lengths
- episode_rewards = data["next", "episode_reward"][data["next", "done"]]
+ episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
if len(episode_rewards) > 0:
- episode_length = data["next", "step_count"][data["next", "done"]]
+ episode_length = data["next", "step_count"][data["next", "terminated"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py
index a5265f442b7..360c6daac28 100644
--- a/examples/distributed/collectors/multi_nodes/ray_train.py
+++ b/examples/distributed/collectors/multi_nodes/ray_train.py
@@ -117,7 +117,7 @@
"object_store_memory": 1024**3,
}
collector = RayCollector(
- env_makers=[env] * num_collectors,
+ create_env_fn=[env] * num_collectors,
policy=policy_module,
collector_class=SyncDataCollector,
collector_kwargs={
diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py
index fba4247e2a7..385e4a53aab 100644
--- a/examples/dreamer/dreamer_utils.py
+++ b/examples/dreamer/dreamer_utils.py
@@ -147,6 +147,7 @@ def transformed_env_constructor(
state_dim_gsde: Optional[int] = None,
batch_dims: Optional[int] = 0,
obs_norm_state_dict: Optional[dict] = None,
+ ignore_device: bool = False,
) -> Union[Callable, EnvCreator]:
"""
Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.
@@ -179,6 +180,7 @@ def transformed_env_constructor(
it should be set to 1 (or the number of dims of the batch).
obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded
into the environment
+ ignore_device (bool, optional): if True, the device is ignored.
"""
def make_transformed_env(**kwargs) -> TransformedEnv:
@@ -189,14 +191,17 @@ def make_transformed_env(**kwargs) -> TransformedEnv:
from_pixels = cfg.from_pixels
if custom_env is None and custom_env_maker is None:
- if isinstance(cfg.collector_device, str):
- device = cfg.collector_device
- elif isinstance(cfg.collector_device, Sequence):
- device = cfg.collector_device[0]
+ if not ignore_device:
+ if isinstance(cfg.collector_device, str):
+ device = cfg.collector_device
+ elif isinstance(cfg.collector_device, Sequence):
+ device = cfg.collector_device[0]
+ else:
+ raise ValueError(
+ "collector_device must be either a string or a sequence of strings"
+ )
else:
- raise ValueError(
- "collector_device must be either a string or a sequence of strings"
- )
+ device = None
env_kwargs = {
"env_name": env_name,
"device": device,
@@ -252,19 +257,19 @@ def parallel_env_constructor(
kwargs: keyword arguments for the `transformed_env_constructor` method.
"""
batch_transform = cfg.batch_transform
+ kwargs.update({"cfg": cfg, "use_env_creator": True})
if cfg.env_per_collector == 1:
- kwargs.update({"cfg": cfg, "use_env_creator": True})
make_transformed_env = transformed_env_constructor(**kwargs)
return make_transformed_env
- kwargs.update({"cfg": cfg, "use_env_creator": True})
make_transformed_env = transformed_env_constructor(
- return_transformed_envs=not batch_transform, **kwargs
+ return_transformed_envs=not batch_transform, ignore_device=True, **kwargs
)
parallel_env = ParallelEnv(
num_workers=cfg.env_per_collector,
create_env_fn=make_transformed_env,
create_env_kwargs=None,
pin_memory=cfg.pin_memory,
+ device=cfg.collector_device,
)
if batch_transform:
kwargs.update(
diff --git a/examples/impala/README.md b/examples/impala/README.md
new file mode 100644
index 00000000000..00e0d010b82
--- /dev/null
+++ b/examples/impala/README.md
@@ -0,0 +1,33 @@
+## Reproducing Importance Weighted Actor-Learner Architecture (IMPALA) Algorithm Results
+
+This repository contains scripts that enable training agents using the IMPALA Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Espeholt et al. 2018.
+
+## Examples Structure
+
+Please note that we provide 2 examples, one for single node training and one for distributed training. Both examples rely on the same utils file, but besides that are independent. Each example contains the following files:
+
+1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. impala_single_node_ray.py).
+
+2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils.py).
+
+3. **Configuration File:** This file includes default hyperparameters specified in the original paper. For the multi-node case, the file also includes the configuration file of the Ray cluster. Users can modify these hyperparameters to customize their experiments (e.g. config_single_node.yaml).
+
+
+## Running the Examples
+
+You can execute the single node IMPALA algorithm on Atari environments by running the following command:
+
+```bash
+python impala_single_node.py
+```
+
+You can execute the multi-node IMPALA algorithm on Atari environments by running the following command:
+
+```bash
+python impala_single_node_ray.py
+```
+or
+
+```bash
+python impala_single_node_submitit.py
+```
diff --git a/examples/impala/config_multi_node_ray.yaml b/examples/impala/config_multi_node_ray.yaml
new file mode 100644
index 00000000000..e312b336651
--- /dev/null
+++ b/examples/impala/config_multi_node_ray.yaml
@@ -0,0 +1,65 @@
+# Environment
+env:
+ env_name: PongNoFrameskip-v4
+
+# Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html
+ray_init_config:
+ address: null
+ num_cpus: null
+ num_gpus: null
+ resources: null
+ object_store_memory: null
+ local_mode: False
+ ignore_reinit_error: False
+ include_dashboard: null
+ dashboard_host: 127.0.0.1
+ dashboard_port: null
+ job_config: null
+ configure_logging: True
+ logging_level: info
+ logging_format: null
+ log_to_driver: True
+ namespace: null
+ runtime_env: null
+ storage: null
+
+# Device for the forward and backward passes
+local_device: "cuda:0"
+
+# Resources assigned to each IMPALA rollout collection worker
+remote_worker_resources:
+ num_cpus: 1
+ num_gpus: 0.25
+ memory: 1073741824 # 1*1024**3 - 1GB
+
+# collector
+collector:
+ frames_per_batch: 80
+ total_frames: 200_000_000
+ num_workers: 12
+
+# logger
+logger:
+ backend: wandb
+ exp_name: Atari_IMPALA
+ test_interval: 200_000_000
+ num_test_episodes: 3
+
+# Optim
+optim:
+ lr: 0.0006
+ eps: 1e-8
+ weight_decay: 0.0
+ momentum: 0.0
+ alpha: 0.99
+ max_grad_norm: 40.0
+ anneal_lr: True
+
+# loss
+loss:
+ gamma: 0.99
+ batch_size: 32
+ sgd_updates: 1
+ critic_coef: 0.5
+ entropy_coef: 0.01
+ loss_critic_type: l2
diff --git a/examples/impala/config_multi_node_submitit.yaml b/examples/impala/config_multi_node_submitit.yaml
new file mode 100644
index 00000000000..f632ba15dc2
--- /dev/null
+++ b/examples/impala/config_multi_node_submitit.yaml
@@ -0,0 +1,46 @@
+# Environment
+env:
+ env_name: PongNoFrameskip-v4
+
+# Device for the forward and backward passes
+local_device: "cuda:0"
+
+# SLURM config
+slurm_config:
+ timeout_min: 10
+ slurm_partition: train
+ slurm_cpus_per_task: 1
+ slurm_gpus_per_node: 1
+
+# collector
+collector:
+ backend: gloo
+ frames_per_batch: 80
+ total_frames: 200_000_000
+ num_workers: 1
+
+# logger
+logger:
+ backend: wandb
+ exp_name: Atari_IMPALA
+ test_interval: 200_000_000
+ num_test_episodes: 3
+
+# Optim
+optim:
+ lr: 0.0006
+ eps: 1e-8
+ weight_decay: 0.0
+ momentum: 0.0
+ alpha: 0.99
+ max_grad_norm: 40.0
+ anneal_lr: True
+
+# loss
+loss:
+ gamma: 0.99
+ batch_size: 32
+ sgd_updates: 1
+ critic_coef: 0.5
+ entropy_coef: 0.01
+ loss_critic_type: l2
diff --git a/examples/impala/config_single_node.yaml b/examples/impala/config_single_node.yaml
new file mode 100644
index 00000000000..d39407c1a69
--- /dev/null
+++ b/examples/impala/config_single_node.yaml
@@ -0,0 +1,38 @@
+# Environment
+env:
+ env_name: PongNoFrameskip-v4
+
+# Device for the forward and backward passes
+device: "cuda:0"
+
+# collector
+collector:
+ frames_per_batch: 80
+ total_frames: 200_000_000
+ num_workers: 12
+
+# logger
+logger:
+ backend: wandb
+ exp_name: Atari_IMPALA
+ test_interval: 200_000_000
+ num_test_episodes: 3
+
+# Optim
+optim:
+ lr: 0.0006
+ eps: 1e-8
+ weight_decay: 0.0
+ momentum: 0.0
+ alpha: 0.99
+ max_grad_norm: 40.0
+ anneal_lr: True
+
+# loss
+loss:
+ gamma: 0.99
+ batch_size: 32
+ sgd_updates: 1
+ critic_coef: 0.5
+ entropy_coef: 0.01
+ loss_critic_type: l2
diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py
new file mode 100644
index 00000000000..a0d2d88c5a2
--- /dev/null
+++ b/examples/impala/impala_multi_node_ray.py
@@ -0,0 +1,278 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+This script reproduces the IMPALA Algorithm
+results from Espeholt et al. 2018 for the on Atari Environments.
+"""
+import hydra
+
+
+@hydra.main(config_path=".", config_name="config_multi_node_ray", version_base="1.1")
+def main(cfg: "DictConfig"): # noqa: F821
+
+ import time
+
+ import torch.optim
+ import tqdm
+
+ from tensordict import TensorDict
+ from torchrl.collectors import SyncDataCollector
+ from torchrl.collectors.distributed import RayCollector
+ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
+ from torchrl.envs import ExplorationType, set_exploration_type
+ from torchrl.objectives import A2CLoss
+ from torchrl.objectives.value import VTrace
+ from torchrl.record.loggers import generate_exp_name, get_logger
+ from utils import eval_model, make_env, make_ppo_models
+
+ device = torch.device(cfg.local_device)
+
+ # Correct for frame_skip
+ frame_skip = 4
+ total_frames = cfg.collector.total_frames // frame_skip
+ frames_per_batch = cfg.collector.frames_per_batch // frame_skip
+ test_interval = cfg.logger.test_interval // frame_skip
+
+ # Extract other config parameters
+ batch_size = cfg.loss.batch_size # Number of rollouts per batch
+ num_workers = (
+ cfg.collector.num_workers
+ ) # Number of parallel workers collecting rollouts
+ lr = cfg.optim.lr
+ anneal_lr = cfg.optim.anneal_lr
+ sgd_updates = cfg.loss.sgd_updates
+ max_grad_norm = cfg.optim.max_grad_norm
+ num_test_episodes = cfg.logger.num_test_episodes
+ total_network_updates = (
+ total_frames // (frames_per_batch * batch_size)
+ ) * cfg.loss.sgd_updates
+
+ # Create models (check utils.py)
+ actor, critic = make_ppo_models(cfg.env.env_name)
+ actor, critic = actor.to(device), critic.to(device)
+
+ # Create collector
+ ray_init_config = {
+ "address": cfg.ray_init_config.address,
+ "num_cpus": cfg.ray_init_config.num_cpus,
+ "num_gpus": cfg.ray_init_config.num_gpus,
+ "resources": cfg.ray_init_config.resources,
+ "object_store_memory": cfg.ray_init_config.object_store_memory,
+ "local_mode": cfg.ray_init_config.local_mode,
+ "ignore_reinit_error": cfg.ray_init_config.ignore_reinit_error,
+ "include_dashboard": cfg.ray_init_config.include_dashboard,
+ "dashboard_host": cfg.ray_init_config.dashboard_host,
+ "dashboard_port": cfg.ray_init_config.dashboard_port,
+ "job_config": cfg.ray_init_config.job_config,
+ "configure_logging": cfg.ray_init_config.configure_logging,
+ "logging_level": cfg.ray_init_config.logging_level,
+ "logging_format": cfg.ray_init_config.logging_format,
+ "log_to_driver": cfg.ray_init_config.log_to_driver,
+ "namespace": cfg.ray_init_config.namespace,
+ "runtime_env": cfg.ray_init_config.runtime_env,
+ "storage": cfg.ray_init_config.storage,
+ }
+ remote_config = {
+ "num_cpus": cfg.remote_worker_resources.num_cpus,
+ "num_gpus": cfg.remote_worker_resources.num_gpus
+ if torch.cuda.device_count()
+ else 0,
+ "memory": cfg.remote_worker_resources.memory,
+ }
+ collector = RayCollector(
+ create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers,
+ policy=actor,
+ collector_class=SyncDataCollector,
+ frames_per_batch=frames_per_batch,
+ total_frames=total_frames,
+ max_frames_per_traj=-1,
+ ray_init_config=ray_init_config,
+ remote_configs=remote_config,
+ sync=False,
+ update_after_each_batch=True,
+ )
+
+ # Create data buffer
+ sampler = SamplerWithoutReplacement()
+ data_buffer = TensorDictReplayBuffer(
+ storage=LazyMemmapStorage(frames_per_batch * batch_size),
+ sampler=sampler,
+ batch_size=frames_per_batch * batch_size,
+ )
+
+ # Create loss and adv modules
+ adv_module = VTrace(
+ gamma=cfg.loss.gamma,
+ value_network=critic,
+ actor_network=actor,
+ average_adv=False,
+ )
+ loss_module = A2CLoss(
+ actor=actor,
+ critic=critic,
+ loss_critic_type=cfg.loss.loss_critic_type,
+ entropy_coef=cfg.loss.entropy_coef,
+ critic_coef=cfg.loss.critic_coef,
+ )
+ loss_module.set_keys(done="eol", terminated="eol")
+
+ # Create optimizer
+ optim = torch.optim.RMSprop(
+ loss_module.parameters(),
+ lr=cfg.optim.lr,
+ weight_decay=cfg.optim.weight_decay,
+ eps=cfg.optim.eps,
+ alpha=cfg.optim.alpha,
+ )
+
+ # Create logger
+ logger = None
+ if cfg.logger.backend:
+ exp_name = generate_exp_name(
+ "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}"
+ )
+ logger = get_logger(
+ cfg.logger.backend,
+ logger_name="impala",
+ experiment_name=exp_name,
+ project="impala",
+ )
+
+ # Create test environment
+ test_env = make_env(cfg.env.env_name, device, is_test=True)
+ test_env.eval()
+
+ # Main loop
+ collected_frames = 0
+ num_network_updates = 0
+ pbar = tqdm.tqdm(total=total_frames)
+ accumulator = []
+ start_time = sampling_start = time.time()
+ for i, data in enumerate(collector):
+
+ log_info = {}
+ sampling_time = time.time() - sampling_start
+ frames_in_batch = data.numel()
+ collected_frames += frames_in_batch * frame_skip
+ pbar.update(data.numel())
+
+ # Get training rewards and episode lengths
+ episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
+ if len(episode_rewards) > 0:
+ episode_length = data["next", "step_count"][data["next", "terminated"]]
+ log_info.update(
+ {
+ "train/reward": episode_rewards.mean().item(),
+ "train/episode_length": episode_length.sum().item()
+ / len(episode_length),
+ }
+ )
+
+ if len(accumulator) < batch_size:
+ accumulator.append(data)
+ if logger:
+ for key, value in log_info.items():
+ logger.log_scalar(key, value, collected_frames)
+ continue
+
+ losses = TensorDict({}, batch_size=[sgd_updates])
+ training_start = time.time()
+ for j in range(sgd_updates):
+
+ # Create a single batch of trajectories
+ stacked_data = torch.stack(accumulator, dim=0).contiguous()
+ stacked_data = stacked_data.to(device, non_blocking=True)
+
+ # Compute advantage
+ with torch.no_grad():
+ stacked_data = adv_module(stacked_data)
+
+ # Add to replay buffer
+ for stacked_d in stacked_data:
+ stacked_data_reshape = stacked_d.reshape(-1)
+ data_buffer.extend(stacked_data_reshape)
+
+ for batch in data_buffer:
+
+ # Linearly decrease the learning rate and clip epsilon
+ alpha = 1.0
+ if anneal_lr:
+ alpha = 1 - (num_network_updates / total_network_updates)
+ for group in optim.param_groups:
+ group["lr"] = lr * alpha
+ num_network_updates += 1
+
+ # Get a data batch
+ batch = batch.to(device, non_blocking=True)
+
+ # Forward pass loss
+ loss = loss_module(batch)
+ losses[j] = loss.select(
+ "loss_critic", "loss_entropy", "loss_objective"
+ ).detach()
+ loss_sum = (
+ loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
+ )
+
+ # Backward pass
+ loss_sum.backward()
+ torch.nn.utils.clip_grad_norm_(
+ list(loss_module.parameters()), max_norm=max_grad_norm
+ )
+
+ # Update the networks
+ optim.step()
+ optim.zero_grad()
+
+ # Get training losses and times
+ training_time = time.time() - training_start
+ losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
+ for key, value in losses.items():
+ log_info.update({f"train/{key}": value.item()})
+ log_info.update(
+ {
+ "train/lr": alpha * lr,
+ "train/sampling_time": sampling_time,
+ "train/training_time": training_time,
+ }
+ )
+
+ # Get test rewards
+ with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
+ i * frames_in_batch * frame_skip
+ ) // test_interval:
+ actor.eval()
+ eval_start = time.time()
+ test_reward = eval_model(
+ actor, test_env, num_episodes=num_test_episodes
+ )
+ eval_time = time.time() - eval_start
+ log_info.update(
+ {
+ "eval/reward": test_reward,
+ "eval/time": eval_time,
+ }
+ )
+ actor.train()
+
+ if logger:
+ for key, value in log_info.items():
+ logger.log_scalar(key, value, collected_frames)
+
+ collector.update_policy_weights_()
+ sampling_start = time.time()
+ accumulator = []
+
+ collector.shutdown()
+ end_time = time.time()
+ execution_time = end_time - start_time
+ print(f"Training took {execution_time:.2f} seconds to finish")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py
new file mode 100644
index 00000000000..3355febbfaf
--- /dev/null
+++ b/examples/impala/impala_multi_node_submitit.py
@@ -0,0 +1,270 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+This script reproduces the IMPALA Algorithm
+results from Espeholt et al. 2018 for the on Atari Environments.
+"""
+import hydra
+
+
+@hydra.main(
+ config_path=".", config_name="config_multi_node_submitit", version_base="1.1"
+)
+def main(cfg: "DictConfig"): # noqa: F821
+
+ import time
+
+ import torch.optim
+ import tqdm
+
+ from tensordict import TensorDict
+ from torchrl.collectors import SyncDataCollector
+ from torchrl.collectors.distributed import DistributedDataCollector
+ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
+ from torchrl.envs import ExplorationType, set_exploration_type
+ from torchrl.objectives import A2CLoss
+ from torchrl.objectives.value import VTrace
+ from torchrl.record.loggers import generate_exp_name, get_logger
+ from utils import eval_model, make_env, make_ppo_models
+
+ device = torch.device(cfg.local_device)
+
+ # Correct for frame_skip
+ frame_skip = 4
+ total_frames = cfg.collector.total_frames // frame_skip
+ frames_per_batch = cfg.collector.frames_per_batch // frame_skip
+ test_interval = cfg.logger.test_interval // frame_skip
+
+ # Extract other config parameters
+ batch_size = cfg.loss.batch_size # Number of rollouts per batch
+ num_workers = (
+ cfg.collector.num_workers
+ ) # Number of parallel workers collecting rollouts
+ lr = cfg.optim.lr
+ anneal_lr = cfg.optim.anneal_lr
+ sgd_updates = cfg.loss.sgd_updates
+ max_grad_norm = cfg.optim.max_grad_norm
+ num_test_episodes = cfg.logger.num_test_episodes
+ total_network_updates = (
+ total_frames // (frames_per_batch * batch_size)
+ ) * cfg.loss.sgd_updates
+
+ # Create models (check utils.py)
+ actor, critic = make_ppo_models(cfg.env.env_name)
+ actor, critic = actor.to(device), critic.to(device)
+
+ slurm_kwargs = {
+ "timeout_min": cfg.slurm_config.timeout_min,
+ "slurm_partition": cfg.slurm_config.slurm_partition,
+ "slurm_cpus_per_task": cfg.slurm_config.slurm_cpus_per_task,
+ "slurm_gpus_per_node": cfg.slurm_config.slurm_gpus_per_node,
+ }
+ # Create collector
+ device_str = "device" if num_workers <= 1 else "devices"
+ if cfg.collector.backend == "nccl":
+ collector_kwargs = {device_str: "cuda:0", f"storing_{device_str}": "cuda:0"}
+ elif cfg.collector.backend == "gloo":
+ collector_kwargs = {device_str: "cpu", f"storing_{device_str}": "cpu"}
+ else:
+ raise NotImplementedError(
+ f"device assignment not implemented for backend {cfg.collector.backend}"
+ )
+ collector = DistributedDataCollector(
+ create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers,
+ policy=actor,
+ num_workers_per_collector=1,
+ frames_per_batch=frames_per_batch,
+ total_frames=total_frames,
+ collector_class=SyncDataCollector,
+ collector_kwargs=collector_kwargs,
+ slurm_kwargs=slurm_kwargs,
+ storing_device="cuda:0" if cfg.collector.backend == "nccl" else "cpu",
+ launcher="submitit",
+ # update_after_each_batch=True,
+ backend=cfg.collector.backend,
+ )
+
+ # Create data buffer
+ sampler = SamplerWithoutReplacement()
+ data_buffer = TensorDictReplayBuffer(
+ storage=LazyMemmapStorage(frames_per_batch * batch_size),
+ sampler=sampler,
+ batch_size=frames_per_batch * batch_size,
+ )
+
+ # Create loss and adv modules
+ adv_module = VTrace(
+ gamma=cfg.loss.gamma,
+ value_network=critic,
+ actor_network=actor,
+ average_adv=False,
+ )
+ loss_module = A2CLoss(
+ actor=actor,
+ critic=critic,
+ loss_critic_type=cfg.loss.loss_critic_type,
+ entropy_coef=cfg.loss.entropy_coef,
+ critic_coef=cfg.loss.critic_coef,
+ )
+ loss_module.set_keys(done="eol", terminated="eol")
+
+ # Create optimizer
+ optim = torch.optim.RMSprop(
+ loss_module.parameters(),
+ lr=cfg.optim.lr,
+ weight_decay=cfg.optim.weight_decay,
+ eps=cfg.optim.eps,
+ alpha=cfg.optim.alpha,
+ )
+
+ # Create logger
+ logger = None
+ if cfg.logger.backend:
+ exp_name = generate_exp_name(
+ "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}"
+ )
+ logger = get_logger(
+ cfg.logger.backend,
+ logger_name="impala",
+ experiment_name=exp_name,
+ project="impala",
+ )
+
+ # Create test environment
+ test_env = make_env(cfg.env.env_name, device, is_test=True)
+ test_env.eval()
+
+ # Main loop
+ collected_frames = 0
+ num_network_updates = 0
+ pbar = tqdm.tqdm(total=total_frames)
+ accumulator = []
+ start_time = sampling_start = time.time()
+ for i, data in enumerate(collector):
+
+ log_info = {}
+ sampling_time = time.time() - sampling_start
+ frames_in_batch = data.numel()
+ collected_frames += frames_in_batch * frame_skip
+ pbar.update(data.numel())
+
+ # Get training rewards and episode lengths
+ episode_rewards = data["next", "episode_reward"][data["next", "done"]]
+ if len(episode_rewards) > 0:
+ episode_length = data["next", "step_count"][data["next", "done"]]
+ log_info.update(
+ {
+ "train/reward": episode_rewards.mean().item(),
+ "train/episode_length": episode_length.sum().item()
+ / len(episode_length),
+ }
+ )
+
+ if len(accumulator) < batch_size:
+ accumulator.append(data)
+ if logger:
+ for key, value in log_info.items():
+ logger.log_scalar(key, value, collected_frames)
+ continue
+
+ losses = TensorDict({}, batch_size=[sgd_updates])
+ training_start = time.time()
+ for j in range(sgd_updates):
+
+ # Create a single batch of trajectories
+ stacked_data = torch.stack(accumulator, dim=0).contiguous()
+ stacked_data = stacked_data.to(device, non_blocking=True)
+
+ # Compute advantage
+ with torch.no_grad():
+ stacked_data = adv_module(stacked_data)
+
+ # Add to replay buffer
+ for stacked_d in stacked_data:
+ stacked_data_reshape = stacked_d.reshape(-1)
+ data_buffer.extend(stacked_data_reshape)
+
+ for batch in data_buffer:
+
+ # Linearly decrease the learning rate and clip epsilon
+ alpha = 1.0
+ if anneal_lr:
+ alpha = 1 - (num_network_updates / total_network_updates)
+ for group in optim.param_groups:
+ group["lr"] = lr * alpha
+ num_network_updates += 1
+
+ # Get a data batch
+ batch = batch.to(device)
+
+ # Forward pass loss
+ loss = loss_module(batch)
+ losses[j] = loss.select(
+ "loss_critic", "loss_entropy", "loss_objective"
+ ).detach()
+ loss_sum = (
+ loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
+ )
+
+ # Backward pass
+ loss_sum.backward()
+ torch.nn.utils.clip_grad_norm_(
+ list(loss_module.parameters()), max_norm=max_grad_norm
+ )
+
+ # Update the networks
+ optim.step()
+ optim.zero_grad()
+
+ # Get training losses and times
+ training_time = time.time() - training_start
+ losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
+ for key, value in losses.items():
+ log_info.update({f"train/{key}": value.item()})
+ log_info.update(
+ {
+ "train/lr": alpha * lr,
+ "train/sampling_time": sampling_time,
+ "train/training_time": training_time,
+ }
+ )
+
+ # Get test rewards
+ with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
+ i * frames_in_batch * frame_skip
+ ) // test_interval:
+ actor.eval()
+ eval_start = time.time()
+ test_reward = eval_model(
+ actor, test_env, num_episodes=num_test_episodes
+ )
+ eval_time = time.time() - eval_start
+ log_info.update(
+ {
+ "eval/reward": test_reward,
+ "eval/time": eval_time,
+ }
+ )
+ actor.train()
+
+ if logger:
+ for key, value in log_info.items():
+ logger.log_scalar(key, value, collected_frames)
+
+ collector.update_policy_weights_()
+ sampling_start = time.time()
+ accumulator = []
+
+ collector.shutdown()
+ end_time = time.time()
+ execution_time = end_time - start_time
+ print(f"Training took {execution_time:.2f} seconds to finish")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py
new file mode 100644
index 00000000000..cd270f4c9e9
--- /dev/null
+++ b/examples/impala/impala_single_node.py
@@ -0,0 +1,248 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+This script reproduces the IMPALA Algorithm
+results from Espeholt et al. 2018 for the on Atari Environments.
+"""
+import hydra
+
+
+@hydra.main(config_path=".", config_name="config_single_node", version_base="1.1")
+def main(cfg: "DictConfig"): # noqa: F821
+
+ import time
+
+ import torch.optim
+ import tqdm
+
+ from tensordict import TensorDict
+ from torchrl.collectors import MultiaSyncDataCollector
+ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
+ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
+ from torchrl.envs import ExplorationType, set_exploration_type
+ from torchrl.objectives import A2CLoss
+ from torchrl.objectives.value import VTrace
+ from torchrl.record.loggers import generate_exp_name, get_logger
+ from utils import eval_model, make_env, make_ppo_models
+
+ device = torch.device(cfg.device)
+
+ # Correct for frame_skip
+ frame_skip = 4
+ total_frames = cfg.collector.total_frames // frame_skip
+ frames_per_batch = cfg.collector.frames_per_batch // frame_skip
+ test_interval = cfg.logger.test_interval // frame_skip
+
+ # Extract other config parameters
+ batch_size = cfg.loss.batch_size # Number of rollouts per batch
+ num_workers = (
+ cfg.collector.num_workers
+ ) # Number of parallel workers collecting rollouts
+ lr = cfg.optim.lr
+ anneal_lr = cfg.optim.anneal_lr
+ sgd_updates = cfg.loss.sgd_updates
+ max_grad_norm = cfg.optim.max_grad_norm
+ num_test_episodes = cfg.logger.num_test_episodes
+ total_network_updates = (
+ total_frames // (frames_per_batch * batch_size)
+ ) * cfg.loss.sgd_updates
+
+ # Create models (check utils.py)
+ actor, critic = make_ppo_models(cfg.env.env_name)
+ actor, critic = actor.to(device), critic.to(device)
+
+ # Create collector
+ collector = MultiaSyncDataCollector(
+ create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers,
+ policy=actor,
+ frames_per_batch=frames_per_batch,
+ total_frames=total_frames,
+ device=device,
+ storing_device=device,
+ max_frames_per_traj=-1,
+ update_at_each_batch=True,
+ )
+
+ # Create data buffer
+ sampler = SamplerWithoutReplacement()
+ data_buffer = TensorDictReplayBuffer(
+ storage=LazyMemmapStorage(frames_per_batch * batch_size),
+ sampler=sampler,
+ batch_size=frames_per_batch * batch_size,
+ )
+
+ # Create loss and adv modules
+ adv_module = VTrace(
+ gamma=cfg.loss.gamma,
+ value_network=critic,
+ actor_network=actor,
+ average_adv=False,
+ )
+ loss_module = A2CLoss(
+ actor=actor,
+ critic=critic,
+ loss_critic_type=cfg.loss.loss_critic_type,
+ entropy_coef=cfg.loss.entropy_coef,
+ critic_coef=cfg.loss.critic_coef,
+ )
+ loss_module.set_keys(done="eol", terminated="eol")
+
+ # Create optimizer
+ optim = torch.optim.RMSprop(
+ loss_module.parameters(),
+ lr=cfg.optim.lr,
+ weight_decay=cfg.optim.weight_decay,
+ eps=cfg.optim.eps,
+ alpha=cfg.optim.alpha,
+ )
+
+ # Create logger
+ logger = None
+ if cfg.logger.backend:
+ exp_name = generate_exp_name(
+ "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}"
+ )
+ logger = get_logger(
+ cfg.logger.backend,
+ logger_name="impala",
+ experiment_name=exp_name,
+ project="impala",
+ )
+
+ # Create test environment
+ test_env = make_env(cfg.env.env_name, device, is_test=True)
+ test_env.eval()
+
+ # Main loop
+ collected_frames = 0
+ num_network_updates = 0
+ pbar = tqdm.tqdm(total=total_frames)
+ accumulator = []
+ start_time = sampling_start = time.time()
+ for i, data in enumerate(collector):
+
+ log_info = {}
+ sampling_time = time.time() - sampling_start
+ frames_in_batch = data.numel()
+ collected_frames += frames_in_batch * frame_skip
+ pbar.update(data.numel())
+
+ # Get training rewards and episode lengths
+ episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
+ if len(episode_rewards) > 0:
+ episode_length = data["next", "step_count"][data["next", "terminated"]]
+ log_info.update(
+ {
+ "train/reward": episode_rewards.mean().item(),
+ "train/episode_length": episode_length.sum().item()
+ / len(episode_length),
+ }
+ )
+
+ if len(accumulator) < batch_size:
+ accumulator.append(data)
+ if logger:
+ for key, value in log_info.items():
+ logger.log_scalar(key, value, collected_frames)
+ continue
+
+ losses = TensorDict({}, batch_size=[sgd_updates])
+ training_start = time.time()
+ for j in range(sgd_updates):
+
+ # Create a single batch of trajectories
+ stacked_data = torch.stack(accumulator, dim=0).contiguous()
+ stacked_data = stacked_data.to(device, non_blocking=True)
+
+ # Compute advantage
+ with torch.no_grad():
+ stacked_data = adv_module(stacked_data)
+
+ # Add to replay buffer
+ for stacked_d in stacked_data:
+ stacked_data_reshape = stacked_d.reshape(-1)
+ data_buffer.extend(stacked_data_reshape)
+
+ for batch in data_buffer:
+
+ # Linearly decrease the learning rate and clip epsilon
+ alpha = 1.0
+ if anneal_lr:
+ alpha = 1 - (num_network_updates / total_network_updates)
+ for group in optim.param_groups:
+ group["lr"] = lr * alpha
+ num_network_updates += 1
+
+ # Get a data batch
+ batch = batch.to(device, non_blocking=True)
+
+ # Forward pass loss
+ loss = loss_module(batch)
+ losses[j] = loss.select(
+ "loss_critic", "loss_entropy", "loss_objective"
+ ).detach()
+ loss_sum = (
+ loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
+ )
+
+ # Backward pass
+ loss_sum.backward()
+ torch.nn.utils.clip_grad_norm_(
+ list(loss_module.parameters()), max_norm=max_grad_norm
+ )
+
+ # Update the networks
+ optim.step()
+ optim.zero_grad()
+
+ # Get training losses and times
+ training_time = time.time() - training_start
+ losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
+ for key, value in losses.items():
+ log_info.update({f"train/{key}": value.item()})
+ log_info.update(
+ {
+ "train/lr": alpha * lr,
+ "train/sampling_time": sampling_time,
+ "train/training_time": training_time,
+ }
+ )
+
+ # Get test rewards
+ with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
+ if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
+ i * frames_in_batch * frame_skip
+ ) // test_interval:
+ actor.eval()
+ eval_start = time.time()
+ test_reward = eval_model(
+ actor, test_env, num_episodes=num_test_episodes
+ )
+ eval_time = time.time() - eval_start
+ log_info.update(
+ {
+ "eval/reward": test_reward,
+ "eval/time": eval_time,
+ }
+ )
+ actor.train()
+
+ if logger:
+ for key, value in log_info.items():
+ logger.log_scalar(key, value, collected_frames)
+
+ collector.update_policy_weights_()
+ sampling_start = time.time()
+ accumulator = []
+
+ collector.shutdown()
+ end_time = time.time()
+ execution_time = end_time - start_time
+ print(f"Training took {execution_time:.2f} seconds to finish")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/impala/utils.py b/examples/impala/utils.py
new file mode 100644
index 00000000000..2983f8a0193
--- /dev/null
+++ b/examples/impala/utils.py
@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn
+import torch.optim
+from tensordict.nn import TensorDictModule
+from torchrl.data import CompositeSpec
+from torchrl.envs import (
+ CatFrames,
+ DoubleToFloat,
+ EndOfLifeTransform,
+ ExplorationType,
+ GrayScale,
+ GymEnv,
+ NoopResetEnv,
+ Resize,
+ RewardClipping,
+ RewardSum,
+ StepCounter,
+ ToTensorImage,
+ TransformedEnv,
+ VecNorm,
+)
+from torchrl.modules import (
+ ActorValueOperator,
+ ConvNet,
+ MLP,
+ OneHotCategorical,
+ ProbabilisticActor,
+ ValueOperator,
+)
+
+
+# ====================================================================
+# Environment utils
+# --------------------------------------------------------------------
+
+
+def make_env(env_name, device, is_test=False):
+ env = GymEnv(
+ env_name, frame_skip=4, from_pixels=True, pixels_only=False, device=device
+ )
+ env = TransformedEnv(env)
+ env.append_transform(NoopResetEnv(noops=30, random=True))
+ if not is_test:
+ env.append_transform(EndOfLifeTransform())
+ env.append_transform(RewardClipping(-1, 1))
+ env.append_transform(ToTensorImage(from_int=False))
+ env.append_transform(GrayScale())
+ env.append_transform(Resize(84, 84))
+ env.append_transform(CatFrames(N=4, dim=-3))
+ env.append_transform(RewardSum())
+ env.append_transform(StepCounter(max_steps=4500))
+ env.append_transform(DoubleToFloat())
+ env.append_transform(VecNorm(in_keys=["pixels"]))
+ return env
+
+
+# ====================================================================
+# Model utils
+# --------------------------------------------------------------------
+
+
+def make_ppo_modules_pixels(proof_environment):
+
+ # Define input shape
+ input_shape = proof_environment.observation_spec["pixels"].shape
+
+ # Define distribution class and kwargs
+ num_outputs = proof_environment.action_spec.space.n
+ distribution_class = OneHotCategorical
+ distribution_kwargs = {}
+
+ # Define input keys
+ in_keys = ["pixels"]
+
+ # Define a shared Module and TensorDictModule (CNN + MLP)
+ common_cnn = ConvNet(
+ activation_class=torch.nn.ReLU,
+ num_cells=[32, 64, 64],
+ kernel_sizes=[8, 4, 3],
+ strides=[4, 2, 1],
+ )
+ common_cnn_output = common_cnn(torch.ones(input_shape))
+ common_mlp = MLP(
+ in_features=common_cnn_output.shape[-1],
+ activation_class=torch.nn.ReLU,
+ activate_last_layer=True,
+ out_features=512,
+ num_cells=[],
+ )
+ common_mlp_output = common_mlp(common_cnn_output)
+
+ # Define shared net as TensorDictModule
+ common_module = TensorDictModule(
+ module=torch.nn.Sequential(common_cnn, common_mlp),
+ in_keys=in_keys,
+ out_keys=["common_features"],
+ )
+
+ # Define on head for the policy
+ policy_net = MLP(
+ in_features=common_mlp_output.shape[-1],
+ out_features=num_outputs,
+ activation_class=torch.nn.ReLU,
+ num_cells=[],
+ )
+ policy_module = TensorDictModule(
+ module=policy_net,
+ in_keys=["common_features"],
+ out_keys=["logits"],
+ )
+
+ # Add probabilistic sampling of the actions
+ policy_module = ProbabilisticActor(
+ policy_module,
+ in_keys=["logits"],
+ spec=CompositeSpec(action=proof_environment.action_spec),
+ distribution_class=distribution_class,
+ distribution_kwargs=distribution_kwargs,
+ return_log_prob=True,
+ default_interaction_type=ExplorationType.RANDOM,
+ )
+
+ # Define another head for the value
+ value_net = MLP(
+ activation_class=torch.nn.ReLU,
+ in_features=common_mlp_output.shape[-1],
+ out_features=1,
+ num_cells=[],
+ )
+ value_module = ValueOperator(
+ value_net,
+ in_keys=["common_features"],
+ )
+
+ return common_module, policy_module, value_module
+
+
+def make_ppo_models(env_name):
+
+ proof_environment = make_env(env_name, device="cpu")
+ common_module, policy_module, value_module = make_ppo_modules_pixels(
+ proof_environment
+ )
+
+ # Wrap modules in a single ActorCritic operator
+ actor_critic = ActorValueOperator(
+ common_operator=common_module,
+ policy_operator=policy_module,
+ value_operator=value_module,
+ )
+
+ actor = actor_critic.get_policy_operator()
+ critic = actor_critic.get_value_operator()
+
+ del proof_environment
+
+ return actor, critic
+
+
+# ====================================================================
+# Evaluation utils
+# --------------------------------------------------------------------
+
+
+def eval_model(actor, test_env, num_episodes=3):
+ test_rewards = torch.zeros(num_episodes, dtype=torch.float32)
+ for i in range(num_episodes):
+ td_test = test_env.rollout(
+ policy=actor,
+ auto_reset=True,
+ auto_cast_to_device=True,
+ break_when_any_done=True,
+ max_steps=10_000_000,
+ )
+ reward = td_test["next", "episode_reward"][td_test["next", "done"]]
+ test_rewards[i] = reward.sum()
+ del td_test
+ return test_rewards.mean()
diff --git a/examples/ppo/config_mujoco.yaml b/examples/ppo/config_mujoco.yaml
index 1272c1f4bff..0322526e7b1 100644
--- a/examples/ppo/config_mujoco.yaml
+++ b/examples/ppo/config_mujoco.yaml
@@ -18,7 +18,7 @@ logger:
optim:
lr: 3e-4
weight_decay: 0.0
- anneal_lr: False
+ anneal_lr: True
# loss
loss:
diff --git a/examples/ppo/ppo_atari.py b/examples/ppo/ppo_atari.py
index eb2ce15ec5a..1bfbccdeba4 100644
--- a/examples/ppo/ppo_atari.py
+++ b/examples/ppo/ppo_atari.py
@@ -134,9 +134,9 @@ def main(cfg: "DictConfig"): # noqa: F821
pbar.update(data.numel())
# Get training rewards and episode lengths
- episode_rewards = data["next", "episode_reward"][data["next", "done"]]
+ episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
if len(episode_rewards) > 0:
- episode_length = data["next", "step_count"][data["next", "stop"]]
+ episode_length = data["next", "step_count"][data["next", "terminated"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
diff --git a/examples/ppo/ppo_mujoco.py b/examples/ppo/ppo_mujoco.py
index ff6aeda51d2..52b12f688e1 100644
--- a/examples/ppo/ppo_mujoco.py
+++ b/examples/ppo/ppo_mujoco.py
@@ -28,7 +28,6 @@ def main(cfg: "DictConfig"): # noqa: F821
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_mujoco import eval_model, make_env, make_ppo_models
- # Define paper hyperparameters
device = "cpu" if not torch.cuda.device_count() else "cuda"
num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size
total_network_updates = (
@@ -67,6 +66,7 @@ def main(cfg: "DictConfig"): # noqa: F821
value_network=critic,
average_gae=False,
)
+
loss_module = ClipPPOLoss(
actor=actor,
critic=critic,
@@ -78,8 +78,8 @@ def main(cfg: "DictConfig"): # noqa: F821
)
# Create optimizers
- actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr)
- critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr)
+ actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr, eps=1e-5)
+ critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr, eps=1e-5)
# Create logger
logger = None
@@ -187,7 +187,9 @@ def main(cfg: "DictConfig"): # noqa: F821
"train/lr": alpha * cfg_optim_lr,
"train/sampling_time": sampling_time,
"train/training_time": training_time,
- "train/clip_epsilon": alpha * cfg_loss_clip_epsilon,
+ "train/clip_epsilon": alpha * cfg_loss_clip_epsilon
+ if cfg_loss_anneal_clip_eps
+ else cfg_loss_clip_epsilon,
}
)
diff --git a/examples/ppo/utils_mujoco.py b/examples/ppo/utils_mujoco.py
index 8fa2a53fd92..7be234b322d 100644
--- a/examples/ppo/utils_mujoco.py
+++ b/examples/ppo/utils_mujoco.py
@@ -28,10 +28,10 @@
def make_env(env_name="HalfCheetah-v4", device="cpu"):
env = GymEnv(env_name, device=device)
env = TransformedEnv(env)
+ env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2))
+ env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
env.append_transform(RewardSum())
env.append_transform(StepCounter())
- env.append_transform(VecNorm(in_keys=["observation"]))
- env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
env.append_transform(DoubleToFloat(in_keys=["observation"]))
return env
@@ -72,7 +72,9 @@ def make_ppo_models_state(proof_environment):
# Add state-independent normal scale
policy_mlp = torch.nn.Sequential(
policy_mlp,
- AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]),
+ AddStateIndependentNormalScale(
+ proof_environment.action_spec.shape[-1], scale_lb=1e-8
+ ),
)
# Add probabilistic sampling of the actions
diff --git a/test/assets/generate.py b/test/assets/generate.py
index 75a87bb71b5..deb47f95999 100644
--- a/test/assets/generate.py
+++ b/test/assets/generate.py
@@ -4,6 +4,12 @@
# LICENSE file in the root directory of this source tree.
"""Script used to generate the mini datasets."""
+import multiprocessing as mp
+
+try:
+ mp.set_start_method("spawn")
+except Exception:
+ pass
from tempfile import TemporaryDirectory
from datasets import Dataset, DatasetDict, load_dataset
@@ -36,8 +42,9 @@ def get_minibatch():
batch_size=16,
block_size=33,
tensorclass_type=PromptData,
- dataset_name="test/datasets_mini/openai_summarize_tldr",
+ dataset_name="../datasets_mini/openai_summarize_tldr",
device="cpu",
+ num_workers=2,
infinite=False,
prefetch=0,
split="train",
@@ -47,3 +54,8 @@ def get_minibatch():
for data in dl:
data = data.clone().memmap_("test/datasets_mini/tldr_batch/")
break
+ print("done")
+
+
+if __name__ == "__main__":
+ get_minibatch()
diff --git a/test/assets/tldr_batch.zip b/test/assets/tldr_batch.zip
index 6293560f580..252b9e3c999 100644
Binary files a/test/assets/tldr_batch.zip and b/test/assets/tldr_batch.zip differ
diff --git a/test/test_cost.py b/test/test_cost.py
index ba8d5b43b7f..a4b87e9f746 100644
--- a/test/test_cost.py
+++ b/test/test_cost.py
@@ -47,7 +47,7 @@
get_default_devices,
)
from mocking_classes import ContinuousActionConvMockEnv
-from tensordict.nn import get_functional, NormalParamExtractor, TensorDictModule
+from tensordict.nn import NormalParamExtractor, TensorDictModule
from tensordict.nn.utils import Buffer
# from torchrl.data.postprocs.utils import expand_as_right
@@ -130,6 +130,7 @@
GAE,
TD1Estimator,
TDLambdaEstimator,
+ VTrace,
)
from torchrl.objectives.value.functional import (
_transpose_time,
@@ -140,6 +141,7 @@
vec_generalized_advantage_estimate,
vec_td1_advantage_estimate,
vec_td_lambda_advantage_estimate,
+ vtrace_advantage_estimate,
)
from torchrl.objectives.value.utils import (
_custom_conv1d,
@@ -437,7 +439,7 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est):
action_spec_type=action_spec_type, device=device
)
loss_fn = DQNLoss(actor, loss_function="l2", delay_value=delay_value)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
@@ -915,7 +917,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est):
action_spec_type=action_spec_type, device=device
)
loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
@@ -1400,7 +1402,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est):
delay_actor=delay_actor,
delay_value=delay_value,
)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
@@ -2009,7 +2011,7 @@ def test_td3(
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
@@ -2696,7 +2698,7 @@ def test_sac(
**kwargs,
)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
@@ -3481,7 +3483,7 @@ def test_discrete_sac(
loss_function="l2",
**kwargs,
)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
@@ -4091,7 +4093,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est):
loss_function="l2",
delay_qvalue=delay_qvalue,
)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
@@ -4458,7 +4460,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est):
loss_function="l2",
delay_qvalue=delay_qvalue,
)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
@@ -4475,7 +4477,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est):
loss_function="l2",
delay_qvalue=delay_qvalue,
)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn_deprec.make_value_estimator(td_est)
return
@@ -4888,7 +4890,7 @@ def test_cql(
delay_qvalue=delay_qvalue,
)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
@@ -4984,6 +4986,18 @@ def test_cql(
else:
raise NotImplementedError(k)
loss_fn.zero_grad()
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.actor_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
+ assert all(
+ (p.grad is None) or (p.grad == 0).all()
+ for p in loss_fn.qvalue_network_params.values(
+ include_nested=True, leaves_only=True
+ )
+ )
sum([item for _, item in loss.items()]).backward()
named_parameters = list(loss_fn.named_parameters())
@@ -5324,7 +5338,7 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est):
action_spec_type=action_spec_type, device=device
)
loss_fn = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
@@ -5555,6 +5569,7 @@ def _create_mock_actor(
action_dim=4,
device="cpu",
observation_key="observation",
+ sample_log_prob_key="sample_log_prob",
):
# Actor
action_spec = BoundedTensorSpec(
@@ -5569,6 +5584,8 @@ def _create_mock_actor(
distribution_class=TanhNormal,
in_keys=["loc", "scale"],
spec=action_spec,
+ return_log_prob=True,
+ log_prob_key=sample_log_prob_key,
)
return actor.to(device)
@@ -5606,6 +5623,7 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu
distribution_class=TanhNormal,
in_keys=["loc", "scale"],
spec=action_spec,
+ return_log_prob=True,
)
module = nn.Sequential(base_layer, nn.Linear(5, 1))
value = ValueOperator(
@@ -5632,6 +5650,7 @@ def _create_mock_actor_value_shared(
distribution_class=TanhNormal,
in_keys=["loc", "scale"],
spec=action_spec,
+ return_log_prob=True,
)
module = nn.Linear(5, 1)
value_head = ValueOperator(
@@ -5739,7 +5758,7 @@ def _create_seq_mock_data_ppo(
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("gradient_mode", (True, False))
- @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None))
+ @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est):
@@ -5752,6 +5771,13 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est):
advantage = GAE(
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
)
+ elif advantage == "vtrace":
+ advantage = VTrace(
+ gamma=0.9,
+ value_network=value,
+ actor_network=actor,
+ differentiable=gradient_mode,
+ )
elif advantage == "td":
advantage = TD1Estimator(
gamma=0.9, value_network=value, differentiable=gradient_mode
@@ -5818,7 +5844,7 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode):
loss_fn2.load_state_dict(sd)
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
- @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None))
+ @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_ppo_shared(self, loss_class, device, advantage):
torch.manual_seed(self.seed)
@@ -5831,6 +5857,12 @@ def test_ppo_shared(self, loss_class, device, advantage):
lmbda=0.9,
value_network=value,
)
+ elif advantage == "vtrace":
+ advantage = VTrace(
+ gamma=0.9,
+ value_network=value,
+ actor_network=actor,
+ )
elif advantage == "td":
advantage = TD1Estimator(
gamma=0.9,
@@ -5892,6 +5924,7 @@ def test_ppo_shared(self, loss_class, device, advantage):
"advantage",
(
"gae",
+ "vtrace",
"td",
"td_lambda",
),
@@ -5911,6 +5944,12 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
lmbda=0.9,
value_network=value,
)
+ elif advantage == "vtrace":
+ advantage = VTrace(
+ gamma=0.9,
+ value_network=value,
+ actor_network=actor,
+ )
elif advantage == "td":
advantage = TD1Estimator(
gamma=0.9,
@@ -5962,11 +6001,9 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses):
)
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("gradient_mode", (True, False))
- @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None))
+ @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
- if pack_version.parse(torch.__version__) > pack_version.parse("1.14"):
- raise pytest.skip("make_functional_with_buffers needs to be changed")
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device)
@@ -5976,6 +6013,13 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
advantage = GAE(
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
)
+ elif advantage == "vtrace":
+ advantage = VTrace(
+ gamma=0.9,
+ value_network=value,
+ actor_network=actor,
+ differentiable=gradient_mode,
+ )
elif advantage == "td":
advantage = TD1Estimator(
gamma=0.9, value_network=value, differentiable=gradient_mode
@@ -5991,21 +6035,31 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
loss_fn = loss_class(actor, value, gamma=0.9, loss_critic_type="l2")
- floss_fn, params, buffers = make_functional_with_buffers(loss_fn)
+ params = TensorDict.from_module(loss_fn, as_module=True)
+
# fill params with zero
- for p in params:
- p.data.zero_()
+ def zero_param(p):
+ if isinstance(p, nn.Parameter):
+ p.data.zero_()
+
+ params.apply(zero_param)
+
# assert len(list(floss_fn.parameters())) == 0
- if advantage is not None:
- advantage(td)
- loss = floss_fn(params, buffers, td)
+ with params.to_module(loss_fn):
+ if advantage is not None:
+ advantage(td)
+ loss = loss_fn(td)
loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
# check that grads are independent and non null
named_parameters = loss_fn.named_parameters()
- for (name, _), p in zip(named_parameters, params):
+ for name, p in params.items(True, True):
+ if isinstance(name, tuple):
+ name = "-".join(name)
+ if not isinstance(p, nn.Parameter):
+ continue
if p.grad is not None and p.grad.norm() > 0.0:
assert "actor" not in name
assert "critic" in name
@@ -6013,12 +6067,12 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
assert "actor" in name
assert "critic" not in name
- for param in params:
- param.grad = None
+ for p in params.values(True, True):
+ p.grad = None
loss_objective.backward()
named_parameters = loss_fn.named_parameters()
-
- for (name, other_p), p in zip(named_parameters, params):
+ for (name, other_p) in named_parameters:
+ p = params.get(tuple(name.split(".")))
assert other_p.shape == p.shape
assert other_p.dtype == p.dtype
assert other_p.device == p.device
@@ -6028,7 +6082,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
if p.grad is None:
assert "actor" not in name
assert "critic" in name
- for param in params:
+ for param in params.values(True, True):
param.grad = None
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@@ -6038,6 +6092,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
ValueEstimators.TD1,
ValueEstimators.TD0,
ValueEstimators.GAE,
+ ValueEstimators.VTrace,
ValueEstimators.TDLambda,
],
)
@@ -6079,7 +6134,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est):
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
- @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None))
+ @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
"""Test PPO loss module with non-default tensordict keys."""
@@ -6097,7 +6152,9 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
sample_log_prob_key=tensor_keys["sample_log_prob"],
action_key=tensor_keys["action"],
)
- actor = self._create_mock_actor()
+ actor = self._create_mock_actor(
+ sample_log_prob_key=tensor_keys["sample_log_prob"]
+ )
value = self._create_mock_value(out_keys=[tensor_keys["value"]])
if advantage == "gae":
@@ -6107,6 +6164,13 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
value_network=value,
differentiable=gradient_mode,
)
+ elif advantage == "vtrace":
+ advantage = VTrace(
+ gamma=0.9,
+ value_network=value,
+ actor_network=actor,
+ differentiable=gradient_mode,
+ )
elif advantage == "td":
advantage = TD1Estimator(
gamma=0.9,
@@ -6200,7 +6264,9 @@ def test_ppo_notensordict(
terminated_key=terminated_key,
)
- actor = self._create_mock_actor(observation_key=observation_key)
+ actor = self._create_mock_actor(
+ observation_key=observation_key, sample_log_prob_key=sample_log_prob_key
+ )
value = self._create_mock_value(observation_key=observation_key)
loss = loss_class(actor=actor, critic=value)
@@ -6259,6 +6325,7 @@ def _create_mock_actor(
action_dim=4,
device="cpu",
observation_key="observation",
+ sample_log_prob_key="sample_log_prob",
):
# Actor
action_spec = BoundedTensorSpec(
@@ -6273,6 +6340,8 @@ def _create_mock_actor(
in_keys=["loc", "scale"],
spec=action_spec,
distribution_class=TanhNormal,
+ return_log_prob=True,
+ log_prob_key=sample_log_prob_key,
)
return actor.to(device)
@@ -6363,6 +6432,7 @@ def _create_seq_mock_data_a2c(
reward_key="reward",
done_key="done",
terminated_key="terminated",
+ sample_log_prob_key="sample_log_prob",
):
# create a tensordict
total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
@@ -6392,7 +6462,7 @@ def _create_seq_mock_data_a2c(
},
"collector": {"mask": mask},
action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0),
- "sample_log_prob": torch.randn_like(action[..., 1]).masked_fill_(
+ sample_log_prob_key: torch.randn_like(action[..., 1]).masked_fill_(
~mask, 0.0
)
/ 10,
@@ -6405,7 +6475,7 @@ def _create_seq_mock_data_a2c(
return td
@pytest.mark.parametrize("gradient_mode", (True, False))
- @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None))
+ @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_a2c(self, device, gradient_mode, advantage, td_est):
@@ -6418,6 +6488,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est):
advantage = GAE(
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
)
+ elif advantage == "vtrace":
+ advantage = VTrace(
+ gamma=0.9,
+ value_network=value,
+ actor_network=actor,
+ differentiable=gradient_mode,
+ )
elif advantage == "td":
advantage = TD1Estimator(
gamma=0.9, value_network=value, differentiable=gradient_mode
@@ -6462,6 +6539,8 @@ def test_a2c(self, device, gradient_mode, advantage, td_est):
assert ("critic" not in name) or ("target_" in name)
value.zero_grad()
+ for n, p in loss_fn.named_parameters():
+ assert p.grad is None or p.grad.norm() == 0, n
loss_objective.backward()
named_parameters = loss_fn.named_parameters()
for name, p in named_parameters:
@@ -6542,7 +6621,7 @@ def test_a2c_separate_losses(self, separate_losses):
not _has_functorch, reason=f"functorch not found, {FUNCTORCH_ERR}"
)
@pytest.mark.parametrize("gradient_mode", (True, False))
- @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None))
+ @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
def test_a2c_diff(self, device, gradient_mode, advantage):
if pack_version.parse(torch.__version__) > pack_version.parse("1.14"):
@@ -6560,6 +6639,13 @@ def test_a2c_diff(self, device, gradient_mode, advantage):
advantage = TD1Estimator(
gamma=0.9, value_network=value, differentiable=gradient_mode
)
+ elif advantage == "vtrace":
+ advantage = VTrace(
+ gamma=0.9,
+ value_network=value,
+ actor_network=actor,
+ differentiable=gradient_mode,
+ )
elif advantage == "td_lambda":
advantage = TDLambdaEstimator(
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
@@ -6609,6 +6695,7 @@ def test_a2c_diff(self, device, gradient_mode, advantage):
ValueEstimators.TD1,
ValueEstimators.TD0,
ValueEstimators.GAE,
+ ValueEstimators.VTrace,
ValueEstimators.TDLambda,
],
)
@@ -6626,6 +6713,7 @@ def test_a2c_tensordict_keys(self, td_est):
"reward": "reward",
"done": "done",
"terminated": "terminated",
+ "sample_log_prob": "sample_log_prob",
}
self.tensordict_keys_test(
@@ -6648,8 +6736,16 @@ def test_a2c_tensordict_keys(self, td_est):
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)
+ @pytest.mark.parametrize(
+ "td_est",
+ [
+ ValueEstimators.GAE,
+ ValueEstimators.VTrace,
+ ],
+ )
+ @pytest.mark.parametrize("advantage", ("gae", "vtrace", None))
@pytest.mark.parametrize("device", get_default_devices())
- def test_a2c_tensordict_keys_run(self, device):
+ def test_a2c_tensordict_keys_run(self, device, advantage, td_est):
"""Test A2C loss module with non-default tensordict keys."""
torch.manual_seed(self.seed)
gradient_mode = True
@@ -6658,6 +6754,7 @@ def test_a2c_tensordict_keys_run(self, device):
value_key = "state_value_test"
action_key = "action_test"
reward_key = "reward_test"
+ sample_log_prob_key = "sample_log_prob_test"
done_key = ("done", "test")
terminated_key = ("terminated", "test")
@@ -6667,24 +6764,29 @@ def test_a2c_tensordict_keys_run(self, device):
reward_key=reward_key,
done_key=done_key,
terminated_key=terminated_key,
+ sample_log_prob_key=sample_log_prob_key,
)
- actor = self._create_mock_actor(device=device)
- value = self._create_mock_value(device=device, out_keys=[value_key])
- advantage = GAE(
- gamma=0.9,
- lmbda=0.9,
- value_network=value,
- differentiable=gradient_mode,
- )
- advantage.set_keys(
- advantage=advantage_key,
- value_target=value_target_key,
- value=value_key,
- reward=reward_key,
- done=done_key,
- terminated=terminated_key,
+ actor = self._create_mock_actor(
+ device=device, sample_log_prob_key=sample_log_prob_key
)
+ value = self._create_mock_value(device=device, out_keys=[value_key])
+ if advantage == "gae":
+ advantage = GAE(
+ gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
+ )
+ elif advantage == "vtrace":
+ advantage = VTrace(
+ gamma=0.9,
+ value_network=value,
+ actor_network=actor,
+ differentiable=gradient_mode,
+ )
+ elif advantage is None:
+ pass
+ else:
+ raise NotImplementedError
+
loss_fn = A2CLoss(actor, value, loss_critic_type="l2")
loss_fn.set_keys(
advantage=advantage_key,
@@ -6694,9 +6796,23 @@ def test_a2c_tensordict_keys_run(self, device):
reward=reward_key,
done=done_key,
terminated=done_key,
+ sample_log_prob=sample_log_prob_key,
)
- advantage(td)
+ if advantage is not None:
+ advantage.set_keys(
+ advantage=advantage_key,
+ value_target=value_target_key,
+ value=value_key,
+ reward=reward_key,
+ done=done_key,
+ terminated=terminated_key,
+ sample_log_prob=sample_log_prob_key,
+ )
+ advantage(td)
+ else:
+ if td_est is not None:
+ loss_fn.make_value_estimator(td_est)
loss = loss_fn(td)
loss_critic = loss["loss_critic"]
@@ -6794,7 +6910,16 @@ class TestReinforce(LossModuleTestBase):
@pytest.mark.parametrize("delay_value", [True, False])
@pytest.mark.parametrize("gradient_mode", [True, False])
@pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None])
- @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
+ @pytest.mark.parametrize(
+ "td_est",
+ [
+ ValueEstimators.TD1,
+ ValueEstimators.TD0,
+ ValueEstimators.GAE,
+ ValueEstimators.TDLambda,
+ None,
+ ],
+ )
def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est):
n_obs = 3
n_act = 5
@@ -6816,20 +6941,20 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est
advantage = GAE(
gamma=gamma,
lmbda=0.9,
- value_network=get_functional(value_net),
+ value_network=value_net,
differentiable=gradient_mode,
)
elif advantage == "td":
advantage = TD1Estimator(
gamma=gamma,
- value_network=get_functional(value_net),
+ value_network=value_net,
differentiable=gradient_mode,
)
elif advantage == "td_lambda":
advantage = TDLambdaEstimator(
gamma=0.9,
lmbda=0.9,
- value_network=get_functional(value_net),
+ value_network=value_net,
differentiable=gradient_mode,
)
elif advantage is None:
@@ -7512,7 +7637,7 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est)
imagination_horizon=imagination_horizon,
discount_loss=discount_loss,
)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_module.make_value_estimator(td_est)
return
@@ -8254,7 +8379,7 @@ def test_iql(
expectile=expectile,
loss_function="l2",
)
- if td_est is ValueEstimators.GAE:
+ if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
@@ -9615,6 +9740,113 @@ def test_gae_multidim(
torch.testing.assert_close(r1, r3, rtol=1e-4, atol=1e-4)
torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4)
+ @pytest.mark.parametrize("device", get_default_devices())
+ @pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1])
+ @pytest.mark.parametrize("N", [(1,), (3,), (7, 3)])
+ @pytest.mark.parametrize("T", [200, 5, 3])
+ @pytest.mark.parametrize("dtype", [torch.float, torch.double])
+ @pytest.mark.parametrize("has_done", [False, True])
+ def test_vtrace(self, device, gamma, N, T, dtype, has_done):
+ torch.manual_seed(0)
+
+ done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool)
+ terminated = done.clone()
+ if has_done:
+ terminated = terminated.bernoulli_(0.1)
+ done = done.bernoulli_(0.1) | terminated
+ reward = torch.randn(*N, T, 1, device=device, dtype=dtype)
+ state_value = torch.randn(*N, T, 1, device=device, dtype=dtype)
+ next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype)
+ log_pi = torch.log(torch.rand(*N, T, 1, device=device, dtype=dtype))
+ log_mu = torch.log(torch.rand(*N, T, 1, device=device, dtype=dtype))
+
+ _, value_target = vtrace_advantage_estimate(
+ gamma,
+ log_pi,
+ log_mu,
+ state_value,
+ next_state_value,
+ reward,
+ done=done,
+ terminated=terminated,
+ )
+
+ assert not torch.isnan(value_target).any()
+ assert not torch.isinf(value_target).any()
+
+ @pytest.mark.parametrize("device", get_default_devices())
+ @pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1])
+ @pytest.mark.parametrize("N", [(3,), (7, 3)])
+ @pytest.mark.parametrize("T", [100, 3])
+ @pytest.mark.parametrize("dtype", [torch.float, torch.double])
+ @pytest.mark.parametrize("feature_dim", [[5], [2, 5]])
+ @pytest.mark.parametrize("has_done", [True, False])
+ def test_vtrace_multidim(self, device, gamma, N, T, dtype, has_done, feature_dim):
+ D = feature_dim
+ time_dim = -1 - len(D)
+
+ torch.manual_seed(0)
+
+ done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool)
+ terminated = done.clone()
+ if has_done:
+ terminated = terminated.bernoulli_(0.1)
+ done = done.bernoulli_(0.1) | terminated
+ reward = torch.randn(*N, T, *D, device=device, dtype=dtype)
+ state_value = torch.randn(*N, T, *D, device=device, dtype=dtype)
+ next_state_value = torch.randn(*N, T, *D, device=device, dtype=dtype)
+ log_pi = torch.log(torch.rand(*N, T, *D, device=device, dtype=dtype))
+ log_mu = torch.log(torch.rand(*N, T, *D, device=device, dtype=dtype))
+
+ r1 = vtrace_advantage_estimate(
+ gamma,
+ log_pi,
+ log_mu,
+ state_value,
+ next_state_value,
+ reward,
+ done=done,
+ terminated=terminated,
+ time_dim=time_dim,
+ )
+ if len(D) == 2:
+ r2 = [
+ vtrace_advantage_estimate(
+ gamma,
+ log_pi[..., i : i + 1, j],
+ log_mu[..., i : i + 1, j],
+ state_value[..., i : i + 1, j],
+ next_state_value[..., i : i + 1, j],
+ reward[..., i : i + 1, j],
+ terminated=terminated[..., i : i + 1, j],
+ done=done[..., i : i + 1, j],
+ time_dim=-2,
+ )
+ for i in range(D[0])
+ for j in range(D[1])
+ ]
+ else:
+ r2 = [
+ vtrace_advantage_estimate(
+ gamma,
+ log_pi[..., i : i + 1],
+ log_mu[..., i : i + 1],
+ state_value[..., i : i + 1],
+ next_state_value[..., i : i + 1],
+ reward[..., i : i + 1],
+ done=done[..., i : i + 1],
+ terminated=terminated[..., i : i + 1],
+ time_dim=-2,
+ )
+ for i in range(D[0])
+ ]
+
+ list2 = list(zip(*r2))
+ r2 = [torch.cat(list2[0], -1), torch.cat(list2[1], -1)]
+ if len(D) == 2:
+ r2 = [r2[0].unflatten(-1, D), r2[1].unflatten(-1, D)]
+ torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4)
+
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("gamma", [0.5, 0.99, 0.1])
@pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99])
@@ -9638,9 +9870,6 @@ def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done):
next_state_value = torch.randn(*N, T, 1, device=device)
gamma_tensor = torch.full((*N, T, 1), gamma, device=device)
- # if len(N) == 2:
- # print(terminated[4, 0, -10:])
- # print(done[4, 0, -10:])
v1 = vec_td_lambda_advantage_estimate(
gamma,
lmbda,
@@ -10549,6 +10778,7 @@ class TestAdv:
[GAE, {"lmbda": 0.95}],
[TD1Estimator, {}],
[TDLambdaEstimator, {"lmbda": 0.95}],
+ [VTrace, {}],
],
)
def test_dispatch(
@@ -10559,18 +10789,46 @@ def test_dispatch(
value_net = TensorDictModule(
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
)
- module = adv(
- gamma=0.98,
- value_network=value_net,
- differentiable=False,
- **kwargs,
- )
- kwargs = {
- "obs": torch.randn(1, 10, 3),
- "next_reward": torch.randn(1, 10, 1, requires_grad=True),
- "next_done": torch.zeros(1, 10, 1, dtype=torch.bool),
- "next_obs": torch.randn(1, 10, 3),
- }
+ if adv is VTrace:
+ actor_net = TensorDictModule(
+ nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]
+ )
+ actor_net = ProbabilisticActor(
+ module=actor_net,
+ in_keys=["logits"],
+ out_keys=["action"],
+ distribution_class=OneHotCategorical,
+ return_log_prob=True,
+ )
+ module = adv(
+ gamma=0.98,
+ actor_network=actor_net,
+ value_network=value_net,
+ differentiable=False,
+ **kwargs,
+ )
+ kwargs = {
+ "obs": torch.randn(1, 10, 3),
+ "sample_log_prob": torch.log(torch.rand(1, 10, 1)),
+ "next_reward": torch.randn(1, 10, 1, requires_grad=True),
+ "next_done": torch.zeros(1, 10, 1, dtype=torch.bool),
+ "next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool),
+ "next_obs": torch.randn(1, 10, 3),
+ }
+ else:
+ module = adv(
+ gamma=0.98,
+ value_network=value_net,
+ differentiable=False,
+ **kwargs,
+ )
+ kwargs = {
+ "obs": torch.randn(1, 10, 3),
+ "next_reward": torch.randn(1, 10, 1, requires_grad=True),
+ "next_done": torch.zeros(1, 10, 1, dtype=torch.bool),
+ "next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool),
+ "next_obs": torch.randn(1, 10, 3),
+ }
advantage, value_target = module(**kwargs)
assert advantage.shape == torch.Size([1, 10, 1])
assert value_target.shape == torch.Size([1, 10, 1])
@@ -10581,6 +10839,7 @@ def test_dispatch(
[GAE, {"lmbda": 0.95}],
[TD1Estimator, {}],
[TDLambdaEstimator, {"lmbda": 0.95}],
+ [VTrace, {}],
],
)
def test_diff_reward(
@@ -10591,23 +10850,55 @@ def test_diff_reward(
value_net = TensorDictModule(
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
)
- module = adv(
- gamma=0.98,
- value_network=value_net,
- differentiable=True,
- **kwargs,
- )
- td = TensorDict(
- {
- "obs": torch.randn(1, 10, 3),
- "next": {
+ if adv is VTrace:
+ actor_net = TensorDictModule(
+ nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]
+ )
+ actor_net = ProbabilisticActor(
+ module=actor_net,
+ in_keys=["logits"],
+ out_keys=["action"],
+ distribution_class=OneHotCategorical,
+ return_log_prob=True,
+ )
+ module = adv(
+ gamma=0.98,
+ actor_network=actor_net,
+ value_network=value_net,
+ differentiable=True,
+ **kwargs,
+ )
+ td = TensorDict(
+ {
"obs": torch.randn(1, 10, 3),
- "reward": torch.randn(1, 10, 1, requires_grad=True),
- "done": torch.zeros(1, 10, 1, dtype=torch.bool),
+ "sample_log_prob": torch.log(torch.rand(1, 10, 1)),
+ "next": {
+ "obs": torch.randn(1, 10, 3),
+ "reward": torch.randn(1, 10, 1, requires_grad=True),
+ "done": torch.zeros(1, 10, 1, dtype=torch.bool),
+ "terminated": torch.zeros(1, 10, 1, dtype=torch.bool),
+ },
},
- },
- [1, 10],
- )
+ [1, 10],
+ )
+ else:
+ module = adv(
+ gamma=0.98,
+ value_network=value_net,
+ differentiable=True,
+ **kwargs,
+ )
+ td = TensorDict(
+ {
+ "obs": torch.randn(1, 10, 3),
+ "next": {
+ "obs": torch.randn(1, 10, 3),
+ "reward": torch.randn(1, 10, 1, requires_grad=True),
+ "done": torch.zeros(1, 10, 1, dtype=torch.bool),
+ },
+ },
+ [1, 10],
+ )
td = module(td.clone(False))
# check that the advantage can't backprop to the value params
td["advantage"].sum().backward()
@@ -10622,6 +10913,7 @@ def test_diff_reward(
[GAE, {"lmbda": 0.95}],
[TD1Estimator, {}],
[TDLambdaEstimator, {"lmbda": 0.95}],
+ [VTrace, {}],
],
)
@pytest.mark.parametrize("shifted", [True, False])
@@ -10629,25 +10921,60 @@ def test_non_differentiable(self, adv, shifted, kwargs):
value_net = TensorDictModule(
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
)
- module = adv(
- gamma=0.98,
- value_network=value_net,
- differentiable=False,
- shifted=shifted,
- **kwargs,
- )
- td = TensorDict(
- {
- "obs": torch.randn(1, 10, 3),
- "next": {
+
+ if adv is VTrace:
+ actor_net = TensorDictModule(
+ nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]
+ )
+ actor_net = ProbabilisticActor(
+ module=actor_net,
+ in_keys=["logits"],
+ out_keys=["action"],
+ distribution_class=OneHotCategorical,
+ return_log_prob=True,
+ )
+ module = adv(
+ gamma=0.98,
+ actor_network=actor_net,
+ value_network=value_net,
+ differentiable=False,
+ shifted=shifted,
+ **kwargs,
+ )
+ td = TensorDict(
+ {
"obs": torch.randn(1, 10, 3),
- "reward": torch.randn(1, 10, 1, requires_grad=True),
- "done": torch.zeros(1, 10, 1, dtype=torch.bool),
+ "sample_log_prob": torch.log(torch.rand(1, 10, 1)),
+ "next": {
+ "obs": torch.randn(1, 10, 3),
+ "reward": torch.randn(1, 10, 1, requires_grad=True),
+ "done": torch.zeros(1, 10, 1, dtype=torch.bool),
+ "terminated": torch.zeros(1, 10, 1, dtype=torch.bool),
+ },
},
- },
- [1, 10],
- names=[None, "time"],
- )
+ [1, 10],
+ names=[None, "time"],
+ )
+ else:
+ module = adv(
+ gamma=0.98,
+ value_network=value_net,
+ differentiable=False,
+ shifted=shifted,
+ **kwargs,
+ )
+ td = TensorDict(
+ {
+ "obs": torch.randn(1, 10, 3),
+ "next": {
+ "obs": torch.randn(1, 10, 3),
+ "reward": torch.randn(1, 10, 1, requires_grad=True),
+ "done": torch.zeros(1, 10, 1, dtype=torch.bool),
+ },
+ },
+ [1, 10],
+ names=[None, "time"],
+ )
td = module(td.clone(False))
assert td["advantage"].is_leaf
@@ -10657,6 +10984,7 @@ def test_non_differentiable(self, adv, shifted, kwargs):
[GAE, {"lmbda": 0.95}],
[TD1Estimator, {}],
[TDLambdaEstimator, {"lmbda": 0.95}],
+ [VTrace, {}],
],
)
@pytest.mark.parametrize("has_value_net", [True, False])
@@ -10679,28 +11007,65 @@ def test_skip_existing(
else:
value_net = None
- module = adv(
- gamma=0.98,
- value_network=value_net,
- differentiable=True,
- shifted=shifted,
- skip_existing=skip_existing,
- **kwargs,
- )
- td = TensorDict(
- {
- "obs": torch.randn(1, 10, 3),
- "state_value": torch.ones(1, 10, 1),
- "next": {
+ if adv is VTrace:
+ actor_net = TensorDictModule(
+ nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]
+ )
+ actor_net = ProbabilisticActor(
+ module=actor_net,
+ in_keys=["logits"],
+ out_keys=["action"],
+ distribution_class=OneHotCategorical,
+ return_log_prob=True,
+ )
+ module = adv(
+ gamma=0.98,
+ actor_network=actor_net,
+ value_network=value_net,
+ differentiable=True,
+ shifted=shifted,
+ skip_existing=skip_existing,
+ **kwargs,
+ )
+ td = TensorDict(
+ {
"obs": torch.randn(1, 10, 3),
+ "sample_log_prob": torch.log(torch.rand(1, 10, 1)),
"state_value": torch.ones(1, 10, 1),
- "reward": torch.randn(1, 10, 1, requires_grad=True),
- "done": torch.zeros(1, 10, 1, dtype=torch.bool),
+ "next": {
+ "obs": torch.randn(1, 10, 3),
+ "state_value": torch.ones(1, 10, 1),
+ "reward": torch.randn(1, 10, 1, requires_grad=True),
+ "done": torch.zeros(1, 10, 1, dtype=torch.bool),
+ "terminated": torch.zeros(1, 10, 1, dtype=torch.bool),
+ },
},
- },
- [1, 10],
- names=[None, "time"],
- )
+ [1, 10],
+ names=[None, "time"],
+ )
+ else:
+ module = adv(
+ gamma=0.98,
+ value_network=value_net,
+ differentiable=True,
+ shifted=shifted,
+ skip_existing=skip_existing,
+ **kwargs,
+ )
+ td = TensorDict(
+ {
+ "obs": torch.randn(1, 10, 3),
+ "state_value": torch.ones(1, 10, 1),
+ "next": {
+ "obs": torch.randn(1, 10, 3),
+ "state_value": torch.ones(1, 10, 1),
+ "reward": torch.randn(1, 10, 1, requires_grad=True),
+ "done": torch.zeros(1, 10, 1, dtype=torch.bool),
+ },
+ },
+ [1, 10],
+ names=[None, "time"],
+ )
td = module(td.clone(False))
if has_value_net and not skip_existing:
exp_val = 0
@@ -10718,15 +11083,34 @@ def test_skip_existing(
[GAE, {"lmbda": 0.95}],
[TD1Estimator, {}],
[TDLambdaEstimator, {"lmbda": 0.95}],
+ [VTrace, {}],
],
)
def test_set_keys(self, value, adv, kwargs):
value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=[value])
- module = adv(
- gamma=0.98,
- value_network=value_net,
- **kwargs,
- )
+ if adv is VTrace:
+ actor_net = TensorDictModule(
+ nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]
+ )
+ actor_net = ProbabilisticActor(
+ module=actor_net,
+ in_keys=["logits"],
+ out_keys=["action"],
+ distribution_class=OneHotCategorical,
+ return_log_prob=True,
+ )
+ module = adv(
+ gamma=0.98,
+ actor_network=actor_net,
+ value_network=value_net,
+ **kwargs,
+ )
+ else:
+ module = adv(
+ gamma=0.98,
+ value_network=value_net,
+ **kwargs,
+ )
module.set_keys(value=value)
assert module.tensor_keys.value == value
@@ -10740,6 +11124,7 @@ def test_set_keys(self, value, adv, kwargs):
[GAE, {"lmbda": 0.95}],
[TD1Estimator, {}],
[TDLambdaEstimator, {"lmbda": 0.95}],
+ [VTrace, {}],
],
)
def test_set_deprecated_keys(self, adv, kwargs):
@@ -10748,14 +11133,36 @@ def test_set_deprecated_keys(self, adv, kwargs):
)
with pytest.warns(DeprecationWarning):
- module = adv(
- gamma=0.98,
- value_network=value_net,
- value_key="test_value",
- advantage_key="advantage_test",
- value_target_key="value_target_test",
- **kwargs,
- )
+
+ if adv is VTrace:
+ actor_net = TensorDictModule(
+ nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]
+ )
+ actor_net = ProbabilisticActor(
+ module=actor_net,
+ in_keys=["logits"],
+ out_keys=["action"],
+ distribution_class=OneHotCategorical,
+ return_log_prob=True,
+ )
+ module = adv(
+ gamma=0.98,
+ actor_network=actor_net,
+ value_network=value_net,
+ value_key="test_value",
+ advantage_key="advantage_test",
+ value_target_key="value_target_test",
+ **kwargs,
+ )
+ else:
+ module = adv(
+ gamma=0.98,
+ value_network=value_net,
+ value_key="test_value",
+ advantage_key="advantage_test",
+ value_target_key="value_target_test",
+ **kwargs,
+ )
assert module.tensor_keys.value == "test_value"
assert module.tensor_keys.advantage == "advantage_test"
assert module.tensor_keys.value_target == "value_target_test"
diff --git a/test/test_env.py b/test/test_env.py
index 6cee7f545d7..aed4e07b0b7 100644
--- a/test/test_env.py
+++ b/test/test_env.py
@@ -354,6 +354,48 @@ def test_mb_env_batch_lock(self, device, seed=0):
class TestParallel:
+ @pytest.mark.skipif(
+ not torch.cuda.device_count(), reason="No cuda device detected."
+ )
+ @pytest.mark.parametrize("parallel", [True, False])
+ @pytest.mark.parametrize("hetero", [True, False])
+ @pytest.mark.parametrize("pdevice", [None, "cpu", "cuda"])
+ @pytest.mark.parametrize("edevice", ["cpu", "cuda"])
+ @pytest.mark.parametrize("bwad", [True, False])
+ def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad):
+ if parallel:
+ cls = ParallelEnv
+ else:
+ cls = SerialEnv
+ if not hetero:
+ env = cls(
+ 2, lambda: ContinuousActionVecMockEnv(device=edevice), device=pdevice
+ )
+ else:
+ env1 = lambda: ContinuousActionVecMockEnv(device=edevice)
+ env2 = lambda: TransformedEnv(ContinuousActionVecMockEnv(device=edevice))
+ env = cls(2, [env1, env2], device=pdevice)
+
+ r = env.rollout(2, break_when_any_done=bwad)
+ if pdevice is not None:
+ assert env.device.type == torch.device(pdevice).type
+ assert r.device.type == torch.device(pdevice).type
+ assert all(
+ item.device.type == torch.device(pdevice).type
+ for item in r.values(True, True)
+ )
+ else:
+ assert env.device.type == torch.device(edevice).type
+ assert r.device.type == torch.device(edevice).type
+ assert all(
+ item.device.type == torch.device(edevice).type
+ for item in r.values(True, True)
+ )
+ if parallel:
+ assert (
+ env.shared_tensordict_parent.device.type == torch.device(edevice).type
+ )
+
@pytest.mark.parametrize("num_parallel_env", [1, 10])
@pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)])
def test_env_with_batch_size(self, num_parallel_env, env_batch_size):
diff --git a/test/test_exploration.py b/test/test_exploration.py
index 8a374cd9009..24bb8c246d0 100644
--- a/test/test_exploration.py
+++ b/test/test_exploration.py
@@ -51,7 +51,7 @@
class TestEGreedy:
- @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0])
+ @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1])
@pytest.mark.parametrize("module", [True, False])
def test_egreedy(self, eps_init, module):
torch.manual_seed(0)
@@ -78,7 +78,7 @@ def test_egreedy(self, eps_init, module):
assert (action == 0).any()
assert ((action == 1) | (action == 0)).all()
- @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0])
+ @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1])
@pytest.mark.parametrize("module", [True, False])
@pytest.mark.parametrize("spec_class", ["discrete", "one_hot"])
def test_egreedy_masked(self, module, eps_init, spec_class):
diff --git a/test/test_libs.py b/test/test_libs.py
index f1715a550f4..c3379021510 100644
--- a/test/test_libs.py
+++ b/test/test_libs.py
@@ -400,7 +400,7 @@ def test_vecenvs_wrapper(self, envname):
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
+ (["FetchReach-v2"] if _has_gym_robotics else []),
)
- @pytest.mark.flaky(reruns=3, reruns_delay=1)
+ @pytest.mark.flaky(reruns=8, reruns_delay=1)
def test_vecenvs_env(self, envname):
from _utils_internal import rollout_consistency_assertion
@@ -1896,12 +1896,8 @@ def test_direct_download(self, task):
keys = keys.intersection(data_d4rl._storage._storage.keys(True, True))
assert len(keys)
assert_allclose_td(
- data_direct._storage._storage.select(*keys).apply(
- lambda t: t.as_tensor().float()
- ),
- data_d4rl._storage._storage.select(*keys).apply(
- lambda t: t.as_tensor().float()
- ),
+ data_direct._storage._storage.select(*keys).apply(lambda t: t.float()),
+ data_d4rl._storage._storage.select(*keys).apply(lambda t: t.float()),
)
@pytest.mark.parametrize(
diff --git a/test/test_rb_distributed.py b/test/test_rb_distributed.py
index 8a46b1a006d..548f04dc41d 100644
--- a/test/test_rb_distributed.py
+++ b/test/test_rb_distributed.py
@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import argparse
import os
+
import sys
import time
@@ -22,10 +23,10 @@
class ReplayBufferNode(RemoteTensorDictReplayBuffer):
- def __init__(self, capacity: int):
+ def __init__(self, capacity: int, scratch_dir=None):
super().__init__(
storage=LazyMemmapStorage(
- max_size=capacity, scratch_dir="/tmp/", device=torch.device("cpu")
+ max_size=capacity, scratch_dir=scratch_dir, device=torch.device("cpu")
),
sampler=RandomSampler(),
writer=RoundRobinWriter(),
diff --git a/test/test_rlhf.py b/test/test_rlhf.py
index 2abb9a6d386..31ef96681df 100644
--- a/test/test_rlhf.py
+++ b/test/test_rlhf.py
@@ -14,7 +14,12 @@
import torch.nn.functional as F
from _utils_internal import get_default_devices
-from tensordict import is_tensor_collection, MemmapTensor, TensorDict, TensorDictBase
+from tensordict import (
+ is_tensor_collection,
+ MemoryMappedTensor,
+ TensorDict,
+ TensorDictBase,
+)
from tensordict.nn import TensorDictModule
from torchrl.data.rlhf import TensorDictTokenizer
from torchrl.data.rlhf.dataset import (
@@ -188,8 +193,8 @@ def test_dataset_to_tensordict(tmpdir, suffix):
else:
assert ("c", "d", "a") in td.keys(True)
assert ("c", "d", "b") in td.keys(True)
- assert isinstance(td.get((suffix, "a")), MemmapTensor)
- assert isinstance(td.get((suffix, "b")), MemmapTensor)
+ assert isinstance(td.get((suffix, "a")), MemoryMappedTensor)
+ assert isinstance(td.get((suffix, "b")), MemoryMappedTensor)
@pytest.mark.skipif(
diff --git a/test/test_shared.py b/test/test_shared.py
index c4790597359..186c8ae9525 100644
--- a/test/test_shared.py
+++ b/test/test_shared.py
@@ -144,24 +144,7 @@ def test_shared(self, shared):
)
-# @pytest.mark.skipif(
-# sys.platform == "win32",
-# reason="RuntimeError from Torch serialization.py when creating td_saved on Windows",
-# )
-@pytest.mark.parametrize(
- "idx",
- [
- torch.tensor(
- [
- 3,
- 5,
- 7,
- 8,
- ]
- ),
- slice(200),
- ],
-)
+@pytest.mark.parametrize("idx", [0, slice(200)])
@pytest.mark.parametrize("dtype", [torch.float, torch.bool])
def test_memmap(idx, dtype, large_scale=False):
N = 5000 if large_scale else 10
diff --git a/test/test_transforms.py b/test/test_transforms.py
index da8bc12c126..cff1d33b34a 100644
--- a/test/test_transforms.py
+++ b/test/test_transforms.py
@@ -9,6 +9,7 @@
import itertools
import pickle
+import re
import sys
from copy import copy
from functools import partial
@@ -4878,6 +4879,39 @@ def test_sum_reward(self, keys, device):
def test_transform_inverse(self):
raise pytest.skip("No inverse for RewardSum")
+ @pytest.mark.parametrize("in_keys", [["reward"], ["reward_1", "reward_2"]])
+ @pytest.mark.parametrize(
+ "out_keys", [["episode_reward"], ["episode_reward_1", "episode_reward_2"]]
+ )
+ @pytest.mark.parametrize("reset_keys", [["_reset"], ["_reset1", "_reset2"]])
+ def test_keys_length_errors(self, in_keys, reset_keys, out_keys, batch=10):
+ reset_dict = {
+ reset_key: torch.zeros(batch, dtype=torch.bool) for reset_key in reset_keys
+ }
+ reward_sum_dict = {out_key: torch.randn(batch) for out_key in out_keys}
+ reset_dict.update(reward_sum_dict)
+ td = TensorDict(reset_dict, [])
+
+ if len(in_keys) != len(out_keys):
+ with pytest.raises(
+ ValueError,
+ match="RewardSum expects the same number of input and output keys",
+ ):
+ RewardSum(in_keys=in_keys, reset_keys=reset_keys, out_keys=out_keys)
+ else:
+ t = RewardSum(in_keys=in_keys, reset_keys=reset_keys, out_keys=out_keys)
+
+ if len(in_keys) != len(reset_keys):
+ with pytest.raises(
+ ValueError,
+ match=re.escape(
+ f"Could not match the env reset_keys {reset_keys} with the in_keys {in_keys}"
+ ),
+ ):
+ t.reset(td)
+ else:
+ t.reset(td)
+
class TestReward2Go(TransformBase):
@pytest.mark.parametrize("device", get_default_devices())
diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py
index ef790b6f9f6..9c8417b9c97 100644
--- a/torchrl/data/replay_buffers/storages.py
+++ b/torchrl/data/replay_buffers/storages.py
@@ -12,7 +12,7 @@
import torch
from tensordict import is_tensorclass
-from tensordict.memmap import MemmapTensor
+from tensordict.memmap import MemmapTensor, MemoryMappedTensor
from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase
from tensordict.utils import expand_right
@@ -482,7 +482,7 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
if self.device == "auto":
self.device = data.device
if isinstance(data, torch.Tensor):
- # if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
+ # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = torch.empty(
self.max_size,
*data.shape,
@@ -531,12 +531,12 @@ class LazyMemmapStorage(LazyTensorStorage):
>>> storage.get(0)
TensorDict(
fields={
- some data: MemmapTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
+ some data: MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
some: TensorDict(
fields={
nested: TensorDict(
fields={
- data: MemmapTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)},
+ data: MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([11]),
device=cpu,
is_shared=False)},
@@ -560,8 +560,8 @@ class LazyMemmapStorage(LazyTensorStorage):
>>> storage.set(range(10), data)
>>> storage.get(0)
MyClass(
- bar=MemmapTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False),
- foo=MemmapTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
+ bar=MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False),
+ foo=MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
batch_size=torch.Size([11]),
device=cpu,
is_shared=False)
@@ -603,7 +603,12 @@ def load_state_dict(self, state_dict):
if isinstance(self._storage, torch.Tensor):
_mem_map_tensor_as_tensor(self._storage).copy_(_storage)
elif self._storage is None:
- self._storage = MemmapTensor(_storage)
+ self._storage = _make_memmap(
+ _storage,
+ path=self.scratch_dir + "/tensor.memmap"
+ if self.scratch_dir is not None
+ else None,
+ )
else:
raise RuntimeError(
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}"
@@ -638,7 +643,8 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
self.device = data.device
if self.device.type != "cpu":
warnings.warn(
- "Support for Memmap device other than CPU will be deprecated in v0.4.0.",
+ "Support for Memmap device other than CPU will be deprecated in v0.4.0. "
+ "Using a 'cuda' device may be suboptimal.",
category=DeprecationWarning,
)
if is_tensor_collection(data):
@@ -656,9 +662,13 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
)
else:
# If not a tensorclass/tensordict, it must be a tensor(-like)
- # if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
- out = MemmapTensor(
- self.max_size, *data.shape, device=self.device, dtype=data.dtype
+ # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
+ out = _make_empty_memmap(
+ (self.max_size, *data.shape),
+ dtype=data.dtype,
+ path=self.scratch_dir + "/tensor.memmap"
+ if self.scratch_dir is not None
+ else None,
)
if VERBOSE:
filesize = os.path.getsize(out.filename) / 1024 / 1024
@@ -668,6 +678,13 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
self._storage = out
self.initialized = True
+ def get(self, index: Union[int, Sequence[int], slice]) -> Any:
+ result = super().get(index)
+ # to be deprecated in v0.4
+ if result.device != self.device:
+ return result.to(self.device, non_blocking=True)
+ return result
+
# Utils
def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor:
@@ -677,6 +694,7 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor:
f"Supported backends are {_CKPT_BACKEND.backends}"
)
if isinstance(mem_map_tensor, torch.Tensor):
+ # This will account for MemoryMappedTensors
return mem_map_tensor
if _CKPT_BACKEND == "torchsnapshot":
# TorchSnapshot doesn't know how to stream MemmapTensor, so we view MemmapTensor
@@ -737,25 +755,27 @@ def _collate_list_tensordict(x):
return out
-def _collate_contiguous(x):
+def _collate_id(x):
return x
-def _collate_as_tensor(x):
- return x.as_tensor()
-
-
def _get_default_collate(storage, _is_tensordict=False):
if isinstance(storage, ListStorage):
if _is_tensordict:
return _collate_list_tensordict
else:
return torch.utils.data._utils.collate.default_collate
- elif isinstance(storage, LazyMemmapStorage):
- return _collate_as_tensor
- elif isinstance(storage, (TensorStorage,)):
- return _collate_contiguous
+ elif isinstance(storage, TensorStorage):
+ return _collate_id
else:
raise NotImplementedError(
f"Could not find a default collate_fn for storage {type(storage)}."
)
+
+
+def _make_memmap(tensor, path):
+ return MemoryMappedTensor.from_tensor(tensor, filename=path)
+
+
+def _make_empty_memmap(shape, dtype, path):
+ return MemoryMappedTensor.empty(shape=shape, dtype=dtype, filename=path)
diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py
index db2b6a418d6..09086bfad65 100644
--- a/torchrl/data/rlhf/dataset.py
+++ b/torchrl/data/rlhf/dataset.py
@@ -77,8 +77,8 @@ class TokenizedDatasetLoader:
>>> print(dataset)
TensorDict(
fields={
- attention_mask: MemmapTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
- input_ids: MemmapTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
+ attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
+ input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([185068]),
device=None,
is_shared=False)
@@ -137,6 +137,7 @@ def load(self):
data_dir = root_dir / str(Path(self.dataset_name).name).split("-")[0]
data_dir_total = data_dir / split / str(max_length)
# search for data
+ print("Looking for data in", data_dir_total)
if os.path.exists(data_dir_total):
dataset = TensorDict.load_memmap(data_dir_total)
return dataset
@@ -270,8 +271,8 @@ def dataset_to_tensordict(
fields={
prefix: TensorDict(
fields={
- labels: MemmapTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
- tokens: MemmapTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
+ labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
+ tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)},
diff --git a/torchrl/data/rlhf/prompt.py b/torchrl/data/rlhf/prompt.py
index d534a95379e..d50653c9967 100644
--- a/torchrl/data/rlhf/prompt.py
+++ b/torchrl/data/rlhf/prompt.py
@@ -74,10 +74,10 @@ def from_dataset(
>>> data = PromptData.from_dataset("train")
>>> print(data)
PromptDataTLDR(
- attention_mask=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
- input_ids=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
- prompt_rindex=MemmapTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False),
- labels=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
+ attention_mask=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
+ input_ids=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
+ prompt_rindex=MemoryMappedTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False),
+ labels=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
logits=None,
loss=None,
batch_size=torch.Size([116722]),
diff --git a/torchrl/data/rlhf/reward.py b/torchrl/data/rlhf/reward.py
index e7843e02f46..20f379ef659 100644
--- a/torchrl/data/rlhf/reward.py
+++ b/torchrl/data/rlhf/reward.py
@@ -41,16 +41,16 @@ class PairwiseDataset:
>>> print(data)
PairwiseDataset(
chosen_data=RewardData(
- attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
- input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
+ attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
+ input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
rewards=None,
end_scores=None,
batch_size=torch.Size([92534]),
device=None,
is_shared=False),
rejected_data=RewardData(
- attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
- input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
+ attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
+ input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
rewards=None,
end_scores=None,
batch_size=torch.Size([92534]),
@@ -97,16 +97,16 @@ def from_dataset(
>>> print(data)
PairwiseDataset(
chosen_data=RewardData(
- attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
- input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
+ attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
+ input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
rewards=None,
end_scores=None,
batch_size=torch.Size([92534]),
device=None,
is_shared=False),
rejected_data=RewardData(
- attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
- input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
+ attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
+ input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
rewards=None,
end_scores=None,
batch_size=torch.Size([92534]),
diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py
index f0e132eb092..ac0a136c7f9 100644
--- a/torchrl/envs/batched_envs.py
+++ b/torchrl/envs/batched_envs.py
@@ -122,11 +122,16 @@ class _BatchedEnv(EnvBase):
memmap (bool): whether or not the returned tensordict will be placed in memory map.
policy_proof (callable, optional): if provided, it'll be used to get the list of
tensors to return through the :obj:`step()` and :obj:`reset()` methods, such as :obj:`"hidden"` etc.
- device (str, int, torch.device): for consistency, this argument is kept. However this
- argument should not be passed, as the device will be inferred from the environments.
- It is assumed that all environments will run on the same device as a common shared
- tensordict will be used to pass data from process to process. The device can be
- changed after instantiation using :obj:`env.to(device)`.
+ device (str, int, torch.device): The device of the batched environment can be passed.
+ If not, it is inferred from the env. In this case, it is assumed that
+ the device of all environments match. If it is provided, it can differ
+ from the sub-environment device(s). In that case, the data will be
+ automatically cast to the appropriate device during collection.
+ This can be used to speed up collection in case casting to device
+ introduces an overhead (eg, numpy-based environents etc.): by using
+ a ``"cuda"`` device for the batched environment but a ``"cpu"``
+ device for the nested environments, one can keep the overhead to a
+ minimum.
num_threads (int, optional): number of threads for this process.
Defaults to the number of workers.
This parameter has no effect for the :class:`~SerialEnv` class.
@@ -162,14 +167,7 @@ def __init__(
num_threads: int = None,
num_sub_threads: int = 1,
):
- if device is not None:
- raise ValueError(
- "Device setting for batched environment can't be done at initialization. "
- "The device will be inferred from the constructed environment. "
- "It can be set through the `to(device)` method."
- )
-
- super().__init__(device=None)
+ super().__init__(device=device)
self.is_closed = True
if num_threads is None:
num_threads = num_workers + 1 # 1 more thread for this proc
@@ -218,7 +216,7 @@ def __init__(
"memmap and shared memory are mutually exclusive features."
)
self._batch_size = None
- self._device = None
+ self._device = torch.device(device) if device is not None else device
self._dummy_env_str = None
self._seeds = None
self.__dict__["_input_spec"] = None
@@ -273,7 +271,9 @@ def _set_properties(self):
self._properties_set = True
if self._single_task:
self._batch_size = meta_data.batch_size
- device = self._device = meta_data.device
+ device = meta_data.device
+ if self._device is None:
+ self._device = device
input_spec = meta_data.specs["input_spec"].to(device)
output_spec = meta_data.specs["output_spec"].to(device)
@@ -289,8 +289,18 @@ def _set_properties(self):
self._batch_locked = meta_data.batch_locked
else:
self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size])
- device = self._device = meta_data[0].device
- # TODO: check that all action_spec and reward spec match (issue #351)
+ devices = set()
+ for _meta_data in meta_data:
+ device = _meta_data.device
+ devices.add(device)
+ if self._device is None:
+ if len(devices) > 1:
+ raise ValueError(
+ f"The device wasn't passed to {type(self)}, but more than one device was found in the sub-environments. "
+ f"Please indicate a device to be used for collection."
+ )
+ device = list(devices)[0]
+ self._device = device
input_spec = []
for md in meta_data:
@@ -413,7 +423,7 @@ def _create_td(self) -> None:
*(unravel_key(("next", key)) for key in self._env_output_keys),
strict=False,
)
- self.shared_tensordict_parent = shared_tensordict_parent.to(self.device)
+ self.shared_tensordict_parent = shared_tensordict_parent
else:
# Multi-task: we share tensordict that *may* have different keys
shared_tensordict_parent = [
@@ -421,7 +431,7 @@ def _create_td(self) -> None:
*self._selected_keys,
*(unravel_key(("next", key)) for key in self._env_output_keys),
strict=False,
- ).to(self.device)
+ )
for tensordict in shared_tensordict_parent
]
shared_tensordict_parent = torch.stack(
@@ -440,13 +450,11 @@ def _create_td(self) -> None:
# Multi-task: we share tensordict that *may* have different keys
# LazyStacked already stores this so we don't need to do anything
self.shared_tensordicts = self.shared_tensordict_parent
- if self.device.type == "cpu":
+ if self.shared_tensordict_parent.device.type == "cpu":
if self._share_memory:
- for td in self.shared_tensordicts:
- td.share_memory_()
+ self.shared_tensordict_parent.share_memory_()
elif self._memmap:
- for td in self.shared_tensordicts:
- td.memmap_()
+ self.shared_tensordict_parent.memmap_()
else:
if self._share_memory:
self.shared_tensordict_parent.share_memory_()
@@ -483,7 +491,6 @@ def close(self) -> None:
self.__dict__["_input_spec"] = None
self.__dict__["_output_spec"] = None
self._properties_set = False
- self.event = None
self._shutdown_workers()
self.is_closed = True
@@ -507,11 +514,6 @@ def to(self, device: DEVICE_TYPING):
if device == self.device:
return self
self._device = device
- self.meta_data = (
- self.meta_data.to(device)
- if self._single_task
- else [meta_data.to(device) for meta_data in self.meta_data]
- )
if not self.is_closed:
warn(
"Casting an open environment to another device requires closing and re-opening it. "
@@ -543,7 +545,7 @@ def _start_workers(self) -> None:
for idx in range(_num_workers):
env = self.create_env_fn[idx](**self.create_env_kwargs[idx])
- self._envs.append(env.to(self.device))
+ self._envs.append(env)
self.is_closed = False
@_check_start
@@ -603,29 +605,39 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
if tensordict_.is_empty():
tensordict_ = None
else:
- # reset will do modifications in-place. We want the original
- # tensorict to be unchaned, so we clone it
- tensordict_ = tensordict_.clone(False)
+ env_device = _env.device
+ if env_device != self.device:
+ tensordict_ = tensordict_.to(env_device)
+ else:
+ tensordict_ = tensordict_.clone(False)
else:
tensordict_ = None
+
_td = _env.reset(tensordict=tensordict_, **kwargs)
self.shared_tensordicts[i].update_(
_td.select(*self._selected_reset_keys_filt, strict=False)
)
selected_output_keys = self._selected_reset_keys_filt
+ device = self.device
if self._single_task:
# select + clone creates 2 tds, but we can create one only
out = TensorDict(
- {}, batch_size=self.shared_tensordict_parent.shape, device=self.device
+ {}, batch_size=self.shared_tensordict_parent.shape, device=device
)
for key in selected_output_keys:
- _set_single_key(self.shared_tensordict_parent, out, key, clone=True)
- return out
+ _set_single_key(
+ self.shared_tensordict_parent, out, key, clone=True, device=device
+ )
else:
- return self.shared_tensordict_parent.select(
+ out = self.shared_tensordict_parent.select(
*selected_output_keys,
strict=False,
- ).clone()
+ )
+ if out.device == device:
+ out = out.clone()
+ else:
+ out = out.to(device, non_blocking=True)
+ return out
def _reset_proc_data(self, tensordict, tensordict_reset):
# since we call `reset` directly, all the postproc has been completed
@@ -643,19 +655,29 @@ def _step(
for i in range(self.num_workers):
# shared_tensordicts are locked, and we need to select the keys since we update in-place.
# There may be unexpected keys, such as "_reset", that we should comfortably ignore here.
- out_td = self._envs[i]._step(tensordict_in[i])
+ env_device = self._envs[i].device
+ if env_device != self.device:
+ data_in = tensordict_in[i].to(env_device, non_blocking=True)
+ else:
+ data_in = tensordict_in[i]
+ out_td = self._envs[i]._step(data_in)
next_td[i].update_(out_td.select(*self._env_output_keys, strict=False))
# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
+ device = self.device
if self._single_task:
out = TensorDict(
- {}, batch_size=self.shared_tensordict_parent.shape, device=self.device
+ {}, batch_size=self.shared_tensordict_parent.shape, device=device
)
for key in self._selected_step_keys:
- _set_single_key(next_td, out, key, clone=True)
+ _set_single_key(next_td, out, key, clone=True, device=device)
else:
# strict=False ensures that non-homogeneous keys are still there
- out = next_td.select(*self._selected_step_keys, strict=False).clone()
+ out = next_td.select(*self._selected_step_keys, strict=False)
+ if out.device == device:
+ out = out.clone()
+ else:
+ out = out.to(device, non_blocking=True)
return out
def __getattr__(self, attr: str) -> Any:
@@ -710,6 +732,32 @@ class ParallelEnv(_BatchedEnv):
"""
__doc__ += _BatchedEnv.__doc__
+ __doc__ += """
+
+ .. note::
+ The choice of the devices where ParallelEnv needs to be executed can
+ drastically influence its performance. The rule of thumbs is:
+
+ - If the base environment (backend, e.g., Gym) is executed on CPU, the
+ sub-environments should be executed on CPU and the data should be
+ passed via shared physical memory.
+ - If the base environment is (or can be) executed on CUDA, the sub-environments
+ should be placed on CUDA too.
+ - If a CUDA device is available and the policy is to be executed on CUDA,
+ the ParallelEnv device should be set to CUDA.
+
+ Therefore, supposing a CUDA device is available, we have the following scenarios:
+
+ >>> # The sub-envs are executed on CPU, but the policy is on GPU
+ >>> env = ParallelEnv(N, MyEnv(..., device="cpu"), device="cuda")
+ >>> # The sub-envs are executed on CUDA
+ >>> env = ParallelEnv(N, MyEnv(..., device="cuda"), device="cuda")
+ >>> # this will create the exact same environment
+ >>> env = ParallelEnv(N, MyEnv(..., device="cuda"))
+ >>> # If no cuda device is available
+ >>> env = ParallelEnv(N, MyEnv(..., device="cpu"))
+
+ """
def _start_workers(self) -> None:
from torchrl.envs.env_creator import EnvCreator
@@ -722,39 +770,39 @@ def _start_workers(self) -> None:
self.parent_channels = []
self._workers = []
- self._events = []
- if self.device.type == "cuda":
+ func = _run_worker_pipe_shared_mem
+ if self.shared_tensordict_parent.device.type == "cuda":
self.event = torch.cuda.Event()
else:
self.event = None
+ self._events = [ctx.Event() for _ in range(_num_workers)]
+ kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)]
with clear_mpi_env_vars():
for idx in range(_num_workers):
if self._verbose:
print(f"initiating worker {idx}")
# No certainty which module multiprocessing_context is
parent_pipe, child_pipe = ctx.Pipe()
- event = ctx.Event()
- self._events.append(event)
env_fun = self.create_env_fn[idx]
if not isinstance(env_fun, EnvCreator):
env_fun = CloudpickleWrapper(env_fun)
-
+ kwargs[idx].update(
+ {
+ "parent_pipe": parent_pipe,
+ "child_pipe": child_pipe,
+ "env_fun": env_fun,
+ "env_fun_kwargs": self.create_env_kwargs[idx],
+ "shared_tensordict": self.shared_tensordicts[idx],
+ "_selected_input_keys": self._selected_input_keys,
+ "_selected_reset_keys": self._selected_reset_keys,
+ "_selected_step_keys": self._selected_step_keys,
+ "has_lazy_inputs": self.has_lazy_inputs,
+ }
+ )
process = _ProcessNoWarn(
- target=_run_worker_pipe_shared_mem,
+ target=func,
num_threads=self.num_sub_threads,
- args=(
- parent_pipe,
- child_pipe,
- env_fun,
- self.create_env_kwargs[idx],
- self.device,
- event,
- self.shared_tensordicts[idx],
- self._selected_input_keys,
- self._selected_reset_keys,
- self._selected_step_keys,
- self.has_lazy_inputs,
- ),
+ kwargs=kwargs[idx],
)
process.daemon = True
process.start()
@@ -834,10 +882,16 @@ def step_and_maybe_reset(
# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
- tensordict.set("next", self.shared_tensordict_parent.get("next").clone())
- tensordict_ = self.shared_tensordict_parent.exclude(
- "next", *self.reset_keys
- ).clone()
+ next_td = self.shared_tensordict_parent.get("next")
+ tensordict_ = self.shared_tensordict_parent.exclude("next", *self.reset_keys)
+ device = self.device
+ if self.shared_tensordict_parent.device == device:
+ next_td = next_td.clone()
+ tensordict_ = tensordict_.clone()
+ else:
+ next_td = next_td.to(device, non_blocking=True)
+ tensordict_ = tensordict_.to(device, non_blocking=True)
+ tensordict.set("next", next_td)
return tensordict, tensordict_
@_check_start
@@ -880,15 +934,20 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
next_td = self.shared_tensordict_parent.get("next")
+ device = self.device
if self._single_task:
out = TensorDict(
- {}, batch_size=self.shared_tensordict_parent.shape, device=self.device
+ {}, batch_size=self.shared_tensordict_parent.shape, device=device
)
for key in self._selected_step_keys:
- _set_single_key(next_td, out, key, clone=True)
+ _set_single_key(next_td, out, key, clone=True, device=device)
else:
# strict=False ensures that non-homogeneous keys are still there
- out = next_td.select(*self._selected_step_keys, strict=False).clone()
+ out = next_td.select(*self._selected_step_keys, strict=False)
+ if out.device == device:
+ out = out.clone()
+ else:
+ out = out.to(device, non_blocking=True)
return out
@_check_start
@@ -944,19 +1003,26 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
event.clear()
selected_output_keys = self._selected_reset_keys_filt
+ device = self.device
if self._single_task:
# select + clone creates 2 tds, but we can create one only
out = TensorDict(
- {}, batch_size=self.shared_tensordict_parent.shape, device=self.device
+ {}, batch_size=self.shared_tensordict_parent.shape, device=device
)
for key in selected_output_keys:
- _set_single_key(self.shared_tensordict_parent, out, key, clone=True)
- return out
+ _set_single_key(
+ self.shared_tensordict_parent, out, key, clone=True, device=device
+ )
else:
- return self.shared_tensordict_parent.select(
+ out = self.shared_tensordict_parent.select(
*selected_output_keys,
strict=False,
- ).clone()
+ )
+ if out.device == device:
+ out = out.clone()
+ else:
+ out = out.to(device, non_blocking=True)
+ return out
@_check_start
def _shutdown_workers(self) -> None:
@@ -981,6 +1047,7 @@ def _shutdown_workers(self) -> None:
del self.parent_channels
self._cuda_events = None
self._events = None
+ self.event = None
@_check_start
def set_seed(
@@ -1063,7 +1130,6 @@ def _run_worker_pipe_shared_mem(
child_pipe: connection.Connection,
env_fun: Union[EnvBase, Callable],
env_fun_kwargs: Dict[str, Any],
- device: DEVICE_TYPING = None,
mp_event: mp.Event = None,
shared_tensordict: TensorDictBase = None,
_selected_input_keys=None,
@@ -1072,13 +1138,11 @@ def _run_worker_pipe_shared_mem(
has_lazy_inputs: bool = False,
verbose: bool = False,
) -> None:
- if device is None:
- device = torch.device("cpu")
+ device = shared_tensordict.device
if device.type == "cuda":
event = torch.cuda.Event()
else:
event = None
-
parent_pipe.close()
pid = os.getpid()
if not isinstance(env_fun, EnvBase):
@@ -1089,7 +1153,6 @@ def _run_worker_pipe_shared_mem(
"env_fun_kwargs must be empty if an environment is passed to a process."
)
env = env_fun
- env = env.to(device)
del env_fun
i = -1
@@ -1144,7 +1207,8 @@ def _run_worker_pipe_shared_mem(
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
- next_td = env._step(shared_tensordict)
+ env_input = shared_tensordict
+ next_td = env._step(env_input)
next_shared_tensordict.update_(next_td)
if event is not None:
event.record()
@@ -1155,7 +1219,8 @@ def _run_worker_pipe_shared_mem(
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
- td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False))
+ env_input = shared_tensordict
+ td, root_next_td = env.step_and_maybe_reset(env_input)
next_shared_tensordict.update_(td.get("next"))
root_shared_tensordict.update_(root_next_td)
if event is not None:
@@ -1208,3 +1273,10 @@ def _run_worker_pipe_shared_mem(
else:
# don't send env through pipe
child_pipe.send(("_".join([cmd, "done"]), None))
+
+
+def _update_cuda(t_dest, t_source):
+ if t_source is None:
+ return
+ t_dest.copy_(t_source.pin_memory(), non_blocking=True)
+ return
diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py
index f3a9f2aa469..5645785117d 100644
--- a/torchrl/envs/transforms/gym_transforms.py
+++ b/torchrl/envs/transforms/gym_transforms.py
@@ -148,7 +148,7 @@ def _step(self, tensordict, next_tensordict):
lives = self._get_lives()
end_of_life = torch.tensor(
- tensordict.get(self.lives_key) < lives, device=self.parent.device
+ tensordict.get(self.lives_key) > lives, device=self.parent.device
)
try:
done = next_tensordict.get(self.done_key)
diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py
index 240c1029486..79ee94318cb 100644
--- a/torchrl/envs/transforms/rlhf.py
+++ b/torchrl/envs/transforms/rlhf.py
@@ -5,18 +5,13 @@
from copy import copy, deepcopy
import torch
-from tensordict import TensorDictBase, unravel_key
-from tensordict.nn import (
- make_functional,
- ProbabilisticTensorDictModule,
- repopulate_module,
- TensorDictParams,
-)
+from tensordict import TensorDict, TensorDictBase, unravel_key
+from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
from tensordict.utils import is_seq_of_nested_key
from torch import nn
from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs.transforms.transforms import Transform
-from torchrl.envs.transforms.utils import _set_missing_tolerance
+from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param
class KLRewardTransform(Transform):
@@ -116,11 +111,10 @@ def __init__(
self.in_keys = self.in_keys + actor.in_keys
# check that the model has parameters
- params = make_functional(
- actor, keep_params=False, funs_to_decorate=["forward", "get_dist"]
- )
- self.functional_actor = deepcopy(actor)
- repopulate_module(actor, params)
+ params = TensorDict.from_module(actor)
+ with params.apply(_stateless_param, device="meta").to_module(actor):
+ # copy a stateless actor
+ self.__dict__["functional_actor"] = deepcopy(actor)
# we need to register these params as buffer to have `to` and similar
# methods work properly
@@ -170,9 +164,8 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.out_keys[0] != ("reward",) and self.parent is not None:
tensordict.set(self.out_keys[0], self.parent.reward_spec.zero())
return tensordict
- dist = self.functional_actor.get_dist(
- tensordict.clone(False), params=self.frozen_params
- )
+ with self.frozen_params.to_module(self.functional_actor):
+ dist = self.functional_actor.get_dist(tensordict.clone(False))
# get the log_prob given the original model
log_prob = dist.log_prob(action)
reward_key = self.in_keys[0]
diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py
index 3e6d597dffd..de8baf2e403 100644
--- a/torchrl/envs/transforms/transforms.py
+++ b/torchrl/envs/transforms/transforms.py
@@ -4692,6 +4692,7 @@ def __init__(
"""Initialises the transform. Filters out non-reward input keys and defines output keys."""
super().__init__(in_keys=in_keys, out_keys=out_keys)
self._reset_keys = reset_keys
+ self._keys_checked = False
@property
def in_keys(self):
@@ -4770,9 +4771,7 @@ def _check_match(reset_keys, in_keys):
return False
return True
- if len(reset_keys) != len(self.in_keys) or not _check_match(
- reset_keys, self.in_keys
- ):
+ if not _check_match(reset_keys, self.in_keys):
raise ValueError(
f"Could not match the env reset_keys {reset_keys} with the {type(self)} in_keys {self.in_keys}. "
f"Please provide the reset_keys manually. Reset entries can be "
@@ -4781,6 +4780,14 @@ def _check_match(reset_keys, in_keys):
)
reset_keys = copy(reset_keys)
self._reset_keys = reset_keys
+
+ if not self._keys_checked and len(reset_keys) != len(self.in_keys):
+ raise ValueError(
+ f"Could not match the env reset_keys {reset_keys} with the in_keys {self.in_keys}. "
+ "Please make sure that these have the same length."
+ )
+ self._keys_checked = True
+
return reset_keys
@reset_keys.setter
diff --git a/torchrl/envs/transforms/utils.py b/torchrl/envs/transforms/utils.py
index a99c22a87da..a1b30cb1aca 100644
--- a/torchrl/envs/transforms/utils.py
+++ b/torchrl/envs/transforms/utils.py
@@ -5,6 +5,7 @@
import torch
+from torch import nn
def check_finite(tensor: torch.Tensor):
@@ -59,3 +60,11 @@ def _get_reset(reset_key, tensordict):
if _reset.ndim > parent_td.ndim:
_reset = _reset.flatten(parent_td.ndim, -1).any(-1)
return _reset
+
+
+def _stateless_param(param):
+ is_param = isinstance(param, nn.Parameter)
+ param = param.data.to("meta")
+ if is_param:
+ return nn.Parameter(param, requires_grad=False)
+ return param
diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py
index 06eec73be97..9a2a71f24bd 100644
--- a/torchrl/envs/utils.py
+++ b/torchrl/envs/utils.py
@@ -237,7 +237,11 @@ def step_mdp(
def _set_single_key(
- source: TensorDictBase, dest: TensorDictBase, key: str | tuple, clone: bool = False
+ source: TensorDictBase,
+ dest: TensorDictBase,
+ key: str | tuple,
+ clone: bool = False,
+ device=None,
):
# key should be already unraveled
if isinstance(key, str):
@@ -253,7 +257,9 @@ def _set_single_key(
source = val
dest = new_val
else:
- if clone:
+ if device is not None and val.device != device:
+ val = val.to(device, non_blocking=True)
+ elif clone:
val = val.clone()
dest._set_str(k, val, inplace=False, validated=True)
# This is a temporary solution to understand if a key is heterogeneous
@@ -262,7 +268,7 @@ def _set_single_key(
if re.match(r"Found more than one unique shape in the tensors", str(err)):
# this is a het key
for s_td, d_td in zip(source.tensordicts, dest.tensordicts):
- _set_single_key(s_td, d_td, k, clone)
+ _set_single_key(s_td, d_td, k, clone=clone, device=device)
break
else:
raise err
diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py
index 1e5a557546a..bf81cfd5dfd 100644
--- a/torchrl/modules/tensordict_module/actors.py
+++ b/torchrl/modules/tensordict_module/actors.py
@@ -183,7 +183,7 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
Examples:
>>> import torch
>>> from tensordict import TensorDict
- >>> from tensordict.nn import TensorDictModule, make_functional
+ >>> from tensordict.nn import TensorDictModule
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules import ProbabilisticActor, NormalParamWrapper, TanhNormal
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
@@ -197,8 +197,9 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
... in_keys=["loc", "scale"],
... distribution_class=TanhNormal,
... )
- >>> params = make_functional(td_module)
- >>> td = td_module(td, params=params)
+ >>> params = TensorDict.from_module(td_module)
+ >>> with params.to_module(td_module):
+ ... td = td_module(td)
>>> td
TensorDict(
fields={
@@ -319,7 +320,6 @@ class ValueOperator(TensorDictModule):
Examples:
>>> import torch
>>> from tensordict import TensorDict
- >>> from tensordict.nn import make_functional
>>> from torch import nn
>>> from torchrl.data import UnboundedContinuousTensorSpec
>>> from torchrl.modules import ValueOperator
@@ -334,8 +334,9 @@ class ValueOperator(TensorDictModule):
>>> td_module = ValueOperator(
... in_keys=["observation", "action"], module=module
... )
- >>> params = make_functional(td_module)
- >>> td = td_module(td, params=params)
+ >>> params = TensorDict.from_module(td_module)
+ >>> with params.to_module(td_module):
+ ... td = td_module(td)
>>> print(td)
TensorDict(
fields={
@@ -792,7 +793,6 @@ class QValueHook:
Examples:
>>> import torch
>>> from tensordict import TensorDict
- >>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHotDiscreteTensorSpec
>>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor
@@ -878,7 +878,6 @@ class DistributionalQValueHook(QValueHook):
Examples:
>>> import torch
>>> from tensordict import TensorDict
- >>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHotDiscreteTensorSpec
>>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor
@@ -893,12 +892,13 @@ class DistributionalQValueHook(QValueHook):
... return self.linear(x).view(-1, nbins, 4).log_softmax(-2)
...
>>> module = CustomDistributionalQval()
- >>> params = make_functional(module)
+ >>> params = TensorDict.from_module(module)
>>> action_spec = OneHotDiscreteTensorSpec(4)
>>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins))
>>> module.register_forward_hook(hook)
>>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
- >>> qvalue_actor(td, params=params)
+ >>> with params.to_module(module):
+ ... qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
@@ -992,7 +992,6 @@ class QValueActor(SafeSequential):
Examples:
>>> import torch
>>> from tensordict import TensorDict
- >>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHotDiscreteTensorSpec
>>> from torchrl.modules.tensordict_module.actors import QValueActor
diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py
index c5f34a7774d..22786519681 100644
--- a/torchrl/modules/tensordict_module/common.py
+++ b/torchrl/modules/tensordict_module/common.py
@@ -138,7 +138,6 @@ class SafeModule(TensorDictModule):
Examples:
>>> import torch
>>> from tensordict import TensorDict
- >>> from tensordict.nn.functional_modules import make_functional
>>> from torchrl.data import UnboundedContinuousTensorSpec
>>> from torchrl.modules import TensorDictModule
>>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,])
@@ -150,8 +149,9 @@ class SafeModule(TensorDictModule):
... in_keys=["input", "hidden"],
... out_keys=["output"],
... )
- >>> params = make_functional(td_fmodule)
- >>> td_functional = td_fmodule(td.clone(), params=params)
+ >>> params = TensorDict.from_module(td_fmodule)
+ >>> with params.to_module(td_module):
+ ... td_functional = td_fmodule(td.clone())
>>> print(td_functional)
TensorDict(
fields={
diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py
index 46f71e2b3d6..5c8ae799061 100644
--- a/torchrl/modules/tensordict_module/exploration.py
+++ b/torchrl/modules/tensordict_module/exploration.py
@@ -110,7 +110,7 @@ def __init__(
self.register_buffer("eps_init", torch.tensor([eps_init]))
self.register_buffer("eps_end", torch.tensor([eps_end]))
self.annealing_num_steps = annealing_num_steps
- self.register_buffer("eps", torch.tensor([eps_init]))
+ self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32))
if spec is not None:
if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
@@ -259,7 +259,7 @@ def __init__(
if self.eps_end > self.eps_init:
raise RuntimeError("eps should decrease over time or be constant")
self.annealing_num_steps = annealing_num_steps
- self.register_buffer("eps", torch.tensor([eps_init]))
+ self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32))
self.action_key = action_key
self.action_mask_key = action_mask_key
if spec is not None:
@@ -405,7 +405,7 @@ def __init__(
self.annealing_num_steps = annealing_num_steps
self.register_buffer("mean", torch.tensor([mean]))
self.register_buffer("std", torch.tensor([std]))
- self.register_buffer("sigma", torch.tensor([sigma_init]))
+ self.register_buffer("sigma", torch.tensor([sigma_init], dtype=torch.float32))
self.action_key = action_key
self.out_keys = list(self.td_module.out_keys)
if action_key not in self.out_keys:
@@ -613,7 +613,7 @@ def __init__(
f"got eps_init={eps_init} and eps_end={eps_end}"
)
self.annealing_num_steps = annealing_num_steps
- self.register_buffer("eps", torch.tensor([eps_init]))
+ self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32))
self.out_keys = list(self.td_module.out_keys) + self.ou.out_keys
self.is_init_key = is_init_key
noise_key = self.ou.noise_key
diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py
index 71167c5106f..28f721ba6a1 100644
--- a/torchrl/modules/tensordict_module/sequence.py
+++ b/torchrl/modules/tensordict_module/sequence.py
@@ -33,7 +33,6 @@ class SafeSequential(TensorDictSequential, SafeModule):
Examples:
>>> import torch
>>> from tensordict import TensorDict
- >>> from tensordict.nn.functional_modules import make_functional
>>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
>>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamWrapper
>>> from torchrl.modules.tensordict_module import SafeProbabilisticModule
@@ -58,8 +57,9 @@ class SafeSequential(TensorDictSequential, SafeModule):
... out_keys=["output"],
... )
>>> td_module = SafeSequential(td_module1, td_module2)
- >>> params = make_functional(td_module)
- >>> td_module(td, params=params)
+ >>> params = TensorDict.from_module(td_module)
+ >>> with params.to_module(td_module):
+ ... td_module(td)
>>> print(td)
TensorDict(
fields={
diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py
index bb7b9014f0d..4384ccef282 100644
--- a/torchrl/objectives/a2c.py
+++ b/torchrl/objectives/a2c.py
@@ -3,11 +3,17 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
+from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple
import torch
-from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
+from tensordict.nn import (
+ dispatch,
+ ProbabilisticTensorDictSequential,
+ repopulate_module,
+ TensorDictModule,
+)
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from torch import distributions as d
@@ -20,7 +26,13 @@
distance_loss,
ValueEstimators,
)
-from torchrl.objectives.value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator
+from torchrl.objectives.value import (
+ GAE,
+ TD0Estimator,
+ TD1Estimator,
+ TDLambdaEstimator,
+ VTrace,
+)
class A2CLoss(LossModule):
@@ -202,6 +214,7 @@ class _AcceptedKeys:
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"
+ sample_log_prob: NestedKey = "sample_log_prob"
default_keys = _AcceptedKeys()
default_value_estimator: ValueEstimators = ValueEstimators.GAE
@@ -314,8 +327,8 @@ def _log_probs(
f"tensordict stored {self.tensor_keys.action} require grad."
)
tensordict_clone = tensordict.select(*self.actor.in_keys).clone()
-
- dist = self.actor.get_dist(tensordict_clone, params=self.actor_params)
+ with self.actor_params.to_module(self.actor):
+ dist = self.actor.get_dist(tensordict_clone)
log_prob = dist.log_prob(action)
log_prob = log_prob.unsqueeze(-1)
return log_prob, dist
@@ -326,10 +339,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
# overhead that we could easily reduce.
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(*self.critic.in_keys)
- state_value = self.critic(
- tensordict_select,
- params=self.critic_params,
- ).get(self.tensor_keys.value)
+ with self.critic_params.to_module(self.critic):
+ state_value = self.critic(
+ tensordict_select,
+ ).get(self.tensor_keys.value)
loss_value = distance_loss(
target_return,
state_value,
@@ -361,6 +374,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
target_params=self.target_critic_params,
)
advantage = tensordict.get(self.tensor_keys.advantage)
+ assert not advantage.requires_grad
log_probs, dist = self._log_probs(tensordict)
loss = -(log_probs * advantage)
td_out = TensorDict({"loss_objective": loss.mean()}, [])
@@ -379,6 +393,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
self.value_type = value_type
hp = dict(default_value_kwargs(value_type))
hp.update(hyperparams)
+
if hasattr(self, "gamma"):
hp["gamma"] = self.gamma
if value_type == ValueEstimators.TD1:
@@ -389,6 +404,14 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
self._value_estimator = GAE(value_network=self.critic, **hp)
elif value_type == ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp)
+ elif value_type == ValueEstimators.VTrace:
+ # VTrace currently does not support functional call on the actor
+ actor_with_params = repopulate_module(
+ deepcopy(self.actor), self.actor_params
+ )
+ self._value_estimator = VTrace(
+ value_network=self.critic, actor_network=actor_with_params, **hp
+ )
else:
raise NotImplementedError(f"Unknown value type {value_type}")
@@ -399,5 +422,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
"terminated": self.tensor_keys.terminated,
+ "sample_log_prob": self.tensor_keys.sample_log_prob,
}
self._value_estimator.set_keys(**tensor_keys)
diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py
index bdccbda3808..00ba8cf456a 100644
--- a/torchrl/objectives/common.py
+++ b/torchrl/objectives/common.py
@@ -10,15 +10,10 @@
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple
-from tensordict import TensorDictBase
-
-from tensordict.nn import (
- make_functional,
- repopulate_module,
- TensorDictModule,
- TensorDictModuleBase,
- TensorDictParams,
-)
+import torch
+from tensordict import TensorDict, TensorDictBase
+
+from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
from torch import nn
from torch.nn import Parameter
@@ -87,7 +82,7 @@ class _AcceptedKeys:
pass
default_value_estimator: ValueEstimators = None
- SEP = "_sep_"
+ SEP = "."
TARGET_NET_WARNING = (
"No target network updater has been associated "
"with this loss module, but target parameters have been found. "
@@ -138,7 +133,7 @@ def set_keys(self, **kwargs) -> None:
"""
for key, value in kwargs.items():
if key not in self._AcceptedKeys.__dict__:
- raise ValueError(f"{key} it not an accepted tensordict key")
+ raise ValueError(f"{key} is not an accepted tensordict key")
if value is not None:
setattr(self.tensor_keys, key, value)
else:
@@ -178,21 +173,15 @@ def convert_to_functional(
expand_dim: Optional[int] = None,
create_target_params: bool = False,
compare_against: Optional[List[Parameter]] = None,
- funs_to_decorate=None,
+ **kwargs,
) -> None:
"""Converts a module to functional to be used in the loss.
Args:
module (TensorDictModule or compatible): a stateful tensordict module.
- This module will be made functional, yet still stateful, meaning
- that it will be callable with the following alternative signatures:
-
- >>> module(tensordict)
- >>> module(tensordict, params=params)
-
- ``params`` is a :class:`tensordict.TensorDict` instance with parameters
- stuctured as the output of :func:`tensordict.nn.make_functional`
- is.
+ Parameters from this module will be isolated in the `_params`
+ attribute and a stateless version of the module will be registed
+ under the `module_name` attribute.
module_name (str): name where the module will be found.
The parameters of the module will be found under ``loss_module._params``
whereas the module will be found under ``loss_module.``.
@@ -223,45 +212,27 @@ def convert_to_functional(
the resulting parameters will be a detached version of the
original parameters. If ``None``, the resulting parameters
will carry gradients as expected.
- funs_to_decorate (list of str, optional): if provided, the list of
- methods of ``module`` to make functional, ie the list of
- methods that will accept the ``params`` keyword argument.
"""
- if funs_to_decorate is None:
- funs_to_decorate = ["forward"]
+ if kwargs.pop("funs_to_decorate", None) is not None:
+ warnings.warn(
+ "funs_to_decorate is without effect with the new objective API.",
+ category=DeprecationWarning,
+ )
+ if kwargs:
+ raise TypeError(f"Unrecognised keyword arguments {list(kwargs.keys())}")
# To make it robust to device casting, we must register list of
# tensors as lazy calls to `getattr(self, name_of_tensor)`.
# Otherwise, casting the module to a device will keep old references
# to uncast tensors
sep = self.SEP
- params = make_functional(module, funs_to_decorate=funs_to_decorate)
- # buffer_names = next(itertools.islice(zip(*module.named_buffers()), 1))
- buffer_names = []
- for key, value in params.items(True, True):
- # we just consider all that is not param as a buffer, but if the module has been made
- # functional and the params have been replaced this may break
- if not isinstance(value, nn.Parameter):
- key = sep.join(key) if not isinstance(key, str) else key
- buffer_names.append(key)
- functional_module = deepcopy(module)
- repopulate_module(module, params)
-
- params_and_buffers = params
- # we transform the buffers in params to make sure they follow the device
- # as tensor = nn.Parameter(tensor) keeps its identity when moved to another device
-
- # separate params and buffers
- params_and_buffers = TensorDictParams(params_and_buffers, no_convert=True)
- # sanity check
- for key in params_and_buffers.keys(True):
+ params = TensorDict.from_module(module, as_module=True)
+
+ for key in params.keys(True):
if sep in key:
raise KeyError(
f"The key {key} contains the '_sep_' pattern which is prohibited. Consider renaming the parameter / buffer."
)
- params_and_buffers_flat = params_and_buffers.flatten_keys(sep)
- buffers = params_and_buffers_flat.select(*buffer_names)
- params = params_and_buffers_flat.exclude(*buffer_names)
if compare_against is not None:
compare_against = set(compare_against)
else:
@@ -273,6 +244,9 @@ def convert_to_functional(
# For buffers, a cloned expansion (or equivalently a repeat) is returned.
def _compare_and_expand(param):
+ if not isinstance(param, nn.Parameter):
+ buffer = param.expand(expand_dim, *param.shape).clone()
+ return buffer
if param in compare_against:
expanded_param = param.data.expand(expand_dim, *param.shape)
# the expanded parameter must be sent to device when to()
@@ -287,45 +261,40 @@ def _compare_and_expand(param):
)
return p_out
- params = params.apply(
- _compare_and_expand, batch_size=[expand_dim, *params.shape]
- )
-
- buffers = buffers.apply(
- lambda buffer: buffer.expand(expand_dim, *buffer.shape).clone(),
- batch_size=[expand_dim, *buffers.shape],
+ params = TensorDictParams(
+ params.apply(
+ _compare_and_expand, batch_size=[expand_dim, *params.shape]
+ ),
+ no_convert=True,
)
- params_and_buffers.update(params.unflatten_keys(sep))
- params_and_buffers.update(buffers.unflatten_keys(sep))
- params_and_buffers.batch_size = params.batch_size
-
- # self.params_to_map = params_to_map
-
param_name = module_name + "_params"
prev_set_params = set(self.parameters())
# register parameters and buffers
- for key, parameter in list(params_and_buffers.items(True, True)):
+ for key, parameter in list(params.items(True, True)):
if parameter not in prev_set_params:
pass
elif compare_against is not None and parameter in compare_against:
- params_and_buffers.set(key, parameter.data)
+ params.set(key, parameter.data)
- setattr(self, param_name, params_and_buffers)
+ setattr(self, param_name, params)
- # set the functional module
- setattr(self, module_name, functional_module)
+ # set the functional module: we need to convert the params to non-differentiable params
+ # otherwise they will appear twice in parameters
+ with params.apply(
+ self._make_meta_params, device=torch.device("meta")
+ ).to_module(module):
+ # avoid buffers and params being exposed
+ self.__dict__[module_name] = deepcopy(module)
name_params_target = "target_" + module_name
if create_target_params:
# if create_target_params:
# we create a TensorDictParams to keep the target params as Buffer instances
target_params = TensorDictParams(
- params_and_buffers.apply(
- _make_target_param(clone=create_target_params)
- ),
+ params.apply(_make_target_param(clone=create_target_params)),
no_convert=True,
)
setattr(self, name_params_target + "_params", target_params)
@@ -447,6 +416,10 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
raise NotImplementedError(
f"Value type {value_type} it not implemented for loss {type(self)}."
)
+ elif value_type == ValueEstimators.VTrace:
+ raise NotImplementedError(
+ f"Value type {value_type} it not implemented for loss {type(self)}."
+ )
elif value_type == ValueEstimators.TDLambda:
raise NotImplementedError(
f"Value type {value_type} it not implemented for loss {type(self)}."
@@ -454,86 +427,18 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
else:
raise NotImplementedError(f"Unknown value type {value_type}")
- # def _apply(self, fn, recurse=True):
- # """Modifies torch.nn.Module._apply to work with Buffer class."""
- # if recurse:
- # for module in self.children():
- # module._apply(fn)
- #
- # def compute_should_use_set_data(tensor, tensor_applied):
- # if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
- # # If the new tensor has compatible tensor type as the existing tensor,
- # # the current behavior is to change the tensor in-place using `.data =`,
- # # and the future behavior is to overwrite the existing tensor. However,
- # # changing the current behavior is a BC-breaking change, and we want it
- # # to happen in future releases. So for now we introduce the
- # # `torch.__future__.get_overwrite_module_params_on_conversion()`
- # # global flag to let the user control whether they want the future
- # # behavior of overwriting the existing tensor or not.
- # return not torch.__future__.get_overwrite_module_params_on_conversion()
- # else:
- # return False
- #
- # for key, param in self._parameters.items():
- # if param is None:
- # continue
- # # Tensors stored in modules are graph leaves, and we don't want to
- # # track autograd history of `param_applied`, so we have to use
- # # `with torch.no_grad():`
- # with torch.no_grad():
- # param_applied = fn(param)
- # should_use_set_data = compute_should_use_set_data(param, param_applied)
- # if should_use_set_data:
- # param.data = param_applied
- # out_param = param
- # else:
- # assert isinstance(param, Parameter)
- # assert param.is_leaf
- # out_param = Parameter(param_applied, param.requires_grad)
- # self._parameters[key] = out_param
- #
- # if param.grad is not None:
- # with torch.no_grad():
- # grad_applied = fn(param.grad)
- # should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
- # if should_use_set_data:
- # assert out_param.grad is not None
- # out_param.grad.data = grad_applied
- # else:
- # assert param.grad.is_leaf
- # out_param.grad = grad_applied.requires_grad_(param.grad.requires_grad)
- #
- # for key, buffer in self._buffers.items():
- # if buffer is None:
- # continue
- # # Tensors stored in modules are graph leaves, and we don't want to
- # # track autograd history of `buffer_applied`, so we have to use
- # # `with torch.no_grad():`
- # with torch.no_grad():
- # buffer_applied = fn(buffer)
- # should_use_set_data = compute_should_use_set_data(buffer, buffer_applied)
- # if should_use_set_data:
- # buffer.data = buffer_applied
- # out_buffer = buffer
- # else:
- # assert isinstance(buffer, Buffer)
- # assert buffer.is_leaf
- # out_buffer = Buffer(buffer_applied, buffer.requires_grad)
- # self._buffers[key] = out_buffer
- #
- # if buffer.grad is not None:
- # with torch.no_grad():
- # grad_applied = fn(buffer.grad)
- # should_use_set_data = compute_should_use_set_data(buffer.grad, grad_applied)
- # if should_use_set_data:
- # assert out_buffer.grad is not None
- # out_buffer.grad.data = grad_applied
- # else:
- # assert buffer.grad.is_leaf
- # out_buffer.grad = grad_applied.requires_grad_(buffer.grad.requires_grad)
-
return self
+ @staticmethod
+ def _make_meta_params(param):
+ is_param = isinstance(param, nn.Parameter)
+
+ pd = param.detach().to("meta")
+
+ if is_param:
+ pd = nn.Parameter(pd, requires_grad=False)
+ return pd
+
class _make_target_param:
def __init__(self, clone):
diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py
index 86d5461b15f..9b2c5eace7a 100644
--- a/torchrl/objectives/cql.py
+++ b/torchrl/objectives/cql.py
@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import math
import warnings
+from copy import deepcopy
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -26,6 +27,7 @@
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
+ _vmap_func,
default_value_kwargs,
distance_loss,
ValueEstimators,
@@ -33,18 +35,6 @@
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
-try:
- try:
- from torch import vmap
- except ImportError:
- from functorch import vmap
-
- _has_functorch = True
- err = ""
-except ImportError as err:
- _has_functorch = False
- FUNCTORCH_ERROR = err
-
class CQLLoss(LossModule):
"""TorchRL implementation of the continuous CQL loss.
@@ -281,8 +271,6 @@ def __init__(
lagrange_thresh: float = 0.0,
) -> None:
self._out_keys = None
- if not _has_functorch:
- raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR
super().__init__()
# Actor
@@ -291,7 +279,6 @@ def __init__(
actor_network,
"actor_network",
create_target_params=self.delay_actor,
- funs_to_decorate=["forward", "get_dist"],
)
# Q value
@@ -362,8 +349,8 @@ def __init__(
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)
- self._vmap_qvalue_networkN0 = vmap(self.qvalue_network, (None, 0))
- self._vmap_qvalue_network00 = vmap(self.qvalue_network)
+ self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))
+ self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
@property
def target_entropy(self):
@@ -590,8 +577,10 @@ def _get_policy_actions(self, data, actor_params, num_actions=10):
batch_size=batch_size,
)
with torch.no_grad():
- with set_exploration_type(ExplorationType.RANDOM):
- dist = self.actor_network.get_dist(tensordict, params=actor_params)
+ with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module(
+ self.actor_network
+ ):
+ dist = self.actor_network.get_dist(tensordict)
action = dist.rsample()
tensordict.set(self.tensor_keys.action, action)
sample_log_prob = dist.log_prob(action)
@@ -607,11 +596,11 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
tensordict = tensordict.clone(False)
# get actions and log-probs
with torch.no_grad():
- with set_exploration_type(ExplorationType.RANDOM):
+ with set_exploration_type(ExplorationType.RANDOM), actor_params.to_module(
+ self.actor_network
+ ):
next_tensordict = tensordict.get("next").clone(False)
- next_dist = self.actor_network.get_dist(
- next_tensordict, params=actor_params
- )
+ next_dist = self.actor_network.get_dist(next_tensordict)
next_action = next_dist.rsample()
next_tensordict.set(self.tensor_keys.action, next_action)
next_sample_log_prob = next_dist.log_prob(next_action)
@@ -1066,7 +1055,8 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
self.value_type = value_type
# we will take care of computing the next value inside this module
- value_net = self.value_network
+ value_net = deepcopy(self.value_network)
+ self.value_network_params.to_module(value_net, return_swap=False)
hp = dict(default_value_kwargs(value_type))
hp.update(hyperparams)
@@ -1117,10 +1107,8 @@ def value_loss(
tensordict: TensorDictBase,
) -> Tuple[torch.Tensor, dict]:
td_copy = tensordict.clone(False)
- self.value_network(
- td_copy,
- params=self.value_network_params,
- )
+ with self.value_network_params.to_module(self.value_network):
+ self.value_network(td_copy)
action = tensordict.get(self.tensor_keys.action)
pred_val = td_copy.get(self.tensor_keys.action_value)
@@ -1137,8 +1125,7 @@ def value_loss(
# calculate target value
with torch.no_grad():
target_value = self.value_estimator.value_estimate(
- td_copy,
- target_params=self._cached_detached_target_value_params,
+ td_copy, params=self._cached_detached_target_value_params
).squeeze(-1)
with torch.no_grad():
diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py
index 1795f785716..3b4debe6259 100644
--- a/torchrl/objectives/ddpg.py
+++ b/torchrl/objectives/ddpg.py
@@ -11,7 +11,7 @@
from typing import Tuple
import torch
-from tensordict.nn import dispatch, make_functional, repopulate_module, TensorDictModule
+from tensordict.nn import dispatch, TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey, unravel_key
@@ -197,10 +197,10 @@ def __init__(
self.delay_value = delay_value
actor_critic = ActorCriticWrapper(actor_network, value_network)
- params = make_functional(actor_critic)
- self.actor_critic = deepcopy(actor_critic)
- repopulate_module(actor_network, params["module", "0"])
- repopulate_module(value_network, params["module", "1"])
+ params = TensorDict.from_module(actor_critic)
+ params_meta = params.apply(self._make_meta_params, device=torch.device("meta"))
+ with params_meta.to_module(actor_critic):
+ self.__dict__["actor_critic"] = deepcopy(actor_critic)
self.convert_to_functional(
actor_network,
@@ -295,14 +295,10 @@ def loss_actor(
td_copy = tensordict.select(
*self.actor_in_keys, *self.value_exclusive_keys
).detach()
- td_copy = self.actor_network(
- td_copy,
- params=self.actor_network_params,
- )
- td_copy = self.value_network(
- td_copy,
- params=self._cached_detached_value_params,
- )
+ with self.actor_network_params.to_module(self.actor_network):
+ td_copy = self.actor_network(td_copy)
+ with self._cached_detached_value_params.to_module(self.value_network):
+ td_copy = self.value_network(td_copy)
loss_actor = -td_copy.get(self.tensor_keys.state_action_value)
metadata = {}
return loss_actor.mean(), metadata
@@ -313,10 +309,8 @@ def loss_value(
) -> Tuple[torch.Tensor, dict]:
# value loss
td_copy = tensordict.select(*self.value_network.in_keys).detach()
- self.value_network(
- td_copy,
- params=self.value_network_params,
- )
+ with self.value_network_params.to_module(self.value_network):
+ self.value_network(td_copy)
pred_val = td_copy.get(self.tensor_keys.state_action_value).squeeze(-1)
target_value = self.value_estimator.value_estimate(
diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py
index db3cf633aef..52339d583dd 100644
--- a/torchrl/objectives/decision_transformer.py
+++ b/torchrl/objectives/decision_transformer.py
@@ -88,7 +88,6 @@ def __init__(
actor_network,
"actor_network",
create_target_params=False,
- funs_to_decorate=["forward", "get_dist"],
)
try:
device = next(self.parameters()).device
@@ -208,9 +207,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if target_actions.requires_grad:
raise RuntimeError("target action cannot be part of a graph.")
- action_dist = self.actor_network.get_dist(
- tensordict, params=self.actor_network_params
- )
+ with self.actor_network_params.to_module(self.actor_network):
+ action_dist = self.actor_network.get_dist(tensordict)
log_likelihood = action_dist.log_prob(target_actions).mean()
entropy = self.get_entropy_bonus(action_dist).mean()
@@ -319,9 +317,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)
target_actions = tensordict.get(self.tensor_keys.action_target).detach()
- pred_actions = self.actor_network(
- tensordict, params=self.actor_network_params
- ).get(self.tensor_keys.action_pred)
+ with self.actor_network_params.to_module(self.actor_network):
+ pred_actions = self.actor_network(tensordict).get(
+ self.tensor_keys.action_pred
+ )
loss = distance_loss(
pred_actions,
target_actions,
diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py
index 696efbdc650..947a7574967 100644
--- a/torchrl/objectives/deprecated.py
+++ b/torchrl/objectives/deprecated.py
@@ -21,21 +21,13 @@
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators
from torchrl.objectives.common import LossModule
-from torchrl.objectives.utils import _cache_values, _GAMMA_LMBDA_DEPREC_WARNING
+from torchrl.objectives.utils import (
+ _cache_values,
+ _GAMMA_LMBDA_DEPREC_WARNING,
+ _vmap_func,
+)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
-try:
- try:
- from torch import vmap
- except ImportError:
- from functorch import vmap
-
- FUNCTORCH_ERR = ""
- _has_functorch = True
-except ImportError as err:
- FUNCTORCH_ERR = str(err)
- _has_functorch = False
-
class REDQLoss_deprecated(LossModule):
"""REDQ Loss module.
@@ -149,8 +141,6 @@ def __init__(
):
self._in_keys = None
self._out_keys = None
- if not _has_functorch:
- raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR
super().__init__()
self._set_deprecated_ctor_keys(priority_key=priority_key)
@@ -208,7 +198,7 @@ def __init__(
self.target_entropy_buffer = None
self.gSDE = gSDE
- self._vmap_qvalue_networkN0 = vmap(self.qvalue_network, (None, 0))
+ self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
@@ -328,11 +318,10 @@ def _cached_detach_qvalue_network_params(self):
def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
obs_keys = self.actor_network.in_keys
tensordict_clone = tensordict.select(*obs_keys)
- with set_exploration_type(ExplorationType.RANDOM):
- self.actor_network(
- tensordict_clone,
- params=self.actor_network_params,
- )
+ with set_exploration_type(
+ ExplorationType.RANDOM
+ ), self.actor_network_params.to_module(self.actor_network):
+ self.actor_network(tensordict_clone)
tensordict_expand = self._vmap_qvalue_networkN0(
tensordict_clone.select(*self.qvalue_network.in_keys),
@@ -364,11 +353,10 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor:
) # next_observation ->
# observation
# select pseudo-action
- with set_exploration_type(ExplorationType.RANDOM):
- self.actor_network(
- next_td,
- params=self.target_actor_network_params,
- )
+ with set_exploration_type(
+ ExplorationType.RANDOM
+ ), self.target_actor_network_params.to_module(self.actor_network):
+ self.actor_network(next_td)
sample_log_prob = next_td.get("sample_log_prob")
# get q-values
next_td = self._vmap_qvalue_networkN0(
diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py
index a9dd4314a35..6af8c165c51 100644
--- a/torchrl/objectives/dqn.py
+++ b/torchrl/objectives/dqn.py
@@ -289,10 +289,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
"""
td_copy = tensordict.clone(False)
- self.value_network(
- td_copy,
- params=self.value_network_params,
- )
+ with self.value_network_params.to_module(self.value_network):
+ self.value_network(td_copy)
action = tensordict.get(self.tensor_keys.action)
pred_val = td_copy.get(self.tensor_keys.action_value)
@@ -462,10 +460,10 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict:
# Calculate current state probabilities (online network noise already
# sampled)
td_clone = tensordict.clone()
- self.value_network(
- td_clone,
- params=self.value_network_params,
- ) # Log probabilities log p(s_t, ·; θonline)
+ with self.value_network_params.to_module(self.value_network):
+ self.value_network(
+ td_clone,
+ ) # Log probabilities log p(s_t, ·; θonline)
action_log_softmax = td_clone.get(self.tensor_keys.action_value)
if self.action_space == "categorical":
@@ -475,24 +473,18 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict:
action, action_log_softmax, batch_size, atoms
)
- with torch.no_grad():
+ with torch.no_grad(), self.value_network_params.to_module(self.value_network):
# Calculate nth next state probabilities
next_td = step_mdp(tensordict)
- self.value_network(
- next_td,
- params=self.value_network_params,
- ) # Probabilities p(s_t+n, ·; θonline)
+ self.value_network(next_td) # Probabilities p(s_t+n, ·; θonline)
next_td_action = next_td.get(self.tensor_keys.action)
if self.action_space == "categorical":
argmax_indices_ns = next_td_action.squeeze(-1)
else:
argmax_indices_ns = next_td_action.argmax(-1) # one-hot encoding
-
- self.value_network(
- next_td,
- params=self.target_value_network_params,
- ) # Probabilities p(s_t+n, ·; θtarget)
+ with self.target_value_network_params.to_module(self.value_network):
+ self.value_network(next_td) # Probabilities p(s_t+n, ·; θtarget)
pns = next_td.get(self.tensor_keys.action_value).exp()
# Double-Q probabilities
# p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)
diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py
index 8d7fa0b53c1..a741d83ba13 100644
--- a/torchrl/objectives/iql.py
+++ b/torchrl/objectives/iql.py
@@ -14,26 +14,16 @@
from torchrl.modules import ProbabilisticActor
from torchrl.objectives.common import LossModule
+
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_WARNING,
+ _vmap_func,
default_value_kwargs,
distance_loss,
ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
-try:
- try:
- from torch import vmap
- except ImportError:
- from functorch import vmap
-
- _has_functorch = True
- err = ""
-except ImportError as err:
- _has_functorch = False
- FUNCTORCH_ERROR = err
-
class IQLLoss(LossModule):
r"""TorchRL implementation of the IQL loss.
@@ -248,8 +238,6 @@ def __init__(
) -> None:
self._in_keys = None
self._out_keys = None
- if not _has_functorch:
- raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR
super().__init__()
self._set_deprecated_ctor_keys(priority=priority_key)
@@ -262,7 +250,6 @@ def __init__(
actor_network,
"actor_network",
create_target_params=False,
- funs_to_decorate=["forward", "get_dist"],
)
if separate_losses:
# we want to make sure there are no duplicates in the params: the
@@ -299,7 +286,7 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
- self._vmap_qvalue_networkN0 = vmap(self.qvalue_network, (None, 0))
+ self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))
@property
def device(self) -> torch.device:
@@ -393,10 +380,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
# KL loss
- dist = self.actor_network.get_dist(
- tensordict,
- params=self.actor_network_params,
- )
+ with self.actor_network_params.to_module(self.actor_network):
+ dist = self.actor_network.get_dist(tensordict)
log_prob = dist.log_prob(tensordict[self.tensor_keys.action])
@@ -412,10 +397,8 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
# state value
with torch.no_grad():
td_copy = tensordict.select(*self.value_network.in_keys).detach()
- self.value_network(
- td_copy,
- params=self.value_network_params,
- )
+ with self.value_network_params.to_module(self.value_network):
+ self.value_network(td_copy)
value = td_copy.get(self.tensor_keys.value).squeeze(
-1
) # assert has no gradient
@@ -434,10 +417,8 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
# state value
td_copy = tensordict.select(*self.value_network.in_keys)
- self.value_network(
- td_copy,
- params=self.value_network_params,
- )
+ with self.value_network_params.to_module(self.value_network):
+ self.value_network(td_copy)
value = td_copy.get(self.tensor_keys.value).squeeze(-1)
value_loss = self.loss_value_diff(min_q - value, self.expectile).mean()
return value_loss, {}
diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py
index 00106571744..23947696c9f 100644
--- a/torchrl/objectives/multiagent/qmixer.py
+++ b/torchrl/objectives/multiagent/qmixer.py
@@ -12,7 +12,7 @@
import torch
from tensordict import TensorDict, TensorDictBase
-from tensordict.nn import dispatch, make_functional, repopulate_module, TensorDictModule
+from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey
from torch import nn
@@ -212,10 +212,11 @@ def __init__(
)
global_value_network = SafeSequential(local_value_network, mixer_network)
- params = make_functional(global_value_network)
- self.global_value_network = deepcopy(global_value_network)
- repopulate_module(local_value_network, params["module", "0"])
- repopulate_module(mixer_network, params["module", "1"])
+ params = TensorDict.from_module(global_value_network)
+ with params.apply(
+ self._make_meta_params, device=torch.device("meta")
+ ).to_module(global_value_network):
+ self.__dict__["global_value_network"] = deepcopy(global_value_network)
self.convert_to_functional(
local_value_network,
@@ -326,10 +327,10 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
td_copy = tensordict.clone(False)
- self.local_value_network(
- td_copy,
- params=self.local_value_network_params,
- )
+ with self.local_value_network_params.to_module(self.local_value_network):
+ self.local_value_network(
+ td_copy,
+ )
action = tensordict.get(self.tensor_keys.action)
pred_val = td_copy.get(
@@ -346,7 +347,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
pred_val_index = (pred_val * action).sum(-1, keepdim=True)
td_copy.set(self.tensor_keys.local_value, pred_val_index) # [*B, n_agents, 1]
- self.mixer_network(td_copy, params=self.mixer_network_params)
+ with self.mixer_network_params.to_module(self.mixer_network):
+ self.mixer_network(td_copy)
pred_val_index = td_copy.get(self.tensor_keys.global_value).squeeze(-1)
# [*B] this is global and shared among the agents as will be the target
diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py
index e576ca33c1c..11b5fef2ae7 100644
--- a/torchrl/objectives/ppo.py
+++ b/torchrl/objectives/ppo.py
@@ -4,11 +4,17 @@
# LICENSE file in the root directory of this source tree.
import math
import warnings
+from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple
import torch
-from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
+from tensordict.nn import (
+ dispatch,
+ ProbabilisticTensorDictSequential,
+ repopulate_module,
+ TensorDictModule,
+)
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from torch import distributions as d
@@ -22,7 +28,7 @@
)
from .common import LossModule
-from .value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator
+from .value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator, VTrace
class PPOLoss(LossModule):
@@ -271,9 +277,7 @@ def __init__(
self._in_keys = None
self._out_keys = None
super().__init__()
- self.convert_to_functional(
- actor, "actor", funs_to_decorate=["forward", "get_dist"]
- )
+ self.convert_to_functional(actor, "actor")
if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
@@ -374,7 +378,8 @@ def _log_weight(
f"tensordict stored {self.tensor_keys.action} requires grad."
)
- dist = self.actor.get_dist(tensordict, params=self.actor_params)
+ with self.actor_params.to_module(self.actor):
+ dist = self.actor.get_dist(tensordict)
log_prob = dist.log_prob(action)
prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
@@ -400,10 +405,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
f"can be used for the value loss."
)
- state_value_td = self.critic(
- tensordict,
- params=self.critic_params,
- )
+ with self.critic_params.to_module(self.critic):
+ state_value_td = self.critic(tensordict)
try:
state_value = state_value_td.get(self.tensor_keys.value)
@@ -469,6 +472,14 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
self._value_estimator = GAE(value_network=self.critic, **hp)
elif value_type == ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp)
+ elif value_type == ValueEstimators.VTrace:
+ # VTrace currently does not support functional call on the actor
+ actor_with_params = repopulate_module(
+ deepcopy(self.actor), self.actor_params
+ )
+ self._value_estimator = VTrace(
+ value_network=self.critic, actor_network=actor_with_params, **hp
+ )
else:
raise NotImplementedError(f"Unknown value type {value_type}")
@@ -479,6 +490,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
"terminated": self.tensor_keys.terminated,
+ "sample_log_prob": self.tensor_keys.sample_log_prob,
}
self._value_estimator.set_keys(**tensor_keys)
@@ -848,7 +860,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
neg_loss = log_weight.exp() * advantage
previous_dist = self.actor.build_dist_from_params(tensordict)
- current_dist = self.actor.get_dist(tensordict, params=self.actor_params)
+ with self.actor_params.to_module(self.actor):
+ current_dist = self.actor.get_dist(tensordict)
try:
kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist)
except NotImplementedError:
diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py
index dd64a4bc033..347becc24ae 100644
--- a/torchrl/objectives/redq.py
+++ b/torchrl/objectives/redq.py
@@ -18,27 +18,17 @@
from torchrl.data import CompositeSpec
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives.common import LossModule
+
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
+ _vmap_func,
default_value_kwargs,
distance_loss,
ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
-try:
- try:
- from torch import vmap
- except ImportError:
- from functorch import vmap
-
- FUNCTORCH_ERR = ""
- _has_functorch = True
-except ImportError as err:
- FUNCTORCH_ERR = str(err)
- _has_functorch = False
-
class REDQLoss(LossModule):
"""REDQ Loss module.
@@ -265,8 +255,6 @@ def __init__(
priority_key: str = None,
separate_losses: bool = False,
):
- if not _has_functorch:
- raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR
super().__init__()
self._in_keys = None
@@ -276,7 +264,6 @@ def __init__(
actor_network,
"actor_network",
create_target_params=self.delay_actor,
- funs_to_decorate=["forward", "get_dist_params"],
)
# let's make sure that actor_network has `return_log_prob` set to True
@@ -331,8 +318,8 @@ def __init__(
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
- self._vmap_qvalue_network00 = vmap(self.qvalue_network)
- self._vmap_getdist = vmap(self.actor_network.get_dist_params)
+ self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
+ self._vmap_getdist = _vmap_func(self.actor_network, func="get_dist_params")
@property
def target_entropy(self):
diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py
index 93910f1eebf..832af829c64 100644
--- a/torchrl/objectives/reinforce.py
+++ b/torchrl/objectives/reinforce.py
@@ -3,12 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
+from copy import deepcopy
from dataclasses import dataclass
from typing import Optional
import torch
-from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
+from tensordict.nn import (
+ dispatch,
+ ProbabilisticTensorDictSequential,
+ repopulate_module,
+ TensorDictModule,
+)
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from torchrl.objectives.common import LossModule
@@ -18,7 +24,13 @@
distance_loss,
ValueEstimators,
)
-from torchrl.objectives.value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator
+from torchrl.objectives.value import (
+ GAE,
+ TD0Estimator,
+ TD1Estimator,
+ TDLambdaEstimator,
+ VTrace,
+)
class ReinforceLoss(LossModule):
@@ -285,10 +297,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
advantage = tensordict.get(self.tensor_keys.advantage)
# compute log-prob
- tensordict = self.actor_network(
- tensordict,
- params=self.actor_network_params,
- )
+ with self.actor_network_params.to_module(self.actor_network):
+ tensordict = self.actor_network(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
if log_prob.shape == advantage.shape[:-1]:
@@ -305,10 +315,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
try:
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(*self.critic.in_keys)
- state_value = self.critic(
- tensordict_select,
- params=self.critic_params,
- ).get(self.tensor_keys.value)
+ with self.critic_params.to_module(self.critic):
+ state_value = self.critic(tensordict_select).get(self.tensor_keys.value)
loss_value = distance_loss(
target_return,
state_value,
@@ -340,6 +348,14 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
self._value_estimator = GAE(value_network=self.critic, **hp)
elif value_type == ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp)
+ elif value_type == ValueEstimators.VTrace:
+ # VTrace currently does not support functional call on the actor
+ actor_with_params = repopulate_module(
+ deepcopy(self.actor), self.actor_params
+ )
+ self._value_estimator = VTrace(
+ value_network=self.critic, actor_network=actor_with_params, **hp
+ )
else:
raise NotImplementedError(f"Unknown value type {value_type}")
@@ -350,5 +366,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
"terminated": self.tensor_keys.terminated,
+ "sample_log_prob": self.tensor_keys.sample_log_prob,
}
self._value_estimator.set_keys(**tensor_keys)
diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py
index 076df1c54a4..d0617dedc74 100644
--- a/torchrl/objectives/sac.py
+++ b/torchrl/objectives/sac.py
@@ -12,7 +12,7 @@
import numpy as np
import torch
-from tensordict.nn import dispatch, make_functional, TensorDictModule
+from tensordict.nn import dispatch, TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from torch import Tensor
@@ -22,27 +22,17 @@
from torchrl.modules import ProbabilisticActor
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
+
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
+ _vmap_func,
default_value_kwargs,
distance_loss,
ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
-try:
- try:
- from torch import vmap
- except ImportError:
- from functorch import vmap
-
- _has_functorch = True
- err = ""
-except ImportError as err:
- _has_functorch = False
- FUNCTORCH_ERROR = err
-
def _delezify(func):
@wraps(func)
@@ -293,8 +283,6 @@ def __init__(
) -> None:
self._in_keys = None
self._out_keys = None
- if not _has_functorch:
- raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR
super().__init__()
self._set_deprecated_ctor_keys(priority_key=priority_key)
@@ -382,16 +370,15 @@ def __init__(
self._target_entropy = target_entropy
self._action_spec = action_spec
if self._version == 1:
- self.actor_critic = ActorCriticWrapper(
+ self.__dict__["actor_critic"] = ActorCriticWrapper(
self.actor_network, self.value_network
)
- make_functional(self.actor_critic)
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
- self._vmap_qnetworkN0 = vmap(self.qvalue_network, (None, 0))
+ self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0))
if self._version == 1:
- self._vmap_qnetwork00 = vmap(qvalue_network)
+ self._vmap_qnetwork00 = _vmap_func(qvalue_network)
@property
def target_entropy_buffer(self):
@@ -589,11 +576,10 @@ def _cached_detached_qvalue_params(self):
def _actor_loss(
self, tensordict: TensorDictBase
) -> Tuple[Tensor, Dict[str, Tensor]]:
- with set_exploration_type(ExplorationType.RANDOM):
- dist = self.actor_network.get_dist(
- tensordict,
- params=self.actor_network_params,
- )
+ with set_exploration_type(
+ ExplorationType.RANDOM
+ ), self.actor_network_params.to_module(self.actor_network):
+ dist = self.actor_network.get_dist(tensordict)
a_reparm = dist.rsample()
log_prob = dist.log_prob(a_reparm)
@@ -680,11 +666,11 @@ def _compute_target_v2(self, tensordict) -> Tensor:
tensordict = tensordict.clone(False)
# get actions and log-probs
with torch.no_grad():
- with set_exploration_type(ExplorationType.RANDOM):
+ with set_exploration_type(
+ ExplorationType.RANDOM
+ ), self.actor_network_params.to_module(self.actor_network):
next_tensordict = tensordict.get("next").clone(False)
- next_dist = self.actor_network.get_dist(
- next_tensordict, params=self.actor_network_params
- )
+ next_dist = self.actor_network.get_dist(next_tensordict)
next_action = next_dist.rsample()
next_tensordict.set(self.tensor_keys.action, next_action)
next_sample_log_prob = next_dist.log_prob(next_action)
@@ -736,16 +722,11 @@ def _value_loss(
) -> Tuple[Tensor, Dict[str, Tensor]]:
# value loss
td_copy = tensordict.select(*self.value_network.in_keys).detach()
- self.value_network(
- td_copy,
- params=self.value_network_params,
- )
+ with self.value_network_params.to_module(self.value_network):
+ self.value_network(td_copy)
pred_val = td_copy.get(self.tensor_keys.value).squeeze(-1)
-
- action_dist = self.actor_network.get_dist(
- td_copy,
- params=self.target_actor_network_params,
- ) # resample an action
+ with self.target_actor_network_params.to_module(self.actor_network):
+ action_dist = self.actor_network.get_dist(td_copy) # resample an action
action = action_dist.rsample()
td_copy.set(self.tensor_keys.action, action, inplace=False)
@@ -991,8 +972,6 @@ def __init__(
separate_losses: bool = False,
):
self._in_keys = None
- if not _has_functorch:
- raise ImportError("Failed to import functorch.") from FUNCTORCH_ERROR
super().__init__()
self._set_deprecated_ctor_keys(priority_key=priority_key)
@@ -1070,7 +1049,7 @@ def __init__(
self.register_buffer(
"target_entropy", torch.tensor(target_entropy, device=device)
)
- self._vmap_qnetworkN0 = vmap(self.qvalue_network, (None, 0))
+ self._vmap_qnetworkN0 = _vmap_func(self.qvalue_network, (None, 0))
def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
@@ -1154,9 +1133,8 @@ def _compute_target(self, tensordict) -> Tensor:
next_tensordict = tensordict.get("next").clone(False)
# get probs and log probs for actions computed from "next"
- next_dist = self.actor_network.get_dist(
- next_tensordict, params=self.actor_network_params
- )
+ with self.actor_network_params.to_module(self.actor_network):
+ next_dist = self.actor_network.get_dist(next_tensordict)
next_prob = next_dist.probs
next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob))
@@ -1221,10 +1199,8 @@ def _actor_loss(
self, tensordict: TensorDictBase
) -> Tuple[Tensor, Dict[str, Tensor]]:
# get probs and log probs for actions
- dist = self.actor_network.get_dist(
- tensordict,
- params=self.actor_network_params,
- )
+ with self.actor_network_params.to_module(self.actor_network):
+ dist = self.actor_network.get_dist(tensordict)
prob = dist.probs
log_prob = torch.log(torch.where(prob == 0, 1e-8, prob))
diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py
index 9912c143ae6..082873a2358 100644
--- a/torchrl/objectives/td3.py
+++ b/torchrl/objectives/td3.py
@@ -15,27 +15,17 @@
from torchrl.envs.utils import step_mdp
from torchrl.objectives.common import LossModule
+
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_WARNING,
+ _vmap_func,
default_value_kwargs,
distance_loss,
ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator
-try:
- try:
- from torch import vmap
- except ImportError:
- from functorch import vmap
-
- FUNCTORCH_ERR = ""
- _has_functorch = True
-except ImportError as err:
- FUNCTORCH_ERR = str(err)
- _has_functorch = False
-
class TD3Loss(LossModule):
"""TD3 Loss module.
@@ -229,10 +219,6 @@ def __init__(
priority_key: str = None,
separate_losses: bool = False,
) -> None:
- if not _has_functorch:
- raise ImportError(
- f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}"
- )
super().__init__()
self._in_keys = None
@@ -310,8 +296,8 @@ def __init__(
if gamma is not None:
warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning)
self.gamma = gamma
- self._vmap_qvalue_network00 = vmap(self.qvalue_network)
- self._vmap_actor_network00 = vmap(self.actor_network)
+ self._vmap_qvalue_network00 = _vmap_func(self.qvalue_network)
+ self._vmap_actor_network00 = _vmap_func(self.actor_network)
def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
@@ -359,9 +345,8 @@ def _cached_stack_actor_params(self):
def actor_loss(self, tensordict):
tensordict_actor_grad = tensordict.select(*self.actor_network.in_keys)
- tensordict_actor_grad = self.actor_network(
- tensordict_actor_grad, self.actor_network_params
- )
+ with self.actor_network_params.to_module(self.actor_network):
+ tensordict_actor_grad = self.actor_network(tensordict_actor_grad)
actor_loss_td = tensordict_actor_grad.select(
*self.qvalue_network.in_keys
).expand(
@@ -395,9 +380,8 @@ def value_loss(self, tensordict):
next_td_actor = step_mdp(tensordict).select(
*self.actor_network.in_keys
) # next_observation ->
- next_td_actor = self.actor_network(
- next_td_actor, self.target_actor_network_params
- )
+ with self.target_actor_network_params.to_module(self.actor_network):
+ next_td_actor = self.actor_network(next_td_actor)
next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp(
self.min_action, self.max_action
)
diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py
index bc678ed0154..c3e7dbc68ce 100644
--- a/torchrl/objectives/utils.py
+++ b/torchrl/objectives/utils.py
@@ -10,10 +10,17 @@
import torch
from tensordict.nn import TensorDictModule
-from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase
+from tensordict.tensordict import TensorDict, TensorDictBase
from torch import nn, Tensor
from torch.nn import functional as F
+try:
+ from torch import vmap
+except ImportError as err:
+ try:
+ from functorch import vmap
+ except ImportError as err_ft:
+ raise err_ft from err
from torchrl.envs.utils import step_mdp
_GAMMA_LMBDA_DEPREC_WARNING = (
@@ -39,6 +46,7 @@ class ValueEstimators(Enum):
TD1 = "TD(1) (infinity-step return)"
TDLambda = "TD(lambda)"
GAE = "Generalized advantage estimate"
+ VTrace = "V-trace"
def default_value_kwargs(value_type: ValueEstimators):
@@ -61,6 +69,8 @@ def default_value_kwargs(value_type: ValueEstimators):
return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True}
elif value_type == ValueEstimators.TDLambda:
return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True}
+ elif value_type == ValueEstimators.VTrace:
+ return {"gamma": 0.99, "differentiable": True}
else:
raise NotImplementedError(f"Unknown value type {value_type}.")
@@ -353,18 +363,19 @@ class hold_out_net(_context_manager):
def __init__(self, network: nn.Module) -> None:
self.network = network
- try:
- self.p_example = next(network.parameters())
- except (AttributeError, StopIteration):
- self.p_example = torch.tensor([])
- self._prev_state = []
+ for p in network.parameters():
+ self.mode = p.requires_grad
+ break
+ else:
+ self.mode = True
def __enter__(self) -> None:
- self._prev_state.append(self.p_example.requires_grad)
- self.network.requires_grad_(False)
+ if self.mode:
+ self.network.requires_grad_(False)
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
- self.network.requires_grad_(self._prev_state.pop())
+ if self.mode:
+ self.network.requires_grad_()
class hold_out_params(_context_manager):
@@ -457,9 +468,23 @@ def new_fun(self, netname=None):
out = fun(self, netname)
else:
out = fun(self)
- if is_tensor_collection(out):
- out.lock_()
+ # TODO: decide what to do with locked tds in functional calls
+ # if is_tensor_collection(out):
+ # out.lock_()
_cache[attr_name] = out
return out
return new_fun
+
+
+def _vmap_func(module, *args, func=None, **kwargs):
+ def decorated_module(*module_args_params):
+ params = module_args_params[-1]
+ module_args = module_args_params[:-1]
+ with params.to_module(module):
+ if func is None:
+ return module(*module_args)
+ else:
+ return getattr(module, func)(*module_args)
+
+ return vmap(decorated_module, *args, **kwargs) # noqa: TOR101
diff --git a/torchrl/objectives/value/__init__.py b/torchrl/objectives/value/__init__.py
index 11ae2e6d9e2..51496986153 100644
--- a/torchrl/objectives/value/__init__.py
+++ b/torchrl/objectives/value/__init__.py
@@ -12,4 +12,5 @@
TDLambdaEstimate,
TDLambdaEstimator,
ValueEstimatorBase,
+ VTrace,
)
diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py
index 4d3a25279a1..f3aff0da1d2 100644
--- a/torchrl/objectives/value/advantages.py
+++ b/torchrl/objectives/value/advantages.py
@@ -2,9 +2,12 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+
+
import abc
import functools
import warnings
+from contextlib import nullcontext
from dataclasses import asdict, dataclass
from functools import wraps
from typing import Callable, List, Optional, Union
@@ -24,7 +27,7 @@
from torchrl._utils import RL_WARNINGS
from torchrl.envs.utils import step_mdp
-from torchrl.objectives.utils import hold_out_net
+from torchrl.objectives.utils import _vmap_func, hold_out_net
from torchrl.objectives.value.functional import (
generalized_advantage_estimate,
td0_return_estimate,
@@ -32,8 +35,10 @@
vec_generalized_advantage_estimate,
vec_td1_return_estimate,
vec_td_lambda_return_estimate,
+ vtrace_advantage_estimate,
)
+
try:
from torch import vmap
except ImportError as err:
@@ -117,7 +122,8 @@ def _call_value_nets(
"the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed."
)
if params is not None:
- value_est = value_net(data_in, params).get(value_key)
+ with params.to_module(value_net):
+ value_est = value_net(data_in).get(value_key)
else:
value_est = value_net(data_in).get(value_key)
value, value_ = value_est[idx], value_est[idx_]
@@ -134,8 +140,8 @@ def _call_value_nets(
"params and next_params must be either both provided or not."
)
elif params is not None:
- params_stack = torch.stack([params, next_params], 0)
- data_out = vmap(value_net, (0, 0))(data_in, params_stack)
+ params_stack = torch.stack([params, next_params], 0).contiguous()
+ data_out = _vmap_func(value_net, (0, 0))(data_in, params_stack)
else:
data_out = vmap(value_net, (0,))(data_in)
value_est = data_out.get(value_key)
@@ -147,6 +153,17 @@ def _call_value_nets(
return value, value_
+def _call_actor_net(
+ actor_net: TensorDictModuleBase,
+ data: TensorDictBase,
+ params: TensorDictBase,
+ log_prob_key: NestedKey,
+):
+ # TODO: extend to handle time dimension (and vmap?)
+ log_pi = actor_net(data.select(actor_net.in_keys)).get(log_prob_key)
+ return log_pi
+
+
class ValueEstimatorBase(TensorDictModuleBase):
"""An abstract parent class for value function modules.
@@ -179,9 +196,11 @@ class _AcceptedKeys:
whether a trajectory is done. Defaults to ``"done"``.
terminated (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is terminated. Defaults to ``"terminated"``.
- steps_to_next_obs_key (NestedKey): The key in the input tensordict
+ steps_to_next_obs (NestedKey): The key in the input tensordict
that indicates the number of steps to the next observation.
Defaults to ``"steps_to_next_obs"``.
+ sample_log_prob (NestedKey): The key in the input tensordict that
+ indicates the log probability of the sampled action. Defaults to ``"sample_log_prob"``.
"""
advantage: NestedKey = "advantage"
@@ -191,6 +210,7 @@ class _AcceptedKeys:
done: NestedKey = "done"
terminated: NestedKey = "terminated"
steps_to_next_obs: NestedKey = "steps_to_next_obs"
+ sample_log_prob: NestedKey = "sample_log_prob"
default_keys = _AcceptedKeys()
value_network: Union[TensorDictModule, Callable]
@@ -223,6 +243,10 @@ def terminated_key(self):
def steps_to_next_obs_key(self):
return self.tensor_keys.steps_to_next_obs
+ @property
+ def sample_log_prob_key(self):
+ return self.tensor_keys.sample_log_prob
+
@abc.abstractmethod
def forward(
self,
@@ -270,7 +294,7 @@ def __init__(
self._tensor_keys = None
self.differentiable = differentiable
self.skip_existing = skip_existing
- self.value_network = value_network
+ self.__dict__["value_network"] = value_network
self.dep_keys = {}
self.shifted = shifted
@@ -341,7 +365,7 @@ def set_keys(self, **kwargs) -> None:
raise ValueError("tensordict keys cannot be None")
if key not in self._AcceptedKeys.__dict__:
raise KeyError(
- f"{key} it not an accepted tensordict key for advantages"
+ f"{key} is not an accepted tensordict key for advantages"
)
if (
key == "value"
@@ -403,10 +427,10 @@ def is_stateless(self):
def _next_value(self, tensordict, target_params, kwargs):
step_td = step_mdp(tensordict, keep_other=False)
if self.value_network is not None:
- if target_params is not None:
- kwargs["params"] = target_params
- with hold_out_net(self.value_network):
- self.value_network(step_td, **kwargs)
+ with hold_out_net(
+ self.value_network
+ ) if target_params is None else target_params.to_module(self.value_network):
+ self.value_network(step_td)
next_value = step_td.get(self.tensor_keys.value)
return next_value
@@ -447,6 +471,7 @@ class TD0Estimator(ValueEstimatorBase):
of the advantage entry. Defaults to ``"value_target"``.
value_key (str or tuple of str, optional): [Deprecated] the value key to
read from the input tensordict. Defaults to ``"state_value"``.
+ device (torch.device, optional): device of the module.
"""
@@ -462,6 +487,7 @@ def __init__(
value_target_key: NestedKey = None,
value_key: NestedKey = None,
skip_existing: Optional[bool] = None,
+ device: Optional[torch.device] = None,
):
super().__init__(
value_network=value_network,
@@ -472,10 +498,6 @@ def __init__(
value_key=value_key,
skip_existing=skip_existing,
)
- try:
- device = next(value_network.parameters()).device
- except (AttributeError, StopIteration):
- device = torch.device("cpu")
self.register_buffer("gamma", torch.tensor(gamma, device=device))
self.average_rewards = average_rewards
@@ -560,7 +582,9 @@ def forward(
params = params.detach()
if target_params is None:
target_params = params.clone(False)
- with hold_out_net(self.value_network):
+ with hold_out_net(self.value_network) if (
+ params is None and target_params is None
+ ) else nullcontext():
# we may still need to pass gradient, but we don't want to assign grads to
# value net params
value, next_value = _call_value_nets(
@@ -597,7 +621,7 @@ def value_estimate(
if self.average_rewards:
reward = reward - reward.mean()
- reward = reward / reward.std().clamp_min(1e-4)
+ reward = reward / reward.std().clamp_min(1e-5)
tensordict.set(
("next", self.tensor_keys.reward), reward
) # we must update the rewards if they are used later in the code
@@ -649,6 +673,7 @@ class TD1Estimator(ValueEstimatorBase):
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
+ device (torch.device, optional): device of the module.
"""
@@ -664,6 +689,7 @@ def __init__(
value_target_key: NestedKey = None,
value_key: NestedKey = None,
shifted: bool = False,
+ device: Optional[torch.device] = None,
):
super().__init__(
value_network=value_network,
@@ -674,10 +700,6 @@ def __init__(
shifted=shifted,
skip_existing=skip_existing,
)
- try:
- device = next(value_network.parameters()).device
- except (AttributeError, StopIteration):
- device = torch.device("cpu")
self.register_buffer("gamma", torch.tensor(gamma, device=device))
self.average_rewards = average_rewards
@@ -761,7 +783,9 @@ def forward(
params = params.detach()
if target_params is None:
target_params = params.clone(False)
- with hold_out_net(self.value_network):
+ with hold_out_net(self.value_network) if (
+ params is None and target_params is None
+ ) else nullcontext():
# we may still need to pass gradient, but we don't want to assign grads to
# value net params
value, next_value = _call_value_nets(
@@ -799,7 +823,7 @@ def value_estimate(
if self.average_rewards:
reward = reward - reward.mean()
- reward = reward / reward.std().clamp_min(1e-4)
+ reward = reward / reward.std().clamp_min(1e-5)
tensordict.set(
("next", self.tensor_keys.reward), reward
) # we must update the rewards if they are used later in the code
@@ -855,6 +879,7 @@ class TDLambdaEstimator(ValueEstimatorBase):
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
+ device (torch.device, optional): device of the module.
"""
@@ -872,6 +897,7 @@ def __init__(
value_target_key: NestedKey = None,
value_key: NestedKey = None,
shifted: bool = False,
+ device: Optional[torch.device] = None,
):
super().__init__(
value_network=value_network,
@@ -882,10 +908,6 @@ def __init__(
skip_existing=skip_existing,
shifted=shifted,
)
- try:
- device = next(value_network.parameters()).device
- except (AttributeError, StopIteration):
- device = torch.device("cpu")
self.register_buffer("gamma", torch.tensor(gamma, device=device))
self.register_buffer("lmbda", torch.tensor(lmbda, device=device))
self.average_rewards = average_rewards
@@ -972,7 +994,9 @@ def forward(
params = params.detach()
if target_params is None:
target_params = params.clone(False)
- with hold_out_net(self.value_network):
+ with hold_out_net(self.value_network) if (
+ params is None and target_params is None
+ ) else nullcontext():
# we may still need to pass gradient, but we don't want to assign grads to
# value net params
value, next_value = _call_value_nets(
@@ -1083,6 +1107,7 @@ class GAE(ValueEstimatorBase):
estimation, for instance) and (2) when the parameters used at time
``t`` and ``t+1`` are identical (which is not the case when target
parameters are to be used). Defaults to ``False``.
+ device (torch.device, optional): device of the module.
GAE will return an :obj:`"advantage"` entry containing the advange value. It will also
return a :obj:`"value_target"` entry with the return value that is to be used
@@ -1112,6 +1137,7 @@ def __init__(
value_target_key: NestedKey = None,
value_key: NestedKey = None,
shifted: bool = False,
+ device: Optional[torch.device] = None,
):
super().__init__(
shifted=shifted,
@@ -1122,10 +1148,6 @@ def __init__(
value_key=value_key,
skip_existing=skip_existing,
)
- try:
- device = next(value_network.parameters()).device
- except (AttributeError, StopIteration):
- device = torch.device("cpu")
self.register_buffer("gamma", torch.tensor(gamma, device=device))
self.register_buffer("lmbda", torch.tensor(lmbda, device=device))
self.average_gae = average_gae
@@ -1137,7 +1159,7 @@ def __init__(
def forward(
self,
tensordict: TensorDictBase,
- *unused_args,
+ *,
params: Optional[List[Tensor]] = None,
target_params: Optional[List[Tensor]] = None,
) -> TensorDictBase:
@@ -1218,7 +1240,9 @@ def forward(
params = params.detach()
if target_params is None:
target_params = params.clone(False)
- with hold_out_net(self.value_network):
+ with hold_out_net(self.value_network) if (
+ params is None and target_params is None
+ ) else nullcontext():
# we may still need to pass gradient, but we don't want to assign grads to
# value net params
value, next_value = _call_value_nets(
@@ -1298,7 +1322,9 @@ def value_estimate(
params = params.detach()
if target_params is None:
target_params = params.clone(False)
- with hold_out_net(self.value_network):
+ with hold_out_net(self.value_network) if (
+ params is None and target_params is None
+ ) else nullcontext():
# we may still need to pass gradient, but we don't want to assign grads to
# value net params
value, next_value = _call_value_nets(
@@ -1328,6 +1354,284 @@ def value_estimate(
return value_target
+class VTrace(ValueEstimatorBase):
+ """A class wrapper around V-Trace estimate functional.
+
+ Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures"
+ :ref:`here `_ for more context.
+
+ Args:
+ gamma (scalar): exponential mean discount.
+ value_network (TensorDictModule): value operator used to retrieve the value estimates.
+ actor_network (TensorDictModule): actor operator used to retrieve the log prob.
+ rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights.
+ Defaults to ``1.0``.
+ c_thresh (Union[float, Tensor]): c clipping parameter for importance weights.
+ Defaults to ``1.0``.
+ average_adv (bool): if ``True``, the resulting advantage values will be standardized.
+ Default is ``False``.
+ differentiable (bool, optional): if ``True``, gradients are propagated through
+ the computation of the value function. Default is ``False``.
+
+ .. note::
+ The proper way to make the function call non-differentiable is to
+ decorate it in a `torch.no_grad()` context manager/decorator or
+ pass detached parameters for functional modules.
+ skip_existing (bool, optional): if ``True``, the value network will skip
+ modules which outputs are already present in the tensordict.
+ Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()`
+ is not affected.
+ Defaults to "state_value".
+ advantage_key (str or tuple of str, optional): [Deprecated] the key of
+ the advantage entry. Defaults to ``"advantage"``.
+ value_target_key (str or tuple of str, optional): [Deprecated] the key
+ of the advantage entry. Defaults to ``"value_target"``.
+ value_key (str or tuple of str, optional): [Deprecated] the value key to
+ read from the input tensordict. Defaults to ``"state_value"``.
+ shifted (bool, optional): if ``True``, the value and next value are
+ estimated with a single call to the value network. This is faster
+ but is only valid whenever (1) the ``"next"`` value is shifted by
+ only one time step (which is not the case with multi-step value
+ estimation, for instance) and (2) when the parameters used at time
+ ``t`` and ``t+1`` are identical (which is not the case when target
+ parameters are to be used). Defaults to ``False``.
+ device (torch.device, optional): device of the module.
+
+ VTrace will return an :obj:`"advantage"` entry containing the advantage value. It will also
+ return a :obj:`"value_target"` entry with the V-Trace target value.
+
+ .. note::
+ As other advantage functions do, if the ``value_key`` is already present
+ in the input tensordict, the VTrace module will ignore the calls to the value
+ network (if any) and use the provided value instead.
+
+ """
+
+ def __init__(
+ self,
+ *,
+ gamma: Union[float, torch.Tensor],
+ actor_network: TensorDictModule,
+ value_network: TensorDictModule,
+ rho_thresh: Union[float, torch.Tensor] = 1.0,
+ c_thresh: Union[float, torch.Tensor] = 1.0,
+ average_adv: bool = False,
+ differentiable: bool = False,
+ skip_existing: Optional[bool] = None,
+ advantage_key: Optional[NestedKey] = None,
+ value_target_key: Optional[NestedKey] = None,
+ value_key: Optional[NestedKey] = None,
+ shifted: bool = False,
+ device: Optional[torch.device] = None,
+ ):
+ super().__init__(
+ shifted=shifted,
+ value_network=value_network,
+ differentiable=differentiable,
+ advantage_key=advantage_key,
+ value_target_key=value_target_key,
+ value_key=value_key,
+ skip_existing=skip_existing,
+ )
+ if not isinstance(gamma, torch.Tensor):
+ gamma = torch.tensor(gamma, device=device)
+ if not isinstance(rho_thresh, torch.Tensor):
+ rho_thresh = torch.tensor(rho_thresh, device=device)
+ if not isinstance(c_thresh, torch.Tensor):
+ c_thresh = torch.tensor(c_thresh, device=device)
+
+ self.register_buffer("gamma", gamma)
+ self.register_buffer("rho_thresh", rho_thresh)
+ self.register_buffer("c_thresh", c_thresh)
+ self.average_adv = average_adv
+ self.actor_network = actor_network
+
+ if isinstance(gamma, torch.Tensor) and gamma.shape != ():
+ raise NotImplementedError(
+ "Per-value gamma is not supported yet. Gamma must be a scalar."
+ )
+
+ @property
+ def in_keys(self):
+ parent_in_keys = super().in_keys
+ extended_in_keys = parent_in_keys + [self.tensor_keys.sample_log_prob]
+ return extended_in_keys
+
+ @_self_set_skip_existing
+ @_self_set_grad_enabled
+ @dispatch
+ def forward(
+ self,
+ tensordict: TensorDictBase,
+ *,
+ params: Optional[List[Tensor]] = None,
+ target_params: Optional[List[Tensor]] = None,
+ ) -> TensorDictBase:
+ """Computes the V-Trace correction given the data in tensordict.
+
+ If a functional module is provided, a nested TensorDict containing the parameters
+ (and if relevant the target parameters) can be passed to the module.
+
+ Args:
+ tensordict (TensorDictBase): A TensorDict containing the data
+ (an observation key, "action", "reward", "done" and "next" tensordict state
+ as returned by the environment) necessary to compute the value estimates and the GAE.
+ The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are
+ the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
+ params (TensorDictBase, optional): A nested TensorDict containing the params
+ to be passed to the functional value network module.
+ target_params (TensorDictBase, optional): A nested TensorDict containing the
+ target params to be passed to the functional value network module.
+
+ Returns:
+ An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
+
+ Examples:
+ >>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"])
+ >>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"])
+ >>> actor_net = ProbabilisticActor(
+ ... module=actor_net,
+ ... in_keys=["logits"],
+ ... out_keys=["action"],
+ ... distribution_class=OneHotCategorical,
+ ... return_log_prob=True,
+ ... )
+ >>> module = VTrace(
+ ... gamma=0.98,
+ ... value_network=value_net,
+ ... actor_network=actor_net,
+ ... differentiable=False,
+ ... )
+ >>> obs, next_obs = torch.randn(2, 1, 10, 3)
+ >>> reward = torch.randn(1, 10, 1)
+ >>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
+ >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
+ >>> sample_log_prob = torch.randn(1, 10, 1)
+ >>> tensordict = TensorDict({
+ ... "obs": obs,
+ ... "done": done,
+ ... "terminated": terminated,
+ ... "sample_log_prob": sample_log_prob,
+ ... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated},
+ ... }, batch_size=[1, 10])
+ >>> _ = module(tensordict)
+ >>> assert "advantage" in tensordict.keys()
+
+ The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
+
+ Examples:
+ >>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"])
+ >>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"])
+ >>> actor_net = ProbabilisticActor(
+ ... module=actor_net,
+ ... in_keys=["logits"],
+ ... out_keys=["action"],
+ ... distribution_class=OneHotCategorical,
+ ... return_log_prob=True,
+ ... )
+ >>> module = VTrace(
+ ... gamma=0.98,
+ ... value_network=value_net,
+ ... actor_network=actor_net,
+ ... differentiable=False,
+ ... )
+ >>> obs, next_obs = torch.randn(2, 1, 10, 3)
+ >>> reward = torch.randn(1, 10, 1)
+ >>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
+ >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
+ >>> sample_log_prob = torch.randn(1, 10, 1)
+ >>> tensordict = TensorDict({
+ ... "obs": obs,
+ ... "done": done,
+ ... "terminated": terminated,
+ ... "sample_log_prob": sample_log_prob,
+ ... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated},
+ ... }, batch_size=[1, 10])
+ >>> advantage, value_target = module(
+ ... obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated, sample_log_prob=sample_log_prob
+ ... )
+
+ """
+ if tensordict.batch_dims < 1:
+ raise RuntimeError(
+ "Expected input tensordict to have at least one dimensions, got "
+ f"tensordict.batch_size = {tensordict.batch_size}"
+ )
+ reward = tensordict.get(("next", self.tensor_keys.reward))
+ device = reward.device
+ gamma = self.gamma.to(device)
+ steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
+ if steps_to_next_obs is not None:
+ gamma = gamma ** steps_to_next_obs.view_as(reward)
+
+ # Make sure we have the value and next value
+ if self.value_network is not None:
+ if params is not None:
+ params = params.detach()
+ if target_params is None:
+ target_params = params.clone(False)
+ with hold_out_net(self.value_network):
+ # we may still need to pass gradient, but we don't want to assign grads to
+ # value net params
+ value, next_value = _call_value_nets(
+ value_net=self.value_network,
+ data=tensordict,
+ params=params,
+ next_params=target_params,
+ single_call=self.shifted,
+ value_key=self.tensor_keys.value,
+ detach_next=True,
+ )
+ else:
+ value = tensordict.get(self.tensor_keys.value)
+ next_value = tensordict.get(("next", self.tensor_keys.value))
+
+ # Make sure we have the log prob computed at collection time
+ if self.tensor_keys.sample_log_prob not in tensordict.keys():
+ raise ValueError(
+ f"Expected {self.tensor_keys.sample_log_prob} to be in tensordict"
+ )
+ log_mu = tensordict.get(self.tensor_keys.sample_log_prob).view_as(value)
+
+ # Compute log prob with current policy
+ with hold_out_net(self.actor_network):
+ log_pi = _call_actor_net(
+ actor_net=self.actor_network,
+ data=tensordict,
+ params=None,
+ log_prob_key=self.tensor_keys.sample_log_prob,
+ ).view_as(value)
+
+ # Compute the V-Trace correction
+ done = tensordict.get(("next", self.tensor_keys.done))
+ terminated = tensordict.get(("next", self.tensor_keys.terminated))
+
+ adv, value_target = vtrace_advantage_estimate(
+ gamma,
+ log_pi,
+ log_mu,
+ value,
+ next_value,
+ reward,
+ done,
+ terminated,
+ rho_thresh=self.rho_thresh,
+ c_thresh=self.c_thresh,
+ time_dim=tensordict.ndim - 1,
+ )
+
+ if self.average_adv:
+ loc = adv.mean()
+ scale = adv.std().clamp_min(1e-5)
+ adv = adv - loc
+ adv = adv / scale
+
+ tensordict.set(self.tensor_keys.advantage, adv)
+ tensordict.set(self.tensor_keys.value_target, value_target)
+
+ return tensordict
+
+
def _deprecate_class(cls, new_cls):
@wraps(cls.__init__)
def new_init(self, *args, **kwargs):
diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py
index 7c33895e965..6c43af02aeb 100644
--- a/torchrl/objectives/value/functional.py
+++ b/torchrl/objectives/value/functional.py
@@ -27,6 +27,7 @@
"vec_td_lambda_return_estimate",
"td_lambda_advantage_estimate",
"vec_td_lambda_advantage_estimate",
+ "vtrace_advantage_estimate",
]
from torchrl.objectives.value.utils import (
@@ -1212,6 +1213,93 @@ def vec_td_lambda_advantage_estimate(
)
+########################################################################
+# V-Trace
+# -----
+
+
+@_transpose_time
+def vtrace_advantage_estimate(
+ gamma: float,
+ log_pi: torch.Tensor,
+ log_mu: torch.Tensor,
+ state_value: torch.Tensor,
+ next_state_value: torch.Tensor,
+ reward: torch.Tensor,
+ done: torch.Tensor,
+ terminated: torch.Tensor | None = None,
+ rho_thresh: Union[float, torch.Tensor] = 1.0,
+ c_thresh: Union[float, torch.Tensor] = 1.0,
+ time_dim: int = -2,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Computes V-Trace off-policy actor critic targets.
+
+ Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures"
+ https://arxiv.org/abs/1802.01561 for more context.
+
+ Args:
+ gamma (scalar): exponential mean discount.
+ log_pi (Tensor): collection actor log probability of taking actions in the environment.
+ log_mu (Tensor): current actor log probability of taking actions in the environment.
+ state_value (Tensor): value function result with state input.
+ next_state_value (Tensor): value function result with next_state input.
+ reward (Tensor): reward of taking actions in the environment.
+ done (Tensor): boolean flag for end of episode.
+ terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states.
+ rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights.
+ c_thresh (Union[float, Tensor]): c clipping parameter for importance weights.
+ time_dim (int): dimension where the time is unrolled. Defaults to -2.
+
+ All tensors (values, reward and done) must have shape
+ ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions.
+ """
+ if not (next_state_value.shape == state_value.shape == reward.shape == done.shape):
+ raise RuntimeError(SHAPE_ERR)
+
+ device = state_value.device
+
+ if not isinstance(rho_thresh, torch.Tensor):
+ rho_thresh = torch.tensor(rho_thresh, device=device)
+ if not isinstance(c_thresh, torch.Tensor):
+ c_thresh = torch.tensor(c_thresh, device=device)
+
+ c_thresh = c_thresh.to(device)
+ rho_thresh = rho_thresh.to(device)
+
+ not_done = (~done).int()
+ not_terminated = not_done if terminated is None else (~terminated).int()
+ *batch_size, time_steps, lastdim = not_done.shape
+ done_discounts = gamma * not_done
+ terminated_discounts = gamma * not_terminated
+
+ rho = (log_pi - log_mu).exp()
+ clipped_rho = rho.clamp_max(rho_thresh)
+ deltas = clipped_rho * (
+ reward + terminated_discounts * next_state_value - state_value
+ )
+ clipped_c = rho.clamp_max(c_thresh)
+
+ vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])]
+ for i in reversed(range(time_steps)):
+ discount_t, c_t, delta_t = (
+ done_discounts[..., i, :],
+ clipped_c[..., i, :],
+ deltas[..., i, :],
+ )
+ vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1])
+ vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:], dim=time_dim)
+ vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[time_dim])
+ vs = vs_minus_v_xs + state_value
+ vs_t_plus_1 = torch.cat(
+ [vs[..., 1:, :], next_state_value[..., -1:, :]], dim=time_dim
+ )
+ advantages = clipped_rho * (
+ reward + terminated_discounts * vs_t_plus_1 - state_value
+ )
+
+ return advantages, vs
+
+
########################################################################
# Reward to go
# ------------
diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py
deleted file mode 100644
index 43f5246502f..00000000000
--- a/torchrl/objectives/value/vtrace.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-import math
-from typing import Tuple, Union
-
-import torch
-
-
-def _c_val(
- log_pi: torch.Tensor,
- log_mu: torch.Tensor,
- c: Union[float, torch.Tensor] = 1,
-) -> torch.Tensor:
- return (log_pi - log_mu).clamp_max(math.log(c)).exp()
-
-
-def _dv_val(
- rewards: torch.Tensor,
- vals: torch.Tensor,
- gamma: Union[float, torch.Tensor],
- rho_bar: Union[float, torch.Tensor],
- log_pi: torch.Tensor,
- log_mu: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- rho = _c_val(log_pi, log_mu, rho_bar)
- next_vals = torch.cat([vals[:, 1:], torch.zeros_like(vals[:, :1])], 1)
- dv = rho * (rewards + gamma * next_vals - vals)
- return dv, rho
-
-
-def _vtrace(
- rewards: torch.Tensor,
- vals: torch.Tensor,
- log_pi: torch.Tensor,
- log_mu: torch.Tensor,
- gamma: Union[torch.Tensor, float],
- rho_bar: Union[float, torch.Tensor] = 1.0,
- c_bar: Union[float, torch.Tensor] = 1.0,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- T = vals.shape[1]
- if not isinstance(gamma, torch.Tensor):
- gamma = torch.full_like(vals, gamma)
-
- dv, rho = _dv_val(rewards, vals, gamma, rho_bar, log_pi, log_mu)
- c = _c_val(log_pi, log_mu, c_bar)
-
- v_out = []
- v_out.append(vals[:, -1] + dv[:, -1])
- for t in range(T - 2, -1, -1):
- _v_out = (
- vals[:, t] + dv[:, t] + gamma[:, t] * c[:, t] * (v_out[-1] - vals[:, t + 1])
- )
- v_out.append(_v_out)
- v_out = torch.stack(list(reversed(v_out)), 1)
- return v_out, rho
diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py
index ee343aa438e..3782de64fa2 100644
--- a/torchrl/trainers/helpers/models.py
+++ b/torchrl/trainers/helpers/models.py
@@ -657,6 +657,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
out_keys=[action_key],
default_interaction_type=InteractionType.RANDOM,
distribution_class=TanhNormal,
+ distribution_kwargs={"tanh_loc": True},
spec=CompositeSpec(**{action_key: proof_environment.action_spec}),
),
)
@@ -703,8 +704,9 @@ def _dreamer_make_actor_real(
SafeProbabilisticModule(
in_keys=["loc", "scale"],
out_keys=[action_key],
- default_interaction_type=InteractionType.RANDOM,
+ default_interaction_type=InteractionType.MODE,
distribution_class=TanhNormal,
+ distribution_kwargs={"tanh_loc": True},
spec=CompositeSpec(
**{action_key: proof_environment.action_spec.to("cpu")}
),
diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py
index 5237b344e56..be6e607c1b5 100644
--- a/tutorials/sphinx-tutorials/rb_tutorial.py
+++ b/tutorials/sphinx-tutorials/rb_tutorial.py
@@ -46,6 +46,7 @@
# replay buffer is a straightforward process, as shown in the following
# example:
#
+import tempfile
from torchrl.data import ReplayBuffer
@@ -175,9 +176,8 @@
######################################################################
# We can also customize the storage location on disk:
#
-buffer_lazymemmap = ReplayBuffer(
- storage=LazyMemmapStorage(size, scratch_dir="/tmp/memmap/")
-)
+tempdir = tempfile.TemporaryDirectory()
+buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size, scratch_dir=tempdir))
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
print("the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename)
@@ -207,8 +207,9 @@
from torchrl.data import TensorDictReplayBuffer
+tempdir = tempfile.TemporaryDirectory()
buffer_lazymemmap = TensorDictReplayBuffer(
- storage=LazyMemmapStorage(size, scratch_dir="/tmp/memmap/"), batch_size=12
+ storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
)
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
@@ -248,8 +249,9 @@ class MyData:
batch_size=[1000],
)
+tempdir = tempfile.TemporaryDirectory()
buffer_lazymemmap = TensorDictReplayBuffer(
- storage=LazyMemmapStorage(size, scratch_dir="/tmp/memmap/"), batch_size=12
+ storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12
)
buffer_lazymemmap.extend(data)
print(f"The buffer has {len(buffer_lazymemmap)} elements")
diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py
index ccb1c9d4ea7..ef995030c9d 100644
--- a/tutorials/sphinx-tutorials/torchrl_envs.py
+++ b/tutorials/sphinx-tutorials/torchrl_envs.py
@@ -25,6 +25,7 @@
# will pass the arguments and keyword arguments to the root library builder.
#
# With gym, it means that building an environment is as easy as:
+
# sphinx_gallery_start_ignore
import warnings