Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performant Scaling of BlockDiagLinearOperator by DiagLinearOperator #14

Merged

Conversation

SebastianAment
Copy link
Collaborator

The primary goal of this PR is to enable the efficient scaling of BlockDiagLinearOperators by DiagLinearOperators. This will allow us to remove a special case in BoTorch's outcome transform.

In order to achieve this, this PR modifies and adds special cases to DiagLinearOperator and BlockDiagLinearOperator's matmul.

I tested the notebook that exhibited the regression by patching in the function definitions here and ended up with a 16 second runtime - as opposed to 30+ minutes before.

@SebastianAment SebastianAment force-pushed the diagonal-performance-improvements branch from 37635f9 to 9e10bc8 Compare September 13, 2022 18:47
@@ -558,12 +558,7 @@ def _mul_matrix(self, other: Union[torch.Tensor, "LinearOperator"]) -> LinearOpe
if isinstance(self, DenseLinearOperator) or isinstance(other, DenseLinearOperator):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was the source of the performance regressions. The reasoning behind it appears to be because MulLinearOperator always computes a root decomposition, which is both inefficient and introduces dead code in its implementation (see below). I am sidestepping this by replacing the * with secondary @ operators in the new special cases of the DiagLinearOperator and BlockDiagLinearOperator matmul methods, leading to MatmulLinearOperators instead.

However, this does not get rid of the more general issue. To fix that, I propose two steps in a future PR:

  1. Introducing logic in the constructor of MulLinearOperator that decides whether or not to build a root decomposition.
  2. Even if a root decomposition seems beneficial, delaying its computation until the very last moment when it is needed in matmul and cache the result. This will give us ~0 overhead in the case where the linear operator represents a posterior covariance matrix that is constructed via a posterior call but only the posterior mean is needed, as was the case in the notebook that exhibited the regression.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change should be fine, since the MulLinearOpeator constructor performs root decompositions on left_linear_op and right_linear_op:

if not isinstance(left_linear_op, RootLinearOperator):

Even if a root decomposition seems beneficial, delaying its computation until the very last moment when it is needed in matmul and cache the result.

Agreed.

@SebastianAment SebastianAment force-pushed the diagonal-performance-improvements branch from 9e10bc8 to a93fe94 Compare September 13, 2022 18:51
Copy link
Collaborator

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this lgtm but I'll let @gpleiss double check the broader implications of this change.

@@ -558,12 +558,7 @@ def _mul_matrix(self, other: Union[torch.Tensor, "LinearOperator"]) -> LinearOpe
if isinstance(self, DenseLinearOperator) or isinstance(other, DenseLinearOperator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change should be fine, since the MulLinearOpeator constructor performs root decompositions on left_linear_op and right_linear_op:

if not isinstance(left_linear_op, RootLinearOperator):

Even if a root decomposition seems beneficial, delaying its computation until the very last moment when it is needed in matmul and cache the result.

Agreed.

@gpleiss gpleiss enabled auto-merge (squash) September 22, 2022 13:05
@gpleiss gpleiss merged commit 3a37f0c into cornellius-gp:main Sep 22, 2022
SebastianAment added a commit to SebastianAment/botorch that referenced this pull request Sep 22, 2022
…(Take 2)

Summary: Due to [this linear operator PR](cornellius-gp/linear_operator#14), we should now be able to remove the custom logic in `Standardize` without performance impact.

Differential Revision: D39746709

fbshipit-source-id: d5bad3c56254a9820dcfe60aa3a9c8dd1f5edb59
SebastianAment added a commit to SebastianAment/botorch that referenced this pull request Nov 7, 2022
…(Take 2) (pytorch#1414)

Summary:
Pull Request resolved: pytorch#1414

Due to [this linear operator PR](cornellius-gp/linear_operator#14), we should now be able to remove the custom logic in `Standardize` without performance impact.

Differential Revision: D39746709

fbshipit-source-id: 506d1fc3a34778fa5bb0d91779ca5f73b24f4146
SebastianAment added a commit to SebastianAment/botorch that referenced this pull request Nov 8, 2022
…(Take 2) (pytorch#1414)

Summary:
Pull Request resolved: pytorch#1414

Due to [this linear operator PR](cornellius-gp/linear_operator#14), we should now be able to remove the custom logic in `Standardize` without performance impact.

Differential Revision: D39746709

fbshipit-source-id: 1c68d8033d18c4a489600dd743ad5ab24efc5fb0
SebastianAment added a commit to SebastianAment/botorch that referenced this pull request Nov 8, 2022
…(Take 2) (pytorch#1414)

Summary:
Pull Request resolved: pytorch#1414

Due to [this linear operator PR](cornellius-gp/linear_operator#14), we should now be able to remove the custom logic in `Standardize` without performance impact.

Reviewed By: saitcakmak

Differential Revision: D39746709

fbshipit-source-id: c1477bcc14ec145583a5d0501fbe1cdac5bfe9bd
facebook-github-bot pushed a commit to pytorch/botorch that referenced this pull request Nov 8, 2022
…(Take 2) (#1414)

Summary:
Pull Request resolved: #1414

Due to [this linear operator PR](cornellius-gp/linear_operator#14), we should now be able to remove the custom logic in `Standardize` without performance impact.

Reviewed By: saitcakmak

Differential Revision: D39746709

fbshipit-source-id: 286b092f073861cb52d409ef85ff3dc9047bae4a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants