diff --git a/botorch/models/contextual.py b/botorch/models/contextual.py index 2ef486d50f..0e99c4336c 100644 --- a/botorch/models/contextual.py +++ b/botorch/models/contextual.py @@ -4,11 +4,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from botorch.models.gp_regression import FixedNoiseGP from botorch.models.kernels.contextual_lcea import LCEAKernel from botorch.models.kernels.contextual_sac import SACKernel +from botorch.utils.datasets import SupervisedDataset from torch import Tensor @@ -40,6 +41,26 @@ def __init__( self.decomposition = decomposition self.to(train_X) + @classmethod + def construct_inputs( + cls, + training_data: SupervisedDataset, + decomposition: Dict[str, List[int]], + **kwargs: Any, + ) -> Dict[str, Any]: + r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`. + + Args: + training_data: A `SupervisedDataset` containing the training data. + decomposition: Dictionary of context names and their indexes of the + corresponding active context parameters. + """ + base_inputs = super().construct_inputs(training_data=training_data, **kwargs) + return { + **base_inputs, + "decomposition": decomposition, + } + class LCEAGP(FixedNoiseGP): r"""A GP using a Latent Context Embedding Additive (LCE-A) Kernel. @@ -67,6 +88,8 @@ def __init__( train_Yvar: (n x 1) Noise variance of Y. decomposition: Keys are context names. Values are the indexes of parameters belong to the context. + train_embedding: Whether to train the embedding layer or not. If False, + the model will use pre-trained embeddings in embs_feature_dict. cat_feature_dict: Keys are context names and values are list of categorical features i.e. {"context_name" : [cat_0, ..., cat_k]}, where k is the number of categorical variables. If None, we use context names in the @@ -91,3 +114,44 @@ def __init__( ) self.decomposition = decomposition self.to(train_X) + + @classmethod + def construct_inputs( + cls, + training_data: SupervisedDataset, + decomposition: Dict[str, List[int]], + train_embedding: bool = True, + cat_feature_dict: Optional[Dict] = None, + embs_feature_dict: Optional[Dict] = None, + embs_dim_list: Optional[List[int]] = None, + context_weight_dict: Optional[Dict] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`. + + Args: + training_data: A `SupervisedDataset` containing the training data. + decomposition: Dictionary of context names and their indexes of the + corresponding active context parameters. + train_embedding: Whether to train the embedding layer or not. + cat_feature_dict: Keys are context names and values are list of categorical + features i.e. {"context_name" : [cat_0, ..., cat_k]}, where k is the + number of categorical variables. If None, we use context names in the + decomposition as the only categorical feature, i.e., k = 1. + embs_feature_dict: Pre-trained continuous embedding features of each + context. + embs_dim_list: Embedding dimension for each categorical variable. The length + equals the number of categorical features k. If None, the embedding + dimension is set to 1 for each categorical variable. + context_weight_dict: Known population weights of each context. + """ + base_inputs = super().construct_inputs(training_data=training_data, **kwargs) + return { + **base_inputs, + "decomposition": decomposition, + "train_embedding": train_embedding, + "cat_feature_dict": cat_feature_dict, + "embs_feature_dict": embs_feature_dict, + "embs_dim_list": embs_dim_list, + "context_weight_dict": context_weight_dict, + } diff --git a/test/models/test_contextual.py b/test/models/test_contextual.py index 673ec3df28..758a084907 100644 --- a/test/models/test_contextual.py +++ b/test/models/test_contextual.py @@ -5,30 +5,46 @@ # LICENSE file in the root directory of this source tree. +from typing import Dict, Tuple + import torch from botorch.fit import fit_gpytorch_mll from botorch.models.contextual import LCEAGP, SACGP from botorch.models.gp_regression import FixedNoiseGP from botorch.models.kernels.contextual_lcea import LCEAKernel from botorch.models.kernels.contextual_sac import SACKernel +from botorch.utils.datasets import SupervisedDataset from botorch.utils.testing import BotorchTestCase from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.means import ConstantMean from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood +from torch import Tensor + + +def _gen_datasets( + **tkwargs, +) -> Tuple[Dict[int, SupervisedDataset], Tuple[Tensor, Tensor, Tensor]]: + train_X = torch.tensor( + [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]], **tkwargs + ) + train_Y = torch.tensor([[1.0], [2.0], [3.0]], **tkwargs) + train_Yvar = 0.01 * torch.ones(3, 1, **tkwargs) + + datasets = SupervisedDataset( + X=train_X, + Y=train_Y, + feature_names=[f"x{i}" for i in range(train_X.shape[-1])], + outcome_names=["y"], + Yvar=train_Yvar, + ) + return datasets, (train_X, train_Y, train_Yvar) class TestContextualGP(BotorchTestCase): def test_SACGP(self): for dtype in (torch.float, torch.double): - train_X = torch.tensor( - [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]], - device=self.device, - dtype=dtype, - ) - train_Y = torch.tensor( - [[1.0], [2.0], [3.0]], device=self.device, dtype=dtype - ) - train_Yvar = 0.01 * torch.ones(3, 1, device=self.device, dtype=dtype) + tkwargs = {"device": self.device, "dtype": dtype} + datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(**tkwargs) self.decomposition = {"1": [0, 3], "2": [1, 2]} model = SACGP(train_X, train_Y, train_Yvar, self.decomposition) @@ -59,17 +75,25 @@ def test_SACGP(self): posterior = model(test_x) self.assertIsInstance(posterior, MultivariateNormal) - def testLCEAGP(self): + def test_SACGP_construct_inputs(self): for dtype in (torch.float, torch.double): - train_X = torch.tensor( - [[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0]], - device=self.device, - dtype=dtype, - ) - train_Y = torch.tensor( - [[1.0], [2.0], [3.0]], device=self.device, dtype=dtype + tkwargs = {"device": self.device, "dtype": dtype} + datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(**tkwargs) + self.decomposition = {"1": [0, 3], "2": [1, 2]} + model = SACGP(train_X, train_Y, train_Yvar, self.decomposition) + data_dict = model.construct_inputs( + training_data=datasets, decomposition=self.decomposition ) - train_Yvar = 0.01 * torch.ones(3, 1, device=self.device, dtype=dtype) + + self.assertTrue(train_X.equal(data_dict["train_X"])) + self.assertTrue(train_Y.equal(data_dict["train_Y"])) + self.assertTrue(train_Yvar.equal(data_dict["train_Yvar"])) + self.assertDictEqual(data_dict["decomposition"], self.decomposition) + + def testLCEAGP(self): + for dtype in (torch.float, torch.double): + tkwargs = {"device": self.device, "dtype": dtype} + datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(**tkwargs) # Test setting attributes decomposition = {"1": [0, 1], "2": [2, 3]} @@ -90,3 +114,22 @@ def testLCEAGP(self): test_x = torch.rand(5, 4, device=self.device, dtype=dtype) posterior = model(test_x) self.assertIsInstance(posterior, MultivariateNormal) + + def test_LCEAGP_construct_inputs(self): + for dtype in (torch.float, torch.double): + tkwargs = {"device": self.device, "dtype": dtype} + datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(**tkwargs) + decomposition = {"1": [0, 1], "2": [2, 3]} + + model = LCEAGP(train_X, train_Y, train_Yvar, decomposition) + data_dict = model.construct_inputs( + training_data=datasets, + decomposition=decomposition, + train_embedding=False, + ) + + self.assertTrue(train_X.equal(data_dict["train_X"])) + self.assertTrue(train_Y.equal(data_dict["train_Y"])) + self.assertTrue(train_Yvar.equal(data_dict["train_Yvar"])) + self.assertDictEqual(data_dict["decomposition"], decomposition) + self.assertFalse(data_dict["train_embedding"])