From d9fabefc33d3c9a91cb2d47287603639a99d919b Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Tue, 8 Nov 2022 09:10:59 -0800 Subject: [PATCH] Removing custom `BlockDiagLazyTensor` logic when using `Standardize` (Take 2) (#1414) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1414 Due to [this linear operator PR](https://github.com/cornellius-gp/linear_operator/pull/14), we should now be able to remove the custom logic in `Standardize` without performance impact. Reviewed By: saitcakmak Differential Revision: D39746709 fbshipit-source-id: c1477bcc14ec145583a5d0501fbe1cdac5bfe9bd --- botorch/models/transforms/outcome.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 8bbe82b0be..217cd6c551 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -33,11 +33,7 @@ ) from botorch.posteriors import GPyTorchPosterior, Posterior, TransformedPosterior from botorch.utils.transforms import normalize_indices -from linear_operator.operators import ( - BlockDiagLinearOperator, - CholLinearOperator, - DiagLinearOperator, -) +from linear_operator.operators import CholLinearOperator, DiagLinearOperator from torch import Tensor from torch.nn import Module, ModuleDict @@ -386,19 +382,8 @@ def untransform_posterior(self, posterior: Posterior) -> Posterior: else: lcv = mvn.lazy_covariance_matrix scale_fac = scale_fac.expand(lcv.shape[:-1]) - # TODO: Remove the custom logic with next GPyTorch release (T126095032). - if isinstance(lcv, BlockDiagLinearOperator): - # Keep the block diag structure of lcv. - base_lcv = lcv.base_linear_op - scale_mat = DiagLinearOperator( - scale_fac.view(*scale_fac.shape[:-1], lcv.num_blocks, -1) - ) - base_lcv_tf = scale_mat @ base_lcv @ scale_mat - covar_tf = BlockDiagLinearOperator(base_linear_op=base_lcv_tf) - else: - # allow batch-evaluation of the model - scale_mat = DiagLinearOperator(scale_fac) - covar_tf = scale_mat @ lcv @ scale_mat + scale_mat = DiagLinearOperator(scale_fac) + covar_tf = scale_mat @ lcv @ scale_mat kwargs = {"interleaved": mvn._interleaved} if posterior._is_mt else {} mvn_tf = mvn.__class__(mean=mean_tf, covariance_matrix=covar_tf, **kwargs)