Skip to content

Commit

Permalink
Model Input Standardization Using TrainingData (pytorch#477)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#477

Different GP models take different kwargs as inputs into their constructors. To standardize the inputs, we create a `TrainingData` dataclass in conjunction with a classmethod `construct_inputs()`.

Reviewed By: Balandat

Differential Revision: D22395030

fbshipit-source-id: be89da3e2878993d8ba8972e48712762e9c3ccc8
  • Loading branch information
EricZLou authored and facebook-github-bot committed Jul 8, 2020
1 parent a153151 commit dc111ba
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 2 deletions.
19 changes: 18 additions & 1 deletion botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

from __future__ import annotations

from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
from botorch import settings
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.model import TrainingData
from botorch.models.transforms.outcome import Log, OutcomeTransform
from botorch.models.utils import validate_input_scaling
from botorch.sampling.samplers import MCSampler
Expand Down Expand Up @@ -136,6 +137,11 @@ def forward(self, x: Tensor) -> MultivariateNormal:
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)

@classmethod
def construct_inputs(cls, training_data: TrainingData) -> Dict[str, Any]:
r"""Standardize kwargs of the model constructor."""
return {"train_X": training_data.Xs[0], "train_Y": training_data.Ys[-1]}


class FixedNoiseGP(BatchedMultiOutputGPyTorchModel, ExactGP):
r"""A single-task exact GP model using fixed noise levels.
Expand Down Expand Up @@ -276,6 +282,17 @@ def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:
new_model.likelihood.noise_covar.noise = new_noise
return new_model

@classmethod
def construct_inputs(cls, training_data: TrainingData) -> Dict[str, Any]:
r"""Standardize kwargs of the model constructor."""
if training_data.Yvars is None:
raise ValueError("Training data is missing Yvars member")
return {
"train_X": training_data.Xs[0],
"train_Y": training_data.Ys[-1],
"train_Yvar": training_data.Yvars[0],
}


class HeteroskedasticSingleTaskGP(SingleTaskGP):
r"""A single-task exact GP model using a heteroskeastic noise model.
Expand Down
10 changes: 9 additions & 1 deletion botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional

from botorch import settings
from botorch.posteriors import Posterior
from botorch.sampling.samplers import MCSampler
from botorch.utils.containers import TrainingData
from torch import Tensor
from torch.nn import Module

Expand Down Expand Up @@ -123,3 +124,10 @@ def fantasize(
post_X = self.posterior(X, observation_noise=observation_noise)
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
return self.condition_on_observations(X=X, Y=Y_fantasized, **kwargs)

@classmethod
def construct_inputs(cls, training_data: TrainingData) -> Dict[str, Any]:
r"""Standardize kwargs of the model constructor."""
raise NotImplementedError(
f"`construct_inputs` not implemented for {cls.__name__}."
)
23 changes: 23 additions & 0 deletions botorch/utils/containers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Containers to standardize inputs into models and acquisition functions.
"""

from dataclasses import dataclass
from typing import List, Optional

from torch import Tensor


@dataclass
class TrainingData:
r"""Standardized struct of model training data."""

Xs: List[Tensor]
Ys: List[Tensor]
Yvars: Optional[List[Tensor]] = None
15 changes: 15 additions & 0 deletions sphinx/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import re
import sys
import warnings


base_path = os.path.abspath(os.path.join(__file__, "..", "..", "..", "botorch"))
Expand Down Expand Up @@ -199,3 +200,17 @@

# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True


# -- Other -------------------------------------------------------------------

# Suppress warnings from sphinx_autodoc_typehints related to using
# @dataclass classes, specifically the `TrainingData` class in
# `botorch/utils/containers.py`. This is an open Sphinx issue:
# https://github.com/agronholm/sphinx-autodoc-typehints/issues/123

warnings.filterwarnings(
"ignore",
message="Cannot treat a function defined as a local function: "
'"botorch.utils.containers.TrainingData" (use @functools.wraps)',
)
5 changes: 5 additions & 0 deletions sphinx/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ Constraints
.. automodule:: botorch.utils.constraints
:members:

Containers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.utils.containers
:members:

Objective
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.utils.objective
Expand Down
39 changes: 39 additions & 0 deletions test/models/test_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from botorch.models.utils import add_output_dim
from botorch.posteriors import GPyTorchPosterior
from botorch.sampling import SobolQMCNormalSampler
from botorch.utils.containers import TrainingData
from botorch.utils.sampling import manual_seed
from botorch.utils.testing import BotorchTestCase, _get_random_data
from gpytorch.kernels import MaternKernel, ScaleKernel
Expand Down Expand Up @@ -271,6 +272,22 @@ def test_subset_model(self):
)
)

def test_construct_inputs(self):
for batch_shape, dtype in itertools.product(
(torch.Size(), torch.Size([2])), (torch.float, torch.double)
):
tkwargs = {"device": self.device, "dtype": dtype}
model, model_kwargs = self._get_model_and_data(
batch_shape=batch_shape, m=2, **tkwargs
)
training_data = TrainingData(
Xs=model_kwargs["train_X"],
Ys=model_kwargs["train_Y"],
Yvars=torch.full_like(model_kwargs["train_Y"], 0.01),
)
data_dict = model.construct_inputs(training_data)
self.assertTrue("train_Yvar" not in data_dict)


class TestFixedNoiseGP(TestSingleTaskGP):
def _get_model_and_data(self, batch_shape, m, outcome_transform=None, **tkwargs):
Expand Down Expand Up @@ -303,6 +320,28 @@ def test_fixed_noise_likelihood(self):
)
)

def test_construct_inputs(self):
for batch_shape, dtype in itertools.product(
(torch.Size(), torch.Size([2])), (torch.float, torch.double)
):
tkwargs = {"device": self.device, "dtype": dtype}
model, model_kwargs = self._get_model_and_data(
batch_shape=batch_shape, m=2, **tkwargs
)
training_data = TrainingData(
Xs=model_kwargs["train_X"],
Ys=model_kwargs["train_Y"],
Yvars=model_kwargs["train_Yvar"],
)
data_dict = model.construct_inputs(training_data)
self.assertTrue("train_Yvar" in data_dict)
# if Yvars is missing, then raise error
training_data = TrainingData(
Xs=model_kwargs["train_X"], Ys=model_kwargs["train_Y"]
)
with self.assertRaises(ValueError):
model.construct_inputs(training_data)


class TestHeteroskedasticSingleTaskGP(TestSingleTaskGP):
def _get_model_and_data(self, batch_shape, m, outcome_transform=None, **tkwargs):
Expand Down
2 changes: 2 additions & 0 deletions test/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ def test_not_so_abstract_base_model(self):
model.num_outputs
with self.assertRaises(NotImplementedError):
model.subset_output([0])
with self.assertRaises(NotImplementedError):
model.construct_inputs(None)
26 changes: 26 additions & 0 deletions test/utils/test_containers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from botorch.utils.containers import TrainingData
from botorch.utils.testing import BotorchTestCase


class TestConstructContainers(BotorchTestCase):
def test_TrainingData(self):
Xs = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]])
Ys = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]])
Yvars = torch.tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 1.0]])

training_data = TrainingData(Xs, Ys)
self.assertTrue(torch.equal(training_data.Xs, Xs))
self.assertTrue(torch.equal(training_data.Ys, Ys))
self.assertEqual(training_data.Yvars, None)

training_data = TrainingData(Xs, Ys, Yvars)
self.assertTrue(torch.equal(training_data.Xs, Xs))
self.assertTrue(torch.equal(training_data.Ys, Ys))
self.assertTrue(torch.equal(training_data.Yvars, Yvars))

0 comments on commit dc111ba

Please sign in to comment.