Skip to content

Commit

Permalink
[BugFix] Improve update_ (#655)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 3, 2024
1 parent cde67c3 commit b1c761c
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 14 deletions.
8 changes: 8 additions & 0 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2039,6 +2039,14 @@ def update_(
if input_dict_or_td is self:
# no op
return self
if not is_tensor_collection(input_dict_or_td):
input_dict_or_td = TensorDict.from_dict(
input_dict_or_td, batch_dims=self.batch_dims
)
if input_dict_or_td.batch_dims <= self.stack_dim:
raise RuntimeError(
f"Built tensordict with ndim={input_dict_or_td.ndim} does not have enough dims."
)
if input_dict_or_td.batch_size[self.stack_dim] != len(self.tensordicts):
raise ValueError("cannot update stacked tensordicts with different shapes.")
for td_dest, td_source in zip(
Expand Down
46 changes: 34 additions & 12 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2501,18 +2501,40 @@ def update_(
if keys_to_update is not None:
if len(keys_to_update) == 0:
return self
keys_to_update = unravel_key_list(keys_to_update)
for key, value in input_dict_or_td.items():
firstkey, *nextkeys = _unravel_key_to_tuple(key)
if keys_to_update and not any(
firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
for ktu in keys_to_update
):
continue
if clone:
value = value.clone()
self.set_((firstkey, *nextkeys), value)
return self
keys_to_update = [_unravel_key_to_tuple(key) for key in keys_to_update]
if keys_to_update:

def inplace_update(name, dest, source):
if source is None:
return dest
name = _unravel_key_to_tuple(name)
for key in keys_to_update:
if key == name[: len(key)]:
return dest.copy_(source, non_blocking=True)
else:
return dest

else:

def inplace_update(name, dest, source):
if source is None:
return dest
return dest.copy_(source, non_blocking=True)

if not is_tensor_collection(input_dict_or_td):
from tensordict import TensorDict

input_dict_or_td = TensorDict.from_dict(
input_dict_or_td, batch_dims=self.batch_dims
)
return self._apply_nest(
inplace_update,
input_dict_or_td,
nested_keys=True,
default=None,
inplace=True,
named=True,
)

def update_at_(
self,
Expand Down
7 changes: 5 additions & 2 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,8 +1466,11 @@ def _set_max_batch_size(source: T, batch_dims=None):
else:
source.batch_size = batch_size
return
for tensor in tensor_data[1:]:
if tensor.dim() <= curr_dim or tensor.size(curr_dim) != curr_dim_size:
for leaf in tensor_data[1:]:
# if we have a nested empty tensordict we can modify its batch size at will
if _is_tensor_collection(type(leaf)) and leaf.is_empty():
continue
if (leaf.dim() <= curr_dim) or (leaf.size(curr_dim) != curr_dim_size):
source.batch_size = batch_size
return
if batch_dims is None or len(batch_size) < batch_dims:
Expand Down
29 changes: 29 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,35 @@ def test_unbind_td(self, device):
td_unbind[0].batch_size == td[:, 0].batch_size
), f"got {td_unbind[0].batch_size} and {td[:, 0].batch_size}"

@pytest.mark.parametrize("stack", [True, False])
@pytest.mark.parametrize("todict", [True, False])
def test_update_(self, stack, todict):
def make(val, todict=False, stack=False):
if todict:
return make(val, stack=stack).to_dict()
if stack:
return LazyStackedTensorDict.lazy_stack([make(val), make(val)])
return TensorDict({"a": {"b": val, "c": {}}, "d": {"e": val, "f": val}}, [])

td1 = make(1, stack=stack)
td2 = make(2, stack=stack, todict=todict)

# plain update_
td1.update_(td2)
assert (td1 == 2).all()

td1 = make(1, stack=stack)
for key in (("a",), "a"):
td1.update_(td2, keys_to_update=[key])
assert (td1.select("a") == 2).all()
assert (td1.exclude("a") == 1).all()

td1 = make(1, stack=stack)
for key in (("a", "b"), (("a",), ((("b"),),))):
td1.update_(td2, keys_to_update=[key])
assert (td1.select(("a", "b")) == 2).all()
assert (td1.exclude(("a", "b")) == 1).all()

def test_update_nested_dict(self):
t = TensorDict({"a": {"d": [[[0]] * 3] * 2}}, [2, 3])
assert ("a", "d") in t.keys(include_nested=True)
Expand Down

0 comments on commit b1c761c

Please sign in to comment.