From 8ea4e871036d656e9e2a0ecf9fd2e81aa3c1706d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 20 Feb 2024 20:46:44 +0000 Subject: [PATCH] [BugFix] Fix torch_function for uninit param (#683) --- tensordict/_torch_func.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 60eb9ec6a..424de42bb 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -520,6 +520,10 @@ def where(condition, input, other, *, out=None): return input.where(condition, other, out=out) +# monkey patch +__prev_torch_function__ = UninitializedTensorMixin.__torch_function__ + + def __torch_function__( cls, func: Callable, @@ -532,8 +536,20 @@ def __torch_function__( fnc_uninit = UNINIT_TENSOR_FUNCTIONS.get(func, None) if fnc_uninit is not None: return fnc_uninit(*args, **kwargs) - with _C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) + if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper": + if kwargs is None: + kwargs = {} + with _C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + # Ideally we'd like to use this from the original __torch_function__ + # return super().__torch_function__(func, types, args, kwargs) + raise ValueError( + f"Attempted to use an uninitialized parameter in {func}. " + "This error happens when you are using a `LazyModule` or " + f"explicitly manipulating `torch.nn.parameter.{cls.__name__}` " + "objects. When using LazyModules Call `forward` with a dummy batch " + "to initialize the parameters before calling torch functions" + ) UninitializedTensorMixin.__torch_function__ = classmethod(__torch_function__)