Skip to content

Commit

Permalink
Use _fast_solves in HOGP.posterior (#1682)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1682

This significantly reduces the memory usage of HOGP.

Reviewed By: Balandat

Differential Revision: D43337107

fbshipit-source-id: f7e173a534a23c2233910dc7c47ba3e132413aaa
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 17, 2023
1 parent 6940b53 commit 0f7d352
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions botorch/models/higher_order_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
LinearOperator,
ZeroLinearOperator,
)
from linear_operator.settings import _fast_solves
from torch import Tensor
from torch.nn import ModuleList, Parameter, ParameterList

Expand Down Expand Up @@ -158,6 +159,19 @@ class HigherOrderGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin):
they would have a 6,000 x 6,000 covariance matrix, with 36 million entries.
The Kronecker structure allows representing this as a product of 10x10,
20x20, and 30x30 covariance matrices, with only 1,400 entries.
NOTE: This model requires the use of specialized Kronecker solves in
linear operator, which are disabled by default in BoTorch. These are enabled
by default in the `HigherOrderGP.posterior` call. However, they need to be
manually enabled by the user during model fitting.
Example:
>>> from linear_operator.settings import _fast_solves
>>> model = SingleTaskGP(train_X, train_Y)
>>> mll = ExactMarginalLogLikelihood(model.likelihood, model)
>>> with _fast_solves(True):
>>> fit_gpytorch_mll_torch(mll)
>>> samples = model.posterior(test_X).rsample()
"""

def __init__(
Expand Down Expand Up @@ -448,6 +462,7 @@ def posterior(
with ExitStack() as es:
es.enter_context(gpt_posterior_settings())
es.enter_context(fast_pred_var(True))
es.enter_context(_fast_solves(True))

# we need to skip posterior variances here
es.enter_context(skip_posterior_variances(True))
Expand Down

0 comments on commit 0f7d352

Please sign in to comment.