Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 21, 2024
1 parent 20dc288 commit 0f44acc
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 57 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/tensorclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,4 @@ Here is an example:

tensorclass
NonTensorData
NonTensorStack
3 changes: 2 additions & 1 deletion tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tensordict.memmap import MemoryMappedTensor
from tensordict.memmap_deprec import is_memmap, MemmapTensor, set_transfer_ownership
from tensordict.persistent import PersistentTensorDict
from tensordict.tensorclass import NonTensorData, tensorclass
from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass
from tensordict.utils import (
assert_allclose_td,
is_batchedtensor,
Expand All @@ -43,6 +43,7 @@
"TensorDict",
"TensorDictBase",
"merge_tensordicts",
"NonTensorStack",
"set_transfer_ownership",
"pad_sequence",
"is_memmap",
Expand Down
38 changes: 14 additions & 24 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def _legacy_unsqueeze(self, dim: int) -> T:
else:
dim = dim - 1
stack_dim = self.stack_dim
return LazyStackedTensorDict(
return type(self)(
*(tensordict.unsqueeze(dim) for tensordict in self.tensordicts),
stack_dim=stack_dim,
)
Expand Down Expand Up @@ -756,7 +756,7 @@ def _legacy_squeeze(self, dim: int | None = None) -> T:
else:
dim = dim - 1
stack_dim = self.stack_dim
return LazyStackedTensorDict(
return type(self)(
*(tensordict.squeeze(dim) for tensordict in self.tensordicts),
stack_dim=stack_dim,
)
Expand Down Expand Up @@ -1236,7 +1236,7 @@ def contiguous(self) -> T:
return out

def empty(self, recurse=False) -> T:
return LazyStackedTensorDict(
return type(self)(
*[td.empty(recurse=recurse) for td in self.tensordicts],
stack_dim=self.stack_dim,
)
Expand All @@ -1245,12 +1245,12 @@ def _clone(self, recurse: bool = True) -> T:
if recurse:
# This could be optimized using copy but we must be careful with
# metadata (_is_shared etc)
result = LazyStackedTensorDict(
result = type(self)(
*[td._clone() for td in self.tensordicts],
stack_dim=self.stack_dim,
)
else:
result = LazyStackedTensorDict(
result = type(self)(
*[td._clone(recurse=False) for td in self.tensordicts],
stack_dim=self.stack_dim,
)
Expand All @@ -1274,7 +1274,7 @@ def to(self, *args, **kwargs) -> T:
if device is not None and dtype is None and device == self.device:
return result

return LazyStackedTensorDict(
return type(self)(
*[td.to(*args, **kwargs) for td in self.tensordicts],
stack_dim=self.stack_dim,
hook_out=self.hook_out,
Expand Down Expand Up @@ -1403,7 +1403,7 @@ def _apply_nest(
if filter_empty and all(r is None for r in results):
return
if not inplace:
out = LazyStackedTensorDict(
out = type(self)(
*results,
stack_dim=self.stack_dim,
)
Expand All @@ -1429,7 +1429,7 @@ def _select(
]
if inplace:
return self
result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim)
result = type(self)(*tensordicts, stack_dim=self.stack_dim)
return result

def _exclude(
Expand All @@ -1442,7 +1442,7 @@ def _exclude(
if inplace:
self.tensordicts = tensordicts
return self
result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim)
result = type(self)(*tensordicts, stack_dim=self.stack_dim)
return result

def __setitem__(self, index: IndexType, value: T) -> T:
Expand Down Expand Up @@ -2336,26 +2336,26 @@ def _transpose(self, dim0, dim1):
# example: shape = [5, 4, 3, 2, 1], stack_dim=1, dim0=1, dim1=4
# resulting shape: [5, 1, 3, 2, 4]
if dim1 == dim0 + 1:
result = LazyStackedTensorDict(*self.tensordicts, stack_dim=dim1)
result = type(self)(*self.tensordicts, stack_dim=dim1)
else:
result = LazyStackedTensorDict(
result = type(self)(
*(td.transpose(dim0, dim1 - 1) for td in self.tensordicts),
stack_dim=dim1,
)
elif dim1 == self.stack_dim:
# example: shape = [5, 4, 3, 2, 1], stack_dim=3, dim0=1, dim1=3
# resulting shape: [5, 2, 3, 4, 1]
if dim0 + 1 == dim1:
result = LazyStackedTensorDict(*self.tensordicts, stack_dim=dim0)
result = type(self)(*self.tensordicts, stack_dim=dim0)
else:
result = LazyStackedTensorDict(
result = type(self)(
*(td.transpose(dim0 + 1, dim1) for td in self.tensordicts),
stack_dim=dim0,
)
else:
dim0 = dim0 if dim0 < self.stack_dim else dim0 - 1
dim1 = dim1 if dim1 < self.stack_dim else dim1 - 1
result = LazyStackedTensorDict(
result = type(self)(
*(td.transpose(dim0, dim1) for td in self.tensordicts),
stack_dim=self.stack_dim,
)
Expand Down Expand Up @@ -2448,16 +2448,6 @@ def _unsqueeze(self, dim):
_to_module = TensorDict._to_module


class StackNonTensor(LazyStackedTensorDict):
"""A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable."""

def tolist(self):
if self.stack_dim == 0:
return [td.tolist() for td in self.tensordicts]
else:
return [td.tolist() for td in self.unbind(0)]


class _CustomOpTensorDict(TensorDictBase):
"""Encodes lazy operations on tensors contained in a TensorDict."""

Expand Down
17 changes: 10 additions & 7 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_register_tensor_class,
BEST_ATTEMPT_INPLACE,
CompatibleType,
is_non_tensor,
is_tensor_collection,
NO_DEFAULT,
T,
Expand Down Expand Up @@ -308,10 +309,9 @@ def is_empty(self):
if _is_tensor_collection(type(item)):
if not item.is_empty():
return False
from tensordict._lazy import StackNonTensor
from tensordict.tensorclass import NonTensorData
from tensordict.tensorclass import NonTensorData, NonTensorStack

if isinstance(item, (NonTensorData, StackNonTensor)):
if isinstance(item, (NonTensorData, NonTensorStack)):
return False
else:
return False
Expand Down Expand Up @@ -680,7 +680,11 @@ def make_result():

any_set = False
for key, item in self.items():
if not call_on_nested and _is_tensor_collection(item.__class__):
if (
not call_on_nested
and _is_tensor_collection(item.__class__)
and not is_non_tensor(item)
):
if default is not NO_DEFAULT:
_others = [_other._get_str(key, default=None) for _other in others]
_others = [
Expand Down Expand Up @@ -2434,11 +2438,10 @@ def get(

def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT):
out = super()._get_non_tensor(key, default=default)
from tensordict._lazy import StackNonTensor
from tensordict.tensorclass import NonTensorData
from tensordict.tensorclass import NonTensorData, NonTensorStack

if isinstance(out, _SubTensorDict) and isinstance(
out._source, (NonTensorData, StackNonTensor)
out._source, (NonTensorData, NonTensorStack)
):
return out._source
return out
Expand Down
5 changes: 2 additions & 3 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,10 @@ def _stack(
if not list_of_tensordicts:
raise RuntimeError("list_of_tensordicts cannot be empty")

from tensordict._lazy import StackNonTensor
from tensordict.tensorclass import NonTensorData
from tensordict.tensorclass import NonTensorData, NonTensorStack

if all(
isinstance(td, (NonTensorData, StackNonTensor)) for td in list_of_tensordicts
isinstance(td, (NonTensorData, NonTensorStack)) for td in list_of_tensordicts
):
return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim)

Expand Down
28 changes: 21 additions & 7 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ def __getitem__(self, index: IndexType) -> T:

if isinstance(result, NonTensorData):
return result.data
from ._lazy import StackNonTensor
from tensordict.tensorclass import NonTensorStack

if isinstance(result, StackNonTensor):
if isinstance(result, NonTensorStack):
return result.tolist()
return result

Expand Down Expand Up @@ -1659,6 +1659,14 @@ def cuda(self, device: int = None) -> T:
return self.to(torch.device("cuda"))
return self.to(f"cuda:{device}")

@property
def is_cuda(self):
return self.device is not None and self.device.type == "cuda"

@property
def is_cpu(self):
return self.device is not None and self.device.type == "cpu"

# Serialization functionality
def state_dict(
self,
Expand Down Expand Up @@ -2317,9 +2325,9 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT):

if isinstance(value, NonTensorData):
return value.data
from ._lazy import StackNonTensor
from tensordict.tensorclass import NonTensorStack

if isinstance(value, StackNonTensor):
if isinstance(value, NonTensorStack):
return value.tolist()
return value

Expand Down Expand Up @@ -5380,16 +5388,22 @@ def is_tensor_collection(datatype: type | Any) -> bool:
return _is_tensor_collection(datatype)


def is_non_tensor(data):
"""Checks if an item is a non-tensor."""
from tensordict.tensorclass import NonTensorData, NonTensorStack

return isinstance(data, (NonTensorData, NonTensorStack))


def _default_is_leaf(cls: Type) -> bool:
return not _is_tensor_collection(cls)


def _is_leaf_nontensor(cls: Type) -> bool:
from tensordict._lazy import StackNonTensor
from tensordict.tensorclass import NonTensorData
from tensordict.tensorclass import NonTensorData, NonTensorStack

if issubclass(cls, KeyedJaggedTensor):
return False
if _is_tensor_collection(cls):
return issubclass(cls, (NonTensorData, StackNonTensor))
return issubclass(cls, (NonTensorData, NonTensorStack))
return issubclass(cls, torch.Tensor)
56 changes: 46 additions & 10 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import re
import sys
import warnings
from copy import copy
from copy import copy, deepcopy
from dataclasses import dataclass
from pathlib import Path
from textwrap import indent
Expand All @@ -24,6 +24,7 @@
import tensordict as tensordict_lib

import torch
from tensordict import LazyStackedTensorDict
from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict._torch_func import TD_HANDLED_FUNCTIONS
Expand Down Expand Up @@ -475,7 +476,7 @@ def wrapper(self, item: str) -> Any:
return wrapper


SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts")
SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts", "names")


def _setattr_wrapper(setattr_: Callable, expected_keys: set[str]) -> Callable:
Expand All @@ -489,12 +490,10 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417
"""
__dict__ = self.__dict__
if (
"_tensordict" not in __dict__
or "_non_tensordict" not in __dict__
or key in SET_ATTRIBUTES
):
if "_tensordict" not in __dict__ or "_non_tensordict" not in __dict__:
return setattr_(self, key, value)
if key in SET_ATTRIBUTES:
return setattr(self._tensordict, key, value)

out = self.set(key, value)
if out is not self:
Expand Down Expand Up @@ -714,6 +713,9 @@ def _set(self, key: NestedKey, value: Any, inplace: bool = False):
__dict__ = self.__dict__
if __dict__["_tensordict"].is_locked:
raise RuntimeError(_LOCK_ERROR)
if key in ("batch_size", "names", "device"):
# handled by setattr
return
expected_keys = self.__dataclass_fields__
if key not in expected_keys:
raise AttributeError(
Expand Down Expand Up @@ -1344,9 +1346,7 @@ def _check_equal(a, b):
device=first.device,
)

from tensordict._lazy import StackNonTensor

return StackNonTensor(*list_of_non_tensor, stack_dim=dim)
return NonTensorStack(*list_of_non_tensor, stack_dim=dim)

@classmethod
def __torch_function__(
Expand Down Expand Up @@ -1410,3 +1410,39 @@ def tolist(self):
if not self.batch_size:
return self.data
return [ntd.tolist() for ntd in self.unbind(0)]

def copy_(self, src: NonTensorData | NonTensorStack, non_blocking: bool = False):
if isinstance(src, NonTensorStack):
raise RuntimeError(
"Cannot update a NonTensorData with a NonTensorStack object."
)
if not isinstance(src, NonTensorData):
raise RuntimeError(
"NonTensorData.copy_ requires the source to be a NonTensorData object."
)
self._non_tensordict["data"] = src.data

def clone(self, recurse: bool = True):
if recurse:
return type(self)(
data=deepcopy(self.data),
batch_size=self.batch_size,
device=self.device,
names=self.names,
)
return type(self)(
data=self.data,
batch_size=self.batch_size,
device=self.device,
names=self.names,
)


class NonTensorStack(LazyStackedTensorDict):
"""A thin wrapper aroung LazyStackedTensorDict to make stack on non-tensor data easily recognizable."""

def tolist(self):
if self.stack_dim == 0:
return [td.tolist() for td in self.tensordicts]
else:
return [td.tolist() for td in self.unbind(0)]
Loading

0 comments on commit 0f44acc

Please sign in to comment.