Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use _fast_solves in HOGP.posterior #1682

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions botorch/models/higher_order_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
LinearOperator,
ZeroLinearOperator,
)
from linear_operator.settings import _fast_solves
from torch import Tensor
from torch.nn import ModuleList, Parameter, ParameterList

Expand Down Expand Up @@ -158,6 +159,19 @@ class HigherOrderGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin):
they would have a 6,000 x 6,000 covariance matrix, with 36 million entries.
The Kronecker structure allows representing this as a product of 10x10,
20x20, and 30x30 covariance matrices, with only 1,400 entries.

NOTE: This model requires the use of specialized Kronecker solves in
linear operator, which are disabled by default in BoTorch. These are enabled
by default in the `HigherOrderGP.posterior` call. However, they need to be
manually enabled by the user during model fitting.

Example:
>>> from linear_operator.settings import _fast_solves
>>> model = SingleTaskGP(train_X, train_Y)
>>> mll = ExactMarginalLogLikelihood(model.likelihood, model)
>>> with _fast_solves(True):
>>> fit_gpytorch_mll_torch(mll)
>>> samples = model.posterior(test_X).rsample()
"""

def __init__(
Expand Down Expand Up @@ -448,6 +462,7 @@ def posterior(
with ExitStack() as es:
es.enter_context(gpt_posterior_settings())
es.enter_context(fast_pred_var(True))
es.enter_context(_fast_solves(True))

# we need to skip posterior variances here
es.enter_context(skip_posterior_variances(True))
Expand Down