Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Input Standardization Using TrainingData #477

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@

from __future__ import annotations

from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
from botorch import settings
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.transforms.outcome import Log, OutcomeTransform
from botorch.models.utils import validate_input_scaling
from botorch.sampling.samplers import MCSampler
from botorch.utils.containers import TrainingData
from gpytorch.constraints.constraints import GreaterThan
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels.matern_kernel import MaternKernel
Expand Down Expand Up @@ -136,6 +137,11 @@ def forward(self, x: Tensor) -> MultivariateNormal:
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)

@classmethod
def construct_inputs(cls, training_data: TrainingData) -> Dict[str, Any]:
r"""Standardize kwargs of the model constructor."""
return {"train_X": training_data.Xs[0], "train_Y": training_data.Ys[-1]}


class FixedNoiseGP(BatchedMultiOutputGPyTorchModel, ExactGP):
r"""A single-task exact GP model using fixed noise levels.
Expand Down Expand Up @@ -276,6 +282,17 @@ def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:
new_model.likelihood.noise_covar.noise = new_noise
return new_model

@classmethod
def construct_inputs(cls, training_data: TrainingData) -> Dict[str, Any]:
r"""Standardize kwargs of the model constructor."""
if training_data.Yvars is None:
raise ValueError("Training data is missing Yvars member")
return {
"train_X": training_data.Xs[0],
"train_Y": training_data.Ys[-1],
"train_Yvar": training_data.Yvars[0],
}


class HeteroskedasticSingleTaskGP(SingleTaskGP):
r"""A single-task exact GP model using a heteroskeastic noise model.
Expand Down
10 changes: 9 additions & 1 deletion botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional

from botorch import settings
from botorch.posteriors import Posterior
from botorch.sampling.samplers import MCSampler
from botorch.utils.containers import TrainingData
from torch import Tensor
from torch.nn import Module

Expand Down Expand Up @@ -123,3 +124,10 @@ def fantasize(
post_X = self.posterior(X, observation_noise=observation_noise)
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
return self.condition_on_observations(X=X, Y=Y_fantasized, **kwargs)

@classmethod
def construct_inputs(cls, training_data: TrainingData) -> Dict[str, Any]:
r"""Standardize kwargs of the model constructor."""
raise NotImplementedError(
f"`construct_inputs` not implemented for {cls.__name__}."
)
21 changes: 21 additions & 0 deletions botorch/utils/containers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Containers to standardize inputs into models and acquisition functions.
"""

from typing import List, NamedTuple, Optional

from torch import Tensor


class TrainingData(NamedTuple):
r"""Standardized struct of model training data."""

Xs: List[Tensor]
Ys: List[Tensor]
Yvars: Optional[List[Tensor]] = None
5 changes: 5 additions & 0 deletions sphinx/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ Constraints
.. automodule:: botorch.utils.constraints
:members:

Containers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.utils.containers
:members:

Objective
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.utils.objective
Expand Down
39 changes: 39 additions & 0 deletions test/models/test_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from botorch.models.utils import add_output_dim
from botorch.posteriors import GPyTorchPosterior
from botorch.sampling import SobolQMCNormalSampler
from botorch.utils.containers import TrainingData
from botorch.utils.sampling import manual_seed
from botorch.utils.testing import BotorchTestCase, _get_random_data
from gpytorch.kernels import MaternKernel, ScaleKernel
Expand Down Expand Up @@ -271,6 +272,22 @@ def test_subset_model(self):
)
)

def test_construct_inputs(self):
for batch_shape, dtype in itertools.product(
(torch.Size(), torch.Size([2])), (torch.float, torch.double)
):
tkwargs = {"device": self.device, "dtype": dtype}
model, model_kwargs = self._get_model_and_data(
batch_shape=batch_shape, m=2, **tkwargs
)
training_data = TrainingData(
Xs=model_kwargs["train_X"],
Ys=model_kwargs["train_Y"],
Yvars=torch.full_like(model_kwargs["train_Y"], 0.01),
)
data_dict = model.construct_inputs(training_data)
self.assertTrue("train_Yvar" not in data_dict)


class TestFixedNoiseGP(TestSingleTaskGP):
def _get_model_and_data(self, batch_shape, m, outcome_transform=None, **tkwargs):
Expand Down Expand Up @@ -303,6 +320,28 @@ def test_fixed_noise_likelihood(self):
)
)

def test_construct_inputs(self):
for batch_shape, dtype in itertools.product(
(torch.Size(), torch.Size([2])), (torch.float, torch.double)
):
tkwargs = {"device": self.device, "dtype": dtype}
model, model_kwargs = self._get_model_and_data(
batch_shape=batch_shape, m=2, **tkwargs
)
training_data = TrainingData(
Xs=model_kwargs["train_X"],
Ys=model_kwargs["train_Y"],
Yvars=model_kwargs["train_Yvar"],
)
data_dict = model.construct_inputs(training_data)
self.assertTrue("train_Yvar" in data_dict)
# if Yvars is missing, then raise error
training_data = TrainingData(
Xs=model_kwargs["train_X"], Ys=model_kwargs["train_Y"]
)
with self.assertRaises(ValueError):
model.construct_inputs(training_data)


class TestHeteroskedasticSingleTaskGP(TestSingleTaskGP):
def _get_model_and_data(self, batch_shape, m, outcome_transform=None, **tkwargs):
Expand Down
2 changes: 2 additions & 0 deletions test/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ def test_not_so_abstract_base_model(self):
model.num_outputs
with self.assertRaises(NotImplementedError):
model.subset_output([0])
with self.assertRaises(NotImplementedError):
model.construct_inputs(None)
26 changes: 26 additions & 0 deletions test/utils/test_containers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from botorch.utils.containers import TrainingData
from botorch.utils.testing import BotorchTestCase


class TestConstructContainers(BotorchTestCase):
def test_TrainingData(self):
Xs = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]])
Ys = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]])
Yvars = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]])

training_data = TrainingData(Xs, Ys)
self.assertTrue(torch.equal(training_data.Xs, Xs))
self.assertTrue(torch.equal(training_data.Ys, Ys))
self.assertEqual(training_data.Yvars, None)

training_data = TrainingData(Xs, Ys, Yvars)
self.assertTrue(torch.equal(training_data.Xs, Xs))
self.assertTrue(torch.equal(training_data.Ys, Ys))
self.assertTrue(torch.equal(training_data.Yvars, Yvars))