From 4404abe08891473ade68d27e0c7fb5f366226beb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 9 Dec 2024 14:38:00 -0800 Subject: [PATCH] [Feature] NonTensorStack.data ghstack-source-id: 86065377cc1cd7c7283ed0a468f5d5602d60526d Pull Request resolved: https://github.com/pytorch/tensordict/pull/1132 --- tensordict/_lazy.py | 6 +++--- tensordict/base.py | 6 +++--- tensordict/tensorclass.py | 11 ++++++++++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index ae5c642b3..202ac8667 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -2130,10 +2130,10 @@ def __getitem__(self, index: IndexType) -> Any: if index_key: leaf = self._get_tuple(index_key, NO_DEFAULT) if is_non_tensor(leaf): - result = getattr(leaf, "data", NO_DEFAULT) - if result is NO_DEFAULT: + # Only lazy stacks of non tensors are actually tensordict instances + if isinstance(leaf, TensorDictBase): return leaf.tolist() - return result + return leaf.data return leaf split_index = self._split_index(index) converted_idx = split_index["index_dict"] diff --git a/tensordict/base.py b/tensordict/base.py index 9b6818a98..0f2df179d 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6301,10 +6301,10 @@ def _get_tuple(self, key, default): ... def _get_tuple_maybe_non_tensor(self, key, default): result = self._get_tuple(key, default) if is_non_tensor(result): - result_data = getattr(result, "data", NO_DEFAULT) - if result_data is NO_DEFAULT: + # Only lazy stacks of non tensors are actually tensordict instances + if isinstance(result, TensorDictBase): return result.tolist() - return result_data + return result.data return result def get_at( diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 89b0c1d42..f0f8a1cf4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -3450,7 +3450,16 @@ def update_at_( @property def data(self): - raise AttributeError + """Attempts to return the unique value in the stack. + + Raises a ValueError if there is more than one unique value. + """ + try: + return NonTensorData._stack_non_tensor( + self.tensordicts, raise_if_non_unique=True + ).data + except ValueError: + raise AttributeError("Cannot get the non-unique data of a NonTensorStack. Use .tolist() instead.") _register_tensor_class(NonTensorStack)