-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performant Scaling of BlockDiagLinearOperator
by DiagLinearOperator
#14
Performant Scaling of BlockDiagLinearOperator
by DiagLinearOperator
#14
Conversation
37635f9
to
9e10bc8
Compare
@@ -558,12 +558,7 @@ def _mul_matrix(self, other: Union[torch.Tensor, "LinearOperator"]) -> LinearOpe | |||
if isinstance(self, DenseLinearOperator) or isinstance(other, DenseLinearOperator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line was the source of the performance regressions. The reasoning behind it appears to be because MulLinearOperator
always computes a root decomposition, which is both inefficient and introduces dead code in its implementation (see below). I am sidestepping this by replacing the *
with secondary @
operators in the new special cases of the DiagLinearOperator
and BlockDiagLinearOperator
matmul
methods, leading to MatmulLinearOperators
instead.
However, this does not get rid of the more general issue. To fix that, I propose two steps in a future PR:
- Introducing logic in the constructor of
MulLinearOperator
that decides whether or not to build a root decomposition. - Even if a root decomposition seems beneficial, delaying its computation until the very last moment when it is needed in
matmul
and cache the result. This will give us ~0 overhead in the case where the linear operator represents a posterior covariance matrix that is constructed via aposterior
call but only the posterior mean is needed, as was the case in the notebook that exhibited the regression.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this change should be fine, since the MulLinearOpeator constructor performs root decompositions on left_linear_op and right_linear_op:
if not isinstance(left_linear_op, RootLinearOperator): |
Even if a root decomposition seems beneficial, delaying its computation until the very last moment when it is needed in matmul and cache the result.
Agreed.
9e10bc8
to
a93fe94
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this lgtm but I'll let @gpleiss double check the broader implications of this change.
@@ -558,12 +558,7 @@ def _mul_matrix(self, other: Union[torch.Tensor, "LinearOperator"]) -> LinearOpe | |||
if isinstance(self, DenseLinearOperator) or isinstance(other, DenseLinearOperator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this change should be fine, since the MulLinearOpeator constructor performs root decompositions on left_linear_op and right_linear_op:
if not isinstance(left_linear_op, RootLinearOperator): |
Even if a root decomposition seems beneficial, delaying its computation until the very last moment when it is needed in matmul and cache the result.
Agreed.
…(Take 2) Summary: Due to [this linear operator PR](cornellius-gp/linear_operator#14), we should now be able to remove the custom logic in `Standardize` without performance impact. Differential Revision: D39746709 fbshipit-source-id: d5bad3c56254a9820dcfe60aa3a9c8dd1f5edb59
…(Take 2) (pytorch#1414) Summary: Pull Request resolved: pytorch#1414 Due to [this linear operator PR](cornellius-gp/linear_operator#14), we should now be able to remove the custom logic in `Standardize` without performance impact. Differential Revision: D39746709 fbshipit-source-id: 506d1fc3a34778fa5bb0d91779ca5f73b24f4146
…(Take 2) (pytorch#1414) Summary: Pull Request resolved: pytorch#1414 Due to [this linear operator PR](cornellius-gp/linear_operator#14), we should now be able to remove the custom logic in `Standardize` without performance impact. Differential Revision: D39746709 fbshipit-source-id: 1c68d8033d18c4a489600dd743ad5ab24efc5fb0
…(Take 2) (pytorch#1414) Summary: Pull Request resolved: pytorch#1414 Due to [this linear operator PR](cornellius-gp/linear_operator#14), we should now be able to remove the custom logic in `Standardize` without performance impact. Reviewed By: saitcakmak Differential Revision: D39746709 fbshipit-source-id: c1477bcc14ec145583a5d0501fbe1cdac5bfe9bd
…(Take 2) (#1414) Summary: Pull Request resolved: #1414 Due to [this linear operator PR](cornellius-gp/linear_operator#14), we should now be able to remove the custom logic in `Standardize` without performance impact. Reviewed By: saitcakmak Differential Revision: D39746709 fbshipit-source-id: 286b092f073861cb52d409ef85ff3dc9047bae4a
The primary goal of this PR is to enable the efficient scaling of
BlockDiagLinearOperators
byDiagLinearOperators
. This will allow us to remove a special case in BoTorch's outcome transform.In order to achieve this, this PR modifies and adds special cases to
DiagLinearOperator
andBlockDiagLinearOperator
'smatmul
.I tested the notebook that exhibited the regression by patching in the function definitions here and ended up with a 16 second runtime - as opposed to 30+ minutes before.