From c95a703944f3b761987c173c6fbcdb1f2c2fb712 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 23 Nov 2024 20:23:27 +0100 Subject: [PATCH] [Feature,Refactor] Refactor from_dict, add from_any, from_dataclass ghstack-source-id: eb25fe4b201fd7f27d60b140278820c0d5d51eb8 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1102 --- docs/source/reference/tensorclass.rst | 1 + tensordict/__init__.py | 1 + tensordict/_lazy.py | 27 +++- tensordict/_td.py | 177 ++++++++++++++++++++++---- tensordict/base.py | 125 +++++++++++++++++- tensordict/functional.py | 7 +- tensordict/nn/common.py | 4 +- tensordict/nn/params.py | 8 +- tensordict/persistent.py | 28 +++- tensordict/tensorclass.py | 149 ++++++++++++++++++++-- tensordict/tensorclass.pyi | 7 + tensordict/utils.py | 13 +- test/_utils_internal.py | 58 ++++++--- test/test_tensorclass.py | 63 ++++++++- test/test_tensordict.py | 112 +++++++++++++--- 15 files changed, 695 insertions(+), 85 deletions(-) diff --git a/docs/source/reference/tensorclass.rst b/docs/source/reference/tensorclass.rst index 17dceff06..ea55aef40 100644 --- a/docs/source/reference/tensorclass.rst +++ b/docs/source/reference/tensorclass.rst @@ -282,6 +282,7 @@ Here is an example: TensorClass NonTensorData NonTensorStack + from_dataclass Auto-casting ------------ diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 364a11f5a..7fc9d349d 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -43,6 +43,7 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.persistent import PersistentTensorDict from tensordict.tensorclass import ( + from_dataclass, NonTensorData, NonTensorStack, tensorclass, diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index eb4248671..73c316981 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -329,15 +329,38 @@ def _reduce_get_metadata(self): @classmethod def from_dict( cls, - input_dict, + input_dict: List[Dict[NestedKey, Any]], + *other, + auto_batch_size: bool = False, batch_size=None, device=None, batch_dims=None, stack_dim_name=None, stack_dim=0, ): + # if batch_size is not None: + # batch_size = list(batch_size) + # if stack_dim is None: + # stack_dim = 0 + # n = batch_size.pop(stack_dim) + # if n != len(input_dict): + # raise ValueError( + # "The number of dicts and the corresponding batch-size must match, " + # f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}." + # ) + # batch_size = torch.Size(batch_size) return LazyStackedTensorDict( - *(input_dict[str(i)] for i in range(len(input_dict))), + *( + TensorDict.from_dict( + input_dict[str(i)], + *other, + auto_batch_size=auto_batch_size, + device=device, + batch_dims=batch_dims, + batch_size=batch_size, + ) + for i in range(len(input_dict)) + ), stack_dim=stack_dim, stack_dim_name=stack_dim_name, ) diff --git a/tensordict/_td.py b/tensordict/_td.py index 7895fae4e..07a98cdfb 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -615,7 +615,7 @@ def __ne__(self, other: object) -> T | bool: if is_tensorclass(other): return other != self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -639,7 +639,7 @@ def __xor__(self, other: object) -> T | bool: if is_tensorclass(other): return other ^ self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -663,7 +663,7 @@ def __or__(self, other: object) -> T | bool: if is_tensorclass(other): return other | self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -687,7 +687,7 @@ def __eq__(self, other: object) -> T | bool: if is_tensorclass(other): return other == self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -709,7 +709,7 @@ def __ge__(self, other: object) -> T | bool: if is_tensorclass(other): return other <= self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -731,7 +731,7 @@ def __gt__(self, other: object) -> T | bool: if is_tensorclass(other): return other < self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -753,7 +753,7 @@ def __le__(self, other: object) -> T | bool: if is_tensorclass(other): return other >= self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -775,7 +775,7 @@ def __lt__(self, other: object) -> T | bool: if is_tensorclass(other): return other > self if isinstance(other, (dict,)): - other = self.from_dict_instance(other) + other = self.from_dict_instance(other, auto_batch_size=False) if _is_tensor_collection(type(other)): keys1 = set(self.keys()) keys2 = set(other.keys()) @@ -1957,8 +1957,46 @@ def _unsqueeze(tensor): @classmethod def from_dict( - cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None + cls, + input_dict, + *others, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, + names=None, ): + if others: + if batch_size is not None: + raise TypeError( + "conflicting batch size values. Please use the keyword argument only." + ) + if device is not None: + raise TypeError( + "conflicting device values. Please use the keyword argument only." + ) + if batch_dims is not None: + raise TypeError( + "conflicting batch_dims values. Please use the keyword argument only." + ) + if names is not None: + raise TypeError( + "conflicting names values. Please use the keyword argument only." + ) + warn( + "All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead.", + category=DeprecationWarning, + ) + batch_size, *others = others + if len(others): + device, *others = others + if len(others): + batch_dims, *others = others + if len(others): + names, *others = others + if len(others): + raise TypeError("Too many positional arguments.") + if batch_dims is not None and batch_size is not None: raise ValueError( "Cannot pass both batch_size and batch_dims to `from_dict`." @@ -1967,12 +2005,12 @@ def from_dict( batch_size_set = torch.Size(()) if batch_size is None else batch_size input_dict = dict(input_dict) for key, value in list(input_dict.items()): - if isinstance(value, (dict,)): - # we don't know if another tensor of smaller size is coming - # so we can't be sure that the batch-size will still be valid later - input_dict[key] = TensorDict.from_dict( - value, batch_size=[], device=device, batch_dims=None - ) + # we don't know if another tensor of smaller size is coming + # so we can't be sure that the batch-size will still be valid later + input_dict[key] = TensorDict.from_any( + value, + auto_batch_size=False, + ) # regular __init__ breaks because a tensor may have the same batch-size as the tensordict out = cls( input_dict, @@ -1981,7 +2019,19 @@ def from_dict( names=names, ) if batch_size is None: - _set_max_batch_size(out, batch_dims) + if auto_batch_size is None and batch_dims is None: + warn( + "The batch-size was not provided and auto_batch_size isn't set either. " + "Currently, from_dict will call set auto_batch_size=True but this behaviour " + "will be changed in v0.8 and auto_batch_size will be False onward. " + "To silence this warning, pass auto_batch_size directly.", + category=DeprecationWarning, + ) + auto_batch_size = True + elif auto_batch_size is None: + auto_batch_size = True + if auto_batch_size: + _set_max_batch_size(out, batch_dims) else: out.batch_size = batch_size return out @@ -1998,8 +2048,46 @@ def _from_dict_validated( ) def from_dict_instance( - self, input_dict, batch_size=None, device=None, batch_dims=None, names=None + self, + input_dict, + *others, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, + names=None, ): + if others: + if batch_size is not None: + raise TypeError( + "conflicting batch size values. Please use the keyword argument only." + ) + if device is not None: + raise TypeError( + "conflicting device values. Please use the keyword argument only." + ) + if batch_dims is not None: + raise TypeError( + "conflicting batch_dims values. Please use the keyword argument only." + ) + if names is not None: + raise TypeError( + "conflicting names values. Please use the keyword argument only." + ) + warn( + "All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead.", + category=DeprecationWarning, + ) + batch_size, *others = others + if len(others): + device, *others = others + if len(others): + batch_dims, *others = others + if len(others): + names, *others = others + if len(others): + raise TypeError("Too many positional arguments.") + if batch_dims is not None and batch_size is not None: raise ValueError( "Cannot pass both batch_size and batch_dims to `from_dict`." @@ -2014,14 +2102,25 @@ def from_dict_instance( cur_value = self.get(key, None) if cur_value is not None: input_dict[key] = cur_value.from_dict_instance( - value, batch_size=[], device=device, batch_dims=None + value, + device=device, + auto_batch_size=False, ) continue - # we don't know if another tensor of smaller size is coming - # so we can't be sure that the batch-size will still be valid later - input_dict[key] = TensorDict.from_dict( - value, batch_size=[], device=device, batch_dims=None + else: + # we don't know if another tensor of smaller size is coming + # so we can't be sure that the batch-size will still be valid later + input_dict[key] = TensorDict.from_dict( + value, + device=device, + auto_batch_size=False, + ) + else: + input_dict[key] = TensorDict.from_any( + value, + auto_batch_size=False, ) + out = TensorDict.from_dict( input_dict, batch_size=batch_size_set, @@ -2029,7 +2128,19 @@ def from_dict_instance( names=names, ) if batch_size is None: - _set_max_batch_size(out, batch_dims) + if auto_batch_size is None and batch_dims is None: + warn( + "The batch-size was not provided and auto_batch_size isn't set either. " + "Currently, from_dict will call set auto_batch_size=True but this behaviour " + "will be changed in v0.8 and auto_batch_size will be False onward. " + "To silence this warning, pass auto_batch_size directly.", + category=DeprecationWarning, + ) + auto_batch_size = True + elif auto_batch_size is None: + auto_batch_size = True + if auto_batch_size: + _set_max_batch_size(out, batch_dims) else: out.batch_size = batch_size return out @@ -3857,7 +3968,14 @@ def expand(self, *args: int, inplace: bool = False) -> T: @classmethod def from_dict( - cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None + cls, + input_dict, + *others, + auto_batch_size: bool = False, + batch_size=None, + device=None, + batch_dims=None, + names=None, ): raise NotImplementedError(f"from_dict not implemented for {cls.__name__}.") @@ -4273,6 +4391,12 @@ def _items( (key, tensordict._get_str(key, NO_DEFAULT)) for key in tensordict._source.keys() ) + from tensordict.persistent import PersistentTensorDict + + if isinstance(tensordict, PersistentTensorDict): + return ( + (key, tensordict._get_str(key, NO_DEFAULT)) for key in tensordict.keys() + ) raise NotImplementedError(type(tensordict)) def _keys(self) -> _TensorDictKeysView: @@ -4697,7 +4821,9 @@ def from_modules( ) -def from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=None): +def from_dict( + input_dict, *others, batch_size=None, device=None, batch_dims=None, names=None +): """Returns a TensorDict created from a dictionary or another :class:`~.tensordict.TensorDict`. If ``batch_size`` is not specified, returns the maximum batch size possible. @@ -4762,6 +4888,7 @@ def from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=N """ return TensorDict.from_dict( input_dict, + *others, batch_size=batch_size, device=device, batch_dims=batch_dims, diff --git a/tensordict/base.py b/tensordict/base.py index 39729eba4..6c600b11f 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -12,6 +12,7 @@ import enum import gc import importlib +import importlib.util import os.path import queue import uuid @@ -54,6 +55,7 @@ _CloudpickleWrapper, _DTYPE2STRDTYPE, _GENERIC_NESTED_ERR, + _is_dataclass as is_dataclass, _is_non_tensor, _is_number, _is_tensorclass, @@ -112,6 +114,8 @@ except ImportError: from tensordict.utils import Buffer +_has_h5 = importlib.util.find_spec("h5py") is not None + # NO_DEFAULT is used as a placeholder whenever the default is not provided. # Using None is not an option since `td.get(key)` is a valid usage. @@ -120,7 +124,6 @@ class _NoDefault(enum.IntEnum): NO_DEFAULT = _NoDefault.ZERO -assert not NO_DEFAULT T = TypeVar("T", bound="TensorDictBase") @@ -1133,6 +1136,8 @@ def auto_device_(self) -> T: def from_dict( cls, input_dict, + *, + auto_batch_size: bool | None = None, batch_size: torch.Size | None = None, device: torch.device | None = None, batch_dims: int | None = None, @@ -1148,6 +1153,10 @@ def from_dict( Args: input_dict (dictionary, optional): a dictionary to use as a data source (nested keys compatible). + + Keyword Args: + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. batch_size (iterable of int, optional): a batch size for the tensordict. device (torch.device or compatible type, optional): a device for the TensorDict. batch_dims (int, optional): the ``batch_dims`` (ie number of leading dimensions @@ -1207,12 +1216,15 @@ def _from_dict_validated(cls, *args, **kwargs): By default, falls back on :meth:`~.from_dict`. """ + kwargs.setdefault("auto_batch_size", True) return cls.from_dict(*args, **kwargs) @abc.abstractmethod def from_dict_instance( self, input_dict, + *others, + auto_batch_size: bool | None = None, batch_size=None, device=None, batch_dims=None, @@ -4284,7 +4296,6 @@ def _view_and_pad(tensor): elif k[-1].startswith(""): # NJT/NT always comes before offsets/shapes nt = oldv - assert not v.numel() nt_lengths = None del flat_dict[k] elif k[-1].startswith(""): @@ -9837,6 +9848,113 @@ def dict_to_namedtuple(dictionary): return dict_to_namedtuple(self.to_dict(retain_none=False)) + @classmethod + def from_any(cls, obj, *, auto_batch_size: bool = False): + """Converts any object to a TensorDict, recursively. + + Keyword Args: + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. + + Support includes: + + - dataclasses through :meth:`~.from_dataclass` (dataclasses will be converted to TensorDict instances, not + tensorclasses). + - namedtuple through :meth:`~.from_namedtuple` + - dict through :meth:`~.from_dict` + - tuple through :meth:`~.from_tuple` + - numpy's structured arrays through :meth:`~.from_struct_array` + - h5 objects through :meth:`~.from_h5` + + """ + if is_tensor_collection(obj): + return obj + if isinstance(obj, dict): + return cls.from_dict(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): + return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, tuple): + if is_namedtuple(obj): + return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) + return cls.from_tuple(obj, auto_batch_size=auto_batch_size) + if isinstance(obj, list): + return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size) + if is_dataclass(obj): + return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) + if _has_h5: + import h5py + + if isinstance(obj, h5py.File): + from tensordict.persistent import PersistentTensorDict + + obj = PersistentTensorDict(group=obj) + if auto_batch_size: + obj.auto_batch_size_() + return obj + return obj + + @classmethod + def from_tuple(cls, obj, *, auto_batch_size: bool = False): + from tensordict import TensorDict + + result = TensorDict({str(i): cls.from_any(item) for i, item in enumerate(obj)}) + if auto_batch_size: + result.auto_batch_size_() + return result + + @classmethod + def from_dataclass( + cls, dataclass, *, auto_batch_size: bool = False, as_tensorclass: bool = False + ): + """Converts a dataclass into a TensorDict instance. + + Args: + dataclass: The dataclass instance to be converted. + + Keyword Args: + auto_batch_size (bool, optional): If ``True``, automatically determines and applies batch size to the + resulting TensorDict. Defaults to ``False``. + as_tensorclass (bool, optional): If ``True``, delegates the conversion to the free function + :func:`~tensordict.from_dataclass` and returns a tensor-compatible class + (:func:`~tensordict.tensorclass`) or instance instead of a ``TensorDict``. Defaults to ``False``. + + Returns: + A TensorDict instance derived from the provided dataclass, unless `as_tensorclass` is True, in which case a tensor-compatible class or instance is returned. + + Raises: + TypeError: If the provided input is not a dataclass instance. + + .. warning:: This method is distinct from the free function `from_dataclass` and serves a different purpose. + While the free function returns a tensor-compatible class or instance, this method returns a TensorDict instance. + + .. notes:: + + - This method creates a new TensorDict instance with keys corresponding to the fields of the input dataclass. + - Each key in the resulting TensorDict is initialized using the `cls.from_any` method. + - The `auto_batch_size` option allows for automatic batch size determination and application to the + resulting TensorDict. + + """ + if as_tensorclass: + from tensordict.tensorclass import from_dataclass + + return from_dataclass(dataclass, auto_batch_size=auto_batch_size) + from dataclasses import fields + + from tensordict import TensorDict + + if not is_dataclass(dataclass): + raise TypeError( + f"Expected a dataclass input, got a {type(dataclass)} input instead." + ) + source = {} + for field in fields(dataclass): + source[field.name] = cls.from_any(getattr(dataclass, field.name)) + result = TensorDict(source) + if auto_batch_size: + result.auto_batch_size_() + return result + @classmethod def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False): """Converts a namedtuple to a TensorDict recursively. @@ -9885,8 +10003,7 @@ def namedtuple_to_dict(namedtuple_obj): "indices": namedtuple_obj.indices, } for key, value in namedtuple_obj.items(): - if is_namedtuple(value): - namedtuple_obj[key] = namedtuple_to_dict(value) + namedtuple_obj[key] = cls.from_any(value) return dict(namedtuple_obj) result = TensorDict(namedtuple_to_dict(named_tuple)) diff --git a/tensordict/functional.py b/tensordict/functional.py index a40095141..2699f36bb 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -437,6 +437,7 @@ def make_tensordict( input_dict: dict[str, CompatibleType] | None = None, batch_size: Sequence[int] | torch.Size | int | None = None, device: DeviceType | None = None, + auto_batch_size: bool | None = None, **kwargs: CompatibleType, # source ) -> TensorDict: """Returns a TensorDict created from the keyword arguments or an input dictionary. @@ -453,6 +454,8 @@ def make_tensordict( (incompatible with nested keys). batch_size (iterable of int, optional): a batch size for the tensordict. device (torch.device or compatible type, optional): a device for the TensorDict. + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. Examples: >>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)} @@ -500,4 +503,6 @@ def make_tensordict( """ if input_dict is not None: kwargs.update(input_dict) - return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device) + return TensorDict.from_dict( + kwargs, batch_size=batch_size, device=device, auto_batch_size=auto_batch_size + ) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 0b55d1cef..ffedba9ad 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -297,9 +297,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: f"The key {expected_key} wasn't found in the keyword arguments " f"but is expected to execute that function." ) + batch_size = torch.Size([]) if not self.auto_batch_size else None tensordict = make_tensordict( tensordict_values, - batch_size=torch.Size([]) if not self.auto_batch_size else None, + batch_size=batch_size, + auto_batch_size=False, ) if _self is not None: out = func(_self, tensordict, *args, **kwargs) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 00d984330..bc07b7689 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -928,7 +928,13 @@ def _exclude( @_carry_over def from_dict_instance( - self, input_dict, batch_size=None, device=None, batch_dims=None + self, + input_dict, + *, + auto_batch_size: bool = False, + batch_size=None, + device=None, + batch_dims=None, ): ... @_carry_over diff --git a/tensordict/persistent.py b/tensordict/persistent.py index d5f59110a..332023587 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -207,12 +207,25 @@ def from_h5(cls, filename, mode="r"): return out @classmethod - def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs): + def from_dict( + cls, + input_dict, + filename, + *others, + auto_batch_size: bool = False, + batch_size=None, + device=None, + **kwargs, + ): """Converts a dictionary or a TensorDict to a h5 file. Args: input_dict (dict, TensorDict or compatible): data to be stored as h5. filename (str or path): path to the h5 file. + + Keyword Args: + auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. + Defaults to ``False``. batch_size (tensordict batch-size, optional): if provided, batch size of the tensordict. If not, the batch size will be gathered from the input structure (if present) or determined automatically. @@ -225,6 +238,19 @@ def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs) A :class:`PersitentTensorDict` instance linked to the newly created file. """ + if others: + if batch_size is not None: + raise TypeError( + "conflicting batch size values. Please use the keyword argument only." + ) + warnings.warn( + "All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead." + ) + if len(others) == 2: + batch_size, device = others + else: + batch_size = others[0] + import h5py file = h5py.File(filename, "w", locking=cls.LOCKING) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 2556729e5..e1c8e77b4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -27,6 +27,7 @@ from textwrap import indent from typing import Any, Callable, get_type_hints, List, Sequence, Type, TypeVar +from warnings import warn import numpy as np import orjson as json @@ -45,6 +46,7 @@ CompatibleType, ) from tensordict.utils import ( # @manual=//pytorch/tensordict:_C + _is_dataclass as is_dataclass, _is_json_serializable, _is_tensorclass, _LOCK_ERROR, @@ -237,6 +239,12 @@ def __subclasscheck__(self, subclass): "floor_", "frac", "frac_", + "from_any", + "from_dataclass", + "to_namedtuple", + "from_namedtuple", + "from_pytree", + "to_pytree", "gather", "isfinite", "isnan", @@ -379,6 +387,100 @@ def __call__(self, cls: T) -> T: return clz +def from_dataclass( + obj: Any, + *, + auto_batch_size: bool = False, + frozen: bool = False, + autocast: bool = False, + nocast: bool = False, +) -> Any: + """Converts a dataclass instance or a type into a tensorclass instance or type, respectively. + + This function takes a dataclass instance or a dataclass type and converts it into a tensor-compatible class, + optionally applying various configurations such as auto-batching, immutability, and type casting. + + Args: + obj (Any): The dataclass instance or type to be converted. If a type is provided, a new class is returned. + + Keyword Args: + auto_batch_size (bool, optional): If ``True``, automatically determines and applies batch size to the resulting object. Defaults to ``False``. + frozen (bool, optional): If ``True``, the resulting class or instance will be immutable. Defaults to ``False``. + autocast (bool, optional): If ``True``, enables automatic type casting for the resulting class or instance. Defaults to ``False``. + nocast (bool, optional): If ``True``, disables any type casting for the resulting class or instance. Defaults to ``False``. + + Returns: + A tensor-compatible class or instance derived from the provided dataclass. + + Raises: + TypeError: If the provided input is not a dataclass instance or type. + + Examples: + >>> from dataclasses import dataclass + >>> import torch + >>> from tensordict.tensorclass import from_dataclass + >>> + >>> @dataclass + >>> class X: + ... a: int + ... b: torch.Tensor + ... + >>> x = X(0, 0) + >>> x2 = from_dataclass(x) + >>> print(x2) + X( + a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> X2 = from_dataclass(X, autocast=True) + >>> print(X2(a=0, b=0)) + X( + a=NonTensorData(data=0, batch_size=torch.Size([]), device=None), + b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + batch_size=torch.Size([]), + device=None, + is_shared=False) + + .. notes:: If a dataclass type is provided, a new class is returned with the specified configurations. + If a dataclass instance is provided, a new instance of the tensor-compatible class is returned. + The `auto_batch_size`, `frozen`, `autocast`, and `nocast` options allow for flexible configuration of the resulting class or instance. + + .. warning:: Whereas :meth:`~tensordict.TensorDict.from_dataclass` will return a :class:`~tensordict.TensorDict` instance + by default, this method will return a tensorclass instance or type. + + """ + from dataclasses import asdict, make_dataclass + + if isinstance(obj, type): + if is_tensorclass(obj): + return obj + cls = make_dataclass( + obj.__name__ + "_tc", fields=obj.__dataclass_fields__, bases=obj.__bases__ + ) + clz = _tensorclass(cls, frozen=frozen) + clz._type_hints = get_type_hints(obj) + clz._autocast = autocast + clz._nocast = nocast + clz._frozen = frozen + return clz + + if not is_dataclass(obj): + raise TypeError(f"Expected a obj input, got a {type(obj)} input instead.") + name = obj.__class__.__name__ + "_tc" + clz = _tensorclass( + make_dataclass(name, fields=obj.__dataclass_fields__), frozen=frozen + ) + clz._autocast = autocast + clz._nocast = nocast + clz._frozen = frozen + result = clz(**asdict(obj)) + if auto_batch_size: + result = result.auto_batch_size_() + return result + + @dataclass_transform() def tensorclass( cls: T = None, @@ -532,6 +634,8 @@ def __torch_function__( _is_non_tensor = getattr(cls, "_is_non_tensor", False) + # Breaks some tests, don't do that: + # if not dataclasses.is_dataclass(cls): cls = dataclass(cls, frozen=frozen) _TENSORCLASS_MEMO[cls] = True @@ -1266,7 +1370,7 @@ def _update( non_blocking: bool = False, ): if isinstance(input_dict_or_td, dict): - input_dict_or_td = self.from_dict(input_dict_or_td) + input_dict_or_td = self.from_dict(input_dict_or_td, auto_batch_size=False) if is_tensorclass(input_dict_or_td): non_tensordict = { @@ -1478,7 +1582,15 @@ def _to_dict(self, *, retain_none: bool = True) -> dict: return td_dict -def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): +def _from_dict( + cls, + input_dict, + *, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, +): # we pass through a tensordict because keys could be passed as NestedKeys # We can't assume all keys are strings, otherwise calling cls(**kwargs) # would work ok @@ -1492,7 +1604,11 @@ def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): non_tensordict=input_dict, ) td = TensorDict.from_dict( - input_dict, batch_size=batch_size, device=device, batch_dims=batch_dims + input_dict, + batch_size=batch_size, + device=device, + batch_dims=batch_dims, + auto_batch_size=auto_batch_size, ) non_tensordict = {} @@ -1500,7 +1616,13 @@ def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): def _from_dict_instance( - self, input_dict, batch_size=None, device=None, batch_dims=None + self, + input_dict, + *, + auto_batch_size: bool | None = None, + batch_size=None, + device=None, + batch_dims=None, ): if batch_dims is not None and batch_size is not None: raise ValueError("Cannot pass both batch_size and batch_dims to `from_dict`.") @@ -1510,7 +1632,7 @@ def _from_dict_instance( # TODO: this is a bit slow and will be a bottleneck every time td[idx] = dict(subtd) # is called when there are non tensor data in it if not _is_tensor_collection(type(input_dict)): - input_tdict = TensorDict.from_dict(input_dict) + input_tdict = TensorDict.from_dict(input_dict, auto_batch_size=auto_batch_size) else: input_tdict = input_dict trsf_dict = {} @@ -1538,7 +1660,19 @@ def _from_dict_instance( ) # check that if batch_size is None: - out._tensordict.auto_batch_size_() + if auto_batch_size is None and batch_dims is None: + warn( + "The batch-size was not provided and auto_batch_size isn't set either. " + "Currently, from_dict will call set auto_batch_size=True but this behaviour " + "will be changed in v0.8 and auto_batch_size will be False onward. " + "To silence this warning, pass auto_batch_size directly.", + category=DeprecationWarning, + ) + auto_batch_size = True + elif auto_batch_size is None: + auto_batch_size = True + if auto_batch_size: + out.auto_batch_size_() return out @@ -1658,7 +1792,7 @@ def _is_castable(datatype): if isinstance(value, dict): if _is_tensor_collection(target_cls): - cast_val = target_cls.from_dict(value) + cast_val = target_cls.from_dict(value, auto_batch_size=False) self._tensordict.set( key, cast_val, inplace=inplace, non_blocking=non_blocking ) @@ -2483,7 +2617,6 @@ def __post_init__(self): data_inner = data.tolist() del _tensordict["data"] _non_tensordict["data"] = data_inner - # assert _tensordict.is_empty(), self._tensordict # TODO: this will probably fail with dynamo at some point, + it's terrible. # Make sure it's patched properly at init time diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index 75678b4b6..a77ef185a 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -209,9 +209,16 @@ class TensorClass: def auto_batch_size_(self, batch_dims: int | None = None) -> T: ... def auto_device_(self) -> T: ... @classmethod + def from_dataclass( + cls, dataclass, *, auto_batch_size: bool = False, as_tensorclass: bool = False + ): ... + @classmethod + def from_any(cls, obj, *, auto_batch_size: bool = False): ... + @classmethod def from_dict( cls, input_dict, + *, batch_size: torch.Size | None = None, device: torch.device | None = None, batch_dims: int | None = None, diff --git a/tensordict/utils.py b/tensordict/utils.py index cdc0756f8..81ab2fa0c 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -20,6 +20,7 @@ from collections import defaultdict from collections.abc import KeysView from copy import copy +from dataclasses import _FIELDS, GenericAlias from functools import wraps from importlib import import_module from numbers import Number @@ -858,7 +859,7 @@ def is_tensorclass(obj: type | Any) -> bool: def _is_tensorclass(cls: type) -> bool: - out = _TENSORCLASS_MEMO.get(cls, None) + out = _TENSORCLASS_MEMO.get(cls) if out is None: out = getattr(cls, "_is_tensorclass", False) if not is_dynamo_compiling(): @@ -2813,3 +2814,13 @@ def _mismatch_keys(keys1, keys2): if sub2 is not None: main.append(sub2) raise KeyError(r" ".join(main)) + + +def _is_dataclass(obj): + """Like dataclasses.is_dataclass but compatible with compile.""" + cls = ( + obj + if isinstance(obj, type) and not isinstance(obj, GenericAlias) + else type(obj) + ) + return hasattr(cls, _FIELDS) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 8879f0e68..ad1a194cd 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -53,7 +53,8 @@ class TestTensorDictsBase: TYPES_DEVICES = [] TYPES_DEVICES_NOLAZY = [] - def td(self, device): + @classmethod + def td(cls, device): return TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -68,7 +69,8 @@ def td(self, device): TYPES_DEVICES += [["td", device]] TYPES_DEVICES_NOLAZY += [["td", device]] - def nested_td(self, device): + @classmethod + def nested_td(cls, device): return TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -86,7 +88,8 @@ def nested_td(self, device): TYPES_DEVICES += [["nested_td", device]] TYPES_DEVICES_NOLAZY += [["nested_td", device]] - def nested_tensorclass(self, device): + @classmethod + def nested_tensorclass(cls, device): nested_class = MyClass( X=torch.randn(4, 3, 2, 1), @@ -119,8 +122,9 @@ def nested_tensorclass(self, device): TYPES_DEVICES += [["nested_tensorclass", device]] TYPES_DEVICES_NOLAZY += [["nested_tensorclass", device]] + @classmethod @set_lazy_legacy(True) - def nested_stacked_td(self, device): + def nested_stacked_td(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -140,8 +144,9 @@ def nested_stacked_td(self, device): TYPES_DEVICES += [["nested_stacked_td", device]] TYPES_DEVICES_NOLAZY += [["nested_stacked_td", device]] + @classmethod @set_lazy_legacy(True) - def stacked_td(self, device): + def stacked_td(cls, device): td1 = TensorDict( source={ "a": torch.randn(4, 3, 1, 5), @@ -165,7 +170,8 @@ def stacked_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["stacked_td", device]] - def idx_td(self, device): + @classmethod + def idx_td(cls, device): td = TensorDict( source={ "a": torch.randn(2, 4, 3, 2, 1, 5), @@ -180,7 +186,8 @@ def idx_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["idx_td", device]] - def sub_td(self, device): + @classmethod + def sub_td(cls, device): td = TensorDict( source={ "a": torch.randn(2, 4, 3, 2, 1, 5), @@ -195,7 +202,8 @@ def sub_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["sub_td", device]] - def sub_td2(self, device): + @classmethod + def sub_td2(cls, device): td = TensorDict( source={ "a": torch.randn(4, 2, 3, 2, 1, 5), @@ -212,17 +220,19 @@ def sub_td2(self, device): temp_path_memmap = tempfile.TemporaryDirectory() - def memmap_td(self, device): - path = pathlib.Path(self.temp_path_memmap.name) + @classmethod + def memmap_td(cls, device): + path = pathlib.Path(cls.temp_path_memmap.name) shutil.rmtree(path) path.mkdir() - return self.td(device).memmap_(path) + return cls.td(device).memmap_(path) TYPES_DEVICES += [["memmap_td", torch.device("cpu")]] TYPES_DEVICES_NOLAZY += [["memmap_td", torch.device("cpu")]] + @classmethod @set_lazy_legacy(True) - def permute_td(self, device): + def permute_td(cls, device): return TensorDict( source={ "a": torch.randn(3, 1, 4, 2, 5), @@ -236,8 +246,9 @@ def permute_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["permute_td", device]] + @classmethod @set_lazy_legacy(True) - def unsqueezed_td(self, device): + def unsqueezed_td(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 2, 5), @@ -252,8 +263,9 @@ def unsqueezed_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["unsqueezed_td", device]] + @classmethod @set_lazy_legacy(True) - def squeezed_td(self, device): + def squeezed_td(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 1, 2, 1, 5), @@ -268,7 +280,8 @@ def squeezed_td(self, device): for device in get_available_devices(): TYPES_DEVICES += [["squeezed_td", device]] - def td_reset_bs(self, device): + @classmethod + def td_reset_bs(cls, device): td = TensorDict( source={ "a": torch.randn(4, 3, 2, 1, 5), @@ -285,13 +298,14 @@ def td_reset_bs(self, device): TYPES_DEVICES += [["td_reset_bs", device]] TYPES_DEVICES_NOLAZY += [["td_reset_bs", device]] + @classmethod def td_h5( - self, + cls, device, ): file = tempfile.NamedTemporaryFile() filename = file.name - nested_td = self.nested_td(device) + nested_td = cls.nested_td(device) td_h5 = PersistentTensorDict.from_dict( nested_td, filename=filename, device=device ) @@ -303,15 +317,17 @@ def td_h5( TYPES_DEVICES += [["td_h5", device]] TYPES_DEVICES_NOLAZY += [["td_h5", device]] - def td_params(self, device): - return TensorDictParams(self.td(device)) + @classmethod + def td_params(cls, device): + return TensorDictParams(cls.td(device)) for device in get_available_devices(): TYPES_DEVICES += [["td_params", device]] TYPES_DEVICES_NOLAZY += [["td_params", device]] - def td_with_non_tensor(self, device): - td = self.td(device) + @classmethod + def td_with_non_tensor(cls, device): + td = cls.td(device) return td.set_non_tensor( ("data", "non_tensor"), # this is allowed since nested NonTensorData are automatically unwrapped diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 4753c3704..0f71bd743 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -23,6 +23,7 @@ import tensordict.utils import torch from tensordict import TensorClass +from tensordict.tensorclass import from_dataclass try: import torchsnapshot @@ -94,6 +95,21 @@ class MyData2: z: list +@dataclasses.dataclass +class MyDataClass: + a: int + b: torch.Tensor + c: str + + +try: + MyTensorClass_autocast = from_dataclass(MyDataClass, autocast=True) + MyTensorClass_nocast = from_dataclass(MyDataClass, nocast=True) + MyTensorClass = from_dataclass(MyDataClass) +except Exception: + MyTensorClass_autocast = MyTensorClass_nocast = MyTensorClass = None + + class TestTensorClass: def test_all_any(self): @tensorclass @@ -517,6 +533,43 @@ class MyClass2: assert (a != c.clone().zero_()).any() assert (c != a.clone().zero_()).any() + def test_from_dataclass(self): + assert is_tensorclass(MyTensorClass_autocast) + assert MyTensorClass_nocast is not MyDataClass + assert MyTensorClass_autocast._autocast + x = MyTensorClass_autocast(a=0, b=0, c=0) + assert isinstance(x.a, int) + assert isinstance(x.b, torch.Tensor) + assert isinstance(x.c, str) + + assert is_tensorclass(MyTensorClass_nocast) + assert MyTensorClass_nocast is not MyTensorClass_autocast + assert MyTensorClass_nocast._nocast + + x = MyTensorClass_nocast(a=0, b=0, c=0) + assert is_tensorclass(MyTensorClass) + assert not MyTensorClass._autocast + assert not MyTensorClass._nocast + assert isinstance(x.a, int) + assert isinstance(x.b, int) + assert isinstance(x.c, int) + + x = MyTensorClass(a=0, b=0, c=0) + assert isinstance(x.a, torch.Tensor) + assert isinstance(x.b, torch.Tensor) + assert isinstance(x.c, torch.Tensor) + + x = TensorDict.from_dataclass(MyTensorClass(a=0, b=0, c=0)) + assert isinstance(x, TensorDict) + assert isinstance(x["a"], torch.Tensor) + assert isinstance(x["b"], torch.Tensor) + assert isinstance(x["c"], torch.Tensor) + x = from_dataclass(MyTensorClass(a=0, b=0, c=0)) + assert is_tensorclass(x) + assert isinstance(x.a, torch.Tensor) + assert isinstance(x.b, torch.Tensor) + assert isinstance(x.c, torch.Tensor) + def test_from_dict(self): td = TensorDict( { @@ -531,7 +584,7 @@ def test_from_dict(self): class MyClass: a: TensorDictBase - tc = MyClass.from_dict(d) + tc = MyClass.from_dict(d, auto_batch_size=True) assert isinstance(tc, MyClass) assert isinstance(tc.a, TensorDict) assert tc.batch_size == torch.Size([10]) @@ -2095,7 +2148,9 @@ class TestClass: my_tensor=torch.tensor([1, 2, 3]), my_str="hello", batch_size=[3] ) - assert (test_class == TestClass.from_dict(test_class.to_dict())).all() + assert ( + test_class == TestClass.from_dict(test_class.to_dict(), auto_batch_size=True) + ).all() # Currently we don't test non-tensor in __eq__ because __eq__ can break with arrays and such # test_class2 = TestClass( @@ -2108,7 +2163,9 @@ class TestClass: my_tensor=torch.tensor([1, 2, 0]), my_str="hello", batch_size=[3] ) - assert not (test_class == TestClass.from_dict(test_class3.to_dict())).all() + assert not ( + test_class == TestClass.from_dict(test_class3.to_dict(), auto_batch_size=True) + ).all() @tensorclass(autocast=True) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 257c2b712..73d401c03 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -20,6 +20,7 @@ import warnings from dataclasses import dataclass from pathlib import Path +from typing import Any import numpy as np import pytest @@ -949,6 +950,64 @@ def test_fromkeys(self): td = TensorDict.fromkeys({"a", "b", "c"}, 1) assert td["a"] == 1 + def test_from_any(self): + from dataclasses import dataclass + + @dataclass + class MyClass: + a: int + + pytree = ( + [torch.randint(10, (3,)), torch.zeros(2)], + { + "tensor": torch.randn( + 2, + ), + "td": TensorDict({"one": 1}), + "tuple": (1, 2, 3), + }, + {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, + {"dataclass": MyClass(a=0)}, + ) + if _has_h5py: + pytree = pytree + ({"h5py": TestTensorDictsBase.td_h5(device="cpu").file},) + td = TensorDict.from_any(pytree) + expected = { + ("0", "0"), + ("0", "1"), + ("1", "td", "one"), + ("1", "tensor"), + ("1", "tuple", "0"), + ("1", "tuple", "1"), + ("1", "tuple", "2"), + ("2", "named_tuple", "two"), + ("3", "dataclass", "a"), + } + if _has_h5py: + expected = expected.union( + { + ("4", "h5py", "a"), + ("4", "h5py", "b"), + ("4", "h5py", "c"), + ("4", "h5py", "my_nested_td", "inner"), + } + ) + assert set(td.keys(True, True)) == expected, set( + td.keys(True, True) + ).symmetric_difference(expected) + + def test_from_dataclass(self): + @dataclass + class MyClass: + a: int + b: Any + + obj = MyClass(a=0, b=1) + obj_td = TensorDict.from_dataclass(obj) + obj_tc = TensorDict.from_dataclass(obj, as_tensorclass=True) + assert is_tensorclass(obj_tc) + assert not is_tensorclass(obj_td) + @pytest.mark.parametrize("batch_size", [None, [3, 4]]) @pytest.mark.parametrize("batch_dims", [None, 1, 2]) @pytest.mark.parametrize("device", get_available_devices()) @@ -967,7 +1026,11 @@ def test_from_dict(self, batch_size, batch_dims, device): ) return data = TensorDict.from_dict( - data, batch_size=batch_size, batch_dims=batch_dims, device=device + data, + batch_size=batch_size, + batch_dims=batch_dims, + device=device, + auto_batch_size=True, ) assert data.device == device assert "a" in data.keys() @@ -1001,7 +1064,7 @@ class MyClass: assert isinstance(td_dict["b"]["y"], int) assert isinstance(td_dict["b"]["z"], dict) assert isinstance(td_dict["b"]["z"]["y"], int) - td_recon = td.from_dict_instance(td_dict) + td_recon = td.from_dict_instance(td_dict, auto_batch_size=True) assert isinstance(td_recon["a"], torch.Tensor) assert isinstance(td_recon["b"], MyClass) assert isinstance(td_recon["b"].x, torch.Tensor) @@ -6443,7 +6506,7 @@ def recursive_checker(cur_dict): assert recursive_checker(td_dict) if td_name == "td_with_non_tensor": assert td_dict["data"]["non_tensor"] == "some text data" - assert (TensorDict.from_dict(td_dict) == td).all() + assert (TensorDict.from_dict(td_dict, auto_batch_size=False) == td).all() def test_to_namedtuple(self, td_name, device): def is_namedtuple(obj): @@ -7771,7 +7834,7 @@ def test_mp(self, td_type, unbind_as): class TestMakeTensorDict: def test_create_tensordict(self): - tensordict = make_tensordict(a=torch.zeros(3, 4)) + tensordict = make_tensordict(a=torch.zeros(3, 4), auto_batch_size=True) assert (tensordict["a"] == torch.zeros(3, 4)).all() def test_nested(self): @@ -7779,7 +7842,7 @@ def test_nested(self): "a": {"b": torch.randn(3, 4), "c": torch.randn(3, 4, 5)}, "d": torch.randn(3), } - tensordict = make_tensordict(input_dict) + tensordict = make_tensordict(input_dict, auto_batch_size=True) assert tensordict.shape == torch.Size([3]) assert tensordict["a"].shape == torch.Size([3, 4]) input_tensordict = TensorDict( @@ -7789,7 +7852,7 @@ def test_nested(self): }, [], ) - tensordict = make_tensordict(input_tensordict) + tensordict = make_tensordict(input_tensordict, auto_batch_size=True) assert tensordict.shape == torch.Size([3]) assert tensordict["a"].shape == torch.Size([3, 4]) input_dict = { @@ -7797,30 +7860,40 @@ def test_nested(self): ("a", "c"): torch.randn(3, 4, 5), "d": torch.randn(3), } - tensordict = make_tensordict(input_dict) + tensordict = make_tensordict(input_dict, auto_batch_size=True) assert tensordict.shape == torch.Size([3]) assert tensordict["a"].shape == torch.Size([3, 4]) def test_tensordict_batch_size(self): - tensordict = make_tensordict() + tensordict = make_tensordict(auto_batch_size=True) assert tensordict.batch_size == torch.Size([]) - tensordict = make_tensordict(a=torch.randn(3, 4)) + tensordict = make_tensordict(a=torch.randn(3, 4), auto_batch_size=True) assert tensordict.batch_size == torch.Size([3, 4]) - tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(3, 4, 5)) + tensordict = make_tensordict( + a=torch.randn(3, 4), b=torch.randn(3, 4, 5), auto_batch_size=True + ) assert tensordict.batch_size == torch.Size([3, 4]) - nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(3, 5)) # nested + nested_tensordict = make_tensordict( + c=tensordict, d=torch.randn(3, 5), auto_batch_size=True + ) # nested assert nested_tensordict.batch_size == torch.Size([3]) - nested_tensordict = make_tensordict(c=tensordict, d=torch.randn(4, 5)) # nested + nested_tensordict = make_tensordict( + c=tensordict, d=torch.randn(4, 5), auto_batch_size=True + ) # nested assert nested_tensordict.batch_size == torch.Size([]) - tensordict = make_tensordict(a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5)) + tensordict = make_tensordict( + a=torch.randn(3, 4, 2), b=torch.randn(3, 4, 5), auto_batch_size=True + ) assert tensordict.batch_size == torch.Size([3, 4]) - tensordict = make_tensordict(a=torch.randn(3, 4), b=torch.randn(1)) + tensordict = make_tensordict( + a=torch.randn(3, 4), b=torch.randn(1), auto_batch_size=True + ) assert tensordict.batch_size == torch.Size([]) tensordict = make_tensordict( @@ -7836,7 +7909,10 @@ def test_tensordict_batch_size(self): @pytest.mark.parametrize("device", get_available_devices()) def test_tensordict_device(self, device): tensordict = make_tensordict( - a=torch.randn(3, 4), b=torch.randn(3, 4), device=device + a=torch.randn(3, 4), + b=torch.randn(3, 4), + device=device, + auto_batch_size=True, ) assert tensordict.device == device assert tensordict["a"].device == device @@ -7847,6 +7923,7 @@ def test_tensordict_device(self, device): b=torch.randn(3, 4), c=torch.randn(3, 4, device="cpu"), device=device, + auto_batch_size=True, ) assert tensordict.device == device assert tensordict["a"].device == device @@ -10584,7 +10661,8 @@ def test_non_tensor_call(self): def test_nontensor_dict(self, non_tensor_data): assert ( - TensorDict.from_dict(non_tensor_data.to_dict()) == non_tensor_data + TensorDict.from_dict(non_tensor_data.to_dict(), auto_batch_size=True) + == non_tensor_data ).all() def test_nontensor_tensor(self): @@ -11125,7 +11203,7 @@ def _to_float(td, td_name, tmpdir): td._source = td._source.float() elif td_name in ("td_h5",): td = PersistentTensorDict.from_dict( - td.float().to_dict(), filename=tmpdir + "/file.t" + td.float().to_dict(), filename=tmpdir + "/file.t", auto_batch_size=True ) elif td_name in ("td_params",): td = TensorDictParams(td.data.float())