Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 22, 2024
1 parent 7e026f4 commit afe7dda
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 27 deletions.
37 changes: 21 additions & 16 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,7 +1515,12 @@ def __getitem__(self, index: IndexType) -> T:
if isinstance(index, (tuple, str)):
index_key = _unravel_key_to_tuple(index)
if index_key:
return self._get_tuple(index_key, NO_DEFAULT)
result = self._get_tuple(index_key, NO_DEFAULT)
from .tensorclass import NonTensorData

if isinstance(result, NonTensorData):
return result.data
return result
split_index = self._split_index(index)
converted_idx = split_index["index_dict"]
isinteger = split_index["isinteger"]
Expand All @@ -1527,22 +1532,22 @@ def __getitem__(self, index: IndexType) -> T:
if has_bool:
mask_unbind = split_index["individual_masks"]
cat_dim = split_index["mask_loc"] - num_single
out = []
result = []
if mask_unbind[0].ndim == 0:
# we can return a stack
for (i, _idx), mask in zip(converted_idx.items(), mask_unbind):
if mask.any():
if mask.all() and self.tensordicts[i].ndim == 0:
out.append(self.tensordicts[i])
result.append(self.tensordicts[i])
else:
out.append(self.tensordicts[i][_idx])
out[-1] = out[-1].squeeze(cat_dim)
return LazyStackedTensorDict.lazy_stack(out, cat_dim)
result.append(self.tensordicts[i][_idx])
result[-1] = result[-1].squeeze(cat_dim)
return LazyStackedTensorDict.lazy_stack(result, cat_dim)
else:
for i, _idx in converted_idx.items():
self_idx = (slice(None),) * split_index["mask_loc"] + (i,)
out.append(self[self_idx][_idx])
return torch.cat(out, cat_dim)
result.append(self[self_idx][_idx])
return torch.cat(result, cat_dim)
elif is_nd_tensor:
new_stack_dim = self.stack_dim - num_single + num_none
return LazyStackedTensorDict.lazy_stack(
Expand All @@ -1556,18 +1561,18 @@ def __getitem__(self, index: IndexType) -> T:
) in (
converted_idx.items()
): # for convenience but there's only one element
out = self.tensordicts[i]
result = self.tensordicts[i]
if _idx is not None and _idx != ():
out = out[_idx]
return out
result = result[_idx]
return result
else:
out = []
result = []
new_stack_dim = self.stack_dim - num_single + num_none - num_squash
for i, _idx in converted_idx.items():
out.append(self.tensordicts[i][_idx])
out = LazyStackedTensorDict.lazy_stack(out, new_stack_dim)
out._td_dim_name = self._td_dim_name
return out
result.append(self.tensordicts[i][_idx])
result = LazyStackedTensorDict.lazy_stack(result, new_stack_dim)
result._td_dim_name = self._td_dim_name
return result

def __eq__(self, other):
if is_tensorclass(other):
Expand Down
11 changes: 10 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def __getitem__(self, index: IndexType) -> T:
The index can be a (nested) key or any valid shape index given the
tensordict batch size.
If the index is a nested key and the result is a :class:`~tensordict.NonTensorData`
object, the content of the non-tensor is returned.
Examples:
>>> td = TensorDict({"root": torch.arange(2), ("nested", "entry"): torch.arange(2)}, [2])
>>> td["root"]
Expand All @@ -232,7 +235,13 @@ def __getitem__(self, index: IndexType) -> T:
# _unravel_key_to_tuple will return an empty tuple if the index isn't a NestedKey
idx_unravel = _unravel_key_to_tuple(index)
if idx_unravel:
return self._get_tuple(idx_unravel, NO_DEFAULT)
result = self._get_tuple(idx_unravel, NO_DEFAULT)
from .tensorclass import NonTensorData

if isinstance(result, NonTensorData):
return result.data
return result

if (istuple and not index) or (not istuple and index is Ellipsis):
# empty tuple returns self
return self
Expand Down
8 changes: 4 additions & 4 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict._torch_func import TD_HANDLED_FUNCTIONS
from tensordict.base import _register_tensor_class
from tensordict.base import _ACCEPTED_CLASSES, _register_tensor_class
from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor

from tensordict.utils import (
Expand All @@ -40,7 +40,6 @@
NestedKey,
)
from torch import Tensor
from tensordict.base import _ACCEPTED_CLASSES

T = TypeVar("T", bound=TensorDictBase)
PY37 = sys.version_info < (3, 8)
Expand Down Expand Up @@ -1306,11 +1305,12 @@ def to_dict(self):
def _stack_non_tensor(cls, list_of_non_tensor, dim=0):
# checks have been performed previously, so we're sure the list is non-empty
first = list_of_non_tensor[0]

def _check_equal(a, b):
if isinstance(a, _ACCEPTED_CLASSES) or isinstance(b, _ACCEPTED_CLASSES):
return (a==b).all()
return (a == b).all()
try:
iseq = a==b
iseq = a == b
except Exception:
iseq = False
return iseq
Expand Down
13 changes: 7 additions & 6 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,17 +1443,18 @@ def _set_max_batch_size(source: T, batch_dims=None):
"""Updates a tensordict with its maximium batch size."""
from tensordict import NonTensorData

tensor_data = list(source.values())
tensor_data = [val for val in source.values() if not isinstance(val, NonTensorData)]

for val in tensor_data:
from tensordict.base import _is_tensor_collection

if _is_tensor_collection(val.__class__) and not isinstance(val, NonTensorData):
_set_max_batch_size(val, batch_dims=batch_dims)
batch_size = []
if not tensor_data: # when source is empty
source.batch_size = batch_size
return

for val in tensor_data:
from tensordict.base import _is_tensor_collection

if _is_tensor_collection(val.__class__):
_set_max_batch_size(val, batch_dims=batch_dims)
curr_dim = 0
while True:
if tensor_data[0].dim() > curr_dim:
Expand Down
6 changes: 6 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3097,6 +3097,12 @@ def test_non_tensor_data(self, td_name, device):
assert td.get_non_tensor(("this", "will")) == "succeed"
assert isinstance(td.get(("this", "will")), NonTensorData)

with td.unlock_():
td["this", "other", "tensor"] = "success"
assert td["this", "other", "tensor"] == "success"
assert isinstance(td.get(("this", "other", "tensor")), NonTensorData)
assert td.get_non_tensor(("this", "other", "tensor")) == "success"

def test_non_tensor_data_flatten_keys(self, td_name, device):
td = getattr(self, td_name)(device)
with td.unlock_():
Expand Down

0 comments on commit afe7dda

Please sign in to comment.