Skip to content

Commit

Permalink
[Feature] nested_keys option in named_apply (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 27, 2024
1 parent c72d500 commit 169b259
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 4 deletions.
13 changes: 12 additions & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ def device(self, value: DeviceType) -> None:
for t in self.tensordicts:
t.device = value

def clear_device_(self) -> T:
for td in self.tensordicts:
td.clear_device_()
return self

@property
def batch_size(self) -> torch.Size:
return self._batch_size
Expand Down Expand Up @@ -1340,6 +1345,8 @@ def _apply_nest(
call_on_nested: bool = False,
default: Any = NO_DEFAULT,
named: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
**constructor_kwargs,
) -> T:
if inplace:
Expand All @@ -1362,6 +1369,8 @@ def _apply_nest(
call_on_nested=call_on_nested,
default=default,
named=named,
nested_keys=nested_keys,
prefix=prefix,
**constructor_kwargs,
)
others = (other.unbind(self.stack_dim) for other in others)
Expand All @@ -1375,8 +1384,10 @@ def _apply_nest(
call_on_nested=call_on_nested,
default=default,
named=named,
nested_keys=nested_keys,
prefix=prefix + (i,),
)
for td, *oth in zip(self.tensordicts, *others)
for i, (td, *oth) in enumerate(zip(self.tensordicts, *others))
),
stack_dim=self.stack_dim,
)
Expand Down
9 changes: 8 additions & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,8 @@ def _apply_nest(
call_on_nested: bool = False,
default: Any = NO_DEFAULT,
named: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
**constructor_kwargs,
) -> T:
if inplace:
Expand Down Expand Up @@ -679,13 +681,18 @@ def _apply_nest(
device=device,
checked=checked,
named=named,
nested_keys=nested_keys,
default=default,
prefix=prefix + (key,),
**constructor_kwargs,
)
else:
_others = [_other._get_str(key, default=default) for _other in others]
if named:
item_trsf = fn(key, item, *_others)
if nested_keys:
item_trsf = fn(unravel_key(prefix + (key,)), item, *_others)
else:
item_trsf = fn(key, item, *_others)
else:
item_trsf = fn(item, *_others)
if item_trsf is not None:
Expand Down
9 changes: 9 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3616,6 +3616,7 @@ def named_apply(
self,
fn: Callable,
*others: T,
nested_keys: bool = False,
batch_size: Sequence[int] | None = None,
device: torch.device | None = None,
names: Sequence[str] | None = None,
Expand All @@ -3637,6 +3638,9 @@ def named_apply(
unnamed inputs as the number of tensordicts, including self.
If other tensordicts have missing entries, a default value
can be passed through the ``default`` keyword argument.
nested_keys (bool, optional): if ``True``, the complete path
to the leaf will be used. Defaults to ``False``, i.e. only the last
string is passed to the function.
batch_size (sequence of int, optional): if provided,
the resulting TensorDict will have the desired batch_size.
The :obj:`batch_size` argument should match the batch_size after
Expand Down Expand Up @@ -3732,6 +3736,7 @@ def named_apply(
checked=False,
default=default,
named=True,
nested_keys=nested_keys,
**constructor_kwargs,
)

Expand All @@ -3748,6 +3753,8 @@ def _apply_nest(
call_on_nested: bool = False,
default: Any = NO_DEFAULT,
named: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
**constructor_kwargs,
) -> T:
...
Expand All @@ -3763,6 +3770,7 @@ def _fast_apply(
call_on_nested: bool = False,
default: Any = NO_DEFAULT,
named: bool = False,
nested_keys: bool = False,
**constructor_kwargs,
) -> T:
"""A faster apply method.
Expand All @@ -3783,6 +3791,7 @@ def _fast_apply(
call_on_nested=call_on_nested,
named=named,
default=default,
nested_keys=nested_keys,
**constructor_kwargs,
)

Expand Down
27 changes: 25 additions & 2 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.

import argparse

import functools
import gc
import json
import logging
Expand All @@ -18,7 +20,6 @@
from tensordict.nn import TensorDictParams
from tensordict.tensorclass import NonTensorData


try:
import torchsnapshot

Expand Down Expand Up @@ -2999,7 +3000,11 @@ def named_plus(name, x):
if inplace:
assert td_1 is td
for key in td_1.keys(True, True):
if "a" in key:
if isinstance(key, tuple):
subkey = key[-1]
else:
subkey = key
if "a" in subkey:
assert (td_c[key] + 1 == td_1[key]).all()
else:
assert (td_c[key] == td_1[key]).all()
Expand All @@ -3010,6 +3015,24 @@ def named_plus(name, x):
assert (td_c[key] + 1 != td[key]).any()
assert (td_1[key] == td[key] + 1).all()

def test_named_apply_complete(self, td_name, device):
td = getattr(self, td_name)(device)
td.unlock_()
# "a" conflicts with root key with the same name
td.set(("some", "a"), td.get(list(td.keys())[0]))
keys_complete = set()
keys_not_complete = set()

def count(name, value, keys):
keys.add(name)

td.named_apply(functools.partial(count, keys=keys_complete), nested_keys=True)
td.named_apply(
functools.partial(count, keys=keys_not_complete), nested_keys=False
)
assert len(keys_complete) == len(list(td.keys(True, True)))
assert len(keys_complete) > len(keys_not_complete)

def test_nested_dict_init(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
Expand Down

0 comments on commit 169b259

Please sign in to comment.