Skip to content

Commit

Permalink
add constructor inputs for contextual GP (#2057)
Browse files Browse the repository at this point in the history
Summary:

In order to enable SACGP/LCEMGP in MBM, we need to add input constructors for the two models.

Reviewed By: saitcakmak

Differential Revision: D50417893
  • Loading branch information
qingfeng10 authored and facebook-github-bot committed Oct 18, 2023
1 parent 41e1c37 commit af665af
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 19 deletions.
66 changes: 65 additions & 1 deletion botorch/models/contextual.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from botorch.models.gp_regression import FixedNoiseGP
from botorch.models.kernels.contextual_lcea import LCEAKernel
from botorch.models.kernels.contextual_sac import SACKernel
from botorch.utils.datasets import SupervisedDataset
from torch import Tensor


Expand Down Expand Up @@ -40,6 +41,26 @@ def __init__(
self.decomposition = decomposition
self.to(train_X)

@classmethod
def construct_inputs(
cls,
training_data: SupervisedDataset,
decomposition: Dict[str, List[int]],
**kwargs: Any,
) -> Dict[str, Any]:
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
Args:
training_data: A `SupervisedDataset` containing the training data.
decomposition: Dictionary of context names and their indexes of the
corresponding active context parameters.
"""
base_inputs = super().construct_inputs(training_data=training_data, **kwargs)
return {
**base_inputs,
"decomposition": decomposition,
}


class LCEAGP(FixedNoiseGP):
r"""A GP using a Latent Context Embedding Additive (LCE-A) Kernel.
Expand Down Expand Up @@ -67,6 +88,8 @@ def __init__(
train_Yvar: (n x 1) Noise variance of Y.
decomposition: Keys are context names. Values are the indexes of
parameters belong to the context.
train_embedding: Whether to train the embedding layer or not. If False, the model
will use pre-trained embeddings in embs_feature_dict.
cat_feature_dict: Keys are context names and values are list of categorical
features i.e. {"context_name" : [cat_0, ..., cat_k]}, where k is the
number of categorical variables. If None, we use context names in the
Expand All @@ -91,3 +114,44 @@ def __init__(
)
self.decomposition = decomposition
self.to(train_X)

@classmethod
def construct_inputs(
cls,
training_data: SupervisedDataset,
decomposition: Dict[str, List[int]],
train_embedding: bool = True,
cat_feature_dict: Optional[Dict] = None,
embs_feature_dict: Optional[Dict] = None,
embs_dim_list: Optional[List[int]] = None,
context_weight_dict: Optional[Dict] = None,
**kwargs: Any,
) -> Dict[str, Any]:
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
Args:
training_data: A `SupervisedDataset` containing the training data.
decomposition: Dictionary of context names and their indexes of the
corresponding active context parameters.
train_embedding: Whether to train the embedding layer or not.
cat_feature_dict: Keys are context names and values are list of categorical
features i.e. {"context_name" : [cat_0, ..., cat_k]}, where k is the
number of categorical variables. If None, we use context names in the
decomposition as the only categorical feature, i.e., k = 1.
embs_feature_dict: Pre-trained continuous embedding features of each
context.
embs_dim_list: Embedding dimension for each categorical variable. The length
equals the number of categorical features k. If None, the embedding
dimension is set to 1 for each categorical variable.
context_weight_dict: Known population weights of each context.
"""
base_inputs = super().construct_inputs(training_data=training_data, **kwargs)
return {
**base_inputs,
"decomposition": decomposition,
"train_embedding": train_embedding,
"cat_feature_dict": cat_feature_dict,
"embs_feature_dict": embs_feature_dict,
"embs_dim_list": embs_dim_list,
"context_weight_dict": context_weight_dict,
}
79 changes: 61 additions & 18 deletions test/models/test_contextual.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,46 @@
# LICENSE file in the root directory of this source tree.


from typing import Dict, Tuple

import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models.contextual import LCEAGP, SACGP
from botorch.models.gp_regression import FixedNoiseGP
from botorch.models.kernels.contextual_lcea import LCEAKernel
from botorch.models.kernels.contextual_sac import SACKernel
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.testing import BotorchTestCase
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.means import ConstantMean
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from torch import Tensor


def _gen_datasets(
**tkwargs,
) -> Tuple[Dict[int, SupervisedDataset], Tuple[Tensor, Tensor, Tensor]]:
train_X = torch.tensor(
[[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]], **tkwargs
)
train_Y = torch.tensor([[1.0], [2.0], [3.0]], **tkwargs)
train_Yvar = 0.01 * torch.ones(3, 1, **tkwargs)

datasets = SupervisedDataset(
X=train_X,
Y=train_Y,
feature_names=[f"x{i}" for i in range(train_X.shape[-1])],
outcome_names=["y"],
Yvar=train_Yvar,
)
return datasets, (train_X, train_Y, train_Yvar)


class TestContextualGP(BotorchTestCase):
def test_SACGP(self):
for dtype in (torch.float, torch.double):
train_X = torch.tensor(
[[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]],
device=self.device,
dtype=dtype,
)
train_Y = torch.tensor(
[[1.0], [2.0], [3.0]], device=self.device, dtype=dtype
)
train_Yvar = 0.01 * torch.ones(3, 1, device=self.device, dtype=dtype)
tkwargs = {"device": self.device, "dtype": dtype}
datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(**tkwargs)
self.decomposition = {"1": [0, 3], "2": [1, 2]}

model = SACGP(train_X, train_Y, train_Yvar, self.decomposition)
Expand Down Expand Up @@ -59,17 +75,25 @@ def test_SACGP(self):
posterior = model(test_x)
self.assertIsInstance(posterior, MultivariateNormal)

def testLCEAGP(self):
def test_SACGP_construct_inputs(self):
for dtype in (torch.float, torch.double):
train_X = torch.tensor(
[[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]],
device=self.device,
dtype=dtype,
)
train_Y = torch.tensor(
[[1.0], [2.0], [3.0]], device=self.device, dtype=dtype
tkwargs = {"device": self.device, "dtype": dtype}
datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(**tkwargs)
self.decomposition = {"1": [0, 3], "2": [1, 2]}
model = SACGP(train_X, train_Y, train_Yvar, self.decomposition)
data_dict = model.construct_inputs(
training_data=datasets, decomposition=self.decomposition
)
train_Yvar = 0.01 * torch.ones(3, 1, device=self.device, dtype=dtype)

self.assertTrue(train_X.equal(data_dict["train_X"]))
self.assertTrue(train_Y.equal(data_dict["train_Y"]))
self.assertTrue(train_Yvar.equal(data_dict["train_Yvar"]))
self.assertDictEqual(data_dict["decomposition"], self.decomposition)

def testLCEAGP(self):
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}
datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(**tkwargs)
# Test setting attributes
decomposition = {"1": [0, 1], "2": [2, 3]}

Expand All @@ -90,3 +114,22 @@ def testLCEAGP(self):
test_x = torch.rand(5, 4, device=self.device, dtype=dtype)
posterior = model(test_x)
self.assertIsInstance(posterior, MultivariateNormal)

def test_LCEAGP_construct_inputs(self):
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}
datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(**tkwargs)
decomposition = {"1": [0, 1], "2": [2, 3]}

model = LCEAGP(train_X, train_Y, train_Yvar, decomposition)
data_dict = model.construct_inputs(
training_data=datasets,
decomposition=decomposition,
train_embedding=False,
)

self.assertTrue(train_X.equal(data_dict["train_X"]))
self.assertTrue(train_Y.equal(data_dict["train_Y"]))
self.assertTrue(train_Yvar.equal(data_dict["train_Yvar"]))
self.assertDictEqual(data_dict["decomposition"], decomposition)
self.assertFalse(data_dict["train_embedding"])

0 comments on commit af665af

Please sign in to comment.