Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] tensorclass method registration check #1175

Merged
merged 3 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7938,7 +7938,7 @@ def reduce(
async_op=False,
return_premature=False,
group=None,
):
) -> None:
"""Reduces the tensordict across all machines.

Only the process with ``rank`` dst is going to receive the final result.
Expand Down Expand Up @@ -9028,7 +9028,7 @@ def newfn(item_and_out):
return out

# Stream
def record_stream(self, stream: torch.cuda.Stream):
def record_stream(self, stream: torch.cuda.Stream) -> T:
"""Marks the tensordict as having been used by this stream.

When the tensordict is deallocated, ensure the tensor memory is not reused for other tensors until all work
Expand Down Expand Up @@ -11345,7 +11345,7 @@ def copy(self):
"""
return self.clone(recurse=False)

def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None):
def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None) -> T:
"""Converts all nested tensors to a padded version and adapts the batch-size accordingly.

Args:
Expand Down Expand Up @@ -12430,7 +12430,7 @@ def split_keys(
default: Any = NO_DEFAULT,
strict: bool = True,
reproduce_struct: bool = False,
):
) -> Tuple[T, ...]:
"""Splits the tensordict in subsets given one or more set of keys.

The method will return ``N+1`` tensordicts, where ``N`` is the number of
Expand Down
120 changes: 113 additions & 7 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __subclasscheck__(self, subclass):
}
# Methods to be executed from tensordict, any ref to self means 'tensorclass'
_METHOD_FROM_TD = [
"dumps",
"load_",
"memmap",
"memmap_",
Expand All @@ -143,21 +144,48 @@ def __subclasscheck__(self, subclass):
"_items_list",
"_maybe_names",
"_multithread_apply_flat",
"_multithread_apply_nest",
"_multithread_rebuild", # rebuild checks if self is a non tensor
"_propagate_lock",
"_propagate_unlock",
"_reduce_get_metadata",
"_values_list",
"bytes",
"cat_tensors",
"data_ptr",
"depth",
"dim",
"dtype",
"entry_class",
"get_item_shape",
"get_non_tensor",
"irecv",
"is_consolidated",
"is_contiguous",
"is_cpu",
"is_cuda",
"is_empty",
"is_floating_point",
"is_memmap",
"is_meta",
"is_shared",
"isend",
"items",
"keys",
"make_memmap",
"make_memmap_from_tensor",
"ndimension",
"numel",
"numpy",
"param_count",
"pop",
"recv",
"reduce",
"saved_path",
"send",
"size",
"sorted_keys",
"to_struct_array",
"values",
# "ndim",
]
Expand Down Expand Up @@ -212,9 +240,6 @@ def __subclasscheck__(self, subclass):
"_map",
"_maybe_remove_batch_dim",
"_memmap_",
"_multithread_apply_flat",
"_multithread_apply_nest",
"_multithread_rebuild",
"_permute",
"_remove_batch_dim",
"_repeat",
Expand All @@ -233,6 +258,8 @@ def __subclasscheck__(self, subclass):
"addcmul",
"addcmul_",
"all",
"amax",
"amin",
"any",
"apply",
"apply_",
Expand All @@ -243,31 +270,43 @@ def __subclasscheck__(self, subclass):
"atan_",
"auto_batch_size_",
"auto_device_",
"bfloat16",
"bitwise_and",
"bool",
"cat",
"cat_from_tensordict",
"ceil",
"ceil_",
"chunk",
"clamp",
"clamp_max",
"clamp_max_",
"clamp_min",
"clamp_min_",
"clear",
"clear_device_",
"complex128",
"complex32",
"complex64",
"consolidate",
"contiguous",
"copy_",
"copy_at_",
"cos",
"cos_",
"cosh",
"cosh_",
"cpu",
"create_nested",
"cuda",
"cummax",
"cummin",
"densify",
"detach",
"detach_",
"div",
"div_",
"double",
"empty",
"erf",
"erf_",
Expand All @@ -280,20 +319,43 @@ def __subclasscheck__(self, subclass):
"expand_as",
"expm1",
"expm1_",
"fill_",
"filter_empty_",
"filter_non_tensor_data",
"flatten",
"flatten_keys",
"float",
"float16",
"float32",
"float64",
"floor",
"floor_",
"frac",
"frac_",
"from_any",
"from_consolidated",
"from_dataclass",
"from_h5",
"from_modules",
"from_namedtuple",
"from_pytree",
"from_struct_array",
"from_tuple",
"fromkeys",
"gather",
"gather_and_stack",
"half",
"int",
"int16",
"int32",
"int64",
"int8",
"isfinite",
"isnan",
"isneginf",
"isposinf",
"isreal",
"lazy_stack",
"lerp",
"lerp_",
"lgamma",
Expand All @@ -310,13 +372,16 @@ def __subclasscheck__(self, subclass):
"log_",
"logical_and",
"logsumexp",
"make_memmap_from_storage",
"map",
"map_iter",
"masked_fill",
"masked_fill_",
"masked_select",
"max",
"maximum",
"maximum_",
"maybe_dense_stack",
"mean",
"min",
"minimum",
Expand All @@ -336,13 +401,22 @@ def __subclasscheck__(self, subclass):
"norm",
"permute",
"pin_memory",
"pin_memory_",
"popitem",
"pow",
"pow_",
"prod",
"qint32",
"qint8",
"quint4x2",
"quint8",
"reciprocal",
"reciprocal_",
"record_stream",
"refine_names",
"rename",
"rename_", # TODO: must be specialized
"rename_key_",
"repeat",
"repeat_interleave",
"replace",
Expand All @@ -351,6 +425,10 @@ def __subclasscheck__(self, subclass):
"round",
"round_",
"select",
"separates",
"set_",
"set_non_tensor",
"setdefault",
"sigmoid",
"sigmoid_",
"sign",
Expand All @@ -361,9 +439,13 @@ def __subclasscheck__(self, subclass):
"sinh_",
"softmax",
"split",
"split_keys",
"sqrt",
"sqrt_",
"squeeze",
"stack",
"stack_from_tensordict",
"stack_tensors",
"std",
"sub",
"sub_",
Expand All @@ -373,13 +455,21 @@ def __subclasscheck__(self, subclass):
"tanh",
"tanh_",
"to",
"to_h5",
"to_module",
"to_namedtuple",
"to_padded_tensor",
"to_pytree",
"transpose",
"trunc",
"trunc_",
"type",
"uint16",
"uint32",
"uint64",
"uint8",
"unflatten",
"unflatten_keys",
"unlock_",
"unsqueeze",
"var",
Expand All @@ -388,10 +478,6 @@ def __subclasscheck__(self, subclass):
"zero_",
"zero_grad",
]
assert not any(v in _METHOD_FROM_TD for v in _FALLBACK_METHOD_FROM_TD), set(
_METHOD_FROM_TD
).intersection(_FALLBACK_METHOD_FROM_TD)
assert len(set(_FALLBACK_METHOD_FROM_TD)) == len(_FALLBACK_METHOD_FROM_TD)

# These methods require a copy of the non tensor data
_FALLBACK_METHOD_FROM_TD_COPY = [
Expand Down Expand Up @@ -863,6 +949,14 @@ def __torch_function__(
cls.device = property(_device, _device_setter)
if not hasattr(cls, "batch_size") and "batch_size" not in expected_keys:
cls.batch_size = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "batch_dims") and "batch_dims" not in expected_keys:
cls.batch_dims = property(_batch_dims)
if not hasattr(cls, "requires_grad") and "requires_grad" not in expected_keys:
cls.requires_grad = property(_requires_grad)
if not hasattr(cls, "is_locked") and "is_locked" not in expected_keys:
cls.is_locked = property(_is_locked)
if not hasattr(cls, "ndim") and "ndim" not in expected_keys:
cls.ndim = property(_batch_dims)
if not hasattr(cls, "shape") and "shape" not in expected_keys:
cls.shape = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "names") and "names" not in expected_keys:
Expand Down Expand Up @@ -2158,6 +2252,18 @@ def _batch_size(self) -> torch.Size:
return self._tensordict.batch_size


def _batch_dims(self) -> torch.Size:
return self._tensordict.batch_dims


def _requires_grad(self) -> torch.Size:
return self._tensordict.requires_grad


def _is_locked(self) -> torch.Size:
return self._tensordict.is_locked


def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417
"""Set the value of batch_size.

Expand Down
40 changes: 39 additions & 1 deletion test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def _get_methods_from_class(cls):
methods = set()
for name in dir(cls):
attr = getattr(cls, name)
if inspect.isfunction(attr) or inspect.ismethod(attr):
if (
inspect.isfunction(attr)
or inspect.ismethod(attr)
or isinstance(attr, property)
):
methods.add(name)

return methods
Expand All @@ -122,6 +126,34 @@ def test_tensorclass_stub_methods():

if missing_methods:
raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}")
raise Exception(
f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}"
)


def test_tensorclass_instance_methods():
@tensorclass
class X:
x: torch.Tensor

tensorclass_pyi_path = (
pathlib.Path(__file__).parent.parent / "tensordict/tensorclass.pyi"
)
tensorclass_abstract_methods = _get_methods_from_pyi(str(tensorclass_pyi_path))

tensorclass_methods = _get_methods_from_class(X)

missing_methods = (
tensorclass_abstract_methods - tensorclass_methods - {"data", "grad"}
)
missing_methods = [
method for method in missing_methods if (not method.startswith("_"))
]

if missing_methods:
raise Exception(
f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}"
)


def _make_data(shape):
Expand Down Expand Up @@ -188,6 +220,12 @@ class MyClass1:
MyClass1(torch.zeros(3, 1), "z", batch_size=[3, 1]),
batch_size=[3, 1],
)
assert x.shape == x.batch_size
assert x.batch_size == (3, 1)
assert x.ndim == 2
assert x.batch_dims == 2
assert x.numel() == 3

assert not x.all()
assert not x.any()
assert isinstance(x.all(), bool)
Expand Down
Loading