Skip to content

Commit

Permalink
Fix TransformedPosterior missing batch shape error in _update_base_sa…
Browse files Browse the repository at this point in the history
…mples

Summary: See pytorch#1623. This is the only use case for `posterior.batch_shape`, so fixing it locally makes sense to me.

Differential Revision: D42421494

fbshipit-source-id: 613ab2aa822d53fda615b53979fba857c9b9c931
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jan 9, 2023
1 parent c2e502c commit 0d07649
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
8 changes: 7 additions & 1 deletion botorch/sampling/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from botorch.posteriors import Posterior
from botorch.posteriors.higher_order import HigherOrderGPPosterior
from botorch.posteriors.multitask import MultitaskGPPosterior
from botorch.posteriors.transformed import TransformedPosterior
from botorch.sampling.base import MCSampler
from botorch.utils.sampling import draw_sobol_normal_samples, manual_seed
from torch import Tensor
Expand Down Expand Up @@ -112,8 +113,13 @@ def _update_base_samples(
..., -n_train_samples:
]
else:
batch_shape = (
posterior._posterior.batch_shape
if isinstance(posterior, TransformedPosterior)
else posterior.batch_shape
)
single_output = (
len(posterior.base_sample_shape) - len(posterior.batch_shape)
len(posterior.base_sample_shape) - len(batch_shape)
) == 1
if single_output:
self.base_samples[
Expand Down
28 changes: 28 additions & 0 deletions test/acquisition/multi_objective/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.transforms.input import InputPerturbation
from botorch.models.transforms.outcome import Standardize
from botorch.posteriors.posterior_list import PosteriorList
from botorch.posteriors.transformed import TransformedPosterior
from botorch.sampling.list_sampler import ListSampler
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.low_rank import sample_cached_cholesky
from botorch.utils.multi_objective.box_decompositions.dominated import (
Expand Down Expand Up @@ -1624,3 +1627,28 @@ def get_acqf(model, matheron):
rtol=1e-4,
)
)

def test_with_transformed(self):
# Verify that _set_sampler works with transformed posteriors.
mm = MockModel(
posterior=PosteriorList(
TransformedPosterior(
MockPosterior(samples=torch.rand(2, 3, 1, 1)), lambda X: X
),
TransformedPosterior(
MockPosterior(samples=torch.rand(2, 3, 1, 1)), lambda X: X
),
)
)
sampler = ListSampler(
IIDNormalSampler(sample_shape=torch.Size([2])),
IIDNormalSampler(sample_shape=torch.Size([2])),
)
# This calls _set_sampler which would previously fail.
qNoisyExpectedHypervolumeImprovement(
model=mm,
ref_point=torch.tensor([0.0, 0.0]),
X_baseline=torch.rand(3, 2),
sampler=sampler,
cache_root=False,
)

0 comments on commit 0d07649

Please sign in to comment.