Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 30, 2025
1 parent dcb5663 commit 2dd983a
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,20 @@ def _stack_uninit_params(list_of_params, dim=0, out=None):
)
out.batch_size = torch.Size([len(list_of_params)])
return out

def implements_for_tdtype(torch_function: Callable) -> Callable[[Callable], Callable]:
"""Register a torch function override for TensorDict."""

from tensordict.dtype import TDTYPE_HANDLED_FUNCTIONS

@functools.wraps(torch_function)
def decorator(func: Callable) -> Callable:
TDTYPE_HANDLED_FUNCTIONS[torch_function] = func
return func

return decorator

@implements_for_tdtype(torch.Tensor.view)
def view(tensor: torch.tensor, dtype: Any) -> TensorDictBase:
from tensordict.dtype import StructDtype
return StructDtype.view(tensor, dtype)
112 changes: 112 additions & 0 deletions tensordict/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import deque
import orjson as json
from typing import Callable, Any


TDTYPE_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}

class StructDtype:
# def __new__(cls, map=None):
# if isinstance(map, StructDtype):
# return map
# return super().__new__(cls)
def __init__(self, map=None):
if map is None:
map = {}
assert isinstance(map, dict)
self._maps = map

@classmethod
def from_td(cls, data: "TensorDictBase"):
from tensordict.base import _is_tensor_collection
self = cls()
map = self._maps
stack = deque()
stack.append((self, data))
while len(stack):
sdtype, local_data = stack.popleft()
map = sdtype._maps
# TODO: handle lazy stacks here
for k, v in local_data.items():
cls = type(v)
if _is_tensor_collection(cls):
# TODO: handle different dtypes here
# TODO: handle LazyStacks here
newmap = map[k] = StructDtype({})
stack.append((newmap, v))
else:
map[k] = {
"shape": v.shape,
"dtype": v.dtype,
}
return self

def items(self, include_nested: bool=False, leaves_only: bool=False):
stack = deque()
stack.append(self)
while len(stack):
node = stack.popleft()
for k, v in node._maps.items():
if isinstance(v, StructDtype):
if include_nested:
stack.append(v)
if not leaves_only:
yield (k, v)
else:
yield k, v

def values(self, include_nested: bool=False, leaves_only: bool=False):
yield from (_, v in self.items(include_nested=include_nested, leaves_only=leaves_only))

def keys(self, include_nested: bool=False, leaves_only: bool=False):
yield from (k, _ in self.items(include_nested=include_nested, leaves_only=leaves_only))

# def json(self):
# return json.dumps(metadata_dict)

@classmethod
def __torch_function__(
cls,
func: Callable,
types: tuple[type, ...],
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
) -> Callable:
if kwargs is None:
kwargs = {}
if func not in TDTYPE_HANDLED_FUNCTIONS:
return NotImplemented
return TDTYPE_HANDLED_FUNCTIONS[func](*args, **kwargs)


@classmethod
def view(cls, tensor, dtype):
from tensordict import TensorDict
ns = []
shapes = []
dts = []
keys = []
stack = deque()
stack.append((dtype.items(), ()))
tensor_itemsize = tensor.dtype.itemsize
while len(stack):
items, prefix = stack.popleft()
for k, dt in items:
currentk = prefix + (k,)
if isinstance(dt, StructDtype):
stack.append((dt.items(), currentk))
continue
assert currentk not in keys, (currentk, keys)
keys.append(currentk)
s = dt["shape"]
dt = dt["dtype"]
shapes.append(s)
dts.append(dt)
nelts = (dt.itemsize * s.numel()) // tensor_itemsize
ns.append(nelts)

return TensorDict({k: v.view(dt).view(shape) for k, v, dt, shape in zip(keys, tensor.split(ns), dts, shapes, strict=True)})

0 comments on commit 2dd983a

Please sign in to comment.