Skip to content

Commit

Permalink
Fix PairwiseGP comparisons might be implicitly modified (#1811)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1811

Fix the bug where the unconsolidated comparisons might be implicitly modified modified when calling forward.
This is because `self.comparisons` is passed into `set_train_data` in `forward`, which itself might already been consolidated/deduped.

Reviewed By: qingfeng10

Differential Revision: D45472153

fbshipit-source-id: e7dda1ecd5245f753a6dd67c7f13a31b375b5e49
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed May 3, 2023
1 parent 1264316 commit 47eac7a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
14 changes: 11 additions & 3 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,11 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
# We pass in the untransformed datapoints into set_train_data
# as we will be setting self.datapoints as the untransformed datapoints
# self.transform_inputs will be called inside before calling _update()
self.set_train_data(datapoints, self.comparisons, update_model=True)
self.set_train_data(
datapoints=datapoints,
comparisons=self.unconsolidated_comparisons,
update_model=True,
)

transformed_dp = self.transform_inputs(self.datapoints)

Expand Down Expand Up @@ -1099,9 +1103,13 @@ def __init__(self, likelihood, model: GP):
def forward(self, post: Posterior, comp: Tensor) -> Tensor:
r"""Calculate approximated log evidence, i.e., log(P(D|theta))
Note that post will be based on the consolidated/deduped datapoints for
numerical stability, but comp will still be the unconsolidated comparisons
so that it's still compatible with fit_gpytorch_*.
Args:
post: training posterior distribution from self.model
comp: Comparisons pairs, see PairwiseGP.__init__ for more details
post: training posterior distribution from self.model (after consolidation)
comp: Comparisons pairs (before consolidation)
Returns:
The approximated evidence, i.e., the marginal log likelihood
Expand Down
34 changes: 34 additions & 0 deletions test/models/test_pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,24 @@ def test_pairwise_gp(self):
with self.assertRaises(RuntimeError):
model.set_train_data(train_X, changed_train_comp, strict=True)

def test_consolidation(self):
for batch_shape, dtype, likelihood_cls in itertools.product(
(torch.Size(), torch.Size([2])),
(torch.float, torch.double),
(PairwiseLogitLikelihood, PairwiseProbitLikelihood),
):
tkwargs = {"device": self.device, "dtype": dtype}
X_dim = 2

_, model_kwargs = self._get_model_and_data(
batch_shape=batch_shape,
X_dim=X_dim,
likelihood_cls=likelihood_cls,
**tkwargs,
)
train_X = model_kwargs["datapoints"]
train_comp = model_kwargs["comparisons"]

# Test consolidation
i1, i2 = train_X.shape[-2], train_X.shape[-2] + 1
dup_comp = torch.cat(
Expand All @@ -206,6 +224,22 @@ def test_pairwise_gp(self):
torch.equal(model.utility, model.unconsolidated_utility)
)

# calling forward with duplicated datapoints should work after consolidation
mll = PairwiseLaplaceMarginalLogLikelihood(model.likelihood, model)
# make sure model is in training mode
self.assertTrue(model.training)
pred = model(dup_X)
# posterior shape in training should match the consolidated utility
self.assertEqual(pred.shape(), model.utility.shape)
if batch_shape:
# do not perform consolidation in batch mode
# because the block structure cannot be guaranteed
self.assertEqual(pred.shape(), dup_X.shape[:-1])
else:
self.assertEqual(pred.shape(), train_X.shape[:-1])
# Pass the original comparisons through mll should work
mll(pred, dup_comp)

def test_condition_on_observations(self):
for batch_shape, dtype, likelihood_cls in itertools.product(
(torch.Size(), torch.Size([2])),
Expand Down

0 comments on commit 47eac7a

Please sign in to comment.