Skip to content

Commit

Permalink
[Doc] Solve ref issues in docstrings
Browse files Browse the repository at this point in the history
ghstack-source-id: 09823fa85a94115291e7434478776fb0834f9b39
Pull Request resolved: #2776

(cherry picked from commit f5445a4)
  • Loading branch information
vmoens committed Feb 11, 2025
1 parent dfde953 commit 44fd0b2
Show file tree
Hide file tree
Showing 54 changed files with 281 additions and 254 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ jobs:
cd ./docs
# timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
# bash -ic "PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build
PYOPENGL_PLATFORM=egl MUJOCO_GL=egl TORCHRL_CONSOLE_STREAM=stdout sphinx-build ./source _local_build -v
cd ..
cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}"
Expand Down
4 changes: 4 additions & 0 deletions docs/source/content_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def generate_tutorial_references(tutorial_path: str, file_type: str) -> None:
for f in os.listdir(tutorial_path)
if f.endswith((".py", ".rst", ".png"))
]
# Make rb_tutorial.py the first one
file_paths = [p for p in file_paths if p.endswith("rb_tutorial.py")] + [
p for p in file_paths if not p.endswith("rb_tutorial.py")
]

for file_path in file_paths:
shutil.copyfile(file_path, os.path.join(target_path, Path(file_path).name))
3 changes: 2 additions & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ component (sub-environments or agents) should be reset.
This allows to reset some but not all of the components.

The ``"_reset"`` key has two distinct functionalities:

1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may
not be present in the input tensordict. TorchRL's convention is that the
absence of the ``"_reset"`` key at a given ``"done"`` level indicates
Expand Down Expand Up @@ -899,7 +900,7 @@ to be able to create this other composition:
Hash
InitTracker
KLRewardTransform
LineariseReward
LineariseRewards
NoopResetEnv
ObservationNorm
ObservationTransform
Expand Down
4 changes: 2 additions & 2 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def reset(cls, setters_dict: Dict[str, implement_for] = None):
"""Resets the setters in setter_dict.
``setter_dict`` is a copy of implementations. We just need to iterate through its
values and call :meth:`~.module_set` for each.
values and call :meth:`module_set` for each.
"""
if VERBOSE:
Expand Down Expand Up @@ -888,7 +888,7 @@ def _standardize(
exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: ().
mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
eps (float): epsilon to be used for numerical stability. Default: float32 resolution.
eps (:obj:`float`): epsilon to be used for numerical stability. Default: float32 resolution.
"""
if eps is None:
Expand Down
9 changes: 7 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,12 @@ class SyncDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
Expand Down Expand Up @@ -1462,6 +1464,7 @@ class _MultiDataCollector(DataCollectorBase):
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such:
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Expand Down Expand Up @@ -1548,7 +1551,7 @@ class _MultiDataCollector(DataCollectorBase):
reset_when_done (bool, optional): if ``True`` (default), an environment
that return a ``True`` value in its ``"done"`` or ``"truncated"``
entry will be reset at the corresponding indices.
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()`
will be called before (sync) or after (async) each data collection.
Defaults to ``False``.
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
Expand Down Expand Up @@ -2774,10 +2777,12 @@ class aSyncDataCollector(MultiaSyncDataCollector):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
Expand Down Expand Up @@ -2863,7 +2868,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
reset_when_done (bool, optional): if ``True`` (default), an environment
that return a ``True`` value in its ``"done"`` or ``"truncated"``
entry will be reset at the corresponding indices.
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weight_()`
will be called before (sync) or after (async) each data collection.
Defaults to ``False``.
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,12 @@ class DistributedDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
Expand Down Expand Up @@ -341,7 +343,7 @@ class DistributedDataCollector(DataCollectorBase):
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
collector_class (type or str, optional): a collector class for the remote node. Can be
collector_class (Type or str, optional): a collector class for the remote node. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
:class:`~torchrl.collectors.MultiaSyncDataCollector`
Expand Down
2 changes: 2 additions & 0 deletions torchrl/collectors/distributed/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,12 @@ class RayCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ class RPCDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
Expand Down Expand Up @@ -190,7 +192,7 @@ class RPCDataCollector(DataCollectorBase):
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
collector_class (type or str, optional): a collector class for the remote node. Can be
collector_class (Type or str, optional): a collector class for the remote node. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
:class:`~torchrl.collectors.MultiaSyncDataCollector`
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/distributed/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,12 @@ class DistributedSyncDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
Expand Down Expand Up @@ -222,7 +224,7 @@ class DistributedSyncDataCollector(DataCollectorBase):
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
or ``torchrl.envs.utils.ExplorationType.MEAN``.
collector_class (type or str, optional): a collector class for the remote node. Can be
collector_class (Type or str, optional): a collector class for the remote node. Can be
:class:`~torchrl.collectors.SyncDataCollector`,
:class:`~torchrl.collectors.MultiSyncDataCollector`,
:class:`~torchrl.collectors.MultiaSyncDataCollector`
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/datasets/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def preprocess(
Args and Keyword Args are forwarded to :meth:`~tensordict.TensorDictBase.map`.
The dataset can subsequently be deleted using :meth:`~.delete`.
The dataset can subsequently be deleted using :meth:`delete`.
Keyword Args:
dest (path or equivalent): a path to the location of the new dataset.
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/datasets/openx.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class for more information on how to interact with non-tensor data
sampling strategy.
If the ``batch_size`` is ``None`` (default), iterating over the
dataset will deliver trajectories one at a time *whereas* calling
:meth:`~.sample` will *still* require a batch-size to be provided.
:meth:`sample` will *still* require a batch-size to be provided.
Keyword Args:
shuffle (bool, optional): if ``True``, trajectories are delivered in a
Expand Down Expand Up @@ -115,7 +115,7 @@ class for more information on how to interact with non-tensor data
replacement (bool, optional): if ``False``, sampling will be done
without replacement. Defaults to ``True`` for downloaded datasets,
``False`` for streamed datasets.
pad (bool, float or None): if ``True``, trajectories of insufficient length
pad (bool, :obj:`float` or None): if ``True``, trajectories of insufficient length
given the `slice_len` or `num_slices` arguments will be padded with
0s. If another value is provided, it will be used for padding. If
``False`` or ``None`` (default) any encounter with a trajectory of
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/map/tdstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def from_tensordict_pair(
in the storage. Defaults to ``None`` (all keys are registered).
max_size (int, optional): the maximum number of elements in the storage. Ignored if the
``storage_constructor`` is passed. Defaults to ``1000``.
storage_constructor (type, optional): a type of tensor storage.
storage_constructor (Type, optional): a type of tensor storage.
Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`.
Other options include :class:`~tensordict.nn.storage.FixedStorage`.
hash_module (Callable, optional): a hash function to use in the :class:`~torchrl.data.map.QueryModule`.
Expand Down
19 changes: 11 additions & 8 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Tree(TensorClass["nocast"]):
node_id (int): A unique identifier for this node.
rollout (TensorDict): Rollout data following the observation encoded in this node, in a TED format.
If there are multiple actions taken at this node, subtrees are stored in the corresponding
entry. Rollouts can be reconstructed using the :meth:`~.rollout_from_path` method.
entry. Rollouts can be reconstructed using the :meth:`rollout_from_path` method.
node (TensorDict): Data defining this node (e.g., observations) before the next branching.
Entries usually matches the ``in_keys`` in ``MCTSForest.node_map``.
subtree (Tree): A stack of subtrees produced when actions are taken.
Expand Down Expand Up @@ -215,7 +215,7 @@ def node_observation(self) -> torch.Tensor | TensorDictBase:
"""Returns the observation associated with this particular node.
This is the observation (or bag of observations) that defines the node before a branching occurs.
If the node contains a :attr:`~.rollout` attribute, the node observation is typically identical to the
If the node contains a :meth:`rollout` attribute, the node observation is typically identical to the
observation resulting from the last action undertaken, i.e., ``node.rollout[..., -1]["next", "observation"]``.
If more than one observation key is associated with the tree specs, a :class:`~tensordict.TensorDict` instance
Expand All @@ -232,7 +232,7 @@ def node_observations(self) -> torch.Tensor | TensorDictBase:
"""Returns the observations associated with this particular node in a TensorDict format.
This is the observation (or bag of observations) that defines the node before a branching occurs.
If the node contains a :attr:`~.rollout` attribute, the node observation is typically identical to the
If the node contains a :meth:`rollout` attribute, the node observation is typically identical to the
observation resulting from the last action undertaken, i.e., ``node.rollout[..., -1]["next", "observation"]``.
If more than one observation key is associated with the tree specs, a :class:`~tensordict.TensorDict` instance
Expand Down Expand Up @@ -442,8 +442,11 @@ def num_vertices(self, *, count_repeat: bool = False) -> int:
"""Returns the number of unique vertices in the Tree.
Keyword Args:
count_repeat (bool, optional): Determines whether to count repeated vertices.
count_repeat (bool, optional): Determines whether to count repeated
vertices.
- If ``False``, counts each unique vertex only once.
- If ``True``, counts vertices multiple times if they appear in different paths.
Defaults to ``False``.
Expand Down Expand Up @@ -629,16 +632,16 @@ class MCTSForest:
``node_map.max_size``. If none of these are provided, defaults to `1000`.
done_keys (list of NestedKey, optional): the done keys of the environment. If not provided,
defaults to ``("done", "terminated", "truncated")``.
The :meth:`~.get_keys_from_env` can be used to automatically determine the keys.
The :meth:`get_keys_from_env` can be used to automatically determine the keys.
action_keys (list of NestedKey, optional): the action keys of the environment. If not provided,
defaults to ``("action",)``.
The :meth:`~.get_keys_from_env` can be used to automatically determine the keys.
The :meth:`get_keys_from_env` can be used to automatically determine the keys.
reward_keys (list of NestedKey, optional): the reward keys of the environment. If not provided,
defaults to ``("reward",)``.
The :meth:`~.get_keys_from_env` can be used to automatically determine the keys.
The :meth:`get_keys_from_env` can be used to automatically determine the keys.
observation_keys (list of NestedKey, optional): the observation keys of the environment. If not provided,
defaults to ``("observation",)``.
The :meth:`~.get_keys_from_env` can be used to automatically determine the keys.
The :meth:`get_keys_from_env` can be used to automatically determine the keys.
excluded_keys (list of NestedKey, optional): a list of keys to exclude from the data storage.
consolidated (bool, optional): if ``True``, the data_map storage will be consolidated on disk.
Defaults to ``False``.
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/postprocs/postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
environment (i.e. before multi-step);
- The "reward" values will be replaced by the newly computed
rewards.
The ``"done"`` key can have either the shape of the tensordict
OR the shape of the tensordict followed by a singleton
dimension OR the shape of the tensordict followed by other
Expand Down
Loading

0 comments on commit 44fd0b2

Please sign in to comment.