From 0f44acc2d3a37135ccd73ef93c38107264fdb96d Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 21 Feb 2024 14:08:14 -0800 Subject: [PATCH] amend --- docs/source/reference/tensorclass.rst | 1 + tensordict/__init__.py | 3 +- tensordict/_lazy.py | 38 +++++++----------- tensordict/_td.py | 17 ++++---- tensordict/_torch_func.py | 5 +-- tensordict/base.py | 28 ++++++++++---- tensordict/tensorclass.py | 56 ++++++++++++++++++++++----- tensordict/utils.py | 9 ++--- 8 files changed, 100 insertions(+), 57 deletions(-) diff --git a/docs/source/reference/tensorclass.rst b/docs/source/reference/tensorclass.rst index 17518df1d..8e6e4e907 100644 --- a/docs/source/reference/tensorclass.rst +++ b/docs/source/reference/tensorclass.rst @@ -273,3 +273,4 @@ Here is an example: tensorclass NonTensorData + NonTensorStack diff --git a/tensordict/__init__.py b/tensordict/__init__.py index c71661ceb..5e6dc8761 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -16,7 +16,7 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.memmap_deprec import is_memmap, MemmapTensor, set_transfer_ownership from tensordict.persistent import PersistentTensorDict -from tensordict.tensorclass import NonTensorData, tensorclass +from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass from tensordict.utils import ( assert_allclose_td, is_batchedtensor, @@ -43,6 +43,7 @@ "TensorDict", "TensorDictBase", "merge_tensordicts", + "NonTensorStack", "set_transfer_ownership", "pad_sequence", "is_memmap", diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 6ff283228..8059b6819 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -718,7 +718,7 @@ def _legacy_unsqueeze(self, dim: int) -> T: else: dim = dim - 1 stack_dim = self.stack_dim - return LazyStackedTensorDict( + return type(self)( *(tensordict.unsqueeze(dim) for tensordict in self.tensordicts), stack_dim=stack_dim, ) @@ -756,7 +756,7 @@ def _legacy_squeeze(self, dim: int | None = None) -> T: else: dim = dim - 1 stack_dim = self.stack_dim - return LazyStackedTensorDict( + return type(self)( *(tensordict.squeeze(dim) for tensordict in self.tensordicts), stack_dim=stack_dim, ) @@ -1236,7 +1236,7 @@ def contiguous(self) -> T: return out def empty(self, recurse=False) -> T: - return LazyStackedTensorDict( + return type(self)( *[td.empty(recurse=recurse) for td in self.tensordicts], stack_dim=self.stack_dim, ) @@ -1245,12 +1245,12 @@ def _clone(self, recurse: bool = True) -> T: if recurse: # This could be optimized using copy but we must be careful with # metadata (_is_shared etc) - result = LazyStackedTensorDict( + result = type(self)( *[td._clone() for td in self.tensordicts], stack_dim=self.stack_dim, ) else: - result = LazyStackedTensorDict( + result = type(self)( *[td._clone(recurse=False) for td in self.tensordicts], stack_dim=self.stack_dim, ) @@ -1274,7 +1274,7 @@ def to(self, *args, **kwargs) -> T: if device is not None and dtype is None and device == self.device: return result - return LazyStackedTensorDict( + return type(self)( *[td.to(*args, **kwargs) for td in self.tensordicts], stack_dim=self.stack_dim, hook_out=self.hook_out, @@ -1403,7 +1403,7 @@ def _apply_nest( if filter_empty and all(r is None for r in results): return if not inplace: - out = LazyStackedTensorDict( + out = type(self)( *results, stack_dim=self.stack_dim, ) @@ -1429,7 +1429,7 @@ def _select( ] if inplace: return self - result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim) + result = type(self)(*tensordicts, stack_dim=self.stack_dim) return result def _exclude( @@ -1442,7 +1442,7 @@ def _exclude( if inplace: self.tensordicts = tensordicts return self - result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim) + result = type(self)(*tensordicts, stack_dim=self.stack_dim) return result def __setitem__(self, index: IndexType, value: T) -> T: @@ -2336,9 +2336,9 @@ def _transpose(self, dim0, dim1): # example: shape = [5, 4, 3, 2, 1], stack_dim=1, dim0=1, dim1=4 # resulting shape: [5, 1, 3, 2, 4] if dim1 == dim0 + 1: - result = LazyStackedTensorDict(*self.tensordicts, stack_dim=dim1) + result = type(self)(*self.tensordicts, stack_dim=dim1) else: - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0, dim1 - 1) for td in self.tensordicts), stack_dim=dim1, ) @@ -2346,16 +2346,16 @@ def _transpose(self, dim0, dim1): # example: shape = [5, 4, 3, 2, 1], stack_dim=3, dim0=1, dim1=3 # resulting shape: [5, 2, 3, 4, 1] if dim0 + 1 == dim1: - result = LazyStackedTensorDict(*self.tensordicts, stack_dim=dim0) + result = type(self)(*self.tensordicts, stack_dim=dim0) else: - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0 + 1, dim1) for td in self.tensordicts), stack_dim=dim0, ) else: dim0 = dim0 if dim0 < self.stack_dim else dim0 - 1 dim1 = dim1 if dim1 < self.stack_dim else dim1 - 1 - result = LazyStackedTensorDict( + result = type(self)( *(td.transpose(dim0, dim1) for td in self.tensordicts), stack_dim=self.stack_dim, ) @@ -2448,16 +2448,6 @@ def _unsqueeze(self, dim): _to_module = TensorDict._to_module -class StackNonTensor(LazyStackedTensorDict): - """A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable.""" - - def tolist(self): - if self.stack_dim == 0: - return [td.tolist() for td in self.tensordicts] - else: - return [td.tolist() for td in self.unbind(0)] - - class _CustomOpTensorDict(TensorDictBase): """Encodes lazy operations on tensors contained in a TensorDict.""" diff --git a/tensordict/_td.py b/tensordict/_td.py index 74f099faa..43562aa89 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -29,6 +29,7 @@ _register_tensor_class, BEST_ATTEMPT_INPLACE, CompatibleType, + is_non_tensor, is_tensor_collection, NO_DEFAULT, T, @@ -308,10 +309,9 @@ def is_empty(self): if _is_tensor_collection(type(item)): if not item.is_empty(): return False - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack - if isinstance(item, (NonTensorData, StackNonTensor)): + if isinstance(item, (NonTensorData, NonTensorStack)): return False else: return False @@ -680,7 +680,11 @@ def make_result(): any_set = False for key, item in self.items(): - if not call_on_nested and _is_tensor_collection(item.__class__): + if ( + not call_on_nested + and _is_tensor_collection(item.__class__) + and not is_non_tensor(item) + ): if default is not NO_DEFAULT: _others = [_other._get_str(key, default=None) for _other in others] _others = [ @@ -2434,11 +2438,10 @@ def get( def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): out = super()._get_non_tensor(key, default=default) - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack if isinstance(out, _SubTensorDict) and isinstance( - out._source, (NonTensorData, StackNonTensor) + out._source, (NonTensorData, NonTensorStack) ): return out._source return out diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 47ad18077..d27cbb1d7 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -379,11 +379,10 @@ def _stack( if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack if all( - isinstance(td, (NonTensorData, StackNonTensor)) for td in list_of_tensordicts + isinstance(td, (NonTensorData, NonTensorStack)) for td in list_of_tensordicts ): return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) diff --git a/tensordict/base.py b/tensordict/base.py index 253836712..3fe1ac63b 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -246,9 +246,9 @@ def __getitem__(self, index: IndexType) -> T: if isinstance(result, NonTensorData): return result.data - from ._lazy import StackNonTensor + from tensordict.tensorclass import NonTensorStack - if isinstance(result, StackNonTensor): + if isinstance(result, NonTensorStack): return result.tolist() return result @@ -1659,6 +1659,14 @@ def cuda(self, device: int = None) -> T: return self.to(torch.device("cuda")) return self.to(f"cuda:{device}") + @property + def is_cuda(self): + return self.device is not None and self.device.type == "cuda" + + @property + def is_cpu(self): + return self.device is not None and self.device.type == "cpu" + # Serialization functionality def state_dict( self, @@ -2317,9 +2325,9 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): if isinstance(value, NonTensorData): return value.data - from ._lazy import StackNonTensor + from tensordict.tensorclass import NonTensorStack - if isinstance(value, StackNonTensor): + if isinstance(value, NonTensorStack): return value.tolist() return value @@ -5380,16 +5388,22 @@ def is_tensor_collection(datatype: type | Any) -> bool: return _is_tensor_collection(datatype) +def is_non_tensor(data): + """Checks if an item is a non-tensor.""" + from tensordict.tensorclass import NonTensorData, NonTensorStack + + return isinstance(data, (NonTensorData, NonTensorStack)) + + def _default_is_leaf(cls: Type) -> bool: return not _is_tensor_collection(cls) def _is_leaf_nontensor(cls: Type) -> bool: - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack if issubclass(cls, KeyedJaggedTensor): return False if _is_tensor_collection(cls): - return issubclass(cls, (NonTensorData, StackNonTensor)) + return issubclass(cls, (NonTensorData, NonTensorStack)) return issubclass(cls, torch.Tensor) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index b7016bcff..90b00ecb8 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -15,7 +15,7 @@ import re import sys import warnings -from copy import copy +from copy import copy, deepcopy from dataclasses import dataclass from pathlib import Path from textwrap import indent @@ -24,6 +24,7 @@ import tensordict as tensordict_lib import torch +from tensordict import LazyStackedTensorDict from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS @@ -475,7 +476,7 @@ def wrapper(self, item: str) -> Any: return wrapper -SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts") +SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts", "names") def _setattr_wrapper(setattr_: Callable, expected_keys: set[str]) -> Callable: @@ -489,12 +490,10 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 """ __dict__ = self.__dict__ - if ( - "_tensordict" not in __dict__ - or "_non_tensordict" not in __dict__ - or key in SET_ATTRIBUTES - ): + if "_tensordict" not in __dict__ or "_non_tensordict" not in __dict__: return setattr_(self, key, value) + if key in SET_ATTRIBUTES: + return setattr(self._tensordict, key, value) out = self.set(key, value) if out is not self: @@ -714,6 +713,9 @@ def _set(self, key: NestedKey, value: Any, inplace: bool = False): __dict__ = self.__dict__ if __dict__["_tensordict"].is_locked: raise RuntimeError(_LOCK_ERROR) + if key in ("batch_size", "names", "device"): + # handled by setattr + return expected_keys = self.__dataclass_fields__ if key not in expected_keys: raise AttributeError( @@ -1344,9 +1346,7 @@ def _check_equal(a, b): device=first.device, ) - from tensordict._lazy import StackNonTensor - - return StackNonTensor(*list_of_non_tensor, stack_dim=dim) + return NonTensorStack(*list_of_non_tensor, stack_dim=dim) @classmethod def __torch_function__( @@ -1410,3 +1410,39 @@ def tolist(self): if not self.batch_size: return self.data return [ntd.tolist() for ntd in self.unbind(0)] + + def copy_(self, src: NonTensorData | NonTensorStack, non_blocking: bool = False): + if isinstance(src, NonTensorStack): + raise RuntimeError( + "Cannot update a NonTensorData with a NonTensorStack object." + ) + if not isinstance(src, NonTensorData): + raise RuntimeError( + "NonTensorData.copy_ requires the source to be a NonTensorData object." + ) + self._non_tensordict["data"] = src.data + + def clone(self, recurse: bool = True): + if recurse: + return type(self)( + data=deepcopy(self.data), + batch_size=self.batch_size, + device=self.device, + names=self.names, + ) + return type(self)( + data=self.data, + batch_size=self.batch_size, + device=self.device, + names=self.names, + ) + + +class NonTensorStack(LazyStackedTensorDict): + """A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable.""" + + def tolist(self): + if self.stack_dim == 0: + return [td.tolist() for td in self.tensordicts] + else: + return [td.tolist() for td in self.unbind(0)] diff --git a/tensordict/utils.py b/tensordict/utils.py index f572c3ff4..b02e367c8 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -656,10 +656,9 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> elif isinstance(tensor, KeyedJaggedTensor): tensor = setitem_keyedjaggedtensor(tensor, index, value) return tensor - from tensordict._lazy import StackNonTensor - from tensordict.tensorclass import NonTensorData + from tensordict.tensorclass import NonTensorData, NonTensorStack - if isinstance(tensor, (NonTensorData, StackNonTensor)): + if isinstance(tensor, (NonTensorData, NonTensorStack)): if ( isinstance(value, NonTensorData) and isinstance(tensor, NonTensorData) @@ -676,9 +675,9 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) -> tensor = _set_item(tensor, idx, tensor_idx, validated=True) return tensor if isinstance(tensor, NonTensorData): - tensor = StackNonTensor(*[tensor[0]] * tensor.shape[0], stack_dim=0) + tensor = NonTensorStack(*[tensor[0]] * tensor.shape[0], stack_dim=0) elif tensor.stack_dim != 0: - tensor = StackNonTensor(*tensor.unbind(0), stack_dim=0) + tensor = NonTensorStack(*tensor.unbind(0), stack_dim=0) tensor[index] = value return tensor else: