Skip to content

Commit

Permalink
Merge branch 'main' into better-dtypes
Browse files Browse the repository at this point in the history
# Conflicts:
#	tensordict/utils.py
  • Loading branch information
vmoens committed Jun 25, 2024
2 parents ae3555f + 19f10d0 commit bdcbe17
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 8 deletions.
168 changes: 161 additions & 7 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import importlib
import json
import numbers
import uuid
import weakref
from collections.abc import MutableMapping

Expand Down Expand Up @@ -66,6 +67,8 @@
IndexType,
infer_size_impl,
int_generator,
is_namedtuple,
is_namedtuple_class,
is_non_tensor,
lazy_legacy,
lock_blocked,
Expand Down Expand Up @@ -796,6 +799,134 @@ def from_dict_instance(
"""
...

@classmethod
def from_pytree(
cls,
pytree,
*,
batch_size: torch.Size | None = None,
auto_batch_size: bool = False,
batch_dims: int | None = None,
):
"""Converts a pytree to a TensorDict instance.
This method is designed to keep the pytree nested structure as much as possible.
Additional non-tensor keys are added to keep track of each level's identity, providing
a built-in pytree-to-tensordict bijective transform API.
Accepted classes currently include lists, tuples, named tuples and dict.
.. note:: for dictionaries, non-NestedKey keys are registered separately as :class:`~tensordict.NonTensorData`
instances.
.. note:: Tensor-castable types (such as int, float or np.ndarray) will be converted to torch.Tensor instances.
NOte that this transformation is surjective: transforming back the tensordict to a pytree will not
recover the original types.
Examples:
>>> # Create a pytree with tensor leaves, and one "weird"-looking dict key
>>> class WeirdLookingClass:
... pass
...
>>> weird_key = WeirdLookingClass()
>>> # Make a pytree with tuple, lists, dict and namedtuple
>>> pytree = (
... [torch.randint(10, (3,)), torch.zeros(2)],
... {
... "tensor": torch.randn(
... 2,
... ),
... "td": TensorDict({"one": 1}),
... weird_key: torch.randint(10, (2,)),
... "list": [1, 2, 3],
... },
... {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()},
... )
>>> # Build a TensorDict from that pytree
>>> td = TensorDict.from_pytree(pytree)
>>> # Recover the pytree
>>> pytree_recon = td.to_pytree()
>>> # Check that the leaves match
>>> def check(v1, v2):
>>> assert (v1 == v2).all()
>>>
>>> torch.utils._pytree.tree_map(check, pytree, pytree_recon)
>>> assert weird_key in pytree_recon[1]
"""
if is_tensor_collection(pytree):
return pytree
if isinstance(pytree, (torch.Tensor,)):
return pytree

from tensordict._td import TensorDict

result = None
if is_namedtuple(pytree):
result = TensorDict.from_namedtuple(named_tuple=pytree)
if batch_dims is not None:
result.batch_size = batch_size
result["_pytree_type"] = type(pytree)
elif isinstance(pytree, (list, tuple)):
source = {str(i): cls.from_pytree(elt) for i, elt in enumerate(pytree)}
source["_pytree_type"] = type(pytree)
result = TensorDict(source, batch_size=batch_size)
elif isinstance(pytree, dict):
source = {}
for key, item in pytree.items():
if isinstance(key, NestedKey):
source[key] = cls.from_pytree(item)
else:
subs_key = "<NON_NESTED>" + str(uuid.uuid1())
source[subs_key] = TensorDict(
{"value": cls.from_pytree(item), "key": key}
)
source["_pytree_type"] = type(pytree)
result = TensorDict(source, batch_size=batch_size)
if result is not None:
if auto_batch_size:
result.auto_batch_size_(batch_dims)
return result
if isinstance(pytree, (int, float, np.ndarray)):
return torch.as_tensor(pytree)
raise NotImplementedError(f"Unknown type {type(pytree)}.")

def to_pytree(self):
"""Converts a tensordict to a PyTree.
If the tensordict was not created from a pytree, this method just returns ``self`` without modification.
See :meth:`~.from_pytree` for more information and examples.
"""
_pytree_type = self._get_str("_pytree_type", default=None)
if _pytree_type is None:
return self
_pytree_type = _pytree_type.data
items = {key: val for (key, val) in self.items() if key != "_pytree_type"}
items = {
key: val if not is_tensor_collection(val) else val.to_pytree()
for key, val in items.items()
}
if _pytree_type in (list, tuple):
return _pytree_type((items[str(i)] for i in range(len(items))))
if _pytree_type is dict:
items = dict(
(
(val["key"], val["value"])
if key.startswith("<NON_NESTED>")
else (key, val)
for (key, val) in items.items()
)
)
return items
if is_namedtuple_class(_pytree_type):
from tensordict._td import TensorDict

return TensorDict(items).to_namedtuple(dest_cls=_pytree_type)
raise NotImplementedError(f"unknown type {_pytree_type}")

@classmethod
def from_h5(cls, filename, mode="r"):
"""Creates a PersistentTensorDict from a h5 file.
Expand Down Expand Up @@ -1651,6 +1782,10 @@ def gather(self, dim: int, index: Tensor, out: T | None = None) -> T:
def view(self, *shape: int):
...

@overload
def view(self, dtype):
...

@overload
def view(self, shape: torch.Size):
...
Expand All @@ -1668,13 +1803,23 @@ def view(
self,
*shape: int,
size: list | tuple | torch.Size | None = None,
batch_size: torch.Size | None = None,
):
"""Returns a tensordict with views of the tensors according to a new shape, compatible with the tensordict batch_size.
Alternatively, a dtype can be provided as a first unnamed argument. In that case, all tensors will be viewed
with the according dtype. Note that this assume that the new shapes will be compatible with the provided dtype.
See :meth:`~torch.view` for more information on dtype views.
Args:
*shape (int): new shape of the resulting tensordict.
dtype (torch.dtype): alternatively, a dtype to use to represent the tensor content.
size: iterable
Keyword Args:
batch_size (torch.Size, optional): if a dtype is provided, the batch-size can be reset using this
keyword argument. If the ``view`` is called with a shape, this is without effect.
Returns:
a new tensordict with the desired batch_size.
Expand All @@ -1689,6 +1834,9 @@ def view(
>>> print(td_view.get("b").shape) # torch.Size([1, 4, 3, 10, 1])
"""
if len(shape) == 1 and isinstance(shape[0], torch.dtype):
dtype = shape[0]
return self._view_dtype(dtype=dtype, batch_size=batch_size)
_lazy_legacy = lazy_legacy()

if _lazy_legacy:
Expand All @@ -1699,6 +1847,10 @@ def view(
result.lock_()
return result

def _view_dtype(self, *, dtype, batch_size):
# We use apply because we want to check the shapes
return self.apply(lambda x: x.view(dtype), batch_size=batch_size)

def _legacy_view(
self,
*shape: int,
Expand Down Expand Up @@ -6789,9 +6941,12 @@ def to_numpy(x):

return torch.utils._pytree.tree_map(to_numpy, as_dict)

def to_namedtuple(self):
def to_namedtuple(self, dest_cls: type | None = None):
"""Converts a tensordict to a namedtuple.
Args:
dest_cls (Type, optional): an optional namedtuple class to use.
Examples:
>>> from tensordict import TensorDict
>>> import torch
Expand All @@ -6807,9 +6962,12 @@ def dict_to_namedtuple(dictionary):
for key, value in dictionary.items():
if isinstance(value, dict):
dictionary[key] = dict_to_namedtuple(value)
return collections.namedtuple("GenericDict", dictionary.keys())(
**dictionary
cls = (
collections.namedtuple("GenericDict", dictionary.keys())
if dest_cls is None
else dest_cls
)
return cls(**dictionary)

return dict_to_namedtuple(self.to_dict())

Expand Down Expand Up @@ -6848,10 +7006,6 @@ def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False):
"""
from tensordict import TensorDict

def is_namedtuple(obj):
"""Check if obj is a namedtuple."""
return isinstance(obj, tuple) and hasattr(obj, "_fields")

def namedtuple_to_dict(namedtuple_obj):
if is_namedtuple(namedtuple_obj):
namedtuple_obj = namedtuple_obj._asdict()
Expand Down
37 changes: 36 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import abc
import collections
import concurrent.futures
import inspect
Expand Down Expand Up @@ -130,7 +131,30 @@ def dims(self, *args, **kwargs):

IndexType = Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]]
DeviceType = Union[torch.device, str, int]
NestedKey = Union[str, Tuple[str, ...]]


class _NestedKeyMeta(abc.ABCMeta):
def __instancecheck__(self, instance):
return isinstance(instance, str) or (
isinstance(instance, tuple)
and len(instance)
and all(isinstance(subkey, NestedKey) for subkey in instance)
)


class NestedKey(metaclass=_NestedKeyMeta):
"""An abstract class for nested keys.
Nested keys are the generic key type accepted by TensorDict.
A nested key is either a string or a non-empty tuple of NestedKeys instances.
The NestedKey class supports instance checks.
"""

pass


_KEY_ERROR = 'key "{}" not found in {} with ' "keys {}"
_LOCK_ERROR = (
Expand Down Expand Up @@ -2277,6 +2301,17 @@ def __missing__(self, key):
return value


def is_namedtuple(obj):
"""Check if obj is a namedtuple."""
return isinstance(obj, tuple) and hasattr(obj, "_fields")


def is_namedtuple_class(cls):
"""Check if a class is a namedtuple class."""
base_attrs = {"_fields", "_replace", "_asdict"}
return all(hasattr(cls, attr) for attr in base_attrs)


def _make_dtype_promotion(func):
dtype = getattr(torch, func.__name__)

Expand Down
33 changes: 33 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,33 @@ def get_leaf(leaf):
assert p.grad is None
assert all(param.grad is not None for param in params.values(True, True))

def test_from_pytree(self):
class WeirdLookingClass:
pass

weird_key = WeirdLookingClass()

pytree = (
[torch.randint(10, (3,)), torch.zeros(2)],
{
"tensor": torch.randn(
2,
),
"td": TensorDict({"one": 1}),
weird_key: torch.randint(10, (2,)),
"list": [1, 2, 3],
},
{"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()},
)
td = TensorDict.from_pytree(pytree)
pytree_recon = td.to_pytree()

def check(v1, v2):
assert (v1 == v2).all()

torch.utils._pytree.tree_map(check, pytree, pytree_recon)
assert weird_key in pytree_recon[1]

@pytest.mark.parametrize(
"idx",
[
Expand Down Expand Up @@ -5722,6 +5749,12 @@ def test_view(self, td_name, device):

assert (td_view.get("a") == 1).all()

@set_lazy_legacy(False)
def test_view_dtype(self, td_name, device):
td = getattr(self, td_name)(device)
tview = td.view(torch.uint8, batch_size=[])
assert all(p.dtype == torch.uint8 for p in tview.values(True, True))

@set_lazy_legacy(False)
def test_view_decorator(self, td_name, device):
td = getattr(self, td_name)(device)
Expand Down

0 comments on commit bdcbe17

Please sign in to comment.