Skip to content

Commit

Permalink
[BugFix] Fix torch_function for uninit param (#683)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 26, 2024
1 parent bfbe24a commit 3cee52b
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,10 @@ def where(condition, input, other, *, out=None):
return input.where(condition, other, out=out)


# monkey patch
__prev_torch_function__ = UninitializedTensorMixin.__torch_function__


def __torch_function__(
cls,
func: Callable,
Expand All @@ -532,8 +536,20 @@ def __torch_function__(
fnc_uninit = UNINIT_TENSOR_FUNCTIONS.get(func, None)
if fnc_uninit is not None:
return fnc_uninit(*args, **kwargs)
with _C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper":
if kwargs is None:
kwargs = {}
with _C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
# Ideally we'd like to use this from the original __torch_function__
# return super().__torch_function__(func, types, args, kwargs)
raise ValueError(
f"Attempted to use an uninitialized parameter in {func}. "
"This error happens when you are using a `LazyModule` or "
f"explicitly manipulating `torch.nn.parameter.{cls.__name__}` "
"objects. When using LazyModules Call `forward` with a dummy batch "
"to initialize the parameters before calling torch functions"
)


UninitializedTensorMixin.__torch_function__ = classmethod(__torch_function__)
Expand Down

0 comments on commit 3cee52b

Please sign in to comment.