Skip to content

Commit

Permalink
Removing custom BlockDiagLazyTensor logic when using Standardize
Browse files Browse the repository at this point in the history
…(Take 2)

Summary: Due to [this linear operator PR](cornellius-gp/linear_operator#14), we should now be able to remove the custom logic in `Standardize` without performance impact.

Differential Revision: D39746709

fbshipit-source-id: d5bad3c56254a9820dcfe60aa3a9c8dd1f5edb59
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Sep 22, 2022
1 parent 28f079b commit b1460d7
Showing 1 changed file with 3 additions and 18 deletions.
21 changes: 3 additions & 18 deletions botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@
)
from botorch.posteriors import GPyTorchPosterior, Posterior, TransformedPosterior
from botorch.utils.transforms import normalize_indices
from linear_operator.operators import (
BlockDiagLinearOperator,
CholLinearOperator,
DiagLinearOperator,
)
from linear_operator.operators import CholLinearOperator, DiagLinearOperator
from torch import Tensor
from torch.nn import Module, ModuleDict

Expand Down Expand Up @@ -386,19 +382,8 @@ def untransform_posterior(self, posterior: Posterior) -> Posterior:
else:
lcv = mvn.lazy_covariance_matrix
scale_fac = scale_fac.expand(lcv.shape[:-1])
# TODO: Remove the custom logic with next GPyTorch release (T126095032).
if isinstance(lcv, BlockDiagLinearOperator):
# Keep the block diag structure of lcv.
base_lcv = lcv.base_linear_op
scale_mat = DiagLinearOperator(
scale_fac.view(*scale_fac.shape[:-1], lcv.num_blocks, -1)
)
base_lcv_tf = scale_mat @ base_lcv @ scale_mat
covar_tf = BlockDiagLinearOperator(base_linear_op=base_lcv_tf)
else:
# allow batch-evaluation of the model
scale_mat = DiagLinearOperator(scale_fac)
covar_tf = scale_mat @ lcv @ scale_mat
scale_mat = DiagLinearOperator(scale_fac)
covar_tf = scale_mat @ lcv @ scale_mat

kwargs = {"interleaved": mvn._interleaved} if posterior._is_mt else {}
mvn_tf = mvn.__class__(mean=mean_tf, covariance_matrix=covar_tf, **kwargs)
Expand Down

0 comments on commit b1460d7

Please sign in to comment.