Skip to content

Commit

Permalink
Fix PairwiseGP comparisons might be implicitly modified (#1811)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1811

Fix the bug where the unconsolidated comparisons might be implicitly modified modified when calling forward.
This is because `self.comparisons` is passed into `set_train_data` in `forward`, which itself might already been consolidated/deduped.

Reviewed By: qingfeng10

Differential Revision: D45472153

fbshipit-source-id: 7f7659e277feed05fa405cd202aea5173351c617
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed May 3, 2023
1 parent 1264316 commit f36eee0
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
14 changes: 11 additions & 3 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,11 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
# We pass in the untransformed datapoints into set_train_data
# as we will be setting self.datapoints as the untransformed datapoints
# self.transform_inputs will be called inside before calling _update()
self.set_train_data(datapoints, self.comparisons, update_model=True)
self.set_train_data(
datapoints=datapoints,
comparisons=self.unconsolidated_comparisons,
update_model=True,
)

transformed_dp = self.transform_inputs(self.datapoints)

Expand Down Expand Up @@ -1099,9 +1103,13 @@ def __init__(self, likelihood, model: GP):
def forward(self, post: Posterior, comp: Tensor) -> Tensor:
r"""Calculate approximated log evidence, i.e., log(P(D|theta))
Note that post will be based on the consolidated/deduped datapoints for
numerical stability, but comp will still be the unconsolidated comparisons
so that it's still compatible with fit_gpytorch_*.
Args:
post: training posterior distribution from self.model
comp: Comparisons pairs, see PairwiseGP.__init__ for more details
post: training posterior distribution from self.model (after consolidation)
comp: Comparisons pairs (before consolidation)
Returns:
The approximated evidence, i.e., the marginal log likelihood
Expand Down
32 changes: 31 additions & 1 deletion test/models/test_pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,29 @@ def test_pairwise_gp(self):
with self.assertRaises(RuntimeError):
model.set_train_data(train_X, changed_train_comp, strict=True)

def test_consolidation(self):
for batch_shape, dtype, likelihood_cls in itertools.product(
(torch.Size(), torch.Size([2])),
(torch.float, torch.double),
(PairwiseLogitLikelihood, PairwiseProbitLikelihood),
):
tkwargs = {"device": self.device, "dtype": dtype}
X_dim = 2

_, model_kwargs = self._get_model_and_data(
batch_shape=batch_shape,
X_dim=X_dim,
likelihood_cls=likelihood_cls,
**tkwargs,
)
train_X = model_kwargs["datapoints"]
train_comp = model_kwargs["comparisons"]

# Test consolidation
i1, i2 = train_X.shape[-2], train_X.shape[-2] + 1
dup_comp = torch.cat(
[
train_comp,
model_kwargs["comparisons"],
torch.tensor(
[[i1, i2]], dtype=train_comp.dtype, device=train_comp.device
).expand(*batch_shape, 1, 2),
Expand All @@ -206,6 +224,18 @@ def test_pairwise_gp(self):
torch.equal(model.utility, model.unconsolidated_utility)
)

# Calling forward with duplicated datapoints should work after consolidation
mll = PairwiseLaplaceMarginalLogLikelihood(model.likelihood, model)
# make sure model is in training mode
self.assertTrue(model.training)
pred = model(dup_X)
# posterior shape in training should match the consolidated utility
self.assertEqual(pred.shape(), model.utility.shape)
# which is equal to the non-duplicated train_X
self.assertEqual(pred.shape(), train_X.shape[:-1])
# Pass the original comparisons through mll should work
mll(pred, dup_comp)

def test_condition_on_observations(self):
for batch_shape, dtype, likelihood_cls in itertools.product(
(torch.Size(), torch.Size([2])),
Expand Down

0 comments on commit f36eee0

Please sign in to comment.