diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml new file mode 100644 index 00000000000..27963a42a24 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -0,0 +1,20 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - scipy + - hydra-core + - minari diff --git a/.github/unittest/linux_libs/scripts_minari/install.sh b/.github/unittest/linux_libs/scripts_minari/install.sh new file mode 100755 index 00000000000..2eb52b8f65e --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/install.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. +apt-get update && apt-get install -y git wget gcc g++ +#apt-get update && apt-get install -y git wget freeglut3 freeglut3-dev + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_minari/post_process.sh b/.github/unittest/linux_libs/scripts_minari/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_minari/run-clang-format.py b/.github/unittest/linux_libs/scripts_minari/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_libs/scripts_minari/run_test.sh b/.github/unittest/linux_libs/scripts_minari/run_test.sh new file mode 100755 index 00000000000..7741a491f5b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/run_test.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 +ln -s /usr/bin/swig3.0 /usr/bin/swig + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +conda deactivate && conda activate ./env + +# this workflow only tests the libs +python -c "import minari" + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestMinari --error-for-skips +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_minari/setup_env.sh b/.github/unittest/linux_libs/scripts_minari/setup_env.sh new file mode 100755 index 00000000000..5214617c2ac --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/setup_env.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e +set -v + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +apt-get update && apt-get install -y git wget gcc g++ unzip + +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip3 install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml index 923a3f3dfc1..42ff3c77251 100644 --- a/.github/workflows/nightly_build.yml +++ b/.github/workflows/nightly_build.yml @@ -87,12 +87,11 @@ jobs: python3 -mpip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu - name: Build TorchRL Nightly run: | - rm -r dist || true export CC=clang CXX=clang++ python3 -mpip install wheel + python3 -mpip install ninja python3 setup.py bdist_wheel \ - --package_name torchrl-nightly \ - --python-tag=${{ matrix.python-tag }} + --package_name torchrl-nightly - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: diff --git a/.github/workflows/test-linux-d4rl.yml b/.github/workflows/test-linux-d4rl.yml index 3a0d534cd8e..ef986e34498 100644 --- a/.github/workflows/test-linux-d4rl.yml +++ b/.github/workflows/test-linux-d4rl.yml @@ -21,6 +21,7 @@ jobs: matrix: python_version: ["3.9"] cuda_arch_version: ["12.1"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: repository: pytorch/rl diff --git a/.github/workflows/test-linux-minari.yml b/.github/workflows/test-linux-minari.yml new file mode 100644 index 00000000000..aa473d5aef2 --- /dev/null +++ b/.github/workflows/test-linux-minari.yml @@ -0,0 +1,42 @@ +name: Minari Tests on Linux + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + workflow_dispatch: + +concurrency: + # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. + # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} + cancel-in-progress: true + +jobs: + unittests: + strategy: + matrix: + python_version: ["3.9"] + cuda_arch_version: ["12.1"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Data') }} + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + docker-image: "nvidia/cudagl:11.4.0-base" + timeout: 120 + script: | + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="cu117" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + + bash .github/unittest/linux_libs/scripts_minari/setup_env.sh + bash .github/unittest/linux_libs/scripts_minari/install.sh + bash .github/unittest/linux_libs/scripts_minari/run_test.sh + bash .github/unittest/linux_libs/scripts_minari/post_process.sh diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 71b7a481ce0..246c5ee15f0 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -76,12 +76,12 @@ def make(envname=envname, gym_backend=gym_backend): # regular parallel env for device in avail_devices: - def make(envname=envname, gym_backend=gym_backend, device=device): + def make(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") # env_make = EnvCreator(make) - penv = ParallelEnv(num_workers, EnvCreator(make)) + penv = ParallelEnv(num_workers, EnvCreator(make), device=device) with torch.inference_mode(): # warmup penv.rollout(2) @@ -103,13 +103,13 @@ def make(envname=envname, gym_backend=gym_backend, device=device): for device in avail_devices: - def make(envname=envname, gym_backend=gym_backend, device=device): + def make(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") env_make = EnvCreator(make) # penv = SerialEnv(num_workers, env_make) - penv = ParallelEnv(num_workers, env_make) + penv = ParallelEnv(num_workers, env_make, device=device) collector = SyncDataCollector( penv, RandomPolicy(penv.action_spec), @@ -164,14 +164,14 @@ def make_env( for device in avail_devices: # async collector # + torchrl parallel env - def make_env( - envname=envname, gym_backend=gym_backend, device=device - ): + def make_env(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") penv = ParallelEnv( - num_workers // num_collectors, EnvCreator(make_env) + num_workers // num_collectors, + EnvCreator(make_env), + device=device, ) collector = MultiaSyncDataCollector( [penv] * num_collectors, @@ -206,10 +206,9 @@ def make_env( envname=envname, num_workers=num_workers, gym_backend=gym_backend, - device=device, ): with set_gym_backend(gym_backend): - penv = GymEnv(envname, num_envs=num_workers, device=device) + penv = GymEnv(envname, num_envs=num_workers, device="cpu") return penv penv = EnvCreator( @@ -247,14 +246,14 @@ def make_env( for device in avail_devices: # sync collector # + torchrl parallel env - def make_env( - envname=envname, gym_backend=gym_backend, device=device - ): + def make_env(envname=envname, gym_backend=gym_backend): with set_gym_backend(gym_backend): - return GymEnv(envname, device=device) + return GymEnv(envname, device="cpu") penv = ParallelEnv( - num_workers // num_collectors, EnvCreator(make_env) + num_workers // num_collectors, + EnvCreator(make_env), + device=device, ) collector = MultiSyncDataCollector( [penv] * num_collectors, @@ -289,10 +288,9 @@ def make_env( envname=envname, num_workers=num_workers, gym_backend=gym_backend, - device=device, ): with set_gym_backend(gym_backend): - penv = GymEnv(envname, num_envs=num_workers, device=device) + penv = GymEnv(envname, num_envs=num_workers, device="cpu") return penv penv = EnvCreator( diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 98d2d40cd5c..55ebd12e867 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -60,6 +60,77 @@ The following mean sampling latency improvements over using ListStorage were fou | :class:`LazyMemmapStorage` | 3.44x | +-------------------------------+-----------+ +Replay buffers with a shared storage and regular (RoundRobin) writers can also +be shared between processes on a single node. This allows each worker to read and +write onto the storage. The following code snippet examplifies this feature: + + >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage + >>> import torch + >>> from torch import multiprocessing as mp + >>> from tensordict import TensorDict + >>> + >>> def worker(rb): + ... # Updates the replay buffer with new data + ... td = TensorDict({"a": torch.ones(10)}, [10]) + ... rb.extend(td) + ... + >>> if __name__ == "__main__": + ... rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21)) + ... td = TensorDict({"a": torch.zeros(10)}, [10]) + ... rb.extend(td) + ... + ... proc = mp.Process(target=worker, args=(rb,)) + ... proc.start() + ... proc.join() + ... # the replay buffer now has a length of 20, since the worker updated it + ... assert len(rb) == 20 + ... assert (rb["_data", "a"][:10] == 0).all() # data from main process + ... assert (rb["_data", "a"][10:20] == 1).all() # data from remote process + +Sharing replay buffers across processes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Replay buffers can be shared between processes as long as their components are +sharable. This feature allows for multiple processes to collect data and populate a shared +replay buffer collaboratively, rather than centralizing the data on the main process +which can incur some data transmission overhead. + +Sharable storages include :class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` +or any subclass of :class:`~torchrl.data.replay_buffers.storages.TensorStorage` +as long as they are instantiated and their content is stored as memory-mapped +tensors. Stateful writers such as :class:`~torchrl.data.replay_buffers.writers.TensorDictRoundRobinWriter` +are currently not sharable, and the same goes for stateful samplers such as +:class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`. + +A shared replay buffer can be read and extended on any process that has access +to it, as the following example shows: + + >>> import pickle + >>> + >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage + >>> import torch + >>> from torch import multiprocessing as mp + >>> from tensordict import TensorDict + >>> + >>> def worker(rb): + ... td = TensorDict({"a": torch.ones(10)}, [10]) + ... # Extends the shared replay buffer on a subprocess + ... rb.extend(td) + >>> + >>> if __name__ == "__main__": + ... rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(21)) + ... td = TensorDict({"a": torch.zeros(10)}, [10]) + .. # extends the replay buffer on the main process + ... rb.extend(td) + ... + ... proc = mp.Process(target=worker, args=(rb,)) + ... proc.start() + ... proc.join() + ... # Checks that the length of the buffer equates the length of both + ... # extensions (local and remote) + ... assert len(rb) == 20 + + Storing trajectories ~~~~~~~~~~~~~~~~~~~~ @@ -103,6 +174,32 @@ can be used: device=None, is_shared=False) +Checkpointing Replay Buffers +---------------------------- + +Each component of the replay buffer can potentially be stateful and, as such, +require a dedicated way of being serialized. +Our replay buffer enjoys two separate APIs for saving their state on disk: +:meth:`~torchrl.data.ReplayBuffer.dumps` and :meth:`~torchrl.data.ReplayBuffer.loads` will save the +data of each component except transforms (storage, writer, sampler) using memory-mapped +tensors and json files for the metadata. This will work across all classes except +:class:`~torchrl.data.replay_buffers.storages.ListStorage`, which content +cannot be anticipated (and as such does not comply with memory-mapped data +structures such as those that can be found in the tensordict library). +This API guarantees that a buffer that is saved and then loaded back will be in +the exact same state, whether we look at the status of its sampler (eg, priority trees) +its writer (eg, max writer heaps) or its storage. +Under the hood, :meth:`~torchrl.data.ReplayBuffer.dumps` will just call the public +`dumps` method in a specific folder for each of its components (except transforms +which we don't assume to be serializable using memory-mapped tensors in general). + +Whenever saving data using :meth:`~torchrl.data.ReplayBuffer.dumps` is not possible, an +alternative way is to use :meth:`~torchrl.data.ReplayBuffer.state_dict`, which returns a data +structure that can be saved using :func:`torch.save` and loaded using :func:`torch.load` +before calling :meth:`~torchrl.data.ReplayBuffer.load_state_dict`. The drawback +of this method is that it will struggle to save big data structures, which is a +common setting when using replay buffers. + Datasets -------- @@ -189,6 +286,7 @@ Here's an example: D4RLExperienceReplay + MinariExperienceReplay OpenMLExperienceReplay TensorSpec diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 8d1e258502e..d859140bb70 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -333,7 +333,11 @@ algorithms, such as DQN, DDPG or Dreamer. DistributionalDQNnet DreamerActor DuelingCnnDQNet + GRUCell + GRU GRUModule + LSTMCell + LSTM LSTMModule ObsDecoder ObsEncoder diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index fba4247e2a7..385e4a53aab 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -147,6 +147,7 @@ def transformed_env_constructor( state_dim_gsde: Optional[int] = None, batch_dims: Optional[int] = 0, obs_norm_state_dict: Optional[dict] = None, + ignore_device: bool = False, ) -> Union[Callable, EnvCreator]: """ Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. @@ -179,6 +180,7 @@ def transformed_env_constructor( it should be set to 1 (or the number of dims of the batch). obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded into the environment + ignore_device (bool, optional): if True, the device is ignored. """ def make_transformed_env(**kwargs) -> TransformedEnv: @@ -189,14 +191,17 @@ def make_transformed_env(**kwargs) -> TransformedEnv: from_pixels = cfg.from_pixels if custom_env is None and custom_env_maker is None: - if isinstance(cfg.collector_device, str): - device = cfg.collector_device - elif isinstance(cfg.collector_device, Sequence): - device = cfg.collector_device[0] + if not ignore_device: + if isinstance(cfg.collector_device, str): + device = cfg.collector_device + elif isinstance(cfg.collector_device, Sequence): + device = cfg.collector_device[0] + else: + raise ValueError( + "collector_device must be either a string or a sequence of strings" + ) else: - raise ValueError( - "collector_device must be either a string or a sequence of strings" - ) + device = None env_kwargs = { "env_name": env_name, "device": device, @@ -252,19 +257,19 @@ def parallel_env_constructor( kwargs: keyword arguments for the `transformed_env_constructor` method. """ batch_transform = cfg.batch_transform + kwargs.update({"cfg": cfg, "use_env_creator": True}) if cfg.env_per_collector == 1: - kwargs.update({"cfg": cfg, "use_env_creator": True}) make_transformed_env = transformed_env_constructor(**kwargs) return make_transformed_env - kwargs.update({"cfg": cfg, "use_env_creator": True}) make_transformed_env = transformed_env_constructor( - return_transformed_envs=not batch_transform, **kwargs + return_transformed_envs=not batch_transform, ignore_device=True, **kwargs ) parallel_env = ParallelEnv( num_workers=cfg.env_per_collector, create_env_fn=make_transformed_env, create_env_kwargs=None, pin_memory=cfg.pin_memory, + device=cfg.collector_device, ) if batch_transform: kwargs.update( diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 79af6482480..ca98e2cff6e 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -325,7 +325,7 @@ class MyClass: def rollout_consistency_assertion( - rollout, *, done_key="done", observation_key="observation" + rollout, *, done_key="done", observation_key="observation", done_strict=False ): """Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise.""" @@ -335,11 +335,13 @@ def rollout_consistency_assertion( # data resulting from step, when it's not done, after step_mdp r_not_done_tp1 = rollout[:, 1:][~done] torch.testing.assert_close( - r_not_done[observation_key], r_not_done_tp1[observation_key] + r_not_done[observation_key], + r_not_done_tp1[observation_key], + msg=f"Key {observation_key} did not match", ) - if not done.any(): - return + if done_strict and not done.any(): + raise RuntimeError("No done detected, test could not complete.") # data resulting from step, when it's done r_done = rollout[:, :-1]["next"][done] @@ -347,7 +349,10 @@ def rollout_consistency_assertion( r_done_tp1 = rollout[:, 1:][done] assert ( (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1 - ).all(), (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) + ).all(), ( + f"Entries in next tensordict do not match entries in root " + f"tensordict after reset : {(r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) < 1e-1}" + ) def rand_reset(env): diff --git a/test/test_env.py b/test/test_env.py index 6cee7f545d7..aed4e07b0b7 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -354,6 +354,48 @@ def test_mb_env_batch_lock(self, device, seed=0): class TestParallel: + @pytest.mark.skipif( + not torch.cuda.device_count(), reason="No cuda device detected." + ) + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("hetero", [True, False]) + @pytest.mark.parametrize("pdevice", [None, "cpu", "cuda"]) + @pytest.mark.parametrize("edevice", ["cpu", "cuda"]) + @pytest.mark.parametrize("bwad", [True, False]) + def test_parallel_devices(self, parallel, hetero, pdevice, edevice, bwad): + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv + if not hetero: + env = cls( + 2, lambda: ContinuousActionVecMockEnv(device=edevice), device=pdevice + ) + else: + env1 = lambda: ContinuousActionVecMockEnv(device=edevice) + env2 = lambda: TransformedEnv(ContinuousActionVecMockEnv(device=edevice)) + env = cls(2, [env1, env2], device=pdevice) + + r = env.rollout(2, break_when_any_done=bwad) + if pdevice is not None: + assert env.device.type == torch.device(pdevice).type + assert r.device.type == torch.device(pdevice).type + assert all( + item.device.type == torch.device(pdevice).type + for item in r.values(True, True) + ) + else: + assert env.device.type == torch.device(edevice).type + assert r.device.type == torch.device(edevice).type + assert all( + item.device.type == torch.device(edevice).type + for item in r.values(True, True) + ) + if parallel: + assert ( + env.shared_tensordict_parent.device.type == torch.device(edevice).type + ) + @pytest.mark.parametrize("num_parallel_env", [1, 10]) @pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1), (32, 0)]) def test_env_with_batch_size(self, num_parallel_env, env_batch_size): diff --git a/test/test_libs.py b/test/test_libs.py index c3379021510..7f42d52d63e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -50,6 +50,7 @@ from torchrl._utils import implement_for from torchrl.collectors.collectors import RandomPolicy, SyncDataCollector from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.datasets.minari_data import MinariExperienceReplay from torchrl.data.datasets.openml import OpenMLExperienceReplay from torchrl.data.replay_buffers import SamplerWithoutReplacement from torchrl.envs import ( @@ -90,6 +91,8 @@ _has_gym_robotics = importlib.util.find_spec("gymnasium_robotics") is not None +_has_minari = importlib.util.find_spec("minari") is not None + if _has_gym: try: import gymnasium as gym @@ -400,21 +403,23 @@ def test_vecenvs_wrapper(self, envname): ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"] + (["FetchReach-v2"] if _has_gym_robotics else []), ) - @pytest.mark.flaky(reruns=8, reruns_delay=1) def test_vecenvs_env(self, envname): - from _utils_internal import rollout_consistency_assertion - with set_gym_backend("gymnasium"): env = GymEnv(envname, num_envs=2, from_pixels=False) - + env.set_seed(0) assert env.get_library_name(env._env) == "gymnasium" # rollouts can be executed without decorator check_env_specs(env) rollout = env.rollout(100, break_when_any_done=False) for obs_key in env.observation_spec.keys(True, True): rollout_consistency_assertion( - rollout, done_key="done", observation_key=obs_key + rollout, + done_key="done", + observation_key=obs_key, + done_strict="CartPole" in envname, ) + env.close() + del env @implement_for("gym", "0.18", "0.27.0") @pytest.mark.parametrize( @@ -441,30 +446,39 @@ def test_vecenvs_wrapper(self, envname): # noqa: F811 ) assert env.batch_size == torch.Size([2]) check_env_specs(env) + env.close() + del env @implement_for("gym", "0.18", "0.27.0") @pytest.mark.parametrize( "envname", ["CartPole-v1", "HalfCheetah-v4"], ) - @pytest.mark.flaky(reruns=3, reruns_delay=1) def test_vecenvs_env(self, envname): # noqa: F811 with set_gym_backend("gym"): env = GymEnv(envname, num_envs=2, from_pixels=False) - + env.set_seed(0) assert env.get_library_name(env._env) == "gym" # rollouts can be executed without decorator check_env_specs(env) rollout = env.rollout(100, break_when_any_done=False) for obs_key in env.observation_spec.keys(True, True): rollout_consistency_assertion( - rollout, done_key="done", observation_key=obs_key + rollout, + done_key="done", + observation_key=obs_key, + done_strict="CartPole" in envname, ) + env.close() + del env if envname != "CartPole-v1": with set_gym_backend("gym"): env = GymEnv(envname, num_envs=2, from_pixels=True) + env.set_seed(0) # rollouts can be executed without decorator check_env_specs(env) + env.close() + del env @implement_for("gym", None, "0.18") @pytest.mark.parametrize( @@ -1817,7 +1831,10 @@ class TestD4RL: @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) @pytest.mark.parametrize("use_truncated_as_done", [True, False]) @pytest.mark.parametrize("split_trajs", [True, False]) - def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): + def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs, tmpdir): + root1 = tmpdir / "1" + root2 = tmpdir / "2" + root3 = tmpdir / "3" with pytest.warns( UserWarning, match="Using use_truncated_as_done=True" @@ -1829,6 +1846,8 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): terminate_on_end=True, batch_size=2, use_truncated_as_done=use_truncated_as_done, + download="force", + root=root1, ) _ = D4RLExperienceReplay( task, @@ -1837,6 +1856,8 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): terminate_on_end=False, batch_size=2, use_truncated_as_done=use_truncated_as_done, + download="force", + root=root2, ) data_from_env = D4RLExperienceReplay( task, @@ -1844,6 +1865,8 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): from_env=True, batch_size=2, use_truncated_as_done=use_truncated_as_done, + download="force", + root=root3, ) if not use_truncated_as_done: keys = set(data_from_env._storage._storage.keys(True, True)) @@ -1874,7 +1897,9 @@ def test_terminate_on_end(self, task, use_truncated_as_done, split_trajs): assert "truncated" not in leaf_names @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) - def test_direct_download(self, task): + def test_direct_download(self, task, tmpdir): + root1 = tmpdir / "1" + root2 = tmpdir / "2" data_direct = D4RLExperienceReplay( task, split_trajs=False, @@ -1882,6 +1907,8 @@ def test_direct_download(self, task): batch_size=2, use_truncated_as_done=True, direct_download=True, + download="force", + root=root1, ) data_d4rl = D4RLExperienceReplay( task, @@ -1891,6 +1918,8 @@ def test_direct_download(self, task): use_truncated_as_done=True, direct_download=False, terminate_on_end=True, # keep the last time step + download="force", + root=root2, ) keys = set(data_direct._storage._storage.keys(True, True)) keys = keys.intersection(data_d4rl._storage._storage.keys(True, True)) @@ -1961,6 +1990,53 @@ def test_d4rl_iteration(self, task, split_trajs): print(f"terminated test after {time.time()-t0}s") +_MINARI_DATASETS = [] + + +def _minari_selected_datasets(): + if not _has_minari: + return + global _MINARI_DATASETS + import minari + + torch.manual_seed(0) + + keys = list(minari.list_remote_datasets()) + indices = torch.randperm(len(keys))[:10] + keys = [keys[idx] for idx in indices] + keys = [ + key + for key in keys + if "=0.4" in minari.list_remote_datasets()[key]["minari_version"] + ] + assert len(keys) > 5 + _MINARI_DATASETS += keys + print("_MINARI_DATASETS", _MINARI_DATASETS) + + +_minari_selected_datasets() + + +@pytest.mark.skipif(not _has_minari, reason="Minari not found") +@pytest.mark.parametrize("split", [False, True]) +@pytest.mark.parametrize("selected_dataset", _MINARI_DATASETS) +class TestMinari: + def test_load(self, selected_dataset, split): + print("dataset", selected_dataset) + data = MinariExperienceReplay( + selected_dataset, batch_size=32, split_trajs=split + ) + t0 = time.time() + for i, sample in enumerate(data): + t1 = time.time() + print(f"sampling time {1000 * (t1-t0): 4.4f}ms") + assert data.metadata["action_space"].is_in(sample["action"]) + assert data.metadata["observation_space"].is_in(sample["observation"]) + t0 = time.time() + if i == 10: + break + + @pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found") @pytest.mark.parametrize( "dataset", diff --git a/test/test_modules.py b/test/test_modules.py index ee1884c5573..cdd8987022d 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -17,6 +17,10 @@ from torchrl.modules import ( CEMPlanner, DTActor, + GRU, + GRUCell, + LSTM, + LSTMCell, LSTMNet, MultiAgentConvNet, MultiAgentMLP, @@ -1186,6 +1190,209 @@ def test_onlinedtactor(self, batch_dims, T=5): assert (dtactor.log_std_max > sig.log()).all() +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("bias", [True, False]) +def test_python_lstm_cell(device, bias): + + lstm_cell1 = LSTMCell(10, 20, device=device, bias=bias) + lstm_cell2 = nn.LSTMCell(10, 20, device=device, bias=bias) + + lstm_cell1.load_state_dict(lstm_cell2.state_dict()) + + # Make sure parameters match + for (k1, v1), (k2, v2) in zip( + lstm_cell1.named_parameters(), lstm_cell2.named_parameters() + ): + assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + + # Run loop + input = torch.randn(2, 3, 10, device=device) + h0 = torch.randn(3, 20, device=device) + c0 = torch.randn(3, 20, device=device) + with torch.no_grad(): + for i in range(input.size()[0]): + h1, c1 = lstm_cell1(input[i], (h0, c0)) + h2, c2 = lstm_cell2(input[i], (h0, c0)) + + # Make sure the final hidden states have the same shape + assert h1.shape == h2.shape + assert c1.shape == c2.shape + torch.testing.assert_close(h1, h2) + torch.testing.assert_close(c1, c2) + h0 = h1 + c0 = c1 + + +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("bias", [True, False]) +def test_python_gru_cell(device, bias): + + gru_cell1 = GRUCell(10, 20, device=device, bias=bias) + gru_cell2 = nn.GRUCell(10, 20, device=device, bias=bias) + + gru_cell2.load_state_dict(gru_cell1.state_dict()) + + # Make sure parameters match + for (k1, v1), (k2, v2) in zip( + gru_cell1.named_parameters(), gru_cell2.named_parameters() + ): + assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" + assert (v1 == v2).all() + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + + # Run loop + input = torch.randn(2, 3, 10, device=device) + h0 = torch.zeros(3, 20, device=device) + with torch.no_grad(): + for i in range(input.size()[0]): + print(i) + h1 = gru_cell1(input[i], h0) + h2 = gru_cell2(input[i], h0) + + # Make sure the final hidden states have the same shape + assert h1.shape == h2.shape + torch.testing.assert_close(h1, h2) + h0 = h1 + + +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("batch_first", [True, False]) +@pytest.mark.parametrize("dropout", [0.0, 0.5]) +@pytest.mark.parametrize("num_layers", [1, 2]) +def test_python_lstm(device, bias, dropout, batch_first, num_layers): + B = 5 + T = 3 + lstm1 = LSTM( + input_size=10, + hidden_size=20, + num_layers=num_layers, + device=device, + bias=bias, + batch_first=batch_first, + ) + lstm2 = nn.LSTM( + input_size=10, + hidden_size=20, + num_layers=num_layers, + device=device, + bias=bias, + batch_first=batch_first, + ) + + lstm2.load_state_dict(lstm1.state_dict()) + + # Make sure parameters match + for (k1, v1), (k2, v2) in zip(lstm1.named_parameters(), lstm2.named_parameters()): + assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + + if batch_first: + input = torch.randn(B, T, 10, device=device) + else: + input = torch.randn(T, B, 10, device=device) + + h0 = torch.randn(num_layers, 5, 20, device=device) + c0 = torch.randn(num_layers, 5, 20, device=device) + + # Test without hidden states + with torch.no_grad(): + output1, (h1, c1) = lstm1(input) + output2, (h2, c2) = lstm2(input) + + assert h1.shape == h2.shape + assert c1.shape == c2.shape + assert output1.shape == output2.shape + if dropout == 0.0: + torch.testing.assert_close(output1, output2) + torch.testing.assert_close(h1, h2) + torch.testing.assert_close(c1, c2) + + # Test with hidden states + with torch.no_grad(): + output1, (h1, c1) = lstm1(input, (h0, c0)) + output2, (h2, c2) = lstm1(input, (h0, c0)) + + assert h1.shape == h2.shape + assert c1.shape == c2.shape + assert output1.shape == output2.shape + if dropout == 0.0: + torch.testing.assert_close(output1, output2) + torch.testing.assert_close(h1, h2) + torch.testing.assert_close(c1, c2) + + +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("batch_first", [True, False]) +@pytest.mark.parametrize("dropout", [0.0, 0.5]) +@pytest.mark.parametrize("num_layers", [1, 2]) +def test_python_gru(device, bias, dropout, batch_first, num_layers): + B = 5 + T = 3 + gru1 = GRU( + input_size=10, + hidden_size=20, + num_layers=num_layers, + device=device, + bias=bias, + batch_first=batch_first, + ) + gru2 = nn.GRU( + input_size=10, + hidden_size=20, + num_layers=num_layers, + device=device, + bias=bias, + batch_first=batch_first, + ) + gru2.load_state_dict(gru1.state_dict()) + + # Make sure parameters match + for (k1, v1), (k2, v2) in zip(gru1.named_parameters(), gru2.named_parameters()): + assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" + torch.testing.assert_close(v1, v2) + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + + if batch_first: + input = torch.randn(B, T, 10, device=device) + else: + input = torch.randn(T, B, 10, device=device) + + h0 = torch.randn(num_layers, 5, 20, device=device) + + # Test without hidden states + with torch.no_grad(): + output1, h1 = gru1(input) + output2, h2 = gru2(input) + + assert h1.shape == h2.shape + assert output1.shape == output2.shape + if dropout == 0.0: + torch.testing.assert_close(output1, output2) + torch.testing.assert_close(h1, h2) + + # Test with hidden states + with torch.no_grad(): + output1, h1 = gru1(input, h0) + output2, h2 = gru2(input, h0) + + assert h1.shape == h2.shape + assert output1.shape == output2.shape + if dropout == 0.0: + torch.testing.assert_close(output1, output2) + torch.testing.assert_close(h1, h2) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_rb.py b/test/test_rb.py index c68c623300b..f740e07e8ca 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -18,6 +18,7 @@ from packaging.version import parse from tensordict import is_tensorclass, tensorclass from tensordict.tensordict import assert_allclose_td, TensorDict, TensorDictBase +from torch import multiprocessing as mp from torchrl.data import ( PrioritizedReplayBuffer, RemoteTensorDictReplayBuffer, @@ -41,6 +42,7 @@ from torchrl.data.replay_buffers.writers import ( RoundRobinWriter, TensorDictMaxValueWriter, + TensorDictRoundRobinWriter, ) from torchrl.envs.transforms.transforms import ( BinarizeReward, @@ -80,7 +82,9 @@ @pytest.mark.parametrize( "sampler", [samplers.RandomSampler, samplers.PrioritizedSampler] ) -@pytest.mark.parametrize("writer", [writers.RoundRobinWriter]) +@pytest.mark.parametrize( + "writer", [writers.RoundRobinWriter, writers.TensorDictMaxValueWriter] +) @pytest.mark.parametrize("storage", [ListStorage, LazyTensorStorage, LazyMemmapStorage]) @pytest.mark.parametrize("size", [3, 5, 100]) class TestComposableBuffers: @@ -104,7 +108,9 @@ def _get_datum(self, rb_type): elif ( rb_type is TensorDictReplayBuffer or rb_type is RemoteTensorDictReplayBuffer ): - data = TensorDict({"a": torch.randint(100, (1,))}, []) + data = TensorDict( + {"a": torch.randint(100, (1,)), "next": {"reward": torch.randn(1)}}, [] + ) else: raise NotImplementedError(rb_type) return data @@ -119,6 +125,7 @@ def _get_data(self, rb_type, size): { "a": torch.randint(100, (size,)), "b": TensorDict({"c": torch.randint(100, (size,))}, [size]), + "next": {"reward": torch.randn(size, 1)}, }, [size], ) @@ -136,6 +143,12 @@ def test_add(self, rb_type, sampler, writer, storage, size): rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) data = self._get_datum(rb_type) + if isinstance(data, torch.Tensor) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.add(data) + return rb.add(data) s = rb.sample(1) assert s.ndim, s @@ -153,7 +166,22 @@ def test_cursor_position(self, rb_type, sampler, writer, storage, size): writer = writer() writer.register_storage(storage) batch1 = self._get_data(rb_type, size=5) - cond = OLD_TORCH and size < len(batch1) and isinstance(storage, TensorStorage) + cond = ( + OLD_TORCH + and not isinstance(writer, TensorDictMaxValueWriter) + and size < len(batch1) + and isinstance(storage, TensorStorage) + ) + + if isinstance(batch1, torch.Tensor) and isinstance( + writer, TensorDictMaxValueWriter + ): + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + writer.extend(batch1) + return + with pytest.warns( UserWarning, match="A cursor of length superior to the storage capacity was provided", @@ -165,13 +193,19 @@ def test_cursor_position(self, rb_type, sampler, writer, storage, size): assert writer._cursor == 5 # Added more data than storage max size elif size < 5: - assert writer._cursor == 5 - size + # if Max writer, we don't necessarily overwrite existing values so + # we just check that the cursor is before the threshold + if isinstance(writer, TensorDictMaxValueWriter): + assert writer._cursor <= 5 - size + else: + assert writer._cursor == 5 - size # Added as data as storage max size else: assert writer._cursor == 0 - batch2 = self._get_data(rb_type, size=size - 1) - writer.extend(batch2) - assert writer._cursor == size - 1 + if not isinstance(writer, TensorDictMaxValueWriter): + batch2 = self._get_data(rb_type, size=size - 1) + writer.extend(batch2) + assert writer._cursor == size - 1 def test_extend(self, rb_type, sampler, writer, storage, size): if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: @@ -183,7 +217,21 @@ def test_extend(self, rb_type, sampler, writer, storage, size): rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) data = self._get_data(rb_type, size=5) - cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) + cond = ( + OLD_TORCH + and writer is not TensorDictMaxValueWriter + and size < len(data) + and isinstance(rb._storage, TensorStorage) + ) + if isinstance(data, torch.Tensor) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.extend(data) + return + length = min(rb._storage.max_size, len(rb) + data.shape[0]) + if writer is TensorDictMaxValueWriter: + data["next", "reward"][-length:] = 1_000_000 with pytest.warns( UserWarning, match="A cursor of length superior to the storage capacity was provided", @@ -207,7 +255,10 @@ def test_extend(self, rb_type, sampler, writer, storage, size): raise RuntimeError("did not find match") data2 = self._get_data(rb_type, size=2 * size + 2) cond = ( - OLD_TORCH and size < len(data2) and isinstance(rb._storage, TensorStorage) + OLD_TORCH + and writer is not TensorDictMaxValueWriter + and size < len(data2) + and isinstance(rb._storage, TensorStorage) ) with pytest.warns( UserWarning, @@ -225,7 +276,18 @@ def test_sample(self, rb_type, sampler, writer, storage, size): rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) data = self._get_data(rb_type, size=5) - cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) + cond = ( + OLD_TORCH + and writer is not TensorDictMaxValueWriter + and size < len(data) + and isinstance(rb._storage, TensorStorage) + ) + if isinstance(data, torch.Tensor) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.extend(data) + return with pytest.warns( UserWarning, match="A cursor of length superior to the storage capacity was provided", @@ -261,7 +323,18 @@ def test_index(self, rb_type, sampler, writer, storage, size): rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) data = self._get_data(rb_type, size=5) - cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) + cond = ( + OLD_TORCH + and writer is not TensorDictMaxValueWriter + and size < len(data) + and isinstance(rb._storage, TensorStorage) + ) + if isinstance(data, torch.Tensor) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.extend(data) + return with pytest.warns( UserWarning, match="A cursor of length superior to the storage capacity was provided", @@ -388,11 +461,124 @@ class TC: with pytest.warns( DeprecationWarning, match="Support for Memmap device other than CPU" ): + # this is rather brittle and will fail with some indices + # when both device (storage and data) don't match (eg, range()) storage.set(0, data) else: storage.set(0, data) assert storage.get(0).device.type == device_storage.type + @pytest.mark.parametrize("storage_in", ["tensor", "memmap"]) + @pytest.mark.parametrize("storage_out", ["tensor", "memmap"]) + @pytest.mark.parametrize("init_out", [True, False]) + def test_storage_state_dict(self, storage_in, storage_out, init_out): + buffer_size = 100 + if storage_in == "memmap": + storage_in = LazyMemmapStorage(buffer_size, device="cpu") + elif storage_in == "tensor": + storage_in = LazyTensorStorage(buffer_size, device="cpu") + if storage_out == "memmap": + storage_out = LazyMemmapStorage(buffer_size, device="cpu") + elif storage_out == "tensor": + storage_out = LazyTensorStorage(buffer_size, device="cpu") + + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, prefetch=3, storage=storage_in, batch_size=3 + ) + # fill replay buffer with random data + transition = TensorDict( + { + "observation": torch.ones(1, 4), + "action": torch.ones(1, 2), + "reward": torch.ones(1, 1), + "dones": torch.ones(1, 1), + "next": {"observation": torch.ones(1, 4)}, + }, + batch_size=1, + ) + for _ in range(3): + replay_buffer.extend(transition) + + state_dict = replay_buffer.state_dict() + + new_replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=3, + storage=storage_out, + batch_size=state_dict["_batch_size"], + ) + if init_out: + new_replay_buffer.extend(transition) + + new_replay_buffer.load_state_dict(state_dict) + s = new_replay_buffer.sample() + assert (s.exclude("index") == 1).all() + + @pytest.mark.parametrize("device_data", get_default_devices()) + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + @pytest.mark.parametrize("data_type", ["tensor", "tc", "td"]) + @pytest.mark.parametrize("isinit", [True, False]) + def test_storage_dumps_loads( + self, device_data, storage_type, data_type, isinit, tmpdir + ): + dir_rb = tmpdir / "rb" + dir_save = tmpdir / "save" + dir_rb.mkdir() + dir_save.mkdir() + torch.manual_seed(0) + + @tensorclass + class TC: + tensor: torch.Tensor + td: TensorDict + text: str + + if data_type == "tensor": + data = torch.randint(10, (3,), device=device_data) + elif data_type == "td": + data = TensorDict( + { + "a": torch.randint(10, (3,), device=device_data), + "b": TensorDict( + {"c": torch.randint(10, (3,), device=device_data)}, + batch_size=[3], + ), + }, + batch_size=[3], + device=device_data, + ) + elif data_type == "tc": + data = TC( + tensor=torch.randint(10, (3,), device=device_data), + td=TensorDict( + {"c": torch.randint(10, (3,), device=device_data)}, batch_size=[3] + ), + text="some text", + batch_size=[3], + device=device_data, + ) + else: + raise NotImplementedError + if storage_type in (LazyMemmapStorage,): + storage = storage_type(max_size=10, scratch_dir=dir_rb) + else: + storage = storage_type(max_size=10) + # We cast the device to CPU as CUDA isn't automatically cast to CPU when using range() index + storage.set(range(3), data.cpu()) + storage.dumps(dir_save) + # check we can dump twice + storage.dumps(dir_save) + storage_recover = storage_type(max_size=10) + if isinit: + storage_recover.set(range(3), data.cpu().zero_()) + storage_recover.loads(dir_save) + if data_type == "tensor": + torch.testing.assert_close(storage._storage, storage_recover._storage) + else: + assert_allclose_td(storage._storage, storage_recover._storage) + if data == "tc": + assert storage._storage.text == storage_recover._storage.text + @pytest.mark.parametrize("max_size", [1000]) @pytest.mark.parametrize("shape", [[3, 4]]) @@ -486,7 +672,8 @@ def test_prototype_prb(priority_key, contiguous, device): "_idx": torch.arange(3).view(3, 1), }, batch_size=[3], - ).to(device) + device=device, + ) rb.extend(td1) s = rb.sample() assert s.batch_size == torch.Size([5]) @@ -501,7 +688,8 @@ def test_prototype_prb(priority_key, contiguous, device): "_idx": torch.arange(5).view(5, 1), }, batch_size=[5], - ).to(device) + device=device, + ) rb.extend(td2) s = rb.sample() assert s.batch_size == torch.Size([5]) @@ -795,6 +983,14 @@ def test_index(self, rbtype, storage, size, prefetch): b = b.all() assert b + def test_index_nonfull(self, rbtype, storage, size, prefetch): + # checks that indexing the buffer before it's full gives the accurate view of the data + rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) + data = self._get_data(rbtype, size=size - 1) + rb.extend(data) + assert len(rb[: size - 1]) == size - 1 + assert len(rb[size - 2 :]) == 1 + def test_multi_loops(): """Tests that one can iterate multiple times over a buffer without rep.""" @@ -1164,125 +1360,188 @@ def test_replay_buffer_iter(size, drop_last): assert i == (size - 1) // 3 -class TestStateDict: - @pytest.mark.parametrize("storage_in", ["tensor", "memmap"]) - @pytest.mark.parametrize("storage_out", ["tensor", "memmap"]) - @pytest.mark.parametrize("init_out", [True, False]) - def test_load_state_dict(self, storage_in, storage_out, init_out): - buffer_size = 100 - if storage_in == "memmap": - storage_in = LazyMemmapStorage(buffer_size, device="cpu") - elif storage_in == "tensor": - storage_in = LazyTensorStorage(buffer_size, device="cpu") - if storage_out == "memmap": - storage_out = LazyMemmapStorage(buffer_size, device="cpu") - elif storage_out == "tensor": - storage_out = LazyTensorStorage(buffer_size, device="cpu") - - replay_buffer = TensorDictReplayBuffer( - pin_memory=False, prefetch=3, storage=storage_in, batch_size=3 +@pytest.mark.parametrize("size", [20, 25, 30]) +@pytest.mark.parametrize("batch_size", [1, 10, 15]) +@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)]) +@pytest.mark.parametrize("device", get_default_devices()) +class TestMaxValueWriter: + def test_max_value_writer(self, size, batch_size, reward_ranges, device): + torch.manual_seed(0) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(size, device=device), + sampler=SamplerWithoutReplacement(), + batch_size=batch_size, + writer=TensorDictMaxValueWriter(rank_key="key"), ) - # fill replay buffer with random data - transition = TensorDict( + + max_reward1, max_reward2, max_reward3 = reward_ranges + + td = TensorDict( { - "observation": torch.ones(1, 4), - "action": torch.ones(1, 2), - "reward": torch.ones(1, 1), - "dones": torch.ones(1, 1), - "next": {"observation": torch.ones(1, 4)}, + "key": torch.clamp_max(torch.rand(size), max=max_reward1), + "obs": torch.rand(size), }, - batch_size=1, + batch_size=size, + device=device, ) - for _ in range(3): - replay_buffer.extend(transition) + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") <= max_reward1).all() + assert (0 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) - state_dict = replay_buffer.state_dict() - - new_replay_buffer = TensorDictReplayBuffer( - pin_memory=False, - prefetch=3, - storage=storage_out, - batch_size=state_dict["_batch_size"], + td = TensorDict( + { + "key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, ) - if init_out: - new_replay_buffer.extend(transition) + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") <= max_reward2).all() + assert (max_reward1 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) - new_replay_buffer.load_state_dict(state_dict) - s = new_replay_buffer.sample() - assert (s.exclude("index") == 1).all() + td = TensorDict( + { + "key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, + ) + for sample in td: + rb.add(sample) -@pytest.mark.parametrize("size", [20, 25, 30]) -@pytest.mark.parametrize("batch_size", [1, 10, 15]) -@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)]) -@pytest.mark.parametrize("device", get_default_devices()) -def test_max_value_writer(size, batch_size, reward_ranges, device): - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(size, device=device), - sampler=SamplerWithoutReplacement(), - batch_size=batch_size, - writer=TensorDictMaxValueWriter(rank_key="key"), - ) + sample = rb.sample() + assert (sample.get("key") <= max_reward3).all() + assert (max_reward2 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) - max_reward1, max_reward2, max_reward3 = reward_ranges + # Finally, test the case when no obs should be added + td = TensorDict( + { + "key": torch.zeros(size), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, + ) + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") != 0).all() - td = TensorDict( - { - "key": torch.clamp_max(torch.rand(size), max=max_reward1), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) - rb.extend(td) - sample = rb.sample() - assert (sample.get("key") <= max_reward1).all() - assert (0 <= sample.get("key")).all() - assert len(sample.get("index").unique()) == len(sample.get("index")) + def test_max_value_writer_serialize( + self, size, batch_size, reward_ranges, device, tmpdir + ): + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(size, device=device), + sampler=SamplerWithoutReplacement(), + batch_size=batch_size, + writer=TensorDictMaxValueWriter(rank_key="key"), + ) - td = TensorDict( - { - "key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) - rb.extend(td) - sample = rb.sample() - assert (sample.get("key") <= max_reward2).all() - assert (max_reward1 <= sample.get("key")).all() - assert len(sample.get("index").unique()) == len(sample.get("index")) + max_reward1, max_reward2, max_reward3 = reward_ranges - td = TensorDict( - { - "key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) + td = TensorDict( + { + "key": torch.clamp_max(torch.rand(size), max=max_reward1), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, + ) + rb.extend(td) + rb._writer.dumps(tmpdir) + # check we can dump twice + rb._writer.dumps(tmpdir) + other = TensorDictMaxValueWriter(rank_key="key") + other.loads(tmpdir) + assert len(rb._writer._current_top_values) == len(other._current_top_values) + torch.testing.assert_close( + torch.tensor(rb._writer._current_top_values), + torch.tensor(other._current_top_values), + ) - for sample in td: - rb.add(sample) - sample = rb.sample() - assert (sample.get("key") <= max_reward3).all() - assert (max_reward2 <= sample.get("key")).all() - assert len(sample.get("index").unique()) == len(sample.get("index")) +class TestMultiProc: + @staticmethod + def worker(rb, q0, q1): + td = TensorDict({"a": torch.ones(10), "next": {"reward": torch.ones(10)}}, [10]) + rb.extend(td) + q0.put("extended") + extended = q1.get(timeout=5) + assert extended == "extended" + assert len(rb) == 21, len(rb) + assert (rb["_data", "a"][:9] == 2).all() + q0.put("finish") + + def exec_multiproc_rb( + self, + storage_type=LazyMemmapStorage, + init=True, + writer_type=TensorDictRoundRobinWriter, + sampler_type=RandomSampler, + ): + rb = TensorDictReplayBuffer( + storage=storage_type(21), writer=writer_type(), sampler=sampler_type() + ) + if init: + td = TensorDict( + {"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, [10] + ) + rb.extend(td) + q0 = mp.Queue(1) + q1 = mp.Queue(1) + proc = mp.Process(target=self.worker, args=(rb, q0, q1)) + proc.start() + try: + extended = q0.get(timeout=100) + assert extended == "extended" + assert len(rb) == 20 + assert (rb["_data", "a"][10:20] == 1).all() + td = TensorDict({"a": torch.zeros(10) + 2}, [10]) + rb.extend(td) + q1.put("extended") + finish = q0.get(timeout=5) + assert finish == "finish" + finally: + proc.join() + + def test_multiproc_rb(self): + return self.exec_multiproc_rb() + + def test_error_list(self): + # list storage cannot be shared + with pytest.raises(RuntimeError, match="Cannot share a storage of type"): + self.exec_multiproc_rb(storage_type=ListStorage) + + def test_error_nonshared(self): + # non shared tensor storage cannot be shared + with pytest.raises( + RuntimeError, match="The storage must be place in shared memory" + ): + self.exec_multiproc_rb(storage_type=LazyTensorStorage) + + def test_error_maxwriter(self): + # TensorDictMaxValueWriter cannot be shared + with pytest.raises(RuntimeError, match="cannot be shared between processes"): + self.exec_multiproc_rb(writer_type=TensorDictMaxValueWriter) + + def test_error_prb(self): + # PrioritizedSampler cannot be shared + with pytest.raises(RuntimeError, match="cannot be shared between processes"): + self.exec_multiproc_rb( + sampler_type=lambda: PrioritizedSampler(21, alpha=1.1, beta=0.5) + ) - # Finally, test the case when no obs should be added - td = TensorDict( - { - "key": torch.zeros(size), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) - rb.extend(td) - sample = rb.sample() - assert (sample.get("key") != 0).all() + def test_error_noninit(self): + # list storage cannot be shared + with pytest.raises(RuntimeError, match="it has not been initialized yet"): + self.exec_multiproc_rb(init=False) if __name__ == "__main__": diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index ea7b204076e..8d5920680ac 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -1672,7 +1672,8 @@ def test_noncontiguous(self): lstm_module(padded) @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) - def test_singel_step(self, shape): + @pytest.mark.parametrize("python_based", [True, False]) + def test_single_step(self, shape, python_based): td = TensorDict( { "observation": torch.zeros(*shape, 3), @@ -1686,6 +1687,7 @@ def test_singel_step(self, shape): batch_first=True, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], + python_based=python_based, ) td = lstm_module(td) td_next = step_mdp(td, keep_other=True) @@ -1697,7 +1699,8 @@ def test_singel_step(self, shape): @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) @pytest.mark.parametrize("t", [1, 10]) - def test_single_step_vs_multi(self, shape, t): + @pytest.mark.parametrize("python_based", [True, False]) + def test_single_step_vs_multi(self, shape, t, python_based): td = TensorDict( { "observation": torch.arange(t, dtype=torch.float32) @@ -1713,6 +1716,7 @@ def test_single_step_vs_multi(self, shape, t): batch_first=True, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], + python_based=python_based, ) lstm_module_ms = lstm_module_ss.set_recurrent_mode() lstm_module_ms(td) @@ -1732,7 +1736,8 @@ def test_single_step_vs_multi(self, shape, t): ) @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) - def test_multi_consecutive(self, shape): + @pytest.mark.parametrize("python_based", [False, True]) + def test_multi_consecutive(self, shape, python_based): t = 20 td = TensorDict( { @@ -1754,6 +1759,7 @@ def test_multi_consecutive(self, shape): batch_first=True, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], + python_based=python_based, ) lstm_module_ms = lstm_module_ss.set_recurrent_mode() lstm_module_ms(td) @@ -1769,11 +1775,13 @@ def test_multi_consecutive(self, shape): lstm_module_ss(td_ss) td_ss = step_mdp(td_ss, keep_other=True) td_ss["observation"][:] = _t + 1 + # import ipdb; ipdb.set_trace() # assert fails when python_based is True, why? torch.testing.assert_close( td_ss["intermediate"], td["intermediate"][..., -1, :] ) - def test_lstm_parallel_env(self): + @pytest.mark.parametrize("python_based", [True, False]) + def test_lstm_parallel_env(self, python_based): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv device = "cuda" if torch.cuda.device_count() else "cpu" @@ -1785,6 +1793,7 @@ def test_lstm_parallel_env(self): in_key="observation", out_key="features", device=device, + python_based=python_based, ) def create_transformed_env(): @@ -1938,7 +1947,8 @@ def test_noncontiguous(self): gru_module(padded) @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) - def test_singel_step(self, shape): + @pytest.mark.parametrize("python_based", [True, False]) + def test_single_step(self, shape, python_based): td = TensorDict( { "observation": torch.zeros(*shape, 3), @@ -1952,6 +1962,7 @@ def test_singel_step(self, shape): batch_first=True, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")], + python_based=python_based, ) td = gru_module(td) td_next = step_mdp(td, keep_other=True) @@ -1961,7 +1972,8 @@ def test_singel_step(self, shape): @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) @pytest.mark.parametrize("t", [1, 10]) - def test_single_step_vs_multi(self, shape, t): + @pytest.mark.parametrize("python_based", [True, False]) + def test_single_step_vs_multi(self, shape, t, python_based): td = TensorDict( { "observation": torch.arange(t, dtype=torch.float32) @@ -1977,6 +1989,7 @@ def test_single_step_vs_multi(self, shape, t): batch_first=True, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")], + python_based=python_based, ) gru_module_ms = gru_module_ss.set_recurrent_mode() gru_module_ms(td) @@ -1994,7 +2007,8 @@ def test_single_step_vs_multi(self, shape, t): torch.testing.assert_close(td_ss["hidden"], td["next", "hidden"][..., -1, :, :]) @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) - def test_multi_consecutive(self, shape): + @pytest.mark.parametrize("python_based", [True, False]) + def test_multi_consecutive(self, shape, python_based): t = 20 td = TensorDict( { @@ -2016,6 +2030,7 @@ def test_multi_consecutive(self, shape): batch_first=True, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")], + python_based=python_based, ) gru_module_ms = gru_module_ss.set_recurrent_mode() gru_module_ms(td) @@ -2035,7 +2050,8 @@ def test_multi_consecutive(self, shape): td_ss["intermediate"], td["intermediate"][..., -1, :] ) - def test_gru_parallel_env(self): + @pytest.mark.parametrize("python_based", [True, False]) + def test_gru_parallel_env(self, python_based): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv device = "cuda" if torch.cuda.device_count() else "cpu" @@ -2047,6 +2063,7 @@ def test_gru_parallel_env(self): in_key="observation", out_key="features", device=device, + python_based=python_based, ) def create_transformed_env(): diff --git a/test/test_transforms.py b/test/test_transforms.py index da8bc12c126..cff1d33b34a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -9,6 +9,7 @@ import itertools import pickle +import re import sys from copy import copy from functools import partial @@ -4878,6 +4879,39 @@ def test_sum_reward(self, keys, device): def test_transform_inverse(self): raise pytest.skip("No inverse for RewardSum") + @pytest.mark.parametrize("in_keys", [["reward"], ["reward_1", "reward_2"]]) + @pytest.mark.parametrize( + "out_keys", [["episode_reward"], ["episode_reward_1", "episode_reward_2"]] + ) + @pytest.mark.parametrize("reset_keys", [["_reset"], ["_reset1", "_reset2"]]) + def test_keys_length_errors(self, in_keys, reset_keys, out_keys, batch=10): + reset_dict = { + reset_key: torch.zeros(batch, dtype=torch.bool) for reset_key in reset_keys + } + reward_sum_dict = {out_key: torch.randn(batch) for out_key in out_keys} + reset_dict.update(reward_sum_dict) + td = TensorDict(reset_dict, []) + + if len(in_keys) != len(out_keys): + with pytest.raises( + ValueError, + match="RewardSum expects the same number of input and output keys", + ): + RewardSum(in_keys=in_keys, reset_keys=reset_keys, out_keys=out_keys) + else: + t = RewardSum(in_keys=in_keys, reset_keys=reset_keys, out_keys=out_keys) + + if len(in_keys) != len(reset_keys): + with pytest.raises( + ValueError, + match=re.escape( + f"Could not match the env reset_keys {reset_keys} with the in_keys {in_keys}" + ), + ): + t.reset(td) + else: + t.reset(td) + class TestReward2Go(TransformBase): @pytest.mark.parametrize("device", get_default_devices()) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 0df292f7b93..3990975c20d 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -280,11 +280,10 @@ def _get_policy_and_device( device = torch.device(device) if device is not None else policy_device get_weights_fn = None if policy_device != device: - param_and_buf = dict(policy.named_parameters()) - param_and_buf.update(dict(policy.named_buffers())) + param_and_buf = TensorDict.from_module(policy, as_module=True) def get_weights_fn(param_and_buf=param_and_buf): - return TensorDict(param_and_buf, []).apply(lambda x: x.data) + return param_and_buf.data policy_cast = deepcopy(policy).requires_grad_(False).to(device) # here things may break bc policy.to("cuda") gives us weights on cuda:0 (same @@ -308,9 +307,9 @@ def update_policy_weights_( """ if policy_weights is not None: - self.policy_weights.apply(lambda x: x.data).update_(policy_weights) + self.policy_weights.data.update_(policy_weights) elif self.get_weights_fn is not None: - self.policy_weights.apply(lambda x: x.data).update_(self.get_weights_fn()) + self.policy_weights.data.update_(self.get_weights_fn()) def __iter__(self) -> Iterator[TensorDictBase]: return self.iterator() @@ -559,10 +558,7 @@ def __init__( ) if isinstance(self.policy, nn.Module): - self.policy_weights = TensorDict(dict(self.policy.named_parameters()), []) - self.policy_weights.update( - TensorDict(dict(self.policy.named_buffers()), []) - ) + self.policy_weights = TensorDict.from_module(self.policy, as_module=True) else: self.policy_weights = TensorDict({}, []) @@ -1200,9 +1196,9 @@ def device_err_msg(device_name, devices_list): ) self._policy_dict[_device] = _policy if isinstance(_policy, nn.Module): - param_dict = dict(_policy.named_parameters()) - param_dict.update(_policy.named_buffers()) - self._policy_weights_dict[_device] = TensorDict(param_dict, []) + self._policy_weights_dict[_device] = TensorDict.from_module( + _policy, as_module=True + ) else: self._policy_weights_dict[_device] = TensorDict({}, []) @@ -1288,11 +1284,9 @@ def frames_per_batch_worker(self): def update_policy_weights_(self, policy_weights=None) -> None: for _device in self._policy_dict: if policy_weights is not None: - self._policy_weights_dict[_device].apply(lambda x: x.data).update_( - policy_weights - ) + self._policy_weights_dict[_device].data.update_(policy_weights) elif self._get_weights_fn_dict[_device] is not None: - self._policy_weights_dict[_device].update_( + self._policy_weights_dict[_device].data.update_( self._get_weights_fn_dict[_device]() ) diff --git a/torchrl/data/datasets/__init__.py b/torchrl/data/datasets/__init__.py index 81a668648d0..85b8e064917 100644 --- a/torchrl/data/datasets/__init__.py +++ b/torchrl/data/datasets/__init__.py @@ -1,2 +1,3 @@ from .d4rl import D4RLExperienceReplay +from .minari_data import MinariExperienceReplay from .openml import OpenMLExperienceReplay diff --git a/torchrl/data/datasets/d4rl.py b/torchrl/data/datasets/d4rl.py index b5fd63696a3..38fce4a6b7c 100644 --- a/torchrl/data/datasets/d4rl.py +++ b/torchrl/data/datasets/d4rl.py @@ -6,22 +6,27 @@ import importlib import os +import tempfile import urllib import warnings + +from pathlib import Path from typing import Callable import numpy as np import torch -from tensordict import PersistentTensorDict +from tensordict import PersistentTensorDict, TensorDict from tensordict.tensordict import make_tensordict from torchrl.collectors.utils import split_trajectories from torchrl.data.datasets.d4rl_infos import D4RL_DATASETS + +from torchrl.data.datasets.utils import _get_root_dir from torchrl.data.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler -from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.data.replay_buffers.storages import LazyMemmapStorage, TensorStorage from torchrl.data.replay_buffers.writers import Writer @@ -93,7 +98,14 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): Otherwise, only the ``terminated`` key is used. Defaults to ``True``. terminate_on_end (bool, optional): Set ``done=True`` on the last timestep in a trajectory. Default is ``False``, and will discard the - last timestep in each trajectory. + last timestep in each trajectory. This is to be used only with + ``direct_download=False``. + root (Path or str, optional): The D4RL dataset root directory. + The actual dataset memory-mapped files will be saved under + `/`. If none is provided, it defaults to + ``~/.cache/torchrl/d4rl`. + download (bool, optional): Whether the dataset should be downloaded if + not found. Defaults to ``True``. **env_kwargs (key-value pairs): additional kwargs for :func:`d4rl.qlearning_dataset`. @@ -103,7 +115,7 @@ class D4RLExperienceReplay(TensorDictReplayBuffer): >>> from torchrl.envs import ObservationNorm >>> data = D4RLExperienceReplay("maze2d-umaze-v1", 128) >>> # we can append transforms to the dataset - >>> data.append_transform(ObservationNorm(loc=-1, scale=1.0)) + >>> data.append_transform(ObservationNorm(loc=-1, scale=1.0, in_keys=["observation"])) >>> data.sample(128) """ @@ -136,9 +148,16 @@ def __init__( use_truncated_as_done: bool = True, direct_download: bool = None, terminate_on_end: bool = None, + download: bool = True, + root: str | Path | None = None, **env_kwargs, ): self.use_truncated_as_done = use_truncated_as_done + if root is None: + root = _get_root_dir("d4rl") + self.root = root + self.name = name + dataset = None if not from_env and direct_download is None: self._import_d4rl() @@ -155,44 +174,66 @@ def __init__( category=DeprecationWarning, ) from_env = True - self.from_env = from_env - if terminate_on_end is None: - # we use the default of d4rl - terminate_on_end = False - self._import_d4rl() - - if not self._has_d4rl: - raise ImportError("Could not import d4rl") from self.D4RL_ERR - - if from_env: - dataset = self._get_dataset_from_env(name, env_kwargs) else: - if self.use_truncated_as_done: - warnings.warn( - "Using use_truncated_as_done=True + terminate_on_end=True " - "with from_env=False may not have the intended effect " - "as the timeouts (truncation) " - "can be absent from the static dataset." - ) - env_kwargs.update({"terminate_on_end": terminate_on_end}) - dataset = self._get_dataset_direct(name, env_kwargs) + warnings.warn( + "You are using the D4RL library for collecting data. " + "We advise against this use, as D4RL formatting can be " + "inconsistent. " + "To download the D4RL data without the D4RL library, use " + "direct_download=True in the dataset constructor. " + "Recurring to `direct_download=False` will soon be deprecated." + ) + self.from_env = from_env else: if from_env is None: from_env = False self.from_env = from_env - if terminate_on_end is False: - raise ValueError( - "Using terminate_on_end=False is not compatible with direct_download=True." - ) - dataset = self._get_dataset_direct_download(name, env_kwargs) - # Fill unknown next states with 0 - dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0 - if split_trajs: - dataset = split_trajectories(dataset) - dataset["next", "done"][:, -1] = True + if (download == "force") or (download and not self._is_downloaded()): + if not direct_download: + if terminate_on_end is None: + # we use the default of d4rl + terminate_on_end = False + self._import_d4rl() + + if not self._has_d4rl: + raise ImportError("Could not import d4rl") from self.D4RL_ERR + + if from_env: + dataset = self._get_dataset_from_env(name, env_kwargs) + else: + if self.use_truncated_as_done: + warnings.warn( + "Using use_truncated_as_done=True + terminate_on_end=True " + "with from_env=False may not have the intended effect " + "as the timeouts (truncation) " + "can be absent from the static dataset." + ) + env_kwargs.update({"terminate_on_end": terminate_on_end}) + dataset = self._get_dataset_direct(name, env_kwargs) + else: + if terminate_on_end is False: + raise ValueError( + "Using terminate_on_end=False is not compatible with direct_download=True." + ) + dataset = self._get_dataset_direct_download(name, env_kwargs) + # Fill unknown next states with 0 + dataset["next", "observation"][dataset["next", "done"].squeeze()] = 0 + + if split_trajs: + dataset = split_trajectories(dataset) + dataset["next", "done"][:, -1] = True + + storage = LazyMemmapStorage( + dataset.shape[0], scratch_dir=Path(self.root) / name + ) + elif self._is_downloaded(): + storage = TensorStorage(TensorDict.load_memmap(Path(self.root) / name)) + else: + raise RuntimeError( + f"The dataset could not be found in {Path(self.root) / name}." + ) - storage = LazyMemmapStorage(dataset.shape[0]) super().__init__( batch_size=batch_size, storage=storage, @@ -203,7 +244,12 @@ def __init__( prefetch=prefetch, transform=transform, ) - self.extend(dataset) + if dataset is not None: + # if dataset has just been downloaded + self.extend(dataset) + + def _is_downloaded(self): + return os.path.exists(Path(self.root) / self.name) def _get_dataset_direct_download(self, name, env_kwargs): """Directly download and use a D4RL dataset.""" @@ -214,10 +260,12 @@ def _get_dataset_direct_download(self, name, env_kwargs): url = D4RL_DATASETS.get(name, None) if url is None: raise KeyError(f"Env {name} not found.") - h5path = _download_dataset_from_url(url) - # h5path_parent = Path(h5path).parent - dataset = PersistentTensorDict.from_h5(h5path) - dataset = dataset.to_tensordict() + with tempfile.TemporaryDirectory() as tmpdir: + os.environ["D4RL_DATASET_DIR"] = tmpdir + h5path = _download_dataset_from_url(url, tmpdir) + # h5path_parent = Path(h5path).parent + dataset = PersistentTensorDict.from_h5(h5path) + dataset = dataset.to_tensordict() with dataset.unlock_(): dataset = self._process_data_from_env(dataset) return dataset @@ -233,15 +281,17 @@ def _get_dataset_direct(self, name, env_kwargs): import gym env = GymWrapper(gym.make(name)) - dataset = d4rl.qlearning_dataset(env._env, **env_kwargs) - - dataset = make_tensordict( - { - k: torch.from_numpy(item) - for k, item in dataset.items() - if isinstance(item, np.ndarray) - } - ) + with tempfile.TemporaryDirectory() as tmpdir: + os.environ["D4RL_DATASET_DIR"] = tmpdir + dataset = d4rl.qlearning_dataset(env._env, **env_kwargs) + + dataset = make_tensordict( + { + k: torch.from_numpy(item) + for k, item in dataset.items() + if isinstance(item, np.ndarray) + } + ) dataset = dataset.unflatten_keys("/") if "metadata" in dataset.keys(): metadata = dataset.get("metadata") @@ -302,14 +352,16 @@ def _get_dataset_from_env(self, name, env_kwargs): # we do a local import to avoid circular import issues from torchrl.envs.libs.gym import GymWrapper - env = GymWrapper(gym.make(name)) - dataset = make_tensordict( - { - k: torch.from_numpy(item) - for k, item in env.get_dataset().items() - if isinstance(item, np.ndarray) - } - ) + with tempfile.TemporaryDirectory() as tmpdir: + os.environ["D4RL_DATASET_DIR"] = tmpdir + env = GymWrapper(gym.make(name)) + dataset = make_tensordict( + { + k: torch.from_numpy(item) + for k, item in env.get_dataset().items() + if isinstance(item, np.ndarray) + } + ) dataset = dataset.unflatten_keys("/") dataset = self._process_data_from_env(dataset, env) return dataset @@ -394,8 +446,8 @@ def _shift_reward_done(self, dataset): dataset[key][0] = 0 -def _download_dataset_from_url(dataset_url): - dataset_filepath = _filepath_from_url(dataset_url) +def _download_dataset_from_url(dataset_url, dataset_path): + dataset_filepath = _filepath_from_url(dataset_url, dataset_path) if not os.path.exists(dataset_filepath): print("Downloading dataset:", dataset_url, "to", dataset_filepath) urllib.request.urlretrieve(dataset_url, dataset_filepath) @@ -404,23 +456,20 @@ def _download_dataset_from_url(dataset_url): return dataset_filepath -def _filepath_from_url(dataset_url): +def _filepath_from_url(dataset_url, dataset_path): _, dataset_name = os.path.split(dataset_url) - dataset_filepath = os.path.join(DATASET_PATH, dataset_name) + dataset_filepath = os.path.join(dataset_path, dataset_name) return dataset_filepath -def _set_dataset_path(path): - global DATASET_PATH - DATASET_PATH = path - os.makedirs(path, exist_ok=True) - - -_set_dataset_path( - os.environ.get( - "D4RL_DATASET_DIR", os.path.expanduser("~/.cache/torchrl/data/d4rl/datasets") - ) -) +# def _set_dataset_path(path): +# global DATASET_PATH +# DATASET_PATH = path +# os.makedirs(path, exist_ok=True) +# +# +# _set_dataset_path( +# os.environ.get(_get_root_dir("d4rl"))) if __name__ == "__main__": data = D4RLExperienceReplay("kitchen-partial-v0", batch_size=128) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py new file mode 100644 index 00000000000..492ac0fff58 --- /dev/null +++ b/torchrl/data/datasets/minari_data.py @@ -0,0 +1,455 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import importlib.util +import json +import os.path +import shutil +import tempfile + +from collections import defaultdict +from contextlib import nullcontext +from dataclasses import asdict +from pathlib import Path +from typing import Callable + +import torch + +from tensordict import PersistentTensorDict, TensorDict +from torchrl._utils import KeyDependentDefaultDict +from torchrl.data.datasets.utils import _get_root_dir +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import Sampler +from torchrl.data.replay_buffers.storages import TensorStorage +from torchrl.data.replay_buffers.writers import Writer +from torchrl.data.tensor_specs import ( + BoundedTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + UnboundedContinuousTensorSpec, +) + +_has_tqdm = importlib.util.find_spec("tqdm", None) is not None + +_NAME_MATCH = KeyDependentDefaultDict(lambda key: key) +_NAME_MATCH["observations"] = "observation" +_NAME_MATCH["rewards"] = "reward" +_NAME_MATCH["truncations"] = "truncated" +_NAME_MATCH["terminations"] = "terminated" +_NAME_MATCH["actions"] = "action" +_NAME_MATCH["infos"] = "info" + + +_DTYPE_DIR = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "int64": torch.int64, + "int32": torch.int32, + "uint8": torch.uint8, +} + + +class MinariExperienceReplay(TensorDictReplayBuffer): + """Minari Experience replay dataset. + + Args: + dataset_id (str): + batch_size (int): + + Keyword Args: + root (Path or str, optional): The Minari dataset root directory. + The actual dataset memory-mapped files will be saved under + `/`. If none is provided, it defaults to + ``~/.cache/torchrl/minari`. + download (bool or str, optional): Whether the dataset should be downloaded if + not found. Defaults to ``True``. Download can also be passed as "force", + in which case the downloaded data will be overwritten. + sampler (Sampler, optional): the sampler to be used. If none is provided + a default RandomSampler() will be used. + writer (Writer, optional): the writer to be used. If none is provided + a default RoundRobinWriter() will be used. + collate_fn (callable, optional): merges a list of samples to form a + mini-batch of Tensor(s)/outputs. Used when using batched + loading from a map-style dataset. + pin_memory (bool): whether pin_memory() should be called on the rb + samples. + prefetch (int, optional): number of next batches to be prefetched + using multithreading. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. + split_trajs (bool, optional): if ``True``, the trajectories will be split + along the first dimension and padded to have a matching shape. + To split the trajectories, the ``"done"`` signal will be used, which + is recovered via ``done = truncated | terminated``. In other words, + it is assumed that any ``truncated`` or ``terminated`` signal is + equivalent to the end of a trajectory. For some datasets from + ``D4RL``, this may not be true. It is up to the user to make + accurate choices regarding this usage of ``split_trajs``. + Defaults to ``False``. + + .. note:: + Text data is currenrtly discarded from the wrapped dataset, as there is not + PyTorch native way of representing text data. + If this feature is required, please post an issue on TorchRL's GitHub + repository. + + Examples: + >>> from torchrl.data.datasets.minari_data import MinariExperienceReplay + >>> data = MinariExperienceReplay("door-human-v1", batch_size=32, download="force") + >>> for sample in data: + ... print(sample) + ... break + TensorDict( + fields={ + action: Tensor(shape=torch.Size([32, 28]), device=cpu, dtype=torch.float32, is_shared=False), + index: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False), + info: TensorDict( + fields={ + success: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), + info: TensorDict( + fields={ + success: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False), + reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float64, is_shared=False), + state: TensorDict( + fields={ + door_body_pos: Tensor(shape=torch.Size([32, 3]), device=cpu, dtype=torch.float64, is_shared=False), + qpos: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False), + qvel: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False), + terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False), + state: TensorDict( + fields={ + door_body_pos: Tensor(shape=torch.Size([32, 3]), device=cpu, dtype=torch.float64, is_shared=False), + qpos: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False), + qvel: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False)}, + batch_size=torch.Size([32]), + device=cpu, + is_shared=False) + + """ + + def __init__( + self, + dataset_id, + batch_size: int, + *, + root: str | Path | None = None, + download: bool = True, + sampler: Sampler | None = None, + writer: Writer | None = None, + collate_fn: Callable | None = None, + pin_memory: bool = False, + prefetch: int | None = None, + transform: "torchrl.envs.Transform" | None = None, # noqa-F821 + split_trajs: bool = False, + **env_kwargs, + ): + self.dataset_id = dataset_id + if root is None: + root = _get_root_dir("minari") + os.makedirs(root, exist_ok=True) + self.root = root + self.split_trajs = split_trajs + self.download = download + if self.download == "force" or (self.download and not self._is_downloaded()): + if self.download == "force": + try: + shutil.rmtree(self.data_path_root) + if self.data_path != self.data_path_root: + shutil.rmtree(self.data_path) + except FileNotFoundError: + pass + storage = self._download_and_preproc() + elif self.split_trajs and not os.path.exists(self.data_path): + storage = self._make_split() + else: + storage = self._load() + storage = TensorStorage(storage) + super().__init__( + storage=storage, + sampler=sampler, + writer=writer, + collate_fn=collate_fn, + pin_memory=pin_memory, + prefetch=prefetch, + batch_size=batch_size, + ) + + def available_datasets(self): + import minari + + return minari.list_remote_datasets().keys() + + def _is_downloaded(self): + return os.path.exists(self.data_path_root) + + @property + def data_path(self): + if self.split_trajs: + return Path(self.root) / (self.dataset_id + "_split") + return self.data_path_root + + @property + def data_path_root(self): + return Path(self.root) / self.dataset_id + + @property + def metadata_path(self): + return Path(self.root) / self.dataset_id / "env_metadata.json" + + def _download_and_preproc(self): + import minari + + if _has_tqdm: + from tqdm import tqdm + + with tempfile.TemporaryDirectory() as tmpdir: + os.environ["MINARI_DATASETS_PATH"] = tmpdir + minari.download_dataset(dataset_id=self.dataset_id) + parent_dir = Path(tmpdir) / self.dataset_id / "data" + + td_data = TensorDict({}, []) + total_steps = 0 + print("first read through data to create data structure...") + h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") + # populate the tensordict + episode_dict = {} + for episode_key, episode in h5_data.items(): + episode_num = int(episode_key[len("episode_") :]) + episode_len = episode["actions"].shape[0] + episode_dict[episode_num] = (episode_key, episode_len) + # Get the total number of steps for the dataset + total_steps += episode_len + for key, val in episode.items(): + match = _NAME_MATCH[key] + if key in ("observations", "state", "infos"): + if ( + not val.shape + ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: + if val.is_empty(): + continue + val = _patch_info(val) + td_data.set(("next", match), torch.zeros_like(val[0])) + td_data.set(match, torch.zeros_like(val[0])) + if key not in ("terminations", "truncations", "rewards"): + td_data.set(match, torch.zeros_like(val[0])) + else: + td_data.set( + ("next", match), + torch.zeros_like(val[0].unsqueeze(-1)), + ) + + # give it the proper size + td_data["next", "done"] = ( + td_data["next", "truncated"] | td_data["next", "terminated"] + ) + if "terminated" in td_data.keys(): + td_data["done"] = td_data["truncated"] | td_data["terminated"] + td_data = td_data.expand(total_steps) + # save to designated location + print(f"creating tensordict data in {self.data_path_root}: ", end="\t") + td_data = td_data.memmap_like(self.data_path_root) + print("tensordict structure:", td_data) + + print(f"Reading data from {max(*episode_dict)} episodes") + index = 0 + with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar: + # iterate over episodes and populate the tensordict + for episode_num in sorted(episode_dict): + episode_key, steps = episode_dict[episode_num] + episode = h5_data.get(episode_key) + idx = slice(index, (index + steps)) + data_view = td_data[idx] + for key, val in episode.items(): + match = _NAME_MATCH[key] + if key in ( + "observations", + "state", + "infos", + ): + if not val.shape or steps != val.shape[0] - 1: + if val.is_empty(): + continue + val = _patch_info(val) + if steps != val.shape[0] - 1: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." + ) + data_view["next", match].copy_(val[1:]) + data_view[match].copy_(val[:-1]) + elif key not in ("terminations", "truncations", "rewards"): + if steps is None: + steps = val.shape[0] + else: + if steps != val.shape[0]: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." + ) + data_view[match].copy_(val) + else: + if steps is None: + steps = val.shape[0] + else: + if steps != val.shape[0]: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." + ) + data_view[("next", match)].copy_(val.unsqueeze(-1)) + data_view["next", "done"].copy_( + data_view["next", "terminated"] | data_view["next", "truncated"] + ) + if "done" in data_view.keys(): + data_view["done"].copy_( + data_view["terminated"] | data_view["truncated"] + ) + if pbar is not None: + pbar.update(steps) + pbar.set_description( + f"index={index} - episode num {episode_num}" + ) + index += steps + h5_data.close() + # Add a "done" entry + if self.split_trajs: + with td_data.unlock_(): + from torchrl.objectives.utils import split_trajectories + + td_data = split_trajectories(td_data).memmap_(self.data_path) + with open(self.metadata_path, "w") as metadata_file: + dataset = minari.load_dataset(self.dataset_id) + self.metadata = asdict(dataset.spec) + self.metadata["observation_space"] = _spec_to_dict( + self.metadata["observation_space"] + ) + self.metadata["action_space"] = _spec_to_dict( + self.metadata["action_space"] + ) + json.dump(self.metadata, metadata_file) + self._load_and_proc_metadata() + return td_data + + def _make_split(self): + from torchrl.collectors.utils import split_trajectories + + self._load_and_proc_metadata() + td_data = TensorDict.load_memmap(self.data_path_root) + td_data = split_trajectories(td_data).memmap_(self.data_path) + return td_data + + def _load(self): + self._load_and_proc_metadata() + return TensorDict.load_memmap(self.data_path) + + def _load_and_proc_metadata(self): + with open(self.metadata_path, "r") as file: + self.metadata = json.load(file) + self.metadata["observation_space"] = _proc_spec( + self.metadata["observation_space"] + ) + self.metadata["action_space"] = _proc_spec(self.metadata["action_space"]) + + +def _proc_spec(spec): + if spec is None: + return + if spec["type"] == "Dict": + return CompositeSpec( + {key: _proc_spec(subspec) for key, subspec in spec["subspaces"].items()} + ) + elif spec["type"] == "Box": + if all(item == -float("inf") for item in spec["low"]) and all( + item == float("inf") for item in spec["high"] + ): + return UnboundedContinuousTensorSpec( + spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]] + ) + return BoundedTensorSpec( + shape=spec["shape"], + low=torch.tensor(spec["low"]), + high=torch.tensor(spec["high"]), + dtype=_DTYPE_DIR[spec["dtype"]], + ) + elif spec["type"] == "Discrete": + return DiscreteTensorSpec( + spec["n"], shape=spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]] + ) + else: + raise NotImplementedError(f"{type(spec)}") + + +def _spec_to_dict(spec): + from torchrl.envs.libs.gym import gym_backend + + if isinstance(spec, gym_backend("spaces").Dict): + return { + "type": "Dict", + "subspaces": {key: _spec_to_dict(val) for key, val in spec.items()}, + } + if isinstance(spec, gym_backend("spaces").Box): + return { + "type": "Box", + "low": spec.low.tolist(), + "high": spec.high.tolist(), + "dtype": str(spec.dtype), + "shape": tuple(spec.shape), + } + if isinstance(spec, gym_backend("spaces").Discrete): + return { + "type": "Discrete", + "dtype": str(spec.dtype), + "n": int(spec.n), + "shape": tuple(spec.shape), + } + if isinstance(spec, gym_backend("spaces").Text): + return + raise NotImplementedError(f"{type(spec)}, {str(spec)}") + + +def _patch_info(info_td): + # Some info dicts have tensors with one less element than others + # We explicitely assume that the missing item is in the first position because + # it wasn't given at reset time. + # An alternative explanation could be that the last element is missing because + # deemed useless for training... + unique_shapes = defaultdict(list) + for subkey, subval in info_td.items(): + unique_shapes[subval.shape[0]].append(subkey) + if len(unique_shapes) == 1: + unique_shapes[subval.shape[0] + 1] = [] + if len(unique_shapes) != 2: + raise RuntimeError( + f"Unique shapes in a sub-tensordict can only be of length 2, got shapes {unique_shapes}." + ) + val_td = info_td.to_tensordict() + min_shape = min(*unique_shapes) # can only be found at root + max_shape = min_shape + 1 + val_td_sel = val_td.select(*unique_shapes[min_shape]) + val_td_sel = val_td_sel.apply( + lambda x: torch.cat([torch.zeros_like(x[:1]), x], 0), batch_size=[min_shape + 1] + ) + val_td_sel.update(val_td.select(*unique_shapes[max_shape])) + return val_td_sel diff --git a/torchrl/data/datasets/utils.py b/torchrl/data/datasets/utils.py new file mode 100644 index 00000000000..b88e3aee14e --- /dev/null +++ b/torchrl/data/datasets/utils.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os + + +def _get_root_dir(dataset: str): + return os.path.join(os.path.expanduser("~"), ".cache", "torchrl", dataset) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index cfc6c90bb2c..8b7acdd9d10 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -4,9 +4,11 @@ # LICENSE file in the root directory of this source tree. import collections +import json import threading import warnings from concurrent.futures import ThreadPoolExecutor +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch @@ -230,6 +232,7 @@ def state_dict(self) -> Dict[str, Any]: "_storage": self._storage.state_dict(), "_sampler": self._sampler.state_dict(), "_writer": self._writer.state_dict(), + "_transforms": self._transform.state_dict(), "_batch_size": self._batch_size, } @@ -237,8 +240,80 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._storage.load_state_dict(state_dict["_storage"]) self._sampler.load_state_dict(state_dict["_sampler"]) self._writer.load_state_dict(state_dict["_writer"]) + self._transform.load_state_dict(state_dict["_transforms"]) self._batch_size = state_dict["_batch_size"] + def dumps(self, path): + """Saves the replay buffer on disk at the specified path. + + Args: + path (Path or str): path where to save the replay buffer. + + Examples: + >>> import tempfile + >>> import tqdm + >>> from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + >>> from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler + >>> import torch + >>> from tensordict import TensorDict + >>> # Build and populate the replay buffer + >>> S = 1_000_000 + >>> sampler = PrioritizedSampler(S, 1.1, 1.0) + >>> # sampler = RandomSampler() + >>> storage = LazyMemmapStorage(S) + >>> rb = TensorDictReplayBuffer(storage=storage, sampler=sampler) + >>> + >>> for _ in tqdm.tqdm(range(100)): + ... td = TensorDict({"obs": torch.randn(100, 3, 4), "next": {"obs": torch.randn(100, 3, 4)}, "td_error": torch.rand(100)}, [100]) + ... rb.extend(td) + ... sample = rb.sample(32) + ... rb.update_tensordict_priority(sample) + >>> # save and load the buffer + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... rb.dumps(tmpdir) + ... + ... sampler = PrioritizedSampler(S, 1.1, 1.0) + ... # sampler = RandomSampler() + ... storage = LazyMemmapStorage(S) + ... rb_load = TensorDictReplayBuffer(storage=storage, sampler=sampler) + ... rb_load.loads(tmpdir) + ... assert len(rb) == len(rb_load) + + """ + path = Path(path).absolute() + path.mkdir(exist_ok=True) + self._storage.dumps(path / "storage") + self._sampler.dumps(path / "sampler") + self._writer.dumps(path / "writer") + # fall back on state_dict for transforms + transform_sd = self._transform.state_dict() + if transform_sd: + torch.save(transform_sd, path / "transform.t") + with open(path / "buffer_metadata.json", "w") as file: + json.dump({"batch_size": self._batch_size}, file) + + def loads(self, path): + """Loads a replay buffer state at the given path. + + The buffer should have matching components and be saved using :meth:`~.dumps`. + + Args: + path (Path or str): path where the replay buffer was saved. + + See :meth:`~.dumps` for more info. + + """ + path = Path(path).absolute() + self._storage.loads(path / "storage") + self._sampler.loads(path / "sampler") + self._writer.loads(path / "writer") + # fall back on state_dict for transforms + if (path / "transform.t").exists(): + self._transform.load_state_dict(torch.load(path / "transform.t")) + with open(path / "buffer_metadata.json", "r") as file: + metadata = json.load(file) + self._batch_size = metadata["batch_size"] + def add(self, data: Any) -> int: """Add a single element to the replay buffer. diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 16660aff90f..fde6ed9b69e 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -2,14 +2,19 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import json import warnings from abc import ABC, abstractmethod -from copy import deepcopy +from copy import copy, deepcopy +from multiprocessing.context import get_spawning_popen +from pathlib import Path from typing import Any, Dict, Tuple, Union import numpy as np import torch +from tensordict import MemoryMappedTensor + from ..._extension import EXTENSION_WARNING try: @@ -68,6 +73,14 @@ def ran_out(self) -> bool: def _empty(self): ... + @abstractmethod + def dumps(self, path): + ... + + @abstractmethod + def loads(self, path): + ... + class RandomSampler(Sampler): """A uniformly random sampler for composable replay buffers. @@ -87,6 +100,14 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] def _empty(self): pass + def dumps(self, path): + # no op + ... + + def loads(self, path): + # no op + ... + class SamplerWithoutReplacement(Sampler): """A data-consuming sampler that ensures that the same sample is not present in consecutive batches. @@ -114,6 +135,29 @@ def __init__(self, drop_last: bool = False): self.drop_last = drop_last self._ran_out = False + def dumps(self, path): + path = Path(path) + path.mkdir(exist_ok=True) + + with open(path / "sampler_metadata.json", "w") as file: + json.dump( + { + "len_storage": self.len_storage, + "_sample_list": self._sample_list, + "drop_last": self.drop_last, + "_ran_out": self._ran_out, + }, + file, + ) + + def loads(self, path): + with open(path / "sampler_metadata.json", "r") as file: + metadata = json.load(file) + self._sample_list = metadata["_sample_list"] + self.len_storage = metadata["len_storage"] + self.drop_last = metadata["drop_last"] + self._ran_out = metadata["_ran_out"] + def _single_sample(self, len_storage, batch_size): index = self._sample_list[:batch_size] self._sample_list = self._sample_list[batch_size:] @@ -208,6 +252,14 @@ def __init__( self.dtype = dtype self._init() + def __getstate__(self): + if get_spawning_popen() is not None: + raise RuntimeError( + f"Samplers of type {type(self)} cannot be shared between processes." + ) + state = copy(self.__dict__) + return state + def _init(self): if self.dtype in (torch.float, torch.FloatType, torch.float32): self._sum_tree = SumSegmentTreeFp32(self._max_capacity) @@ -276,11 +328,15 @@ def _add_or_extend(self, index: Union[int, torch.Tensor]) -> None: def add(self, index: int) -> None: super().add(index) - self._add_or_extend(index) + if index is not None: + # some writers don't systematically write data and can return None + self._add_or_extend(index) def extend(self, index: torch.Tensor) -> None: super().extend(index) - self._add_or_extend(index) + if index is not None: + # some writers don't systematically write data and can return None + self._add_or_extend(index) def update_priority( self, index: Union[int, torch.Tensor], priority: Union[float, torch.Tensor] @@ -339,3 +395,74 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._max_priority = state_dict["_max_priority"] self._sum_tree = state_dict.pop("_sum_tree") self._min_tree = state_dict.pop("_min_tree") + + def dumps(self, path): + path = Path(path).absolute() + path.mkdir(exist_ok=True) + try: + mm_st = MemoryMappedTensor.from_filename( + shape=(self._max_capacity,), + dtype=torch.float64, + filename=path / "sumtree.memmap", + ) + mm_mt = MemoryMappedTensor.from_filename( + shape=(self._max_capacity,), + dtype=torch.float64, + filename=path / "mintree.memmap", + ) + except FileNotFoundError: + mm_st = MemoryMappedTensor.empty( + (self._max_capacity,), + dtype=torch.float64, + filename=path / "sumtree.memmap", + ) + mm_mt = MemoryMappedTensor.empty( + (self._max_capacity,), + dtype=torch.float64, + filename=path / "mintree.memmap", + ) + mm_st.copy_( + torch.tensor([self._sum_tree[i] for i in range(self._max_capacity)]) + ) + mm_mt.copy_( + torch.tensor([self._min_tree[i] for i in range(self._max_capacity)]) + ) + with open(path / "sampler_metadata.json", "w") as file: + json.dump( + { + "_alpha": self._alpha, + "_beta": self._beta, + "_eps": self._eps, + "_max_priority": self._max_priority, + "_max_capacity": self._max_capacity, + }, + file, + ) + + def loads(self, path): + path = Path(path).absolute() + with open(path / "sampler_metadata.json", "r") as file: + metadata = json.load(file) + self._alpha = metadata["_alpha"] + self._beta = metadata["_beta"] + self._eps = metadata["_eps"] + self._max_priority = metadata["_max_priority"] + _max_capacity = metadata["_max_capacity"] + if _max_capacity != self._max_capacity: + raise RuntimeError( + f"max capacity of loaded metadata ({_max_capacity}) differs from self._max_capacity ({self._max_capacity})." + ) + mm_st = MemoryMappedTensor.from_filename( + shape=(self._max_capacity,), + dtype=torch.float64, + filename=path / "sumtree.memmap", + ) + mm_mt = MemoryMappedTensor.from_filename( + shape=(self._max_capacity,), + dtype=torch.float64, + filename=path / "mintree.memmap", + ) + for i, elt in enumerate(mm_st.tolist()): + self._sum_tree[i] = elt + for i, elt in enumerate(mm_mt.tolist()): + self._min_tree[i] = elt diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 9c8417b9c97..4e01eeffb67 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -4,17 +4,22 @@ # LICENSE file in the root directory of this source tree. import abc +import json import os import warnings from collections import OrderedDict from copy import copy +from multiprocessing.context import get_spawning_popen +from pathlib import Path from typing import Any, Dict, Sequence, Union +import numpy as np import torch from tensordict import is_tensorclass from tensordict.memmap import MemmapTensor, MemoryMappedTensor from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase -from tensordict.utils import expand_right +from tensordict.utils import _STRDTYPE2DTYPE, expand_right +from torch import multiprocessing as mp from torchrl._utils import _CKPT_BACKEND, implement_for, VERBOSE from torchrl.data.replay_buffers.utils import INT_CLASSES @@ -52,6 +57,14 @@ def set(self, cursor: int, data: Any): def get(self, index: int) -> Any: ... + @abc.abstractmethod + def dumps(self, path): + ... + + @abc.abstractmethod + def loads(self, path): + ... + def attach(self, buffer: Any) -> None: """This function attaches a sampler to this storage. @@ -107,8 +120,23 @@ def __init__(self, max_size: int): super().__init__(max_size) self._storage = [] + def dumps(self, path): + raise NotImplementedError( + "ListStorage doesn't support serialization via `dumps` - `loads` API." + ) + + def loads(self, path): + raise NotImplementedError( + "ListStorage doesn't support serialization via `dumps` - `loads` API." + ) + def set(self, cursor: Union[int, Sequence[int], slice], data: Any): if not isinstance(cursor, INT_CLASSES): + if (isinstance(cursor, torch.Tensor) and cursor.numel() <= 1) or ( + isinstance(cursor, np.ndarray) and cursor.size <= 1 + ): + self.set(int(cursor), data) + return if isinstance(cursor, slice): self._storage[cursor] = data return @@ -166,6 +194,14 @@ def load_state_dict(self, state_dict): def _empty(self): self._storage = [] + def __getstate__(self): + if get_spawning_popen() is not None: + raise RuntimeError( + f"Cannot share a storage of type {type(self)} between processes." + ) + state = copy(self.__dict__) + return state + class TensorStorage(Storage): """A storage for tensors and tensordicts. @@ -259,6 +295,123 @@ def __init__(self, storage, max_size=None, device="cpu"): ) self._storage = storage + def dumps(self, path): + path = Path(path) + path.mkdir(exist_ok=True) + + if not self.initialized: + raise RuntimeError("Cannot save a non-initialized storage.") + if isinstance(self._storage, torch.Tensor): + try: + MemoryMappedTensor.from_filename( + shape=self._storage.shape, + filename=path / "storage.memmap", + dtype=self._storage.dtype, + ).copy_(self._storage) + except FileNotFoundError: + MemoryMappedTensor.from_tensor( + self._storage, filename=path / "storage.memmap", copy_existing=True + ) + is_tensor = True + dtype = str(self._storage.dtype) + shape = list(self._storage.shape) + else: + # try to load the path and overwrite. + try: + saved = TensorDict.load_memmap(path) + except FileNotFoundError: + # otherwise create a new one + saved = self._storage.memmap_like(path) + saved.update_(self._storage) + is_tensor = False + dtype = None + shape = None + + with open(path / "storage_metadata.json", "w") as file: + json.dump( + { + "is_tensor": is_tensor, + "dtype": dtype, + "shape": shape, + "len": self._len, + }, + file, + ) + + def loads(self, path): + with open(path / "storage_metadata.json", "r") as file: + metadata = json.load(file) + is_tensor = metadata["is_tensor"] + shape = metadata["shape"] + dtype = metadata["dtype"] + _len = metadata["len"] + if dtype is not None: + shape = torch.Size(shape) + dtype = _STRDTYPE2DTYPE[dtype] + if is_tensor: + _storage = MemoryMappedTensor.from_filename( + path / "storage.memmap", shape=shape, dtype=dtype + ).clone() + else: + _storage = TensorDict.load_memmap(path) + if not self.initialized: + self._storage = _storage + self.initialized = True + else: + self._storage.copy_(_storage) + self._len = _len + + @property + def _len(self): + _len_value = self.__dict__.get("_len_value", None) + if _len_value is None: + _len_value = self._len_value = mp.Value("i", 0) + return _len_value.value + + @_len.setter + def _len(self, value): + _len_value = self.__dict__.get("_len_value", None) + if _len_value is None: + _len_value = self._len_value = mp.Value("i", 0) + _len_value.value = value + + def __getstate__(self): + state = copy(self.__dict__) + if get_spawning_popen() is None: + len = self._len + del state["_len_value"] + state["len__context"] = len + elif not self.initialized: + # check that the storage is initialized + raise RuntimeError( + f"Cannot share a storage of type {type(self)} between processed if " + f"it has not been initialized yet. Populate the buffer with " + f"some data in the main process before passing it to the other " + f"subprocesses (or create the buffer explicitely with a TensorStorage)." + ) + else: + # check that the content is shared, otherwise tell the user we can't help + storage = self._storage + STORAGE_ERR = "The storage must be place in shared memory or memmapped before being shared between processes." + if is_tensor_collection(storage): + if not storage.is_memmap() and not storage.is_shared(): + raise RuntimeError(STORAGE_ERR) + else: + if ( + not isinstance(storage, MemoryMappedTensor) + and not storage.is_shared() + ): + raise RuntimeError(STORAGE_ERR) + + return state + + def __setstate__(self, state): + len = state.pop("len__context", None) + if len is not None: + _len_value = mp.Value("i", len) + state["_len_value"] = _len_value + self.__dict__.update(state) + def state_dict(self) -> Dict[str, Any]: _storage = self._storage if isinstance(_storage, torch.Tensor): @@ -387,14 +540,18 @@ def set( # noqa: F811 self._storage[cursor] = data def get(self, index: Union[int, Sequence[int], slice]) -> Any: + if self._len < self.max_size: + storage = self._storage[: self._len] + else: + storage = self._storage if not self.initialized: raise RuntimeError( "Cannot get an item from an unitialized LazyMemmapStorage" ) - out = self._storage[index] + out = storage[index] if is_tensor_collection(out): out = _reset_batch_size(out) - return out.unlock_() + return out # .unlock_() return out def __len__(self): diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index cf78e0a0d99..702898b5292 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -4,11 +4,18 @@ # LICENSE file in the root directory of this source tree. import heapq +import json from abc import ABC, abstractmethod +from copy import copy +from multiprocessing.context import get_spawning_popen +from pathlib import Path from typing import Any, Dict, Sequence import numpy as np import torch +from tensordict import is_tensor_collection, MemoryMappedTensor +from tensordict.utils import _STRDTYPE2DTYPE +from torch import multiprocessing as mp from .storages import Storage @@ -36,6 +43,14 @@ def extend(self, data: Sequence) -> torch.Tensor: def _empty(self): ... + @abstractmethod + def dumps(self, path): + ... + + @abstractmethod + def loads(self, path): + ... + def state_dict(self) -> Dict[str, Any]: return {} @@ -50,16 +65,31 @@ def __init__(self, **kw) -> None: super().__init__(**kw) self._cursor = 0 + def dumps(self, path): + path = Path(path).absolute() + path.mkdir(exist_ok=True) + with open(path / "metadata.json", "w") as file: + json.dump({"cursor": self._cursor}, file) + + def loads(self, path): + path = Path(path).absolute() + with open(path / "metadata.json", "r") as file: + metadata = json.load(file) + self._cursor = metadata["cursor"] + def add(self, data: Any) -> int: ret = self._cursor - self._storage[self._cursor] = data + _cursor = self._cursor + # we need to update the cursor first to avoid race conditions between workers self._cursor = (self._cursor + 1) % self._storage.max_size + self._storage[_cursor] = data return ret def extend(self, data: Sequence) -> torch.Tensor: cur_size = self._cursor batch_size = len(data) index = np.arange(cur_size, batch_size + cur_size) % self._storage.max_size + # we need to update the cursor first to avoid race conditions between workers self._cursor = (batch_size + cur_size) % self._storage.max_size self._storage[index] = data return index @@ -73,21 +103,52 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def _empty(self): self._cursor = 0 + @property + def _cursor(self): + _cursor_value = self.__dict__.get("_cursor_value", None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value("i", 0) + return _cursor_value.value + + @_cursor.setter + def _cursor(self, value): + _cursor_value = self.__dict__.get("_cursor_value", None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value("i", 0) + _cursor_value.value = value + + def __getstate__(self): + state = copy(self.__dict__) + if get_spawning_popen() is None: + cursor = self._cursor + del state["_cursor_value"] + state["cursor__context"] = cursor + return state + + def __setstate__(self, state): + cursor = state.pop("cursor__context", None) + if cursor is not None: + _cursor_value = mp.Value("i", cursor) + state["_cursor_value"] = _cursor_value + self.__dict__.update(state) + class TensorDictRoundRobinWriter(RoundRobinWriter): """A RoundRobin Writer class for composable, tensordict-based replay buffers.""" def add(self, data: Any) -> int: ret = self._cursor + # we need to update the cursor first to avoid race conditions between workers + self._cursor = (ret + 1) % self._storage.max_size data["index"] = ret - self._storage[self._cursor] = data - self._cursor = (self._cursor + 1) % self._storage.max_size + self._storage[ret] = data return ret def extend(self, data: Sequence) -> torch.Tensor: cur_size = self._cursor batch_size = len(data) index = np.arange(cur_size, batch_size + cur_size) % self._storage.max_size + # we need to update the cursor first to avoid race conditions between workers self._cursor = (batch_size + cur_size) % self._storage.max_size # storage must convert the data to the appropriate format if needed data["index"] = index @@ -144,6 +205,10 @@ def __init__(self, rank_key=None, **kwargs) -> None: def get_insert_index(self, data: Any) -> int: """Returns the index where the data should be inserted, or ``None`` if it should not be inserted.""" + if not is_tensor_collection(data): + raise RuntimeError( + f"{type(self)} expects data to be a tensor collection (tensordict or tensorclass). Found a {type(data)} instead." + ) if data.batch_dims > 1: raise RuntimeError( "Expected input tensordict to have no more than 1 dimension, got" @@ -151,7 +216,7 @@ def get_insert_index(self, data: Any) -> int: ) ret = None - rank_data = data.get(("_data", self._rank_key)) + rank_data = data.get("_data", default=data).get(self._rank_key) # If time dimension, sum along it. rank_data = rank_data.sum(-1).item() @@ -161,7 +226,6 @@ def get_insert_index(self, data: Any) -> int: # If the buffer is not full, add the data if len(self._current_top_values) < self._storage.max_size: - ret = self._cursor self._cursor = (self._cursor + 1) % self._storage.max_size @@ -209,14 +273,65 @@ def extend(self, data: Sequence) -> None: # Replace the data in the storage all at once if len(data_to_replace) > 0: keys, values = zip(*data_to_replace.items()) - index = data.get("index") + index = data.get("index", None) + dtype = index.dtype if index is not None else torch.long + device = index.device if index is not None else data.device values = list(values) - keys = index[values] = torch.tensor( - keys, dtype=index.dtype, device=index.device - ) - data.set("index", index) - self._storage[keys] = data[values] + keys = torch.tensor(keys, dtype=dtype, device=device) + if index is not None: + index[values] = keys + data.set("index", index) + self._storage.set(keys, data[values]) + return keys.long() + return None def _empty(self) -> None: self._cursor = 0 self._current_top_values = [] + + def __getstate__(self): + if get_spawning_popen() is not None: + raise RuntimeError( + f"Writers of type {type(self)} cannot be shared between processes." + ) + state = copy(self.__dict__) + return state + + def dumps(self, path): + path = Path(path).absolute() + path.mkdir(exist_ok=True) + t = torch.tensor(self._current_top_values) + try: + MemoryMappedTensor.from_filename( + filename=path / "current_top_values.memmap", + shape=t.shape, + dtype=t.dtype, + ).copy_(t) + except FileNotFoundError: + MemoryMappedTensor.from_tensor( + t, filename=path / "current_top_values.memmap" + ) + with open(path / "metadata.json", "w") as file: + json.dump( + { + "cursor": self._cursor, + "rank_key": self._rank_key, + "dtype": str(t.dtype), + "shape": list(t.shape), + }, + file, + ) + + def loads(self, path): + path = Path(path).absolute() + with open(path / "metadata.json", "r") as file: + metadata = json.load(file) + self._cursor = metadata["cursor"] + self._rank_key = metadata["rank_key"] + shape = torch.Size(metadata["shape"]) + dtype = metadata["dtype"] + self._current_top_values = MemoryMappedTensor.from_filename( + filename=path / "current_top_values.memmap", + dtype=_STRDTYPE2DTYPE[dtype], + shape=shape, + ).tolist() diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f0e132eb092..ac0a136c7f9 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -122,11 +122,16 @@ class _BatchedEnv(EnvBase): memmap (bool): whether or not the returned tensordict will be placed in memory map. policy_proof (callable, optional): if provided, it'll be used to get the list of tensors to return through the :obj:`step()` and :obj:`reset()` methods, such as :obj:`"hidden"` etc. - device (str, int, torch.device): for consistency, this argument is kept. However this - argument should not be passed, as the device will be inferred from the environments. - It is assumed that all environments will run on the same device as a common shared - tensordict will be used to pass data from process to process. The device can be - changed after instantiation using :obj:`env.to(device)`. + device (str, int, torch.device): The device of the batched environment can be passed. + If not, it is inferred from the env. In this case, it is assumed that + the device of all environments match. If it is provided, it can differ + from the sub-environment device(s). In that case, the data will be + automatically cast to the appropriate device during collection. + This can be used to speed up collection in case casting to device + introduces an overhead (eg, numpy-based environents etc.): by using + a ``"cuda"`` device for the batched environment but a ``"cpu"`` + device for the nested environments, one can keep the overhead to a + minimum. num_threads (int, optional): number of threads for this process. Defaults to the number of workers. This parameter has no effect for the :class:`~SerialEnv` class. @@ -162,14 +167,7 @@ def __init__( num_threads: int = None, num_sub_threads: int = 1, ): - if device is not None: - raise ValueError( - "Device setting for batched environment can't be done at initialization. " - "The device will be inferred from the constructed environment. " - "It can be set through the `to(device)` method." - ) - - super().__init__(device=None) + super().__init__(device=device) self.is_closed = True if num_threads is None: num_threads = num_workers + 1 # 1 more thread for this proc @@ -218,7 +216,7 @@ def __init__( "memmap and shared memory are mutually exclusive features." ) self._batch_size = None - self._device = None + self._device = torch.device(device) if device is not None else device self._dummy_env_str = None self._seeds = None self.__dict__["_input_spec"] = None @@ -273,7 +271,9 @@ def _set_properties(self): self._properties_set = True if self._single_task: self._batch_size = meta_data.batch_size - device = self._device = meta_data.device + device = meta_data.device + if self._device is None: + self._device = device input_spec = meta_data.specs["input_spec"].to(device) output_spec = meta_data.specs["output_spec"].to(device) @@ -289,8 +289,18 @@ def _set_properties(self): self._batch_locked = meta_data.batch_locked else: self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size]) - device = self._device = meta_data[0].device - # TODO: check that all action_spec and reward spec match (issue #351) + devices = set() + for _meta_data in meta_data: + device = _meta_data.device + devices.add(device) + if self._device is None: + if len(devices) > 1: + raise ValueError( + f"The device wasn't passed to {type(self)}, but more than one device was found in the sub-environments. " + f"Please indicate a device to be used for collection." + ) + device = list(devices)[0] + self._device = device input_spec = [] for md in meta_data: @@ -413,7 +423,7 @@ def _create_td(self) -> None: *(unravel_key(("next", key)) for key in self._env_output_keys), strict=False, ) - self.shared_tensordict_parent = shared_tensordict_parent.to(self.device) + self.shared_tensordict_parent = shared_tensordict_parent else: # Multi-task: we share tensordict that *may* have different keys shared_tensordict_parent = [ @@ -421,7 +431,7 @@ def _create_td(self) -> None: *self._selected_keys, *(unravel_key(("next", key)) for key in self._env_output_keys), strict=False, - ).to(self.device) + ) for tensordict in shared_tensordict_parent ] shared_tensordict_parent = torch.stack( @@ -440,13 +450,11 @@ def _create_td(self) -> None: # Multi-task: we share tensordict that *may* have different keys # LazyStacked already stores this so we don't need to do anything self.shared_tensordicts = self.shared_tensordict_parent - if self.device.type == "cpu": + if self.shared_tensordict_parent.device.type == "cpu": if self._share_memory: - for td in self.shared_tensordicts: - td.share_memory_() + self.shared_tensordict_parent.share_memory_() elif self._memmap: - for td in self.shared_tensordicts: - td.memmap_() + self.shared_tensordict_parent.memmap_() else: if self._share_memory: self.shared_tensordict_parent.share_memory_() @@ -483,7 +491,6 @@ def close(self) -> None: self.__dict__["_input_spec"] = None self.__dict__["_output_spec"] = None self._properties_set = False - self.event = None self._shutdown_workers() self.is_closed = True @@ -507,11 +514,6 @@ def to(self, device: DEVICE_TYPING): if device == self.device: return self self._device = device - self.meta_data = ( - self.meta_data.to(device) - if self._single_task - else [meta_data.to(device) for meta_data in self.meta_data] - ) if not self.is_closed: warn( "Casting an open environment to another device requires closing and re-opening it. " @@ -543,7 +545,7 @@ def _start_workers(self) -> None: for idx in range(_num_workers): env = self.create_env_fn[idx](**self.create_env_kwargs[idx]) - self._envs.append(env.to(self.device)) + self._envs.append(env) self.is_closed = False @_check_start @@ -603,29 +605,39 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: if tensordict_.is_empty(): tensordict_ = None else: - # reset will do modifications in-place. We want the original - # tensorict to be unchaned, so we clone it - tensordict_ = tensordict_.clone(False) + env_device = _env.device + if env_device != self.device: + tensordict_ = tensordict_.to(env_device) + else: + tensordict_ = tensordict_.clone(False) else: tensordict_ = None + _td = _env.reset(tensordict=tensordict_, **kwargs) self.shared_tensordicts[i].update_( _td.select(*self._selected_reset_keys_filt, strict=False) ) selected_output_keys = self._selected_reset_keys_filt + device = self.device if self._single_task: # select + clone creates 2 tds, but we can create one only out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in selected_output_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) - return out + _set_single_key( + self.shared_tensordict_parent, out, key, clone=True, device=device + ) else: - return self.shared_tensordict_parent.select( + out = self.shared_tensordict_parent.select( *selected_output_keys, strict=False, - ).clone() + ) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) + return out def _reset_proc_data(self, tensordict, tensordict_reset): # since we call `reset` directly, all the postproc has been completed @@ -643,19 +655,29 @@ def _step( for i in range(self.num_workers): # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. - out_td = self._envs[i]._step(tensordict_in[i]) + env_device = self._envs[i].device + if env_device != self.device: + data_in = tensordict_in[i].to(env_device, non_blocking=True) + else: + data_in = tensordict_in[i] + out_td = self._envs[i]._step(data_in) next_td[i].update_(out_td.select(*self._env_output_keys, strict=False)) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps + device = self.device if self._single_task: out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True, device=device) else: # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False).clone() + out = next_td.select(*self._selected_step_keys, strict=False) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) return out def __getattr__(self, attr: str) -> Any: @@ -710,6 +732,32 @@ class ParallelEnv(_BatchedEnv): """ __doc__ += _BatchedEnv.__doc__ + __doc__ += """ + + .. note:: + The choice of the devices where ParallelEnv needs to be executed can + drastically influence its performance. The rule of thumbs is: + + - If the base environment (backend, e.g., Gym) is executed on CPU, the + sub-environments should be executed on CPU and the data should be + passed via shared physical memory. + - If the base environment is (or can be) executed on CUDA, the sub-environments + should be placed on CUDA too. + - If a CUDA device is available and the policy is to be executed on CUDA, + the ParallelEnv device should be set to CUDA. + + Therefore, supposing a CUDA device is available, we have the following scenarios: + + >>> # The sub-envs are executed on CPU, but the policy is on GPU + >>> env = ParallelEnv(N, MyEnv(..., device="cpu"), device="cuda") + >>> # The sub-envs are executed on CUDA + >>> env = ParallelEnv(N, MyEnv(..., device="cuda"), device="cuda") + >>> # this will create the exact same environment + >>> env = ParallelEnv(N, MyEnv(..., device="cuda")) + >>> # If no cuda device is available + >>> env = ParallelEnv(N, MyEnv(..., device="cpu")) + + """ def _start_workers(self) -> None: from torchrl.envs.env_creator import EnvCreator @@ -722,39 +770,39 @@ def _start_workers(self) -> None: self.parent_channels = [] self._workers = [] - self._events = [] - if self.device.type == "cuda": + func = _run_worker_pipe_shared_mem + if self.shared_tensordict_parent.device.type == "cuda": self.event = torch.cuda.Event() else: self.event = None + self._events = [ctx.Event() for _ in range(_num_workers)] + kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)] with clear_mpi_env_vars(): for idx in range(_num_workers): if self._verbose: print(f"initiating worker {idx}") # No certainty which module multiprocessing_context is parent_pipe, child_pipe = ctx.Pipe() - event = ctx.Event() - self._events.append(event) env_fun = self.create_env_fn[idx] if not isinstance(env_fun, EnvCreator): env_fun = CloudpickleWrapper(env_fun) - + kwargs[idx].update( + { + "parent_pipe": parent_pipe, + "child_pipe": child_pipe, + "env_fun": env_fun, + "env_fun_kwargs": self.create_env_kwargs[idx], + "shared_tensordict": self.shared_tensordicts[idx], + "_selected_input_keys": self._selected_input_keys, + "_selected_reset_keys": self._selected_reset_keys, + "_selected_step_keys": self._selected_step_keys, + "has_lazy_inputs": self.has_lazy_inputs, + } + ) process = _ProcessNoWarn( - target=_run_worker_pipe_shared_mem, + target=func, num_threads=self.num_sub_threads, - args=( - parent_pipe, - child_pipe, - env_fun, - self.create_env_kwargs[idx], - self.device, - event, - self.shared_tensordicts[idx], - self._selected_input_keys, - self._selected_reset_keys, - self._selected_step_keys, - self.has_lazy_inputs, - ), + kwargs=kwargs[idx], ) process.daemon = True process.start() @@ -834,10 +882,16 @@ def step_and_maybe_reset( # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - tensordict.set("next", self.shared_tensordict_parent.get("next").clone()) - tensordict_ = self.shared_tensordict_parent.exclude( - "next", *self.reset_keys - ).clone() + next_td = self.shared_tensordict_parent.get("next") + tensordict_ = self.shared_tensordict_parent.exclude("next", *self.reset_keys) + device = self.device + if self.shared_tensordict_parent.device == device: + next_td = next_td.clone() + tensordict_ = tensordict_.clone() + else: + next_td = next_td.to(device, non_blocking=True) + tensordict_ = tensordict_.to(device, non_blocking=True) + tensordict.set("next", next_td) return tensordict, tensordict_ @_check_start @@ -880,15 +934,20 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps next_td = self.shared_tensordict_parent.get("next") + device = self.device if self._single_task: out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in self._selected_step_keys: - _set_single_key(next_td, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True, device=device) else: # strict=False ensures that non-homogeneous keys are still there - out = next_td.select(*self._selected_step_keys, strict=False).clone() + out = next_td.select(*self._selected_step_keys, strict=False) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) return out @_check_start @@ -944,19 +1003,26 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: event.clear() selected_output_keys = self._selected_reset_keys_filt + device = self.device if self._single_task: # select + clone creates 2 tds, but we can create one only out = TensorDict( - {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + {}, batch_size=self.shared_tensordict_parent.shape, device=device ) for key in selected_output_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) - return out + _set_single_key( + self.shared_tensordict_parent, out, key, clone=True, device=device + ) else: - return self.shared_tensordict_parent.select( + out = self.shared_tensordict_parent.select( *selected_output_keys, strict=False, - ).clone() + ) + if out.device == device: + out = out.clone() + else: + out = out.to(device, non_blocking=True) + return out @_check_start def _shutdown_workers(self) -> None: @@ -981,6 +1047,7 @@ def _shutdown_workers(self) -> None: del self.parent_channels self._cuda_events = None self._events = None + self.event = None @_check_start def set_seed( @@ -1063,7 +1130,6 @@ def _run_worker_pipe_shared_mem( child_pipe: connection.Connection, env_fun: Union[EnvBase, Callable], env_fun_kwargs: Dict[str, Any], - device: DEVICE_TYPING = None, mp_event: mp.Event = None, shared_tensordict: TensorDictBase = None, _selected_input_keys=None, @@ -1072,13 +1138,11 @@ def _run_worker_pipe_shared_mem( has_lazy_inputs: bool = False, verbose: bool = False, ) -> None: - if device is None: - device = torch.device("cpu") + device = shared_tensordict.device if device.type == "cuda": event = torch.cuda.Event() else: event = None - parent_pipe.close() pid = os.getpid() if not isinstance(env_fun, EnvBase): @@ -1089,7 +1153,6 @@ def _run_worker_pipe_shared_mem( "env_fun_kwargs must be empty if an environment is passed to a process." ) env = env_fun - env = env.to(device) del env_fun i = -1 @@ -1144,7 +1207,8 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - next_td = env._step(shared_tensordict) + env_input = shared_tensordict + next_td = env._step(env_input) next_shared_tensordict.update_(next_td) if event is not None: event.record() @@ -1155,7 +1219,8 @@ def _run_worker_pipe_shared_mem( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - td, root_next_td = env.step_and_maybe_reset(shared_tensordict.clone(False)) + env_input = shared_tensordict + td, root_next_td = env.step_and_maybe_reset(env_input) next_shared_tensordict.update_(td.get("next")) root_shared_tensordict.update_(root_next_td) if event is not None: @@ -1208,3 +1273,10 @@ def _run_worker_pipe_shared_mem( else: # don't send env through pipe child_pipe.send(("_".join([cmd, "done"]), None)) + + +def _update_cuda(t_dest, t_source): + if t_source is None: + return + t_dest.copy_(t_source.pin_memory(), non_blocking=True) + return diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 017280579f5..8e3214f3692 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -1160,7 +1160,7 @@ def _read_obs(self, obs, key, tensor, index): # Simplest case: there is one observation, # presented as a np.ndarray. The key should be pixels or observation. # We just write that value at its location in the tensor - tensor[index] = torch.as_tensor(obs, device=tensor.device) + tensor[index] = torch.tensor(obs, device=tensor.device) elif isinstance(obs, dict): if key not in obs: raise KeyError( @@ -1171,13 +1171,13 @@ def _read_obs(self, obs, key, tensor, index): # if the obs is a dict, we expect that the key points also to # a value in the obs. We retrieve this value and write it in the # tensor - tensor[index] = torch.as_tensor(subobs, device=tensor.device) + tensor[index] = torch.tensor(subobs, device=tensor.device) elif isinstance(obs, (list, tuple)): # tuples are stacked along the first dimension when passing gym spaces # to torchrl specs. As such, we can simply stack the tuple and set it # at the relevant index (assuming stacking can be achieved) - tensor[index] = torch.as_tensor(obs, device=tensor.device) + tensor[index] = torch.tensor(obs, device=tensor.device) else: raise NotImplementedError( f"Observations of type {type(obs)} are not supported yet." @@ -1186,11 +1186,12 @@ def _read_obs(self, obs, key, tensor, index): def __call__(self, info_dict, tensordict): terminal_obs = info_dict.get(self.backend_key[self.backend], None) for key, item in self.info_spec.items(True, True): - final_obs = item.zero() + final_obs_buffer = item.zero() if terminal_obs is not None: for i, obs in enumerate(terminal_obs): - self._read_obs(obs, key[-1], final_obs, index=i) - tensordict.set(key, final_obs) + # writes final_obs inplace with terminal_obs content + self._read_obs(obs, key[-1], final_obs_buffer, index=i) + tensordict.set(key, final_obs_buffer) return tensordict diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index b42a7d6be97..dbe097aa312 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -45,22 +45,8 @@ def _get_envs(): import vmas all_scenarios = vmas.scenarios + vmas.mpe_scenarios + vmas.debug_scenarios - # TODO heterogenous spaces - # For now torchrl does not support heterogenous spaces (Tple(Box)) so many OpenAI MPE scenarios do not work - heterogenous_spaces_scenarios = [ - "simple_adversary", - "simple_crypto", - "simple_push", - "simple_speaker_listener", - "simple_tag", - "simple_world_comm", - ] - - return [ - scenario - for scenario in all_scenarios - if scenario not in heterogenous_spaces_scenarios - ] + + return all_scenarios @set_gym_backend("gym") diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 3e6d597dffd..de8baf2e403 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4692,6 +4692,7 @@ def __init__( """Initialises the transform. Filters out non-reward input keys and defines output keys.""" super().__init__(in_keys=in_keys, out_keys=out_keys) self._reset_keys = reset_keys + self._keys_checked = False @property def in_keys(self): @@ -4770,9 +4771,7 @@ def _check_match(reset_keys, in_keys): return False return True - if len(reset_keys) != len(self.in_keys) or not _check_match( - reset_keys, self.in_keys - ): + if not _check_match(reset_keys, self.in_keys): raise ValueError( f"Could not match the env reset_keys {reset_keys} with the {type(self)} in_keys {self.in_keys}. " f"Please provide the reset_keys manually. Reset entries can be " @@ -4781,6 +4780,14 @@ def _check_match(reset_keys, in_keys): ) reset_keys = copy(reset_keys) self._reset_keys = reset_keys + + if not self._keys_checked and len(reset_keys) != len(self.in_keys): + raise ValueError( + f"Could not match the env reset_keys {reset_keys} with the in_keys {self.in_keys}. " + "Please make sure that these have the same length." + ) + self._keys_checked = True + return reset_keys @reset_keys.setter diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 06eec73be97..9a2a71f24bd 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -237,7 +237,11 @@ def step_mdp( def _set_single_key( - source: TensorDictBase, dest: TensorDictBase, key: str | tuple, clone: bool = False + source: TensorDictBase, + dest: TensorDictBase, + key: str | tuple, + clone: bool = False, + device=None, ): # key should be already unraveled if isinstance(key, str): @@ -253,7 +257,9 @@ def _set_single_key( source = val dest = new_val else: - if clone: + if device is not None and val.device != device: + val = val.to(device, non_blocking=True) + elif clone: val = val.clone() dest._set_str(k, val, inplace=False, validated=True) # This is a temporary solution to understand if a key is heterogeneous @@ -262,7 +268,7 @@ def _set_single_key( if re.match(r"Found more than one unique shape in the tensors", str(err)): # this is a het key for s_td, d_td in zip(source.tensordicts, dest.tensordicts): - _set_single_key(s_td, d_td, k, clone) + _set_single_key(s_td, d_td, k, clone=clone, device=device) break else: raise err diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 16d621f2bec..fa3fbc6286f 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -56,8 +56,12 @@ DistributionalQValueModule, EGreedyModule, EGreedyWrapper, + GRU, + GRUCell, GRUModule, LMHeadActorValueOperator, + LSTM, + LSTMCell, LSTMModule, OrnsteinUhlenbeckProcessWrapper, ProbabilisticActor, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 7605238f99a..302cfcaf2cc 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -31,6 +31,6 @@ SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, ) -from .rnn import GRUModule, LSTMModule +from .rnn import GRU, GRUCell, GRUModule, LSTM, LSTMCell, LSTMModule from .sequence import SafeSequential from .world_models import WorldModelWrapper diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 22be1432edf..b705e33474e 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -6,6 +6,7 @@ from typing import Optional, Tuple import torch +import torch.nn.functional as F from tensordict import TensorDictBase, unravel_key_list from tensordict.nn import TensorDictModuleBase as ModuleBase @@ -13,7 +14,8 @@ from tensordict.tensordict import NO_DEFAULT from tensordict.utils import prod -from torch import nn +from torch import nn, Tensor +from torch.nn.modules.rnn import RNNCellBase from torchrl.data import UnboundedContinuousTensorSpec from torchrl.objectives.value.functional import ( @@ -23,6 +25,294 @@ from torchrl.objectives.value.utils import _get_num_per_traj_init +class LSTMCell(RNNCellBase): + r"""A long short-term memory (LSTM) cell that performs the same operation as nn.LSTMCell but is fully coded in Python. + + .. note:: + This class is implemented without relying on CuDNN, which makes it compatible with :func:`torch.vmap` and :func:`torch.compile`. + + Examples: + >>> import torch + >>> from torchrl.modules.tensordict_module.rnn import LSTMCell + >>> device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu") + >>> B = 2 + >>> N_IN = 10 + >>> N_OUT = 20 + >>> V = 4 # vector size + >>> lstm_cell = LSTMCell(input_size=N_IN, hidden_size=N_OUT, device=device) + + # single call + >>> x = torch.randn(B, 10, device=device) + >>> h0 = torch.zeros(B, 20, device=device) + >>> c0 = torch.zeros(B, 20, device=device) + >>> with torch.no_grad(): + ... (h1, c1) = lstm_cell(x, (h0, c0)) + + # vectorised call - not possible with nn.LSTMCell + >>> def call_lstm(x, h, c): + ... h_out, c_out = lstm_cell(x, (h, c)) + ... return h_out, c_out + >>> batched_call = torch.vmap(call_lstm) + >>> x = torch.randn(V, B, 10, device=device) + >>> h0 = torch.zeros(V, B, 20, device=device) + >>> c0 = torch.zeros(V, B, 20, device=device) + >>> with torch.no_grad(): + ... (h1, c1) = batched_call(x, h0, c0) + """ + + __doc__ += nn.LSTMCell.__doc__ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) + + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[Tensor, Tensor]: + if input.dim() not in (1, 2): + raise ValueError( + f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) + if hx is not None: + for idx, value in enumerate(hx): + if value.dim() not in (1, 2): + raise ValueError( + f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + hx = (zeros, zeros) + else: + hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx + + ret = self.lstm_cell(input, hx[0], hx[1]) + + if not is_batched: + ret = (ret[0].squeeze(0), ret[1].squeeze(0)) + return ret + + def lstm_cell(self, x, hx, cx): + x = x.view(-1, x.size(1)) + + gates = F.linear(x, self.weight_ih, self.bias_ih) + F.linear( + hx, self.weight_hh, self.bias_hh + ) + + i_gate, f_gate, g_gate, o_gate = gates.chunk(4, 1) + + i_gate = i_gate.sigmoid() + f_gate = f_gate.sigmoid() + g_gate = g_gate.tanh() + o_gate = o_gate.sigmoid() + + cy = cx * f_gate + i_gate * g_gate + + hy = o_gate * cy.tanh() + + return hy, cy + + +# copy LSTM +class LSTMBase(nn.RNNBase): + """A Base module for LSTM. Inheriting from LSTMBase enables compatibility with torch.compile.""" + + def __init__(self, *args, **kwargs): + return super().__init__("LSTM", *args, **kwargs) + + +for attr in nn.LSTM.__dict__: + if attr != "__init__": + setattr(LSTMBase, attr, getattr(nn.LSTM, attr)) + + +class LSTM(LSTMBase): + """A PyTorch module for executing multiple steps of a multi-layer LSTM. The module behaves exactly like :class:`torch.nn.LSTM`, but this implementation is exclusively coded in Python. + + .. note:: + This class is implemented without relying on CuDNN, which makes it compatible with :func:`torch.vmap` and :func:`torch.compile`. + + Examples: + >>> import torch + >>> from torchrl.modules.tensordict_module.rnn import LSTM + + >>> device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu") + >>> B = 2 + >>> T = 4 + >>> N_IN = 10 + >>> N_OUT = 20 + >>> N_LAYERS = 2 + >>> V = 4 # vector size + >>> lstm = LSTM( + ... input_size=N_IN, + ... hidden_size=N_OUT, + ... device=device, + ... num_layers=N_LAYERS, + ... ) + + # single call + >>> x = torch.randn(B, T, N_IN, device=device) + >>> h0 = torch.zeros(N_LAYERS, B, N_OUT, device=device) + >>> c0 = torch.zeros(N_LAYERS, B, N_OUT, device=device) + >>> with torch.no_grad(): + ... h1, c1 = lstm(x, (h0, c0)) + + # vectorised call - not possible with nn.LSTM + >>> def call_lstm(x, h, c): + ... h_out, c_out = lstm(x, (h, c)) + ... return h_out, c_out + >>> batched_call = torch.vmap(call_lstm) + >>> x = torch.randn(V, B, T, 10, device=device) + >>> h0 = torch.zeros(V, N_LAYERS, B, N_OUT, device=device) + >>> c0 = torch.zeros(V, N_LAYERS, B, N_OUT, device=device) + >>> with torch.no_grad(): + ... h1, c1 = batched_call(x, h0, c0) + """ + + __doc__ += nn.LSTM.__doc__ + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + batch_first: bool = True, + bias: bool = True, + dropout: float = 0.0, + bidirectional: float = False, + proj_size: int = 0, + device=None, + dtype=None, + ) -> None: + + if bidirectional is True: + raise NotImplementedError( + "Bidirectional LSTMs are not supported yet in this implementation." + ) + + super().__init__( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional, + proj_size=proj_size, + device=device, + dtype=dtype, + ) + + @staticmethod + def _lstm_cell(x, hx, cx, weight_ih, bias_ih, weight_hh, bias_hh): + + gates = F.linear(x, weight_ih, bias_ih) + F.linear(hx, weight_hh, bias_hh) + + i_gate, f_gate, g_gate, o_gate = gates.chunk(4, 1) + + i_gate = i_gate.sigmoid() + f_gate = f_gate.sigmoid() + g_gate = g_gate.tanh() + o_gate = o_gate.sigmoid() + + cy = cx * f_gate + i_gate * g_gate + + hy = o_gate * cy.tanh() + + return hy, cy + + def _lstm(self, x, hx): + + if self.batch_first is False: + x = x.permute( + 1, 0, 2 + ) # Change (seq_len, batch, features) to (batch, seq_len, features) + + # should check self.batch_first + bs, seq_len, input_size = x.size() + h_t, c_t = [list(h.unbind(0)) for h in hx] + + outputs = [] + for t in range(seq_len): + + x_t = x[:, t, :] + + for layer in range(self.num_layers): + # Retrieve weights + weights = self._all_weights[layer] + weight_ih = getattr(self, weights[0]) + weight_hh = getattr(self, weights[1]) + if self.bias is True: + bias_ih = getattr(self, weights[2]) + bias_hh = getattr(self, weights[3]) + else: + bias_ih = None + bias_hh = None + + # Run cell + h_t[layer], c_t[layer] = self._lstm_cell( + x_t, h_t[layer], c_t[layer], weight_ih, bias_ih, weight_hh, bias_hh + ) + + # Apply dropout if in training mode + if layer < self.num_layers - 1 and self.dropout: + x_t = F.dropout(h_t[layer], p=self.dropout, training=self.training) + else: # No dropout after the last layer + x_t = h_t[layer] + + outputs.append(x_t) + + outputs = torch.stack(outputs, dim=1) + if self.batch_first is False: + outputs = outputs.permute( + 1, 0, 2 + ) # Change back (batch, seq_len, features) to (seq_len, batch, features) + + return outputs, (torch.stack(h_t, 0), torch.stack(c_t, 0)) + + def forward(self, input, hx=None): # noqa: F811 + real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size + if input.dim() != 3: + raise ValueError( + f"LSTM: Expected input to be 3D, got {input.dim()}D instead" + ) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + if hx is None: + h_zeros = torch.zeros( + self.num_layers, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + else: + self.check_forward_args(input, hx, batch_sizes=None) + result = self._lstm(input, hx) + output = result[0] + hidden = result[1] + return output, hidden + + class LSTMModule(ModuleBase): """An embedder for an LSTM module. @@ -62,6 +352,7 @@ class LSTMModule(ModuleBase): dropout: If non-zero, introduces a `Dropout` layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 + python_based: If ``True``, will use a full Python implementation of the LSTM cell. Default: ``False`` Keyword Args: in_key (str or tuple of str): the input key of the module. Exclusive use @@ -142,6 +433,7 @@ def __init__( dropout=0, proj_size=0, bidirectional=False, + python_based=False, *, in_key=None, in_keys=None, @@ -165,17 +457,30 @@ def __init__( raise ValueError("The input lstm must have batch_first=True.") if bidirectional: raise ValueError("The input lstm cannot be bidirectional.") - lstm = nn.LSTM( - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - bias=bias, - dropout=dropout, - proj_size=proj_size, - device=device, - batch_first=True, - bidirectional=False, - ) + if python_based is True: + lstm = LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + dropout=dropout, + proj_size=proj_size, + device=device, + batch_first=True, + bidirectional=False, + ) + else: + lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + dropout=dropout, + proj_size=proj_size, + device=device, + batch_first=True, + bidirectional=False, + ) if not ((in_key is None) ^ (in_keys is None)): raise ValueError( f"Either in_keys or in_key must be specified but not both or none. Got {in_keys} and {in_key} respectively." @@ -413,6 +718,283 @@ def _lstm( return tuple(out) +class GRUCell(RNNCellBase): + r"""A gated recurrent unit (GRU) cell that performs the same operation as nn.LSTMCell but is fully coded in Python. + + .. note:: + This class is implemented without relying on CuDNN, which makes it compatible with :func:`torch.vmap` and :func:`torch.compile`. + + Examples: + >>> import torch + >>> from torchrl.modules.tensordict_module.rnn import GRUCell + >>> device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu") + >>> B = 2 + >>> N_IN = 10 + >>> N_OUT = 20 + >>> V = 4 # vector size + >>> gru_cell = GRUCell(input_size=N_IN, hidden_size=N_OUT, device=device) + + # single call + >>> x = torch.randn(B, 10, device=device) + >>> h0 = torch.zeros(B, 20, device=device) + >>> with torch.no_grad(): + ... h1 = gru_cell(x, h0) + + # vectorised call - not possible with nn.GRUCell + >>> def call_gru(x, h): + ... h_out = gru_cell(x, h) + ... return h_out + >>> batched_call = torch.vmap(call_gru) + >>> x = torch.randn(V, B, 10, device=device) + >>> h0 = torch.zeros(V, B, 20, device=device) + >>> with torch.no_grad(): + ... h1 = batched_call(x, h0) + """ + + __doc__ += nn.GRUCell.__doc__ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + if input.dim() not in (1, 2): + raise ValueError( + f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) + if hx is not None and hx.dim() not in (1, 2): + raise ValueError( + f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + ret = self.gru_cell(input, hx) + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + def gru_cell(self, x, hx): + + x = x.view(-1, x.size(1)) + + gate_x = F.linear(x, self.weight_ih, self.bias_ih) + gate_h = F.linear(hx, self.weight_hh, self.bias_hh) + + i_r, i_i, i_n = gate_x.chunk(3, 1) + h_r, h_i, h_n = gate_h.chunk(3, 1) + + resetgate = F.sigmoid(i_r + h_r) + inputgate = F.sigmoid(i_i + h_i) + newgate = F.tanh(i_n + (resetgate * h_n)) + + hy = newgate + inputgate * (hx - newgate) + + return hy + + +# copy GRU +class GRUBase(nn.RNNBase): + """A Base module for GRU. Inheriting from GRUBase enables compatibility with torch.compile.""" + + def __init__(self, *args, **kwargs): + return super().__init__("GRU", *args, **kwargs) + + +for attr in nn.GRU.__dict__: + if attr != "__init__": + setattr(GRUBase, attr, getattr(nn.GRU, attr)) + + +class GRU(GRUBase): + """A PyTorch module for executing multiple steps of a multi-layer GRU. The module behaves exactly like :class:`torch.nn.GRU`, but this implementation is exclusively coded in Python. + + .. note:: + This class is implemented without relying on CuDNN, which makes it compatible with :func:`torch.vmap` and :func:`torch.compile`. + + Examples: + >>> import torch + >>> from torchrl.modules.tensordict_module.rnn import GRU + + >>> device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu") + >>> B = 2 + >>> T = 4 + >>> N_IN = 10 + >>> N_OUT = 20 + >>> N_LAYERS = 2 + >>> V = 4 # vector size + >>> gru = GRU( + ... input_size=N_IN, + ... hidden_size=N_OUT, + ... device=device, + ... num_layers=N_LAYERS, + ... ) + + # single call + >>> x = torch.randn(B, T, N_IN, device=device) + >>> h0 = torch.zeros(N_LAYERS, B, N_OUT, device=device) + >>> with torch.no_grad(): + ... h1 = gru(x, h0) + + # vectorised call - not possible with nn.GRU + >>> def call_gru(x, h): + ... h_out = gru(x, h) + ... return h_out + >>> batched_call = torch.vmap(call_gru) + >>> x = torch.randn(V, B, T, 10, device=device) + >>> h0 = torch.zeros(V, N_LAYERS, B, N_OUT, device=device) + >>> with torch.no_grad(): + ... h1 = batched_call(x, h0) + """ + + __doc__ += nn.GRU.__doc__ + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = True, + dropout: float = 0.0, + bidirectional: bool = False, + device=None, + dtype=None, + ) -> None: + + if bidirectional is True: + raise NotImplementedError( + "Bidirectional LSTMs are not supported yet in this implementation." + ) + + super().__init__( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + batch_first=batch_first, + dropout=dropout, + bidirectional=False, + device=device, + dtype=dtype, + ) + + @staticmethod + def _gru_cell(x, hx, weight_ih, bias_ih, weight_hh, bias_hh): + x = x.view(-1, x.size(1)) + + gate_x = F.linear(x, weight_ih, bias_ih) + gate_h = F.linear(hx, weight_hh, bias_hh) + + i_r, i_i, i_n = gate_x.chunk(3, 1) + h_r, h_i, h_n = gate_h.chunk(3, 1) + + resetgate = (i_r + h_r).sigmoid() + inputgate = (i_i + h_i).sigmoid() + newgate = (i_n + (resetgate * h_n)).tanh() + + hy = newgate + inputgate * (hx - newgate) + + return hy + + def _gru(self, x, hx): + + if not self.batch_first: + x = x.permute( + 1, 0, 2 + ) # Change (seq_len, batch, features) to (batch, seq_len, features) + + bs, seq_len, input_size = x.size() + h_t = list(hx.unbind(0)) + + outputs = [] + + for t in range(seq_len): + x_t = x[:, t, :] + + for layer in range(self.num_layers): + + # Retrieve weights + weights = self._all_weights[layer] + weight_ih = getattr(self, weights[0]) + weight_hh = getattr(self, weights[1]) + if self.bias is True: + bias_ih = getattr(self, weights[2]) + bias_hh = getattr(self, weights[3]) + else: + bias_ih = None + bias_hh = None + + h_t[layer] = self._gru_cell( + x_t, + h_t[layer], + weight_ih, + bias_ih, + weight_hh, + bias_hh, + ) + + # Apply dropout if in training mode and not the last layer + if layer < self.num_layers - 1 and self.dropout: + x_t = F.dropout(h_t[layer], p=self.dropout, training=self.training) + else: + x_t = h_t[layer] + + outputs.append(x_t) + + outputs = torch.stack(outputs, dim=1) + if self.batch_first is False: + outputs = outputs.permute( + 1, 0, 2 + ) # Change back (batch, seq_len, features) to (seq_len, batch, features) + + return outputs, torch.stack(h_t, 0) + + def forward(self, input, hx=None): # noqa: F811 + if input.dim() != 3: + raise ValueError( + f"GRU: Expected input to be 3D, got {input.dim()}D instead" + ) + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" + ) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + if hx is None: + hx = torch.zeros( + self.num_layers, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + + self.check_forward_args(input, hx, batch_sizes=None) + result = self._gru(input, hx) + + output = result[0] + hidden = result[1] + + return output, hidden + + class GRUModule(ModuleBase): """An embedder for an GRU module. @@ -446,6 +1028,7 @@ class GRUModule(ModuleBase): GRU layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 proj_size: If ``> 0``, will use GRU with projections of corresponding size. Default: 0 + python_based: If ``True``, will use a full Python implementation of the GRU cell. Default: ``False`` Keyword Args: in_key (str or tuple of str): the input key of the module. Exclusive use @@ -552,6 +1135,7 @@ def __init__( batch_first=True, dropout=0, bidirectional=False, + python_based=False, *, in_key=None, in_keys=None, @@ -575,16 +1159,29 @@ def __init__( raise ValueError("The input gru must have batch_first=True.") if bidirectional: raise ValueError("The input gru cannot be bidirectional.") - gru = nn.GRU( - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - bias=bias, - dropout=dropout, - device=device, - batch_first=True, - bidirectional=False, - ) + + if python_based is True: + gru = GRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + dropout=dropout, + device=device, + batch_first=True, + bidirectional=False, + ) + else: + gru = nn.GRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + dropout=dropout, + device=device, + batch_first=True, + bidirectional=False, + ) if not ((in_key is None) ^ (in_keys is None)): raise ValueError( f"Either in_keys or in_key must be specified but not both or none. Got {in_keys} and {in_key} respectively." diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index f3aff0da1d2..1c43d536fe8 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -160,7 +160,7 @@ def _call_actor_net( log_prob_key: NestedKey, ): # TODO: extend to handle time dimension (and vmap?) - log_pi = actor_net(data.select(actor_net.in_keys)).get(log_prob_key) + log_pi = actor_net(data.select(*actor_net.in_keys)).get(log_prob_key) return log_pi