Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix empty(recurse) call in _apply_nest #658

Merged
merged 3 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 53 additions & 47 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,11 +1326,9 @@ def entry_class(self, key: NestedKey) -> type:
return data_type

def apply_(self, fn: Callable, *others, **kwargs):
for i, td in enumerate(self.tensordicts):
idx = (slice(None),) * self.stack_dim + (i,)
td._fast_apply(
fn, *[other[idx] for other in others], inplace=True, **kwargs
)
others = (other.unbind(self.stack_dim) for other in others)
for td, *_others in zip(self.tensordicts, *others):
td._fast_apply(fn, *_others, inplace=True, **kwargs)
return self

def _apply_nest(
Expand All @@ -1349,51 +1347,59 @@ def _apply_nest(
prefix: tuple = (),
**constructor_kwargs,
) -> T:
if inplace:
if any(arg for arg in (batch_size, device, names, constructor_kwargs)):
raise ValueError(
"Cannot pass other arguments to LazyStackedTensorDict.apply when inplace=True."
)
return self.apply_(fn, *others, named=named, default=default)
else:
if batch_size is not None:
# any op that modifies the batch-size will result in a regular TensorDict
return TensorDict._apply_nest(
self,
fn,
*others,
batch_size=batch_size,
device=device,
names=names,
checked=checked,
call_on_nested=call_on_nested,
default=default,
named=named,
nested_keys=nested_keys,
prefix=prefix,
**constructor_kwargs,
)
others = (other.unbind(self.stack_dim) for other in others)
if inplace and any(
arg for arg in (batch_size, device, names, constructor_kwargs)
):
raise ValueError(
"Cannot pass other arguments to LazyStackedTensorDict.apply when inplace=True."
)
if batch_size is not None:
# any op that modifies the batch-size will result in a regular TensorDict
return TensorDict._apply_nest(
self,
fn,
*others,
batch_size=batch_size,
device=device,
names=names,
checked=checked,
call_on_nested=call_on_nested,
default=default,
named=named,
nested_keys=nested_keys,
prefix=prefix,
inplace=inplace,
**constructor_kwargs,
)

others = (other.unbind(self.stack_dim) for other in others)
results = [
td._apply_nest(
fn,
*oth,
checked=checked,
device=device,
call_on_nested=call_on_nested,
default=default,
named=named,
nested_keys=nested_keys,
prefix=prefix + (i,),
inplace=inplace,
)
for i, (td, *oth) in enumerate(zip(self.tensordicts, *others))
]
if not inplace:
out = LazyStackedTensorDict(
*(
td._apply_nest(
fn,
*oth,
checked=checked,
device=device,
call_on_nested=call_on_nested,
default=default,
named=named,
nested_keys=nested_keys,
prefix=prefix + (i,),
)
for i, (td, *oth) in enumerate(zip(self.tensordicts, *others))
),
*results,
stack_dim=self.stack_dim,
)
if names is not None:
out.names = names
return out
else:
out = self
if names is not None:
out.names = names
else:
out._td_dim_name = self._td_dim_name
return out

def _select(
self,
Expand Down
3 changes: 2 additions & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,8 @@ def _apply_nest(
if default is not NO_DEFAULT:
_others = [_other._get_str(key, default=None) for _other in others]
_others = [
self.empty() if _other is None else _other for _other in _others
self.empty(recurse=True) if _other is None else _other
for _other in _others
]
else:
_others = [
Expand Down
6 changes: 4 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2503,6 +2503,7 @@ def update_(
return self
keys_to_update = [_unravel_key_to_tuple(key) for key in keys_to_update]
if keys_to_update:
named = True

def inplace_update(name, dest, source):
if source is None:
Expand All @@ -2515,8 +2516,9 @@ def inplace_update(name, dest, source):
return dest

else:
named = False

def inplace_update(name, dest, source):
def inplace_update(dest, source):
if source is None:
return dest
return dest.copy_(source, non_blocking=True)
Expand All @@ -2533,7 +2535,7 @@ def inplace_update(name, dest, source):
nested_keys=True,
default=None,
inplace=True,
named=True,
named=named,
)

def update_at_(
Expand Down
Loading