From 44fd0b2c8668082e81ad15efe47c01e877d7ee7d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 11 Feb 2025 22:07:54 +0000 Subject: [PATCH] [Doc] Solve ref issues in docstrings ghstack-source-id: 09823fa85a94115291e7434478776fb0834f9b39 Pull Request resolved: https://github.com/pytorch/rl/pull/2776 (cherry picked from commit f5445a4bd77b7046e7e62e0d5622d13c5c0ee799) --- .github/workflows/docs.yml | 2 +- docs/source/content_generation.py | 4 + docs/source/reference/envs.rst | 3 +- torchrl/_utils.py | 4 +- torchrl/collectors/collectors.py | 9 +- torchrl/collectors/distributed/generic.py | 4 +- torchrl/collectors/distributed/ray.py | 2 + torchrl/collectors/distributed/rpc.py | 4 +- torchrl/collectors/distributed/sync.py | 4 +- torchrl/data/datasets/common.py | 2 +- torchrl/data/datasets/openx.py | 4 +- torchrl/data/map/tdstorage.py | 2 +- torchrl/data/map/tree.py | 19 ++- torchrl/data/postprocs/postprocs.py | 1 + torchrl/data/replay_buffers/replay_buffers.py | 32 ++-- torchrl/data/replay_buffers/samplers.py | 4 +- torchrl/data/replay_buffers/storages.py | 8 +- torchrl/data/rlhf/dataset.py | 4 +- torchrl/data/rlhf/utils.py | 2 + torchrl/data/tensor_specs.py | 18 +-- torchrl/envs/batched_envs.py | 6 +- torchrl/envs/common.py | 12 +- torchrl/envs/custom/llm.py | 2 - torchrl/envs/gym_like.py | 35 ++-- torchrl/envs/libs/brax.py | 4 +- torchrl/envs/libs/dm_control.py | 4 +- torchrl/envs/libs/envpool.py | 4 +- torchrl/envs/libs/gym.py | 36 ++--- torchrl/envs/libs/habitat.py | 2 +- torchrl/envs/libs/jumanji.py | 4 +- torchrl/envs/libs/openml.py | 2 +- torchrl/envs/libs/openspiel.py | 16 +- torchrl/envs/libs/robohive.py | 2 +- torchrl/envs/libs/unity_mlagents.py | 4 +- torchrl/envs/model_based/common.py | 12 +- torchrl/envs/transforms/rlhf.py | 2 +- torchrl/envs/transforms/transforms.py | 25 +-- torchrl/modules/models/exploration.py | 4 +- torchrl/modules/models/model_based.py | 1 + torchrl/modules/models/models.py | 2 +- torchrl/modules/models/multiagent.py | 8 +- torchrl/modules/planners/mppi.py | 1 + torchrl/modules/tensordict_module/actors.py | 17 +- .../modules/tensordict_module/exploration.py | 6 +- .../tensordict_module/probabilistic.py | 6 +- torchrl/objectives/cql.py | 2 +- torchrl/objectives/crossq.py | 2 +- torchrl/objectives/decision_transformer.py | 2 +- torchrl/objectives/sac.py | 2 +- torchrl/record/loggers/csv.py | 15 +- tutorials/sphinx-tutorials-save/README.rst | 4 + tutorials/sphinx-tutorials/multiagent_ppo.py | 2 +- tutorials/sphinx-tutorials/pendulum.py | 5 +- tutorials/sphinx-tutorials/rb_tutorial.py | 152 +++++++++--------- 54 files changed, 281 insertions(+), 254 deletions(-) create mode 100644 tutorials/sphinx-tutorials-save/README.rst diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 6c99cb9af05..5d6f06123d0 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -112,7 +112,7 @@ jobs: cd ./docs # timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi # bash -ic "PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi - PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build + PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build -v cd .. cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}" diff --git a/docs/source/content_generation.py b/docs/source/content_generation.py index e24dbd33a04..29e1afff29d 100644 --- a/docs/source/content_generation.py +++ b/docs/source/content_generation.py @@ -83,6 +83,10 @@ def generate_tutorial_references(tutorial_path: str, file_type: str) -> None: for f in os.listdir(tutorial_path) if f.endswith((".py", ".rst", ".png")) ] + # Make rb_tutorial.py the first one + file_paths = [p for p in file_paths if p.endswith("rb_tutorial.py")] + [ + p for p in file_paths if not p.endswith("rb_tutorial.py") + ] for file_path in file_paths: shutil.copyfile(file_path, os.path.join(target_path, Path(file_path).name)) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 8b64b87f8bd..0610d1229d7 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -212,6 +212,7 @@ component (sub-environments or agents) should be reset. This allows to reset some but not all of the components. The ``"_reset"`` key has two distinct functionalities: + 1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may not be present in the input tensordict. TorchRL's convention is that the absence of the ``"_reset"`` key at a given ``"done"`` level indicates @@ -899,7 +900,7 @@ to be able to create this other composition: Hash InitTracker KLRewardTransform - LineariseReward + LineariseRewards NoopResetEnv ObservationNorm ObservationTransform diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 5fd2acf6b41..10f303d6885 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -513,7 +513,7 @@ def reset(cls, setters_dict: Dict[str, implement_for] = None): """Resets the setters in setter_dict. ``setter_dict`` is a copy of implementations. We just need to iterate through its - values and call :meth:`~.module_set` for each. + values and call :meth:`module_set` for each. """ if VERBOSE: @@ -888,7 +888,7 @@ def _standardize( exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: (). mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None. std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None. - eps (float): epsilon to be used for numerical stability. Default: float32 resolution. + eps (:obj:`float`): epsilon to be used for numerical stability. Default: float32 resolution. """ if eps is None: diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 1b57270bb3e..a7552fa2d1e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -339,10 +339,12 @@ class SyncDataCollector(DataCollectorBase): instances) it will be wrapped in a `nn.Module` first. Then, the collector will try to assess if these modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or any typing with a single argument typed as a subclass of ``TensorDictBase``) then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: @@ -1462,6 +1464,7 @@ class _MultiDataCollector(DataCollectorBase): ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or any typing with a single argument typed as a subclass of ``TensorDictBase``) then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. @@ -1548,7 +1551,7 @@ class _MultiDataCollector(DataCollectorBase): reset_when_done (bool, optional): if ``True`` (default), an environment that return a ``True`` value in its ``"done"`` or ``"truncated"`` entry will be reset at the corresponding indices. - update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()` + update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()` will be called before (sync) or after (async) each data collection. Defaults to ``False``. preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers @@ -2774,10 +2777,12 @@ class aSyncDataCollector(MultiaSyncDataCollector): instances) it will be wrapped in a `nn.Module` first. Then, the collector will try to assess if these modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or any typing with a single argument typed as a subclass of ``TensorDictBase``) then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: @@ -2863,7 +2868,7 @@ class aSyncDataCollector(MultiaSyncDataCollector): reset_when_done (bool, optional): if ``True`` (default), an environment that return a ``True`` value in its ``"done"`` or ``"truncated"`` entry will be reset at the corresponding indices. - update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()` + update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()` will be called before (sync) or after (async) each data collection. Defaults to ``False``. preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 5e49ad95f49..0e67370bb25 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -262,10 +262,12 @@ class DistributedDataCollector(DataCollectorBase): instances) it will be wrapped in a `nn.Module` first. Then, the collector will try to assess if these modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or any typing with a single argument typed as a subclass of ``TensorDictBase``) then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: @@ -341,7 +343,7 @@ class DistributedDataCollector(DataCollectorBase): collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. - collector_class (type or str, optional): a collector class for the remote node. Can be + collector_class (Type or str, optional): a collector class for the remote node. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, :class:`~torchrl.collectors.MultiaSyncDataCollector` diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 1f088c2c404..715e41f50fd 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -135,10 +135,12 @@ class RayCollector(DataCollectorBase): instances) it will be wrapped in a `nn.Module` first. Then, the collector will try to assess if these modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or any typing with a single argument typed as a subclass of ``TensorDictBase``) then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 0c4922778ef..03fb2048a85 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -110,10 +110,12 @@ class RPCDataCollector(DataCollectorBase): instances) it will be wrapped in a `nn.Module` first. Then, the collector will try to assess if these modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or any typing with a single argument typed as a subclass of ``TensorDictBase``) then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: @@ -190,7 +192,7 @@ class RPCDataCollector(DataCollectorBase): ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. - collector_class (type or str, optional): a collector class for the remote node. Can be + collector_class (Type or str, optional): a collector class for the remote node. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, :class:`~torchrl.collectors.MultiaSyncDataCollector` diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index b90111763d7..6aa66dfbdd2 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -143,10 +143,12 @@ class DistributedSyncDataCollector(DataCollectorBase): instances) it will be wrapped in a `nn.Module` first. Then, the collector will try to assess if these modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + - If the policy forward signature matches any of ``forward(self, tensordict)``, ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or any typing with a single argument typed as a subclass of ``TensorDictBase``) then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. Keyword Args: @@ -222,7 +224,7 @@ class DistributedSyncDataCollector(DataCollectorBase): collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. - collector_class (type or str, optional): a collector class for the remote node. Can be + collector_class (Type or str, optional): a collector class for the remote node. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, :class:`~torchrl.collectors.MultiaSyncDataCollector` diff --git a/torchrl/data/datasets/common.py b/torchrl/data/datasets/common.py index 6668cb03872..11113266ff6 100644 --- a/torchrl/data/datasets/common.py +++ b/torchrl/data/datasets/common.py @@ -72,7 +72,7 @@ def preprocess( Args and Keyword Args are forwarded to :meth:`~tensordict.TensorDictBase.map`. - The dataset can subsequently be deleted using :meth:`~.delete`. + The dataset can subsequently be deleted using :meth:`delete`. Keyword Args: dest (path or equivalent): a path to the location of the new dataset. diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 2dbf0720a37..01f5fdf98ce 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -66,7 +66,7 @@ class for more information on how to interact with non-tensor data sampling strategy. If the ``batch_size`` is ``None`` (default), iterating over the dataset will deliver trajectories one at a time *whereas* calling - :meth:`~.sample` will *still* require a batch-size to be provided. + :meth:`sample` will *still* require a batch-size to be provided. Keyword Args: shuffle (bool, optional): if ``True``, trajectories are delivered in a @@ -115,7 +115,7 @@ class for more information on how to interact with non-tensor data replacement (bool, optional): if ``False``, sampling will be done without replacement. Defaults to ``True`` for downloaded datasets, ``False`` for streamed datasets. - pad (bool, float or None): if ``True``, trajectories of insufficient length + pad (bool, :obj:`float` or None): if ``True``, trajectories of insufficient length given the `slice_len` or `num_slices` arguments will be padded with 0s. If another value is provided, it will be used for padding. If ``False`` or ``None`` (default) any encounter with a trajectory of diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index 9413033bac4..34d4bb8d0fa 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -193,7 +193,7 @@ def from_tensordict_pair( in the storage. Defaults to ``None`` (all keys are registered). max_size (int, optional): the maximum number of elements in the storage. Ignored if the ``storage_constructor`` is passed. Defaults to ``1000``. - storage_constructor (type, optional): a type of tensor storage. + storage_constructor (Type, optional): a type of tensor storage. Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`. Other options include :class:`~tensordict.nn.storage.FixedStorage`. hash_module (Callable, optional): a hash function to use in the :class:`~torchrl.data.map.QueryModule`. diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index c09db75aa5b..d7fd72869dd 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -50,7 +50,7 @@ class Tree(TensorClass["nocast"]): node_id (int): A unique identifier for this node. rollout (TensorDict): Rollout data following the observation encoded in this node, in a TED format. If there are multiple actions taken at this node, subtrees are stored in the corresponding - entry. Rollouts can be reconstructed using the :meth:`~.rollout_from_path` method. + entry. Rollouts can be reconstructed using the :meth:`rollout_from_path` method. node (TensorDict): Data defining this node (e.g., observations) before the next branching. Entries usually matches the ``in_keys`` in ``MCTSForest.node_map``. subtree (Tree): A stack of subtrees produced when actions are taken. @@ -215,7 +215,7 @@ def node_observation(self) -> torch.Tensor | TensorDictBase: """Returns the observation associated with this particular node. This is the observation (or bag of observations) that defines the node before a branching occurs. - If the node contains a :attr:`~.rollout` attribute, the node observation is typically identical to the + If the node contains a :meth:`rollout` attribute, the node observation is typically identical to the observation resulting from the last action undertaken, i.e., ``node.rollout[..., -1]["next", "observation"]``. If more than one observation key is associated with the tree specs, a :class:`~tensordict.TensorDict` instance @@ -232,7 +232,7 @@ def node_observations(self) -> torch.Tensor | TensorDictBase: """Returns the observations associated with this particular node in a TensorDict format. This is the observation (or bag of observations) that defines the node before a branching occurs. - If the node contains a :attr:`~.rollout` attribute, the node observation is typically identical to the + If the node contains a :meth:`rollout` attribute, the node observation is typically identical to the observation resulting from the last action undertaken, i.e., ``node.rollout[..., -1]["next", "observation"]``. If more than one observation key is associated with the tree specs, a :class:`~tensordict.TensorDict` instance @@ -442,8 +442,11 @@ def num_vertices(self, *, count_repeat: bool = False) -> int: """Returns the number of unique vertices in the Tree. Keyword Args: - count_repeat (bool, optional): Determines whether to count repeated vertices. + count_repeat (bool, optional): Determines whether to count repeated + vertices. + - If ``False``, counts each unique vertex only once. + - If ``True``, counts vertices multiple times if they appear in different paths. Defaults to ``False``. @@ -629,16 +632,16 @@ class MCTSForest: ``node_map.max_size``. If none of these are provided, defaults to `1000`. done_keys (list of NestedKey, optional): the done keys of the environment. If not provided, defaults to ``("done", "terminated", "truncated")``. - The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + The :meth:`get_keys_from_env` can be used to automatically determine the keys. action_keys (list of NestedKey, optional): the action keys of the environment. If not provided, defaults to ``("action",)``. - The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + The :meth:`get_keys_from_env` can be used to automatically determine the keys. reward_keys (list of NestedKey, optional): the reward keys of the environment. If not provided, defaults to ``("reward",)``. - The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + The :meth:`get_keys_from_env` can be used to automatically determine the keys. observation_keys (list of NestedKey, optional): the observation keys of the environment. If not provided, defaults to ``("observation",)``. - The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + The :meth:`get_keys_from_env` can be used to automatically determine the keys. excluded_keys (list of NestedKey, optional): a list of keys to exclude from the data storage. consolidated (bool, optional): if ``True``, the data_map storage will be consolidated on disk. Defaults to ``False``. diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index a471a7b4d5f..7814c6cce14 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -182,6 +182,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: environment (i.e. before multi-step); - The "reward" values will be replaced by the newly computed rewards. + The ``"done"`` key can have either the shape of the tensordict OR the shape of the tensordict followed by a singleton dimension OR the shape of the tensordict followed by other diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index dfac5a14bbf..4e0ee36cd4a 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -92,7 +92,7 @@ class ReplayBuffer: prefetch (int, optional): number of next batches to be prefetched using multithreading. Defaults to None (no prefetching). transform (Transform, optional): Transform to be executed when - :meth:`~.sample` is called. + :meth:`sample` is called. To chain transforms use the :class:`~torchrl.envs.Compose` class. Transforms should be used with :class:`tensordict.TensorDict` content. A generic callable can also be passed if the replay buffer @@ -104,21 +104,21 @@ class ReplayBuffer: ``batch_size`` argument, or at sampling time. The former should be preferred whenever the batch-size is consistent across the experiment. If the batch-size is likely to change, it can be - passed to the :meth:`~.sample` method. This option is + passed to the :meth:`sample` method. This option is incompatible with prefetching (since this requires to know the batch-size in advance) as well as with samplers that have a ``drop_last`` argument. dim_extend (int, optional): indicates the dim to consider for - extension when calling :meth:`~.extend`. Defaults to ``storage.ndim-1``. + extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``. When using ``dim_extend > 0``, we recommend using the ``ndim`` argument in the storage instantiation if that argument is available, to let storages know that the data is multi-dimensional and keep consistent notions of storage-capacity and batch-size during sampling. - .. note:: This argument has no effect on :meth:`~.add` and - therefore should be used with caution when both :meth:`~.add` - and :meth:`~.extend` are used in a codebase. For example: + .. note:: This argument has no effect on :meth:`add` and + therefore should be used with caution when both :meth:`add` + and :meth:`extend` are used in a codebase. For example: >>> data = torch.zeros(3, 4) >>> rb = ReplayBuffer( @@ -541,12 +541,12 @@ def dumps(self, path): def loads(self, path): """Loads a replay buffer state at the given path. - The buffer should have matching components and be saved using :meth:`~.dumps`. + The buffer should have matching components and be saved using :meth:`dumps`. Args: path (Path or str): path where the replay buffer was saved. - See :meth:`~.dumps` for more info. + See :meth:`dumps` for more info. """ path = Path(path).absolute() @@ -566,15 +566,15 @@ def loads(self, path): self._batch_size = metadata["batch_size"] def save(self, *args, **kwargs): - """Alias for :meth:`~.dumps`.""" + """Alias for :meth:`dumps`.""" return self.dumps(*args, **kwargs) def dump(self, *args, **kwargs): - """Alias for :meth:`~.dumps`.""" + """Alias for :meth:`dumps`.""" return self.dumps(*args, **kwargs) def load(self, *args, **kwargs): - """Alias for :meth:`~.loads`.""" + """Alias for :meth:`loads`.""" return self.loads(*args, **kwargs) def register_save_hook(self, hook: Callable[[Any], Any]): @@ -931,21 +931,21 @@ class PrioritizedReplayBuffer(ReplayBuffer): ``batch_size`` argument, or at sampling time. The former should be preferred whenever the batch-size is consistent across the experiment. If the batch-size is likely to change, it can be - passed to the :meth:`~.sample` method. This option is + passed to the :meth:`sample` method. This option is incompatible with prefetching (since this requires to know the batch-size in advance) as well as with samplers that have a ``drop_last`` argument. dim_extend (int, optional): indicates the dim to consider for - extension when calling :meth:`~.extend`. Defaults to ``storage.ndim-1``. + extension when calling :meth:`extend`. Defaults to ``storage.ndim-1``. When using ``dim_extend > 0``, we recommend using the ``ndim`` argument in the storage instantiation if that argument is available, to let storages know that the data is multi-dimensional and keep consistent notions of storage-capacity and batch-size during sampling. - .. note:: This argument has no effect on :meth:`~.add` and - therefore should be used with caution when both :meth:`~.add` - and :meth:`~.extend` are used in a codebase. For example: + .. note:: This argument has no effect on :meth:`add` and + therefore should be used with caution when both :meth:`add` + and :meth:`extend` are used in a codebase. For example: >>> data = torch.zeros(3, 4) >>> rb = ReplayBuffer( diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index fa92d84295a..12019de3080 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -119,7 +119,7 @@ class RandomSampler(Sampler): Args: batch_size (int, optional): if provided, the batch size to be used by - the replay buffer when calling :meth:`~.ReplayBuffer.sample`. + the replay buffer when calling :meth:`ReplayBuffer.sample`. """ @@ -1848,7 +1848,7 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler): samples if they follow another of higher priority, and transitions with a high priority but closer to the end of a trajectory may never be sampled if they cannot be used as start points. Currently, it is the user responsibility to aggregate priorities across items of a trajectory using - :meth:`~.update_priority`. + :meth:`update_priority`. Args: alpha (:obj:`float`): exponent α determines how much prioritization is used, diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 344814e728c..8be38376a6c 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -194,15 +194,15 @@ def flatten(self): ) def save(self, *args, **kwargs): - """Alias for :meth:`~.dumps`.""" + """Alias for :meth:`dumps`.""" return self.dumps(*args, **kwargs) def dump(self, *args, **kwargs): - """Alias for :meth:`~.dumps`.""" + """Alias for :meth:`dumps`.""" return self.dumps(*args, **kwargs) def load(self, *args, **kwargs): - """Alias for :meth:`~.loads`.""" + """Alias for :meth:`loads`.""" return self.loads(*args, **kwargs) def __getstate__(self): @@ -1342,7 +1342,7 @@ class StorageEnsemble(Storage): transforms of the same length as storages. .. warning:: - This class signatures for :meth:`~.get` does not match other storages, as + This class signatures for :meth:`get` does not match other storages, as it will return a tuple ``(buffer_id, samples)`` rather than just the samples. .. warning:: diff --git a/torchrl/data/rlhf/dataset.py b/torchrl/data/rlhf/dataset.py index c1411b81a09..773665bf4fd 100644 --- a/torchrl/data/rlhf/dataset.py +++ b/torchrl/data/rlhf/dataset.py @@ -57,7 +57,7 @@ class TokenizedDatasetLoader: num_workers (int, optional): number of workers for :meth:`datasets.dataset.map` which is called during tokenization. Defaults to ``max(os.cpu_count() // 2, 1)``. - tokenizer_class (type, optional): A tokenizer class, such as + tokenizer_class (Type, optional): A tokenizer class, such as :class:`~transformers.AutoTokenizer` (default). tokenizer_model_name (str, optional): The model from which the vocabulary should be gathered. Defaults to ``"gpt2"``. @@ -182,7 +182,7 @@ def _tokenize( """Preprocesses a text dataset from ``datasets``. Args: - dataset (datasets.Dataset): a dataset loaded using :meth:`~.load_dataset`. + dataset (datasets.Dataset): a dataset loaded using :meth:`load_dataset`. excluded_features (sequence of str, optional): the features to exclude once tokenization is complete. Defaults to ``{"text", "prompt", "label", "valid_sample"}``. diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 775eb652c4b..26fb710f94c 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -259,6 +259,7 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio): Returns: A :class:`~tensordict.TensorDict` with the following keys: + - ``"action"``: the sequence of actions (generated tokens) - ``"input_ids"``: the input_ids passed to the generative model at each time step. @@ -280,6 +281,7 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio): training - ``("next", "reward_kl")``: The KL term from the reward. This is mainly for debugging and logging, it is not used in training. + """ rollout_generated = self._get_rollout_generated(generated, batch) rollout_attention_mask = (rollout_generated != self.EOS_TOKEN_ID).bool() diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index cadb7638f45..6be03d0f167 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -146,7 +146,7 @@ def _validate_iterable( Args: idx (Iterable[Any]): Iterable, may contain nested iterables - expected_type (type): Required item type in the Iterable (e.g. int) + expected_type (Type): Required item type in the Iterable (e.g. int) iterable_classname (str): Iterable type as a string (e.g. 'List'). Logging purpose only. """ for item in idx: @@ -864,7 +864,7 @@ def is_in(self, val: torch.Tensor | TensorDictBase) -> bool: def contains(self, item: torch.Tensor | TensorDictBase) -> bool: """If the value ``val`` could have been generated by the ``TensorSpec``, returns ``True``, otherwise ``False``. - See :meth:`~.is_in` for more information. + See :meth:`is_in` for more information. """ return self.is_in(item) @@ -943,7 +943,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: def sample(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: """Returns a random tensor in the space defined by the spec. - See :meth:`~.rand` for details. + See :meth:`rand` for details. """ return self.rand(shape=shape) @@ -968,7 +968,7 @@ def zero(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: ) def zeros(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: - """Proxy to :meth:`~.zero`.""" + """Proxy to :meth:`zero`.""" return self.zero(shape=shape) def one(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: @@ -990,7 +990,7 @@ def one(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: return self.zero(shape) + 1 def ones(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: - """Proxy to :meth:`~.one`.""" + """Proxy to :meth:`one`.""" return self.one(shape=shape) @abc.abstractmethod @@ -1494,7 +1494,7 @@ class OneHot(TensorSpec): elements will be mapped in a register to a series of unique one-hot binary vectors). mask (torch.Tensor or None): mask some of the possible outcomes when a - sample is taken. See :meth:`~.update_mask` for more information. + sample is taken. See :meth:`update_mask` for more information. Examples: >>> from torchrl.data.tensor_specs import OneHot @@ -2894,7 +2894,7 @@ class MultiOneHot(OneHot): the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. mask (torch.Tensor or None): mask some of the possible outcomes when a - sample is taken. See :meth:`~.update_mask` for more information. + sample is taken. See :meth:`update_mask` for more information. Examples: >>> ts = MultiOneHot((3,2,3)) @@ -3350,7 +3350,7 @@ class Categorical(TensorSpec): device (str, int or torch.device, optional): the device of the tensors. dtype (str or torch.dtype, optional): the dtype of the tensors. mask (torch.Tensor or None): A boolean mask to prevent some of the possible outcomes when a sample is taken. - See :meth:`~.update_mask` for more information. + See :meth:`update_mask` for more information. Examples: >>> categ = Categorical(3) @@ -4040,7 +4040,7 @@ class MultiCategorical(Categorical): remove_singleton (bool, optional): if ``True``, singleton samples (of size [1]) will be squeezed. Defaults to ``True``. mask (torch.Tensor or None): mask some of the possible outcomes when a - sample is taken. See :meth:`~.update_mask` for more information. + sample is taken. See :meth:`update_mask` for more information. Examples: >>> ts = MultiCategorical((3, 2, 3)) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 5475f42c61a..963c640c322 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -195,7 +195,7 @@ class BatchedEnvBase(EnvBase): .. note:: One can pass keyword arguments to each sub-environments using the following - technique: every keyword argument in :meth:`~.reset` will be passed to each + technique: every keyword argument in :meth:`reset` will be passed to each environment except for the ``list_of_kwargs`` argument which, if present, should contain a list of the same length as the number of workers with the worker-specific keyword arguments stored in a dictionary. @@ -1354,8 +1354,8 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta): >>> env = ParallelEnv(N, MyEnv(..., device="cpu")) .. warning:: - ParallelEnv disable gradients in all operations (:meth:`~.step`, - :meth:`~.reset` and :meth:`~.step_and_maybe_reset`) because gradients + ParallelEnv disable gradients in all operations (:meth:`step`, + :meth:`reset` and :meth:`step_and_maybe_reset`) because gradients cannot be passed through :class:`multiprocessing.Pipe` objects. Only :class:`~torchrl.envs.SerialEnv` will support backpropagation. diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 14be04ef985..4690772db7c 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -301,7 +301,7 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): run_type_checks (bool, optional): If ``True``, type-checks will occur at every reset and every step. Defaults to ``False``. allow_done_after_reset (bool, optional): if ``True``, an environment can - be done after a call to :meth:`~.reset` is made. Defaults to ``False``. + be done after a call to :meth:`reset` is made. Defaults to ``False``. spec_locked (bool, optional): if ``True``, the specs are locked and can only be modified if :meth:`~torchrl.envs.EnvBase.set_spec_lock_` is called. @@ -702,7 +702,7 @@ def cardinality(self, tensordict: TensorDictBase | None = None) -> int: - The action cardinality may depend on the action mask; - The shape can be dynamic, as in ``Unbound(shape=(-1))``. - In these cases, the :meth:`~.cardinality` should be overwritten, + In these cases, the :meth:`cardinality` should be overwritten, Args: tensordict (TensorDictBase, optional): a tensordict containing the data required to compute the cardinality. @@ -2130,9 +2130,9 @@ def register_gym( This method is designed with the following scopes in mind: - - Incorporate a TorchRL-first environment in a framework that uses Gym; - - Incorporate another environment (eg, DeepMind Control, Brax, Jumanji, ...) - in a framework that uses Gym. + - Incorporate a TorchRL-first environment in a framework that uses Gym; + - Incorporate another environment (eg, DeepMind Control, Brax, Jumanji, ...) + in a framework that uses Gym. Args: id (str): the name of the environment. Should follow the @@ -3365,7 +3365,7 @@ def step_and_maybe_reset( """Runs a step in the environment and (partially) resets it if needed. Args: - tensordict (TensorDictBase): an input data structure for the :meth:`~.step` + tensordict (TensorDictBase): an input data structure for the :meth:`step` method. This method allows to easily code non-stopping rollout functions. diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 4b8b1a5f21b..92f265e85d2 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -32,8 +32,6 @@ class LLMHashingEnv(EnvBase): .. figure:: /_static/img/rollout-llm.png :alt: Data collection loop with our LLM environment. - .. seealso:: the :ref:`Beam Search ` tutorial gives a practical example of how this env can be used. - Args: vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed. diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index bb5a4ddea43..857b3b96b2f 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -157,7 +157,7 @@ class GymLikeEnv(_EnvWrapper): where the outputs are the observation, reward and done state respectively. In this implementation, the info output is discarded (but specific keys can be read - by updating info_dict_reader, see :meth:`~.set_info_dict_reader` method). + by updating info_dict_reader, see :meth:`set_info_dict_reader` method). By default, the first output is written at the "observation" key-value pair in the output tensordict, unless the first output is a dictionary. In that case, each observation output will be put at the corresponding @@ -215,10 +215,11 @@ def read_done( Defaults to ``None``. Returns: a tuple with 4 boolean / tensor values, - - a terminated state, - - a truncated state, - - a done state, - - a boolean value indicating whether the frame_skip loop should be broken. + + - a terminated state, + - a truncated state, + - a done state, + - a boolean value indicating whether the frame_skip loop should be broken. """ if truncated is not None and done is None: @@ -418,16 +419,16 @@ def _output_transform( These three concepts have different usage: - - ``"terminated"`` indicated the final stage of a Markov Decision - Process. It means that one should not pay attention to the - upcoming observations (eg., in value functions) as they should be - regarded as not valid. - - ``"truncated"`` means that the environment has reached a stage where - we decided to stop the collection for some reason but the next - observation should not be discarded. If it were not for this - arbitrary decision, the collection could have proceeded further. - - ``"done"`` is either one or the other. It is to be interpreted as - "a reset should be called before the next step is undertaken". + - ``"terminated"`` indicated the final stage of a Markov Decision + Process. It means that one should not pay attention to the + upcoming observations (eg., in value functions) as they should be + regarded as not valid. + - ``"truncated"`` means that the environment has reached a stage where + we decided to stop the collection for some reason but the next + observation should not be discarded. If it were not for this + arbitrary decision, the collection could have proceeded further. + - ``"done"`` is either one or the other. It is to be interpreted as + "a reset should be called before the next step is undertaken". """ ... @@ -461,7 +462,7 @@ def set_info_dict_reader( .. note:: Automatically registering an info_dict reader should be done via - :meth:`~.auto_register_info_dict`, which will ensure that the env + :meth:`auto_register_info_dict`, which will ensure that the env specs are properly constructed. Examples: @@ -524,7 +525,7 @@ def auto_register_info_dict( Keyword Args: info_dict_reader (BaseInfoDictReader, optional): the info_dict_reader, if it is known in advance. - Unlike :meth:`~.set_info_dict_reader`, this method will create the primers necessary to get + Unlike :meth:`set_info_dict_reader`, this method will create the primers necessary to get :func:`~torchrl.envs.utils.check_env_specs` to run. Examples: diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 9542b8e71ff..8c55fa3255f 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -67,7 +67,7 @@ class BraxWrapper(_EnvWrapper): In ``brax``, this indicates the number of vectorized environments. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. Attributes: @@ -451,7 +451,7 @@ class BraxEnv(BraxWrapper): In ``brax``, this indicates the number of vectorized environments. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. Attributes: diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 2ca62e106f6..ba1fdcfc9ae 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -150,7 +150,7 @@ class DMControlWrapper(GymLikeEnv): rewards, actions and infos. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. Attributes: @@ -378,7 +378,7 @@ class DMControlEnv(DMControlWrapper): rewards, actions and infos. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. Attributes: diff --git a/torchrl/envs/libs/envpool.py b/torchrl/envs/libs/envpool.py index 599645dfdfc..a4339820b9f 100644 --- a/torchrl/envs/libs/envpool.py +++ b/torchrl/envs/libs/envpool.py @@ -44,7 +44,7 @@ class MultiThreadedEnvWrapper(_EnvWrapper): device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``torch.device("cpu")``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. Attributes: @@ -342,7 +342,7 @@ class MultiThreadedEnv(MultiThreadedEnvWrapper): device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``torch.device("cpu")``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. Examples: diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index f0626265486..89e58a046a2 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -858,7 +858,7 @@ class GymWrapper(GymLikeEnv, metaclass=_GymAsyncMeta): rewards, actions and infos. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. convert_actions_to_numpy (bool, optional): if ``True``, actions will be converted from tensors to numpy arrays and moved to CPU before being passed to the @@ -908,7 +908,7 @@ class GymWrapper(GymLikeEnv, metaclass=_GymAsyncMeta): .. note:: info dictionaries will be read using :class:`~torchrl.envs.gym_like.default_info_dict_reader` if no other reader is provided. To provide another reader, refer to - :meth:`~.set_info_dict_reader`. To automatically register the info_dict + :meth:`set_info_dict_reader`. To automatically register the info_dict content, refer to :meth:`torchrl.envs.GymLikeEnv.auto_register_info_dict`. For parallel (Vectorized) environments, the info dictionary reader is automatically set and should not be set manually. @@ -917,13 +917,13 @@ class GymWrapper(GymLikeEnv, metaclass=_GymAsyncMeta): The following spaces are accounted for provided that they can be represented by a torch.Tensor, a nested tensor and/or within a tensordict: - - spaces.Box - - spaces.Sequence - - spaces.Tuple - - spaces.Discrete - - spaces.MultiBinary - - spaces.MultiDiscrete - - spaces.Dict + - spaces.Box + - spaces.Sequence + - spaces.Tuple + - spaces.Discrete + - spaces.MultiBinary + - spaces.MultiDiscrete + - spaces.Dict Some considerations should be made when working with gym spaces. For instance, a tuple of spaces can only be supported if the spaces are semantically identical (same dtype and same number of dimensions). @@ -1522,7 +1522,7 @@ class GymEnv(GymWrapper): rewards, actions and infos. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. Attributes: @@ -1581,20 +1581,20 @@ class GymEnv(GymWrapper): .. note:: info dictionaries will be read using :class:`~torchrl.envs.gym_like.default_info_dict_reader` if no other reader is provided. To provide another reader, refer to - :meth:`~.set_info_dict_reader`. To automatically register the info_dict + :meth:`set_info_dict_reader`. To automatically register the info_dict content, refer to :meth:`torchrl.envs.GymLikeEnv.auto_register_info_dict`. .. note:: Gym spaces are not completely covered. The following spaces are accounted for provided that they can be represented by a torch.Tensor, a nested tensor and/or within a tensordict: - - spaces.Box - - spaces.Sequence - - spaces.Tuple - - spaces.Discrete - - spaces.MultiBinary - - spaces.MultiDiscrete - - spaces.Dict + - spaces.Box + - spaces.Sequence + - spaces.Tuple + - spaces.Discrete + - spaces.MultiBinary + - spaces.MultiDiscrete + - spaces.Dict Some considerations should be made when working with gym spaces. For instance, a tuple of spaces can only be supported if the spaces are semantically identical (same dtype and same number of dimensions). diff --git a/torchrl/envs/libs/habitat.py b/torchrl/envs/libs/habitat.py index 4180c42b2dc..999277a2db8 100644 --- a/torchrl/envs/libs/habitat.py +++ b/torchrl/envs/libs/habitat.py @@ -80,7 +80,7 @@ class HabitatEnv(GymEnv): rewards, actions and infos. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. Attributes: diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 002a376dd2e..30c279fd3a6 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -169,7 +169,7 @@ class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender): device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``torch.device("cpu")``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. jit (bool, optional): whether the step and reset method should be wrapped in `jit`. Defaults to ``False``. @@ -776,7 +776,7 @@ class JumanjiEnv(JumanjiWrapper): With ``jumanji``, this indicates the number of vectorized environments. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. Attributes: diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index 55b246bd902..a35be60c1b9 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -47,7 +47,7 @@ class OpenMLEnv(EnvBase): device (torch.device or compatible, optional): the device where the input and output data is to be expected. Defaults to ``"cpu"``. batch_size (torch.Size or compatible, optional): the batch size of the environment, - ie. the number of elements samples and returned when a :meth:`~.reset` is + ie. the number of elements samples and returned when a :meth:`reset` is called. Defaults to an empty batch size, ie. one element is sampled at a time. diff --git a/torchrl/envs/libs/openspiel.py b/torchrl/envs/libs/openspiel.py index 8d2d76f453f..3a2ab55cd13 100644 --- a/torchrl/envs/libs/openspiel.py +++ b/torchrl/envs/libs/openspiel.py @@ -52,7 +52,7 @@ class OpenSpielWrapper(_EnvWrapper): batch_size (torch.Size, optional): the batch size of the environment. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for input/output. See @@ -64,8 +64,8 @@ class OpenSpielWrapper(_EnvWrapper): (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. return_state (bool, optional): if ``True``, "state" is included in the - output of :meth:`~.reset` and :meth:`~step`. The state can be given - to :meth:`~.reset` to reset to that state, rather than resetting to + output of :meth:`reset` and :meth:`~step`. The state can be given + to :meth:`reset` to reset to that state, rather than resetting to the initial state. Defaults to ``False``. @@ -113,7 +113,7 @@ class OpenSpielWrapper(_EnvWrapper): >>> print(env.available_envs) ['2048', 'add_noise', 'amazons', 'backgammon', ...] - :meth:`~.reset` can restore a specific state, rather than the initial + :meth:`reset` can restore a specific state, rather than the initial state, as long as ``return_state=True``. >>> import pyspiel @@ -521,7 +521,7 @@ class OpenSpielEnv(OpenSpielWrapper): batch_size (torch.Size, optional): the batch size of the environment. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for input/output. See @@ -533,8 +533,8 @@ class OpenSpielEnv(OpenSpielWrapper): (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. return_state (bool, optional): if ``True``, "state" is included in the - output of :meth:`~.reset` and :meth:`~step`. The state can be given - to :meth:`~.reset` to reset to that state, rather than resetting to + output of :meth:`reset` and :meth:`~step`. The state can be given + to :meth:`reset` to reset to that state, rather than resetting to the initial state. Defaults to ``False``. @@ -580,7 +580,7 @@ class OpenSpielEnv(OpenSpielWrapper): >>> print(env.available_envs) ['2048', 'add_noise', 'amazons', 'backgammon', ...] - :meth:`~.reset` can restore a specific state, rather than the initial state, + :meth:`reset` can restore a specific state, rather than the initial state, as long as ``return_state=True``. >>> from torchrl.envs import OpenSpielEnv diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index ab090a6e837..2a4e04f7d71 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -110,7 +110,7 @@ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild): ``RoboHiveEnv`` since vectorized environments are not supported within the class. To execute more than one environment at a time, see :class:`~torchrl.envs.ParallelEnv`. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. Attributes: diff --git a/torchrl/envs/libs/unity_mlagents.py b/torchrl/envs/libs/unity_mlagents.py index 5aeabc4d0aa..6565453b075 100644 --- a/torchrl/envs/libs/unity_mlagents.py +++ b/torchrl/envs/libs/unity_mlagents.py @@ -54,7 +54,7 @@ class UnityMLAgentsWrapper(_EnvWrapper): batch_size (torch.Size, optional): the batch size of the environment. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for input/output. See @@ -539,7 +539,7 @@ class UnityMLAgentsEnv(UnityMLAgentsWrapper): batch_size (torch.Size, optional): the batch size of the environment. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated - for envs to be ``done`` just after :meth:`~.reset` is called. + for envs to be ``done`` just after :meth:`reset` is called. Defaults to ``False``. group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for input/output. See diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index f0b2b90fa82..202960eec07 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -84,12 +84,12 @@ class ModelBasedEnvBase(EnvBase): Properties: - - observation_spec (Composite): sampling spec of the observations; - - action_spec (TensorSpec): sampling spec of the actions; - - reward_spec (TensorSpec): sampling spec of the rewards; - - input_spec (Composite): sampling spec of the inputs; - - batch_size (torch.Size): batch_size to be used by the env. If not set, the env accept tensordicts of all batch sizes. - - device (torch.device): device where the env input and output are expected to live + observation_spec (Composite): sampling spec of the observations; + action_spec (TensorSpec): sampling spec of the actions; + reward_spec (TensorSpec): sampling spec of the rewards; + input_spec (Composite): sampling spec of the inputs; + batch_size (torch.Size): batch_size to be used by the env. If not set, the env accept tensordicts of all batch sizes. + device (torch.device): device where the env input and output are expected to live Args: world_model (nn.Module): model that generates world states and its corresponding rewards; diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index ab469ecec13..f237183752a 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -36,7 +36,7 @@ class KLRewardTransform(Transform): .. note:: If the parameters are not differentiable (default), they will *not* follow the module when dtype or device casting operations will be called - (such as :meth:`~.cuda`, :meth:`~.to` etc.). When ``requires_grad=True``, + (such as :meth:`cuda`, :meth:`to` etc.). When ``requires_grad=True``, casting operations will work as expected. Examples: diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 6372a0e4294..9bba4fdede9 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -313,10 +313,10 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: """Reads the input tensordict, and for the selected keys, applies the transform. For any operation that relates exclusively to the parent env (e.g. FrameSkip), - modify the _step method instead. :meth:`~._call` should only be overwritten + modify the _step method instead. :meth:`_call` should only be overwritten if a modification of the input tensordict is needed. - :meth:`~._call` will be called by :meth:`TransformedEnv.step` and + :meth:`_call` will be called by :meth:`TransformedEnv.step` and :meth:`TransformedEnv.reset`. """ @@ -351,13 +351,13 @@ def _step( ) -> TensorDictBase: """The parent method of a transform during the ``env.step`` execution. - This method should be overwritten whenever the :meth:`~._step` needs to be - adapted. Unlike :meth:`~._call`, it is assumed that :meth:`~._step` + This method should be overwritten whenever the :meth:`_step` needs to be + adapted. Unlike :meth:`_call`, it is assumed that :meth:`_step` will execute some operation with the parent env or that it requires access to the content of the tensordict at time ``t`` and not only ``t+1`` (the ``"next"`` entry in the input tensordict). - :meth:`~._step` will only be called by :meth:`TransformedEnv.step` and + :meth:`_step` will only be called by :meth:`TransformedEnv.step` and not by :meth:`TransformedEnv.reset`. Args: @@ -413,7 +413,7 @@ def transform_output_spec(self, output_spec: Composite) -> Composite: """Transforms the output spec such that the resulting spec matches transform mapping. This method should generally be left untouched. Changes should be implemented using - :meth:`~.transform_observation_spec`, :meth:`~.transform_reward_spec` and :meth:`~.transformfull_done_spec`. + :meth:`transform_observation_spec`, :meth:`transform_reward_spec` and :meth:`transform_full_done_spec`. Args: output_spec (TensorSpec): spec before the transform @@ -1555,8 +1555,8 @@ class ClipTransform(Transform): Args: in_keys (list of NestedKeys): input entries (read) out_keys (list of NestedKeys): input entries (write) - in_keys_inv (list of NestedKeys): input entries (read) during :meth:`~.inv` calls. - out_keys_inv (list of NestedKeys): input entries (write) during :meth:`~.inv` calls. + in_keys_inv (list of NestedKeys): input entries (read) during :meth:`inv` calls. + out_keys_inv (list of NestedKeys): input entries (write) during :meth:`inv` calls. Keyword Args: low (scalar, optional): the lower bound of the clipped space. @@ -2322,7 +2322,7 @@ class UnsqueezeTransform(Transform): in_keys (list of NestedKeys): input entries (read). out_keys (list of NestedKeys): input entries (write). Defaults to ``in_keys`` if not provided. - in_keys_inv (list of NestedKeys): input entries (read) during :meth:`~.inv` calls. + in_keys_inv (list of NestedKeys): input entries (read) during :meth:`inv` calls. out_keys_inv (list of NestedKeys): input entries (write) during :meth:`~.inv` calls. Defaults to ``in_keys_in`` if not provided. """ @@ -3163,6 +3163,7 @@ def make_rb_transform_and_sampler( Returns: A tuple containing: + - transform (Transform): A transform that stacks frames on-the-fly during sampling. - sampler (SliceSampler): A sampler that ensures the correct sequence length is maintained. @@ -5672,7 +5673,7 @@ class TensorDictPrimer(Transform): random (bool, optional): if ``True``, the values will be drawn randomly from the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed. Defaults to `False`. - default_value (float, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random + default_value (:obj:`float`, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float, all elements of the tensors will be set to that value. If it is a callable, this callable is expected to return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value` @@ -8045,7 +8046,7 @@ class Reward2GoTransform(Transform): and not to the collector or within an environment. Args: - gamma (float or torch.Tensor): the discount factor. Defaults to 1.0. + gamma (:obj:`float` or torch.Tensor): the discount factor. Defaults to 1.0. in_keys (sequence of NestedKey): the entries to rename. Defaults to ``("next", "reward")`` if none is provided. out_keys (sequence of NestedKey): the entries to rename. Defaults to @@ -9295,7 +9296,7 @@ class AutoResetTransform(Transform): ``False`` overrides any subsequent filling keyword argument. This argumet can also be passed with the constructor method by passing a ``auto_reset_replace`` argument: ``env = FooEnv(..., auto_reset=True, auto_reset_replace=False)``. - fill_float (float or str, optional): The filling value for floating point tensors + fill_float (:obj:`float` or str, optional): The filling value for floating point tensors that terminate an episode. A value of ``None`` means no replacement (values are just placed as they are in the ``"next"`` entry even if they are not valid). fill_int (int, optional): The filling value for signed integer tensors diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 54c451ebb82..23b44dbbe98 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -617,13 +617,13 @@ class ConsistentDropoutModule(TensorDictModuleBase): Keyword Args: input_shape (tuple, optional): the shape of the input (non-batchted), used to generate the - tensordict primers with :meth:`~.make_tensordict_primer`. + tensordict primers with :meth:`make_tensordict_primer`. input_dtype (torch.dtype, optional): the dtype of the input for the primer. If none is pased, ``torch.get_default_dtype`` is assumed. .. note:: To use this class within a policy, one needs the mask to be reset at reset time. This can be achieved through a :class:`~torchrl.envs.TensorDictPrimer` transform that can be obtained - with :meth:`~.make_tensordict_primer`. See this method for more information. + with :meth:`make_tensordict_primer`. See this method for more information. Examples: >>> from tensordict import TensorDict diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 58d0362a118..60d4dd020ef 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -228,6 +228,7 @@ def forward(self, tensordict): The rollout requires a belief and posterior state primer. At each step, two probability distributions are built and sampled from: + - A prior distribution p(s_{t+1} | s_t, a_t, b_t) where b_t is a deterministic transform of the form b_t(s_{t-1}, a_{t-1}). The previous state s_t is sampled according to the posterior diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index a92d59fbaae..bb170e7ec38 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1519,7 +1519,7 @@ class OnlineDTActor(nn.Module): action_dim (int): action dimension. transformer_config (Dict or :class:`DecisionTransformer.DTConfig`): config for the GPT2 transformer. - Defaults to :meth:`~.default_config`. + Defaults to :meth:`default_config`. device (torch.device, optional): device to use. Defaults to None. Examples: diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index e352101ee55..71b5c254d0a 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -25,7 +25,7 @@ class MultiAgentNetBase(nn.Module): """A base class for multi-agent networks. .. note:: to initialize the MARL module parameters with the `torch.nn.init` - module, please refer to :meth:`~.get_stateful_net` and :meth:`~.from_stateful_net` + module, please refer to :meth:`get_stateful_net` and :meth:`from_stateful_net` methods. """ @@ -187,7 +187,7 @@ def get_stateful_net(self, copy: bool = True): If the parameters are modified in-place (recommended) there is no need to copy the parameters back into the MARL module. - See :meth:`~.from_stateful_net` for details on how to re-populate the MARL model with + See :meth:`from_stateful_net` for details on how to re-populate the MARL model with parameters that have been re-initialized out-of-place. Examples: @@ -230,7 +230,7 @@ def get_stateful_net(self, copy: bool = True): def from_stateful_net(self, stateful_net: nn.Module): """Populates the parameters given a stateful version of the network. - See :meth:`~.get_stateful_net` for details on how to gather a stateful version of the network. + See :meth:`get_stateful_net` for details on how to gather a stateful version of the network. Args: stateful_net (nn.Module): the stateful network from which the params should be @@ -326,7 +326,7 @@ class MultiAgentMLP(MultiAgentNetBase): **kwargs: for :class:`torchrl.modules.models.MLP` can be passed to customize the MLPs. .. note:: to initialize the MARL module parameters with the `torch.nn.init` - module, please refer to :meth:`~.get_stateful_net` and :meth:`~.from_stateful_net` + module, please refer to :meth:`get_stateful_net` and :meth:`from_stateful_net` methods. Examples: diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index 002094fb5d2..31c95650d25 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -15,6 +15,7 @@ class MPPIPlanner(MPCPlannerBase): """MPPI Planner Module. Reference: + - Model predictive path integral control using covariance variable importance sampling. (Williams, G., Aldrich, A., and Theodorou, E. A.) https://arxiv.org/abs/1509.01149 - Temporal Difference Learning for Model Predictive Control diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 6175bc8bf0c..7445423fa37 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1766,7 +1766,7 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): **not** modify the tensordict in-place. .. note:: If the action, observation or reward-to-go key is not standard, - the method :meth:`~.set_tensor_keys` should be used, e.g. + the method :meth:`set_tensor_keys` should be used, e.g. >>> dt_inference_wrapper.set_tensor_keys(action="foo", observation="bar", return_to_go="baz") @@ -1992,10 +1992,10 @@ class TanhModule(TensorDictModuleBase): If a Composite is provided, its key(s) must match the key(s) in out_keys. Otherwise, the key(s) of out_keys are assumed and the same spec is used for all outputs. - low (float, np.ndarray or torch.Tensor): the lower bound of the space. + low (:obj:`float`, np.ndarray or torch.Tensor): the lower bound of the space. If none is provided and no spec is provided, -1 is assumed. If a spec is provided, the minimum value of the spec will be retrieved. - high (float, np.ndarray or torch.Tensor): the higher bound of the space. + high (:obj:`float`, np.ndarray or torch.Tensor): the higher bound of the space. If none is provided and no spec is provided, 1 is assumed. If a spec is provided, the maximum value of the spec will be retrieved. clamp (bool, optional): if ``True``, the outputs will be clamped to be @@ -2149,11 +2149,12 @@ class LMHeadActorValueOperator(ActorValueOperator): """Builds an Actor-Value operator from an huggingface-like *LMHeadModel. This method: - - takes as input an huggingface-like *LMHeadModel - - extracts the final linear layer uses it as a base layer of the actor_head and - adds the sampling layer - - uses the common transformer as common model - - adds a linear critic + + - takes as input an huggingface-like *LMHeadModel + - extracts the final linear layer uses it as a base layer of the actor_head and + adds the sampling layer + - uses the common transformer as common model + - adds a linear critic Args: base_model (nn.Module): a torch model composed by a `.transformer` model and `.lm_head` linear layer diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index d43634dce32..c8baaa44938 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -57,7 +57,7 @@ class EGreedyModule(TensorDictModuleBase): device (torch.device, optional): the device of the exploration module. .. note:: - It is crucial to incorporate a call to :meth:`~.step` in the training loop + It is crucial to incorporate a call to :meth:`step` in the training loop to update the exploration factor. Since it is not easy to capture this omission no warning or exception will be raised if this is ommitted! @@ -270,7 +270,7 @@ class AdditiveGaussianModule(TensorDictModuleBase): .. note:: It is - crucial to incorporate a call to :meth:`~.step` in the training loop + crucial to incorporate a call to :meth:`step` in the training loop to update the exploration factor. Since it is not easy to capture this omission no warning or exception will be raised if this is ommitted! @@ -424,7 +424,7 @@ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase): .. note:: It is - crucial to incorporate a call to :meth:`~.step` in the training loop + crucial to incorporate a call to :meth:`step` in the training loop to update the exploration factor. Since it is not easy to capture this omission no warning or exception will be raised if this is ommitted! diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 65d84f78301..ecb03c54615 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -36,8 +36,8 @@ class SafeProbabilisticModule(ProbabilisticTensorDictModule): argument and the ``interaction_type()`` global function. `SafeProbabilisticModule` can be used to construct the distribution - (through the :meth:`~.get_dist` method) and/or sampling from this distribution - (through a regular :meth:`~.__call__` to the module). + (through the :meth:`get_dist` method) and/or sampling from this distribution + (through a regular :meth:`__call__` to the module). A `SafeProbabilisticModule` instance has two main features: @@ -45,7 +45,7 @@ class SafeProbabilisticModule(ProbabilisticTensorDictModule): - It uses a real mapping R^n -> R^m to create a distribution in R^d from which values can be sampled or computed. - When the :meth:`~.__call__` and :meth:`~.forward` method are called, a distribution is + When the :meth:`__call__` and :meth:`~.forward` method are called, a distribution is created, and a value computed (depending on the ``interaction_type`` value, 'dist.mean', 'dist.mode', 'dist.median' attributes could be used, as well as the 'dist.rsample', 'dist.sample' method). The sampling step is skipped if the supplied diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index c1a89df3545..1f84a2e0900 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -72,7 +72,7 @@ class CQLLoss(LossModule): initial value. Otherwise, alpha will be optimized to match the 'target_entropy' value. Default is ``False``. - target_entropy (float or str, optional): Target entropy for the + target_entropy (:obj:`float` or str, optional): Target entropy for the stochastic policy. Default is "auto", where target entropy is computed as :obj:`-prod(n_actions)`. delay_actor (bool, optional): Whether to separate the target actor diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 8f5df2b65cb..f9675b7af56 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -80,7 +80,7 @@ class CrossQLoss(LossModule): initial value. Otherwise, alpha will be optimized to match the 'target_entropy' value. Default is ``False``. - target_entropy (float or str, optional): Target entropy for the + target_entropy (:obj:`float` or str, optional): Target entropy for the stochastic policy. Default is "auto", where target entropy is computed as :obj:`-prod(n_actions)`. priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index c83ae56a137..d3958bd0c4d 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -39,7 +39,7 @@ class OnlineDTLoss(LossModule): initial value. Otherwise, alpha will be optimized to match the 'target_entropy' value. Default is ``False``. - target_entropy (float or str, optional): Target entropy for the + target_entropy (:obj:`float` or str, optional): Target entropy for the stochastic policy. Default is "auto", where target entropy is computed as :obj:`-prod(n_actions)`. samples_mc_entropy (int): number of samples to estimate the entropy diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 3b8c951e29d..4b78e4a19a6 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -105,7 +105,7 @@ class SACLoss(LossModule): initial value. Otherwise, alpha will be optimized to match the 'target_entropy' value. Default is ``False``. - target_entropy (float or str, optional): Target entropy for the + target_entropy (:obj:`float` or str, optional): Target entropy for the stochastic policy. Default is "auto", where target entropy is computed as :obj:`-prod(n_actions)`. delay_actor (bool, optional): Whether to separate the target actor diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 0052a6149db..b2ee1e24743 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -54,13 +54,14 @@ def add_video(self, tag, vid_tensor, global_step: Optional[int] = None, **kwargs """Writes a video on a file on disk. The video format can be one of - - `"pt"`: uses :func:`~torch.save` to save the video tensor); - - `"memmap"`: saved the file as memory-mapped array (reading this file will require - the dtype and shape to be known at read time); - - `"mp4"`: saves the file as an `.mp4` file using torchvision :func:`~torchvision.io.write_video` - API. Any ``kwargs`` passed to ``add_video`` will be transmitted to ``write_video``. - These include ``preset``, ``crf`` and others. - See ffmpeg's doc (https://trac.ffmpeg.org/wiki/Encode/H.264) for some more information of the video format options. + + - `"pt"`: uses :func:`~torch.save` to save the video tensor); + - `"memmap"`: saved the file as memory-mapped array (reading this file will require + the dtype and shape to be known at read time); + - `"mp4"`: saves the file as an `.mp4` file using torchvision :func:`~torchvision.io.write_video` + API. Any ``kwargs`` passed to ``add_video`` will be transmitted to ``write_video``. + These include ``preset``, ``crf`` and others. + See ffmpeg's doc (https://trac.ffmpeg.org/wiki/Encode/H.264) for some more information of the video format options. """ if global_step is None: diff --git a/tutorials/sphinx-tutorials-save/README.rst b/tutorials/sphinx-tutorials-save/README.rst new file mode 100644 index 00000000000..7995a1fbb2e --- /dev/null +++ b/tutorials/sphinx-tutorials-save/README.rst @@ -0,0 +1,4 @@ +README Tutos +============ + +Check the tutorials on torchrl documentation: https://pytorch.org/rl diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index f57e328f582..7d590db5aec 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -164,7 +164,7 @@ # Sampling frames_per_batch = 6_000 # Number of team frames collected per training iteration -n_iters = 10 # Number of sampling and training iterations +n_iters = 5 # Number of sampling and training iterations total_frames = frames_per_batch * n_iters # Training diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index dd5de6a5f99..426768b5b19 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -868,8 +868,9 @@ def simple_rollout(steps=100): # which demonstrates that the pendulum is upward and still as desired. # batch_size = 32 -pbar = tqdm.tqdm(range(20_000 // batch_size)) -scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 20_000) +n_iter = 1000 # set to 20_000 for a proper training +pbar = tqdm.tqdm(range(n_iter // batch_size)) +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_iter) logs = defaultdict(list) for _ in pbar: diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 2a852f0e364..3bc5adc7ce4 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -7,6 +7,7 @@ .. _rb_tuto: """ + ###################################################################### # Replay buffers are a central piece of any RL or control algorithm. # Supervised learning methods are usually characterized by a training loop @@ -55,24 +56,13 @@ # example: # -# sphinx_gallery_start_ignore -import warnings - -warnings.filterwarnings("ignore") -from torch import multiprocessing +###################################################################### +# .. warning:: This tutorial build is temporarily disabled. -# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside -# `__main__` method call, but for the easy of reading the code switch to fork -# which is also a default spawn method in Google's Colaboratory -try: - is_sphinx = __sphinx_build__ -except NameError: - is_sphinx = False +exit(0) -try: - multiprocessing.set_start_method("spawn" if is_sphinx else "fork") -except RuntimeError: - pass +# sphinx_gallery_start_ignore +import gc # sphinx_gallery_end_ignore @@ -200,27 +190,26 @@ print("samples", sample["a"], sample["b", "c"]) ###################################################################### -# A :class:`~torchrl.data.LazyMemmapStorage` is created in the same manner: -# - -buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size)) -buffer_lazymemmap.extend(data) -print(f"The buffer has {len(buffer_lazymemmap)} elements") -sample = buffer_lazytensor.sample(5) -print("samples: a=", sample["a"], "\n('b', 'c'):", sample["b", "c"]) - -###################################################################### +# A :class:`~torchrl.data.LazyMemmapStorage` is created in the same manner. # We can also customize the storage location on disk: # -tempdir = tempfile.TemporaryDirectory() -buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size, scratch_dir=tempdir)) -buffer_lazymemmap.extend(data) -print(f"The buffer has {len(buffer_lazymemmap)} elements") -print("the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename) -print( - "the ('b', 'c') tensor is stored in", - buffer_lazymemmap._storage._storage["b", "c"].filename, -) + +with tempfile.TemporaryDirectory() as tempdir: + buffer_lazymemmap = ReplayBuffer( + storage=LazyMemmapStorage(size, scratch_dir=tempdir) + ) + buffer_lazymemmap.extend(data) + print(f"The buffer has {len(buffer_lazymemmap)} elements") + print( + "the 'a' tensor is stored in", buffer_lazymemmap._storage._storage["a"].filename + ) + print( + "the ('b', 'c') tensor is stored in", + buffer_lazymemmap._storage._storage["b", "c"].filename, + ) + sample = buffer_lazytensor.sample(5) + print("samples: a=", sample["a"], "\n('b', 'c'):", sample["b", "c"]) + del buffer_lazymemmap ###################################################################### @@ -247,14 +236,15 @@ from torchrl.data import TensorDictReplayBuffer -tempdir = tempfile.TemporaryDirectory() -buffer_lazymemmap = TensorDictReplayBuffer( - storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12 -) -buffer_lazymemmap.extend(data) -print(f"The buffer has {len(buffer_lazymemmap)} elements") -sample = buffer_lazymemmap.sample() -print("sample:", sample) +with tempfile.TemporaryDirectory() as tempdir: + buffer_lazymemmap = TensorDictReplayBuffer( + storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12 + ) + buffer_lazymemmap.extend(data) + print(f"The buffer has {len(buffer_lazymemmap)} elements") + sample = buffer_lazymemmap.sample() + print("sample:", sample) + del buffer_lazymemmap ###################################################################### # Our sample now has an extra ``"index"`` key that indicates what indices @@ -289,13 +279,10 @@ class MyData: batch_size=[10], ) -tempdir = tempfile.TemporaryDirectory() -buffer_lazymemmap = ReplayBuffer( - storage=LazyMemmapStorage(size, scratch_dir=tempdir), batch_size=12 -) -buffer_lazymemmap.extend(data) -print(f"The buffer has {len(buffer_lazymemmap)} elements") -sample = buffer_lazymemmap.sample() +buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=12) +buffer_lazy.extend(data) +print(f"The buffer has {len(buffer_lazy)} elements") +sample = buffer_lazy.sample() print("sample:", sample) @@ -322,8 +309,8 @@ class MyData: ###################################################################### -# Let's build our replay buffer on disk: -rb = ReplayBuffer(storage=LazyMemmapStorage(size)) +# Let's build our replay buffer on RAM: +rb = ReplayBuffer(storage=LazyTensorStorage(size)) data = { "a": torch.randn(3), "b": {"c": (torch.zeros(2), [torch.ones(1)])}, @@ -394,9 +381,9 @@ def assert0(x): batch_size=[200], ) -buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size), batch_size=128) -buffer_lazymemmap.extend(data) -buffer_lazymemmap.sample() +buffer_lazy = ReplayBuffer(storage=LazyTensorStorage(size), batch_size=128) +buffer_lazy.extend(data) +buffer_lazy.sample() ###################################################################### @@ -408,11 +395,11 @@ def assert0(x): # using prioritized samplers): -buffer_lazymemmap = ReplayBuffer( - storage=LazyMemmapStorage(size), batch_size=128, prefetch=10 +buffer_lazy = ReplayBuffer( + storage=LazyTensorStorage(size), batch_size=128, prefetch=10 ) # creates a queue of 10 elements to be prefetched in the background -buffer_lazymemmap.extend(data) -print(buffer_lazymemmap.sample()) +buffer_lazy.extend(data) +print(buffer_lazy.sample()) ###################################################################### @@ -423,11 +410,12 @@ def assert0(x): # dataloader, as long as the batch-size is predefined: -for i, data in enumerate(buffer_lazymemmap): +for i, data in enumerate(buffer_lazy): if i == 3: print(data) break +del buffer_lazy ###################################################################### # Due to the fact that our sampling technique is entirely random and does not @@ -439,8 +427,8 @@ def assert0(x): from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement -buffer_lazymemmap = ReplayBuffer( - storage=LazyMemmapStorage(size), batch_size=32, sampler=SamplerWithoutReplacement() +buffer_lazy = ReplayBuffer( + storage=LazyTensorStorage(size), batch_size=32, sampler=SamplerWithoutReplacement() ) ###################################################################### # we create a data that is big enough to get a couple of samples @@ -452,11 +440,12 @@ def assert0(x): batch_size=[16], ) -buffer_lazymemmap.extend(data) -for _i, _ in enumerate(buffer_lazymemmap): +buffer_lazy.extend(data) +for _i, _ in enumerate(buffer_lazy): continue print(f"A total of {_i+1} batches have been collected") +del buffer_lazy ###################################################################### # Dynamic batch-size @@ -465,13 +454,14 @@ def assert0(x): # In contrast to what we have seen earlier, the ``batch_size`` keyword # argument can be omitted and passed directly to the ``sample`` method: - -buffer_lazymemmap = ReplayBuffer( - storage=LazyMemmapStorage(size), sampler=SamplerWithoutReplacement() +buffer_lazy = ReplayBuffer( + storage=LazyTensorStorage(size), sampler=SamplerWithoutReplacement() ) -buffer_lazymemmap.extend(data) -print("sampling 3 elements:", buffer_lazymemmap.sample(3)) -print("sampling 5 elements:", buffer_lazymemmap.sample(5)) +buffer_lazy.extend(data) +print("sampling 3 elements:", buffer_lazy.sample(3)) +print("sampling 5 elements:", buffer_lazy.sample(5)) + +del buffer_lazy ###################################################################### # Prioritized Replay buffers @@ -597,8 +587,8 @@ def assert0(x): # higher indices should occur more frequently: from matplotlib import pyplot as plt -plt.hist(sample["index"].numpy()) - +fig = plt.hist(sample["index"].numpy()) +plt.show() ###################################################################### # Once we have worked with our sample, we update the priority key using @@ -614,10 +604,9 @@ def assert0(x): ###################################################################### # Now, higher indices should occur less frequently: sample = rb.sample() -from matplotlib import pyplot as plt - -plt.hist(sample["index"].numpy()) +fig = plt.hist(sample["index"].numpy()) +plt.show() ###################################################################### # Using transforms @@ -700,6 +689,7 @@ def assert0(x): print(data) break +collector.shutdown() ###################################################################### # We create a replay buffer with the same transform as the environment. @@ -721,7 +711,7 @@ def assert0(x): Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ) -rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) +rb = TensorDictReplayBuffer(storage=LazyTensorStorage(1000), transform=t, batch_size=16) rb.extend(data) @@ -769,6 +759,8 @@ def assert0(x): print(data) break +collector.shutdown() + ###################################################################### # The buffer transform looks pretty much like the environment one, but with # extra ``("next", ...)`` keys like before: @@ -783,7 +775,7 @@ def assert0(x): UnsqueezeTransform(-4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ) -rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(size), transform=t, batch_size=16) +rb = TensorDictReplayBuffer(storage=LazyTensorStorage(size), transform=t, batch_size=16) data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) rb.add(data_exclude) @@ -819,7 +811,7 @@ def assert0(x): # compatible with tensordict-structured data): the number of slices or their # length and some information about where the separation between the # episodes can be found (e.g. :ref:`recall that ` with a -# :ref:`DataCollector `, the trajectory id is stored in +# :ref:`DataCollector `, the trajectory id is stored in # ``("collector", "traj_ids")``). In this simple example, we construct a data # with 4 consecutive short trajectories and sample 4 slices out of it, each of # length 2 (since the batch size is 8, and 8 items // 4 slices = 2 time steps). @@ -828,7 +820,7 @@ def assert0(x): from torchrl.data import SliceSampler rb = TensorDictReplayBuffer( - storage=LazyMemmapStorage(size), + storage=LazyTensorStorage(size), sampler=SliceSampler(traj_key="episode", num_slices=4), batch_size=8, ) @@ -853,6 +845,8 @@ def assert0(x): print("episode are grouped", sample["episode"]) print("steps are successive", sample["steps"]) +gc.collect() + ###################################################################### # Conclusion # ----------