Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] named_apply and default value in apply #584

Merged
merged 2 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,10 +1084,12 @@ def entry_class(self, key: NestedKey) -> type:
return LazyStackedTensorDict
return data_type

def apply_(self, fn: Callable, *others):
def apply_(self, fn: Callable, *others, **kwargs):
for i, td in enumerate(self.tensordicts):
idx = (slice(None),) * self.stack_dim + (i,)
td._fast_apply(fn, *[other[idx] for other in others], inplace=True)
td._fast_apply(
fn, *[other[idx] for other in others], inplace=True, **kwargs
)
return self

def _apply_nest(
Expand All @@ -1100,14 +1102,16 @@ def _apply_nest(
inplace: bool = False,
checked: bool = False,
call_on_nested: bool = False,
default: Any = NO_DEFAULT,
named: bool = False,
**constructor_kwargs,
) -> T:
if inplace:
if any(arg for arg in (batch_size, device, names, constructor_kwargs)):
raise ValueError(
"Cannot pass other arguments to LazyStackedTensorDict.apply when inplace=True."
)
return self.apply_(fn, *others)
return self.apply_(fn, *others, named=named, default=default)
else:
if batch_size is not None:
# any op that modifies the batch-size will result in a regular TensorDict
Expand All @@ -1120,6 +1124,8 @@ def _apply_nest(
names=names,
checked=checked,
call_on_nested=call_on_nested,
default=default,
named=named,
**constructor_kwargs,
)
others = (other.unbind(self.stack_dim) for other in others)
Expand All @@ -1131,6 +1137,8 @@ def _apply_nest(
checked=checked,
device=device,
call_on_nested=call_on_nested,
default=default,
named=named,
)
for td, *oth in zip(self.tensordicts, *others)
),
Expand Down
21 changes: 19 additions & 2 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,8 @@ def _apply_nest(
inplace: bool = False,
checked: bool = False,
call_on_nested: bool = False,
default: Any = NO_DEFAULT,
named: bool = False,
**constructor_kwargs,
) -> T:
if inplace:
Expand Down Expand Up @@ -535,19 +537,34 @@ def _apply_nest(
out.unlock_()

for key, item in self.items():
_others = [_other._get_str(key, default=NO_DEFAULT) for _other in others]
if not call_on_nested and _is_tensor_collection(item.__class__):
if default is not NO_DEFAULT:
_others = [_other._get_str(key, default=None) for _other in others]
_others = [
self.empty() if _other is None else _other for _other in _others
]
else:
_others = [
_other._get_str(key, default=NO_DEFAULT) for _other in others
]

item_trsf = item._apply_nest(
fn,
*_others,
inplace=inplace,
batch_size=batch_size,
device=device,
checked=checked,
named=named,
default=default,
**constructor_kwargs,
)
else:
item_trsf = fn(item, *_others)
_others = [_other._get_str(key, default=default) for _other in others]
if named:
item_trsf = fn(key, item, *_others)
else:
item_trsf = fn(item, *_others)
if item_trsf is not None:
if isinstance(self, _SubTensorDict):
out.set(key, item_trsf, inplace=inplace)
Expand Down
174 changes: 168 additions & 6 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2888,14 +2888,12 @@ def apply(
device: torch.device | None = None,
names: Sequence[str] | None = None,
inplace: bool = False,
default: Any = NO_DEFAULT,
**constructor_kwargs,
) -> T:
"""Applies a callable to all values stored in the tensordict and sets them in a new tensordict.

The apply method will return an :class:`~tensordict.TensorDict` instance,
regardless of the input type. To keep the same type, one can execute

>>> out = td.clone(False).update(td.apply(...))
The callable signature must be ``Callable[Tuple[Tensor, ...], Optional[Union[Tensor, TensorDictBase]]]``.

Args:
fn (Callable): function to be applied to the tensors in the
Expand All @@ -2904,6 +2902,8 @@ def apply(
tensordict instances should have a structure matching the one
of self. The ``fn`` argument should receive as many
unnamed inputs as the number of tensordicts, including self.
If other tensordicts have missing entries, a default value
can be passed through the ``default`` keyword argument.
batch_size (sequence of int, optional): if provided,
the resulting TensorDict will have the desired batch_size.
The :obj:`batch_size` argument should match the batch_size after
Expand All @@ -2913,6 +2913,9 @@ def apply(
batch_size is modified.
inplace (bool, optional): if True, changes are made in-place.
Default is False. This is a keyword only argument.
default (Any, optional): default value for missing entries in the
other tensordicts. If not provided, missing entries will
raise a `KeyError`.
**constructor_kwargs: additional keyword arguments to be passed to the
TensorDict constructor.

Expand All @@ -2925,12 +2928,163 @@ def apply(
... "b": {"c": torch.ones(3)}},
... batch_size=[3])
>>> td_1 = td.apply(lambda x: x+1)
>>> assert (td["a"] == 0).all()
>>> assert (td["b", "c"] == 2).all()
>>> assert (td_1["a"] == 0).all()
>>> assert (td_1["b", "c"] == 2).all()
>>> td_2 = td.apply(lambda x, y: x+y, td)
>>> assert (td_2["a"] == -2).all()
>>> assert (td_2["b", "c"] == 2).all()

.. note::
If ``None`` is returned by the function, the entry is ignored. This
can be used to filter the data in the tensordict:

>>> td = TensorDict({"1": 1, "2": 2, "b": {"2": 2, "1": 1}}, [])
>>> def filter(tensor):
... if tensor == 1:
... return tensor
>>> td.apply(filter)
TensorDict(
fields={
1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
b: TensorDict(
fields={
1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)

.. note::
The apply method will return an :class:`~tensordict.TensorDict` instance,
regardless of the input type. To keep the same type, one can execute

>>> out = td.clone(False).update(td.apply(...))


"""
return self._apply_nest(
fn,
*others,
batch_size=batch_size,
device=device,
names=names,
inplace=inplace,
checked=False,
default=default,
**constructor_kwargs,
)

def named_apply(
self,
fn: Callable,
*others: T,
batch_size: Sequence[int] | None = None,
device: torch.device | None = None,
names: Sequence[str] | None = None,
inplace: bool = False,
default: Any = NO_DEFAULT,
**constructor_kwargs,
) -> T:
"""Applies a key-conditioned callable to all values stored in the tensordict and sets them in a new atensordict.

The callable signature must be ``Callable[Tuple[str, Tensor, ...], Optional[Union[Tensor, TensorDictBase]]]``.

Args:
fn (Callable): function to be applied to the (name, tensor) pairs in the
tensordict. For each leaf, only its leaf name will be used (not
the full `NestedKey`).
*others (TensorDictBase instances, optional): if provided, these
tensordict instances should have a structure matching the one
of self. The ``fn`` argument should receive as many
unnamed inputs as the number of tensordicts, including self.
If other tensordicts have missing entries, a default value
can be passed through the ``default`` keyword argument.
batch_size (sequence of int, optional): if provided,
the resulting TensorDict will have the desired batch_size.
The :obj:`batch_size` argument should match the batch_size after
the transformation. This is a keyword only argument.
device (torch.device, optional): the resulting device, if any.
names (list of str, optional): the new dimension names, in case the
batch_size is modified.
inplace (bool, optional): if True, changes are made in-place.
Default is False. This is a keyword only argument.
default (Any, optional): default value for missing entries in the
other tensordicts. If not provided, missing entries will
raise a `KeyError`.
**constructor_kwargs: additional keyword arguments to be passed to the
TensorDict constructor.

Returns:
a new tensordict with transformed_in tensors.

Example:
>>> td = TensorDict({
... "a": -torch.ones(3),
... "nested": {"a": torch.ones(3), "b": torch.zeros(3)}},
... batch_size=[3])
>>> def name_filter(name, tensor):
... if name == "a":
... return tensor
>>> td.named_apply(name_filter)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
>>> def name_filter(name, *tensors):
... if name == "a":
... r = 0
... for tensor in tensors:
... r = r + tensor
... return tensor
>>> out = td.named_apply(name_filter, td)
>>> print(out)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
>>> print(out["a"])
tensor([-1., -1., -1.])

.. note::
If ``None`` is returned by the function, the entry is ignored. This
can be used to filter the data in the tensordict:

>>> td = TensorDict({"1": 1, "2": 2, "b": {"2": 2, "1": 1}}, [])
>>> def name_filter(name, tensor):
... if name == "1":
... return tensor
>>> td.named_apply(name_filter)
TensorDict(
fields={
1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
b: TensorDict(
fields={
1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)

"""
return self._apply_nest(
fn,
Expand All @@ -2940,6 +3094,8 @@ def apply(
names=names,
inplace=inplace,
checked=False,
default=default,
named=True,
**constructor_kwargs,
)

Expand All @@ -2954,6 +3110,8 @@ def _apply_nest(
inplace: bool = False,
checked: bool = False,
call_on_nested: bool = False,
default: Any = NO_DEFAULT,
named: bool = False,
**constructor_kwargs,
) -> T:
...
Expand All @@ -2967,6 +3125,8 @@ def _fast_apply(
names: Sequence[str] | None = None,
inplace: bool = False,
call_on_nested: bool = False,
default: Any = NO_DEFAULT,
named: bool = False,
**constructor_kwargs,
) -> T:
"""A faster apply method.
Expand All @@ -2985,6 +3145,8 @@ def _fast_apply(
inplace=inplace,
checked=True,
call_on_nested=call_on_nested,
named=named,
default=default,
**constructor_kwargs,
)

Expand Down
15 changes: 15 additions & 0 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,21 @@ def apply(
device: torch.device | None = None,
names: Sequence[str] | None = None,
inplace: bool = False,
default: Any = NO_DEFAULT,
**constructor_kwargs,
) -> TensorDictBase:
...

@_unlock_and_set(inplace=True)
def named_apply(
self,
fn: Callable,
*others: TensorDictBase,
batch_size: Sequence[int] | None = None,
device: torch.device | None = None,
names: Sequence[str] | None = None,
inplace: bool = False,
default: Any = NO_DEFAULT,
**constructor_kwargs,
) -> TensorDictBase:
...
Expand Down
Loading