Skip to content

Commit

Permalink
[Feature] Expose call_on_nested to apply and named_apply (#768)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 30, 2024
1 parent 0c72dd7 commit 1f78271
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4299,6 +4299,7 @@ def apply(
default: Any = NO_DEFAULT,
filter_empty: bool | None = None,
propagate_lock: bool = False,
call_on_nested: bool = False,
**constructor_kwargs,
) -> T | None:
"""Applies a callable to all values stored in the tensordict and sets them in a new tensordict.
Expand Down Expand Up @@ -4336,6 +4337,26 @@ def apply(
Defaults to ``False`` for backward compatibility.
propagate_lock (bool, optional): if ``True``, a locked tensordict will produce
another locked tensordict. Defaults to ``False``.
call_on_nested (bool, optional): if ``True``, the function will be called on first-level tensors
and containers (TensorDict or tensorclass). In this scenario, ``func`` is responsible of
propagating its calls to nested levels. This allows a fine-grained behaviour
when propagating the calls to nested tensordicts.
If ``False``, the function will only be called on leaves, and ``apply`` will take care of dispatching
the function to all leaves.
>>> td = TensorDict({"a": {"b": [0.0, 1.0]}, "c": [1.0, 2.0]})
>>> def mean_tensor_only(val):
... if is_tensor_collection(val):
... raise RuntimeError("Unexpected!")
... return val.mean()
>>> td_mean = td.apply(mean_tensor_only)
>>> def mean_any(val):
... if is_tensor_collection(val):
... # Recurse
... return val.apply(mean_any, call_on_nested=True)
... return val.mean()
>>> td_mean = td.apply(mean_any, call_on_nested=True)
**constructor_kwargs: additional keyword arguments to be passed to the
TensorDict constructor.
Expand Down Expand Up @@ -4394,6 +4415,7 @@ def apply(
checked=False,
default=default,
filter_empty=filter_empty,
call_on_nested=call_on_nested,
**constructor_kwargs,
)
if propagate_lock and not inplace and self.is_locked and result is not None:
Expand All @@ -4412,6 +4434,7 @@ def named_apply(
default: Any = NO_DEFAULT,
filter_empty: bool | None = None,
propagate_lock: bool = False,
call_on_nested: bool = False,
**constructor_kwargs,
) -> T | None:
"""Applies a key-conditioned callable to all values stored in the tensordict and sets them in a new atensordict.
Expand Down Expand Up @@ -4449,6 +4472,26 @@ def named_apply(
``False`` for backward compatibility.
propagate_lock (bool, optional): if ``True``, a locked tensordict will produce
another locked tensordict. Defaults to ``False``.
call_on_nested (bool, optional): if ``True``, the function will be called on first-level tensors
and containers (TensorDict or tensorclass). In this scenario, ``func`` is responsible of
propagating its calls to nested levels. This allows a fine-grained behaviour
when propagating the calls to nested tensordicts.
If ``False``, the function will only be called on leaves, and ``apply`` will take care of dispatching
the function to all leaves.
>>> td = TensorDict({"a": {"b": [0.0, 1.0]}, "c": [1.0, 2.0]})
>>> def mean_tensor_only(val):
... if is_tensor_collection(val):
... raise RuntimeError("Unexpected!")
... return val.mean()
>>> td_mean = td.apply(mean_tensor_only)
>>> def mean_any(val):
... if is_tensor_collection(val):
... # Recurse
... return val.apply(mean_any, call_on_nested=True)
... return val.mean()
>>> td_mean = td.apply(mean_any, call_on_nested=True)
**constructor_kwargs: additional keyword arguments to be passed to the
TensorDict constructor.
Expand Down Expand Up @@ -4534,6 +4577,7 @@ def named_apply(
named=True,
nested_keys=nested_keys,
filter_empty=filter_empty,
call_on_nested=call_on_nested,
**constructor_kwargs,
)
if propagate_lock and not inplace and self.is_locked and result is not None:
Expand Down
2 changes: 2 additions & 0 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def apply(
inplace: bool = False,
default: Any = NO_DEFAULT,
filter_empty: bool | None = None,
call_on_nested: bool = False,
**constructor_kwargs,
) -> TensorDictBase | None:
...
Expand All @@ -500,6 +501,7 @@ def named_apply(
inplace: bool = False,
default: Any = NO_DEFAULT,
filter_empty: bool | None = None,
call_on_nested: bool = False,
**constructor_kwargs,
) -> TensorDictBase | None:
...
Expand Down

0 comments on commit 1f78271

Please sign in to comment.