Skip to content

Commit

Permalink
[BugFix] Fix torch_function for uninit param (#683)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 20, 2024
1 parent 3594dd0 commit 8ea4e87
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,10 @@ def where(condition, input, other, *, out=None):
return input.where(condition, other, out=out)


# monkey patch
__prev_torch_function__ = UninitializedTensorMixin.__torch_function__


def __torch_function__(
cls,
func: Callable,
Expand All @@ -532,8 +536,20 @@ def __torch_function__(
fnc_uninit = UNINIT_TENSOR_FUNCTIONS.get(func, None)
if fnc_uninit is not None:
return fnc_uninit(*args, **kwargs)
with _C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper":
if kwargs is None:
kwargs = {}
with _C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
# Ideally we'd like to use this from the original __torch_function__
# return super().__torch_function__(func, types, args, kwargs)
raise ValueError(
f"Attempted to use an uninitialized parameter in {func}. "
"This error happens when you are using a `LazyModule` or "
f"explicitly manipulating `torch.nn.parameter.{cls.__name__}` "
"objects. When using LazyModules Call `forward` with a dummy batch "
"to initialize the parameters before calling torch functions"
)


UninitializedTensorMixin.__torch_function__ = classmethod(__torch_function__)
Expand Down

1 comment on commit 8ea4e87

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 8ea4e87 Previous: 3594dd0 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 79257.99842557898 iter/sec (stddev: 6.669060966151845e-7) 190106.3053325936 iter/sec (stddev: 4.1012761732294033e-7) 2.40
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 79288.46945395312 iter/sec (stddev: 6.531248856849023e-7) 189005.92308016526 iter/sec (stddev: 4.093706001749644e-7) 2.38

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.