Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Improve functional call efficiency #567

Merged
merged 8 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 48 additions & 32 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from functorch import dim as ftdim
from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key
from tensordict.base import (
_ACCEPTED_CLASSES,
_is_tensor_collection,
Expand Down Expand Up @@ -133,6 +133,8 @@ class LazyStackedTensorDict(TensorDictBase):

"""

_is_vmapped: bool = False

@classmethod
def __torch_function__(
cls,
Expand Down Expand Up @@ -362,7 +364,7 @@ def _set_str(
if not validated:
value = self._validate_value(value)
validated = True
if self.hook_in is not None:
if self._is_vmapped:
value = self.hook_in(value)
values = value.unbind(self.stack_dim)
for tensordict, item in zip(self.tensordicts, values):
Expand Down Expand Up @@ -397,7 +399,7 @@ def _set_tuple(
if not validated:
value = self._validate_value(value)
validated = True
if self.hook_in is not None:
if self._is_vmapped:
value = self.hook_in(value)
values = value.unbind(self.stack_dim)
for tensordict, item in zip(self.tensordicts, values):
Expand Down Expand Up @@ -554,7 +556,7 @@ def _set_at_str(self, key, value, index, *, validated):
if not validated:
value = self._validate_value(value, check_shape=False)
validated = True
if self.hook_in is not None:
if self._is_vmapped:
value = self.hook_in(value)
split_index = self._split_index(index)
converted_idx = split_index["index_dict"]
Expand Down Expand Up @@ -649,7 +651,7 @@ def _set_at_tuple(self, key, value, idx, *, validated):
if not validated:
value = self._validate_value(value, check_shape=False)
validated = True
if self.hook_in is not None:
if self._is_vmapped:
value = self.hook_in(value)
item = td._get_str(key, NO_DEFAULT)
item[idx] = value
Expand Down Expand Up @@ -778,10 +780,22 @@ def _get_str(
# then it's a LazyStackedTD
out.hook_out = self.hook_out
out.hook_in = self.hook_in
out._is_vmapped = self._is_vmapped
incr = 0 if not self._is_vmapped else 1
out._batch_size = (
self._batch_size
+ out.batch_size[(len(self._batch_size) + incr) :]
)
else:
# then it's a tensorclass
out._tensordict.hook_out = self.hook_out
out._tensordict.hook_in = self.hook_in
out._tensordict._is_vmapped = self._is_vmapped
incr = 0 if not self._is_vmapped else 1
out._tensordict._batch_size = (
self._batch_size
+ out._tensordict.batch_size[(len(self._batch_size) + incr) :]
)
elif self.hook_out is not None:
out = self.hook_out(out)
return out
Expand All @@ -802,7 +816,7 @@ def _get_str(
def _get_tuple(self, key, default):
first = self._get_str(key[0], None)
if first is None:
return self._default_get(first, default)
return self._default_get(key[0], default)
if len(key) == 1:
return first
try:
Expand Down Expand Up @@ -850,7 +864,7 @@ def _cached_add_batch_dims(cls, td, in_dim, vmap_level):
# we return a stack with hook_out, and hack the batch_size and names
# Per se it is still a LazyStack but the stacking dim is "hidden" from
# the outside
out = td.clone(False)
out = td.copy()

def hook_out(tensor, in_dim=in_dim, vmap_level=vmap_level):
return _add_batch_dim(tensor, in_dim, vmap_level)
Expand All @@ -869,6 +883,7 @@ def hook_in(

out.hook_out = hook_out
out.hook_in = hook_in
out._is_vmapped = True
out._batch_size = torch.Size(
[dim for i, dim in enumerate(out._batch_size) if i != out.stack_dim]
)
Expand Down Expand Up @@ -1570,7 +1585,7 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T:
isinstance(input_dict_or_td, LazyStackedTensorDict)
and input_dict_or_td.stack_dim == self.stack_dim
):
if not input_dict_or_td.shape[self.stack_dim] == len(self.tensordicts):
if len(input_dict_or_td.tensordicts) != len(self.tensordicts):
raise ValueError(
"cannot update stacked tensordicts with different shapes."
)
Expand All @@ -1580,36 +1595,37 @@ def update(self, input_dict_or_td: T, clone: bool = False, **kwargs: Any) -> T:
td_dest.update(td_source, clone=clone, **kwargs)
return self

keys = self.keys(False)
inplace = kwargs.get("inplace", False)
for key, value in input_dict_or_td.items():
if clone and hasattr(value, "clone"):
value = value.clone()
else:
elif clone:
value = tree_map(torch.clone, value)
key = unravel_key(key)
if isinstance(key, tuple):
key, subkey = key[0], key[1:]
else:
subkey = ()
# the key must be a string by now. Let's check if it is present
if key in keys:
target_class = self.entry_class(key)
if _is_tensor_collection(target_class):
if isinstance(value, dict):
value_unbind = TensorDict(
value, self.batch_size, _run_checks=False
).unbind(self.stack_dim)
else:
value_unbind = value.unbind(self.stack_dim)
for t, _value in zip(self.tensordicts, value_unbind):
if len(subkey):
t.update({key: {subkey: _value}}, clone=clone, **kwargs)
else:
t.update({key: _value}, clone=clone, **kwargs)
continue
if len(subkey):
self.set((key, *subkey), value, **kwargs)
# we must check that the target is not a leaf
target = self._get_str(key[0], default=None)
if is_tensor_collection(target):
target.update({key[1:]: value}, inplace=inplace, clone=clone)
elif target is None:
self._set_tuple(key, value, inplace=inplace, validated=False)
else:
raise TypeError(
f"Type mismatch: self.get(key[0]) is {type(target)} but expected a tensor collection."
)
else:
self.set(key, value, **kwargs)
target = self._get_str(key, default=None)
if is_tensor_collection(target) and (
is_tensor_collection(value) or isinstance(value, dict)
):
target.update(value, inplace=inplace, clone=clone)
elif target is None or not is_tensor_collection(value):
self._set_str(key, value, inplace=inplace, validated=False)
else:
raise TypeError(
f"Type mismatch: self.get(key) is {type(target)} but value is of type {type(value)}."
)

return self

def update_(
Expand Down
29 changes: 17 additions & 12 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,11 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None)
# For batch-size it is a minor issue (unlikely that a td with batch-size
# is passed with to_module) but for the device it could be a problem.
if swap_dest is None:
swap = self.empty()
swap.clear_device_()
swap = TensorDict({}, batch_size=[])
else:
swap = swap_dest
memo[id(module)] = swap
_swap = {}

for key, value in self.items():
if isinstance(value, (Tensor, ftdim.Tensor)):
Expand Down Expand Up @@ -320,8 +320,13 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None)
swap = swap.to(device=local_out.device)

if return_swap:
assert local_out is not None, key
swap._set_str(key, local_out, inplace=False, validated=True)
_swap[key] = local_out
if return_swap:
if isinstance(swap, TensorDict):
# this is very ad-hoc but faster than calling _set_str every time
swap._tensordict.update(_swap)
else:
swap.update(_swap)
return swap

def __ne__(self, other: object) -> T | bool:
Expand Down Expand Up @@ -1242,12 +1247,13 @@ def _set_str(
inplace: bool,
validated: bool,
) -> T:
best_attempt = inplace is BEST_ATTEMPT_INPLACE
inplace = self._convert_inplace(inplace, key)
if inplace is not False:
best_attempt = inplace is BEST_ATTEMPT_INPLACE
inplace = self._convert_inplace(inplace, key)
if not validated:
value = self._validate_value(value, check_shape=True)
if not inplace:
if self.is_locked:
if self._is_locked:
raise RuntimeError(_LOCK_ERROR)
self._tensordict[key] = value
else:
Expand Down Expand Up @@ -1703,14 +1709,13 @@ def contiguous(self) -> T:
def empty(self, recurse=False) -> T:
if not recurse:
return TensorDict(
device=self.device,
batch_size=self.batch_size,
device=self._device,
batch_size=self._batch_size,
source={},
# names=self.names if self._has_names() else None,
names=self._td_dim_names,
_run_checks=False,
_is_memmap=self._is_memmap,
_is_shared=self._is_shared,
_is_memmap=False,
_is_shared=False,
)
return super().empty(recurse=recurse)

Expand Down
7 changes: 5 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3125,8 +3125,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return self.lock_()
if last_op == self.__class__.to_module.__name__:
if is_tensor_collection(out):
with out.unlock_():
return self.to_module(*args, **kwargs, swap_dest=out)
return self.to_module(*args, **kwargs, swap_dest=out)
else:
raise RuntimeError(
"to_module cannot be used as a decorator when return_swap=False."
Expand Down Expand Up @@ -3520,6 +3519,10 @@ def flatten_keys(self, separator: str = ".", inplace: bool = False) -> T:
result._set_str(
leaf_flat, self.get(leaf), validated=True, inplace=False
)
shared = result._is_shared = self._is_shared
mmap = result._is_memmap = self._is_memmap
if shared or mmap:
result._is_locked = True
return result

@cache # noqa: B019
Expand Down