Skip to content

Commit

Permalink
Fix batch_shape property of ModelListGPyTorchModel
Browse files Browse the repository at this point in the history
Summary: Previously, this was assuming the submodel batch shape is same as the input batch shape, which is not true for SAAS models. With this update, the batch shape correctly reflects the batch shape of the underlying models.

Differential Revision: D40165810

fbshipit-source-id: d69f961182421211a0fc88b7d537ccf8911250b7
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 7, 2022
1 parent ceb7000 commit 0440432
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def batch_shape(self) -> torch.Size:
to the `posterior` method returns a Posterior object over an output of
shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
"""
batch_shapes = {ti[0].shape[:-2] for ti in self.train_inputs}
batch_shapes = {m.batch_shape for m in self.models}
if len(batch_shapes) > 1:
msg = (
f"Component models of {self.__class__.__name__} have different "
Expand Down
2 changes: 2 additions & 0 deletions test/models/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ def test_fit_model(self):
mean, var = posterior.mean, posterior.variance
self.assertEqual(mean.shape, expected_shape)
self.assertEqual(var.shape, expected_shape)
# This check is only for ModelListGP.
self.assertEqual(model_list.batch_shape, model.batch_shape)

# Mixing fully Bayesian models with different batch shapes isn't supported
_, _, _, model2 = self._get_data_and_model(
Expand Down
16 changes: 16 additions & 0 deletions test/models/test_gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import itertools
import warnings
from typing import Optional

import torch
from botorch import settings
Expand Down Expand Up @@ -96,6 +97,8 @@ def forward(self, x):


class SimpleBatchedMultiOutputGPyTorchModel(BatchedMultiOutputGPyTorchModel, ExactGP):
_batch_shape: Optional[torch.Size] = None

def __init__(self, train_X, train_Y, outcome_transform=None, input_transform=None):
r"""
Args:
Expand Down Expand Up @@ -134,6 +137,12 @@ def forward(self, x):
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)

@property
def batch_shape(self) -> torch.Size:
if self._batch_shape is not None:
return self._batch_shape
return super().batch_shape


class SimpleModelListGPyTorchModel(IndependentModelList, ModelListGPyTorchModel):
def __init__(self, *gp_models: GPyTorchModel):
Expand Down Expand Up @@ -397,6 +406,13 @@ def test_model_list_gpytorch_model(self):
)
train_Y1 = torch.sin(train_X1)
train_Y2 = torch.cos(train_X2)
# test SAAS type batch shape
m1 = SimpleBatchedMultiOutputGPyTorchModel(train_X1, train_Y1)
m2 = SimpleBatchedMultiOutputGPyTorchModel(train_X2, train_Y2)
m1._batch_shape = torch.Size([2])
m2._batch_shape = torch.Size([2])
model = SimpleModelListGPyTorchModel(m1, m2)
self.assertEqual(model.batch_shape, torch.Size([2]))
# test different batch shapes (broadcastable)
m1 = SimpleGPyTorchModel(
train_X1.expand(2, *train_X1.shape), train_Y1.expand(2, *train_Y1.shape)
Expand Down

0 comments on commit 0440432

Please sign in to comment.