Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 25, 2024
1 parent 6ca5e95 commit d186639
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
10 changes: 6 additions & 4 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,8 @@ def is_empty(self):
if _is_tensor_collection(type(item)):
if not item.is_empty():
return False
from tensordict.tensorclass import NonTensorData, NonTensorStack

if isinstance(item, (NonTensorData, NonTensorStack)):
if is_non_tensor(item):
return False
else:
return False
Expand Down Expand Up @@ -693,7 +692,7 @@ def make_result():
if (
not call_on_nested
and _is_tensor_collection(item.__class__)
and not is_non_tensor(item)
# and not is_non_tensor(item)
):
if default is not NO_DEFAULT:
_others = [_other._get_str(key, default=None) for _other in others]
Expand Down Expand Up @@ -2467,7 +2466,10 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT):

def _get_str(self, key, default):
if key in self.keys() and _is_tensor_collection(self.entry_class(key)):
return _SubTensorDict(self._source._get_str(key, NO_DEFAULT), self.idx)
data = self._source._get_str(key, NO_DEFAULT)
if is_non_tensor(data):
return data[self.idx]
return _SubTensorDict(data, self.idx)
return self._source._get_at_str(key, self.idx, default=default)

def _get_tuple(self, key, default):
Expand Down
5 changes: 1 addition & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2321,12 +2321,9 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT):
return subtd._get_non_tensor(key[1:], default=default)
value = self._get_str(key, default=default)

from .tensorclass import NonTensorData

from .tensorclass import NonTensorData, NonTensorStack
if isinstance(value, NonTensorData):
return value.data
from tensordict.tensorclass import NonTensorStack

if isinstance(value, NonTensorStack):
return value.tolist()
return value
Expand Down

0 comments on commit d186639

Please sign in to comment.