Skip to content

Commit

Permalink
Model.construct_inputs expects a single TrainingData now –– BoTor…
Browse files Browse the repository at this point in the history
…ch changes (#529)

Summary:
Pull Request resolved: #529

Change the structure of `TrainingData` and `construct_inputs` in BoTorch.

**Previously:**
```
class TrainingData:
    Xs: List[Tensor]
    Ys: List[Tensor]
    Yvars: Optional[List[Tensor]] = None

training_data: TrainingData
```
**Now:**
```
class TrainingData:
    X: Tensor
    Y: Tensor
    Yvar: Optional[Tensor] = None

training_data: Dict[str, TrainingData]
```

Reviewed By: Balandat

Differential Revision: D23380460

fbshipit-source-id: 1e1dbf1e2c19f06b9b0edb1b2042dc0226e52926
  • Loading branch information
Elena Kashtelyan authored and facebook-github-bot committed Sep 2, 2020
1 parent c80c4fd commit 3621198
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 241 deletions.
3 changes: 2 additions & 1 deletion botorch/acquisition/multi_objective/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def forward(self, samples: Tensor, **kwargs) -> Tensor:
Returns:
A `sample_shape x batch_shape x q x m'`-dim Tensor of objective values with
`m'` the output dimension. This assumes maximization in each output dimension).
`m'` the output dimension. This assumes maximization in each output
dimension).
This method is usually not called directly, but via the objectives
Expand Down
47 changes: 22 additions & 25 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,14 @@ def forward(self, x: Tensor) -> MultivariateNormal:
def construct_inputs(
cls, training_data: TrainingData, **kwargs: Any
) -> Dict[str, Any]:
r"""Standardize kwargs of the model constructor."""
Xs = training_data.Xs
Ys = training_data.Ys
if len(Xs) == len(Ys) == 1:
return {"train_X": Xs[0], "train_Y": Ys[0]}
if all(torch.equal(Xs[0], X) for X in Xs[1:]):
# Use batched multioutput, single task GP.
return {"train_X": Xs[0], "train_Y": torch.cat(Ys, dim=-1)}
raise ValueError("Unexpected training data format.")
r"""Construct kwargs for the `Model` from `TrainingData`.
Args:
training_data: `TrainingData` container with data for single outcome
or for multiple outcomes for batched multi-output case.
**kwargs: None expected for this class.
"""
return {"train_X": training_data.X, "train_Y": training_data.Y}


class FixedNoiseGP(BatchedMultiOutputGPyTorchModel, ExactGP):
Expand Down Expand Up @@ -296,22 +295,20 @@ def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:
def construct_inputs(
cls, training_data: TrainingData, **kwargs: Any
) -> Dict[str, Any]:
r"""Standardize kwargs of the model constructor."""
if training_data.Yvars is None:
raise ValueError(f"Yvars required for {cls.__name__}.")
Xs = training_data.Xs
Ys = training_data.Ys
Yvars = training_data.Yvars
if len(Xs) == len(Ys) == 1:
return {"train_X": Xs[0], "train_Y": Ys[0], "train_Yvar": Yvars[0]}
if all(torch.equal(Xs[0], X) for X in Xs[1:]):
# Use batched multioutput, single task GP.
return {
"train_X": Xs[0],
"train_Y": torch.cat(Ys, dim=-1),
"train_Yvar": torch.cat(Yvars, dim=-1),
}
raise ValueError("Unexpected training data format.")
r"""Construct kwargs for the `Model` from `TrainingData`.
Args:
training_data: `TrainingData` container with data for single outcome
or for multiple outcomes for batched multi-output case.
**kwargs: None expected for this class.
"""
if training_data.Yvar is None:
raise ValueError(f"Yvar required for {cls.__name__}.")
return {
"train_X": training_data.X,
"train_Y": training_data.Y,
"train_Yvar": training_data.Yvar,
}


class HeteroskedasticSingleTaskGP(SingleTaskGP):
Expand Down
70 changes: 31 additions & 39 deletions botorch/models/gp_regression_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,26 +117,23 @@ def __init__(

@classmethod
def construct_inputs(cls, training_data: TrainingData, **kwargs) -> Dict[str, Any]:
r"""Standardize kwargs of the model constructor."""
r"""Construct kwargs for the `Model` from `TrainingData`.
Args:
training_data: `TrainingData` container with data for single outcome
or for multiple outcomes for batched multi-output case.
**kwargs: Options, expected for this class:
- fidelity_features: List of columns of X that are fidelity parameters.
"""
fidelity_features = kwargs.get("fidelity_features")
if fidelity_features is None:
raise ValueError(f"Fidelity features required for {cls.__name__}.")
Xs = training_data.Xs
Ys = training_data.Ys
if len(Xs) == len(Ys) == 1:
return {
"train_X": Xs[0],
"train_Y": Ys[0],
"data_fidelity": fidelity_features[0],
}
if all(torch.equal(Xs[0], X) for X in Xs[1:]):
# Use batched multioutput, single task GP.
return {
"train_X": Xs[0],
"train_Y": torch.cat(Ys, dim=-1),
"data_fidelity": fidelity_features[0],
}
raise ValueError("Unexpected training data format.")

return {
"train_X": training_data.X,
"train_Y": training_data.Y,
"data_fidelity": fidelity_features[0],
}


class FixedNoiseMultiFidelityGP(FixedNoiseGP):
Expand Down Expand Up @@ -220,31 +217,26 @@ def __init__(

@classmethod
def construct_inputs(cls, training_data: TrainingData, **kwargs) -> Dict[str, Any]:
r"""Standardize kwargs of the model constructor."""
if training_data.Yvars is None:
raise ValueError(f"Yvars required for {cls.__name__}.")
Xs = training_data.Xs
Ys = training_data.Ys
Yvars = training_data.Yvars
r"""Construct kwargs for the `Model` from `TrainingData`.
Args:
training_data: `TrainingData` container with data for single outcome
or for multiple outcomes for batched multi-output case.
**kwargs: Options, expected for this class:
- fidelity_features: List of columns of X that are fidelity parameters.
"""
fidelity_features = kwargs.get("fidelity_features")
if fidelity_features is None:
raise ValueError(f"Fidelity features required for {cls.__name__}.")
if len(Xs) == len(Ys) == 1:
return {
"train_X": Xs[0],
"train_Y": Ys[0],
"train_Yvar": Yvars[0],
"data_fidelity": fidelity_features[0],
}
if all(torch.equal(Xs[0], X) for X in Xs[1:]):
# Use batched multioutput, single task GP.
return {
"train_X": Xs[0],
"train_Y": torch.cat(Ys, dim=-1),
"train_Yvar": torch.cat(Yvars, dim=-1),
"data_fidelity": fidelity_features[0],
}
raise ValueError("Unexpected training data format.")
if training_data.Yvar is None:
raise ValueError(f"Yvar required for {cls.__name__}.")

return {
"train_X": training_data.X,
"train_Y": training_data.Y,
"train_Yvar": training_data.Yvar,
"data_fidelity": fidelity_features[0],
}


def _setup_multifidelity_covar_module(
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def fantasize(
def construct_inputs(
cls, training_data: TrainingData, **kwargs: Any
) -> Dict[str, Any]:
r"""Standardize kwargs of the model constructor."""
r"""Construct kwargs for the `Model` from `TrainingData`."""
raise NotImplementedError(
f"`construct_inputs` not implemented for {cls.__name__}."
)
51 changes: 50 additions & 1 deletion botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

from __future__ import annotations

from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch
from botorch.models.gpytorch import MultiTaskGPyTorchModel
from botorch.utils.containers import TrainingData
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels.index_kernel import IndexKernel
from gpytorch.kernels.matern_kernel import MaternKernel
Expand Down Expand Up @@ -166,6 +167,29 @@ def get_all_tasks(
all_tasks = train_X[:, task_feature].unique().to(dtype=torch.long).tolist()
return all_tasks, task_feature, d

@classmethod
def construct_inputs(cls, training_data: TrainingData, **kwargs) -> Dict[str, Any]:
r"""Construct kwargs for the `Model` from `TrainingData`.
Args:
training_data: `TrainingData` container with data for single outcome
or for multiple outcomes for batched multi-output case.
**kwargs: Additional options for the model that pertain to the
training data:
- `task_features` – indices of the input columns containing the task
features (expected list of length 1).
"""

task_features = kwargs.get("task_features")
if task_features is None:
raise ValueError(f"task features required for {cls.__name__}.")

return {
"train_X": training_data.X,
"train_Y": training_data.Y,
"task_feature": task_features[0],
}


class FixedNoiseMultiTaskGP(MultiTaskGP):
r"""Multi-Task GP model using an ICM kernel, with known observation noise.
Expand Down Expand Up @@ -226,3 +250,28 @@ def __init__(
)
self.likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar.squeeze(-1))
self.to(train_X)

@classmethod
def construct_inputs(cls, training_data: TrainingData, **kwargs) -> Dict[str, Any]:
r"""Construct kwargs for the `Model` from `TrainingData`.
Args:
training_data: `TrainingData` container with data for single outcome
or for multiple outcomes for batched multi-output case.
**kwargs: Additional options for the model that pertain to the
training data:
- `task_features` – indices of task features in X.
"""

task_features = kwargs.get("task_features")
if task_features is None:
raise ValueError(f"task features required for {cls.__name__}.")
if training_data.Yvar is None:
raise ValueError(f"Yvar required for {cls.__name__}.")

return {
"train_X": training_data.X,
"train_Y": training_data.Y,
"train_Yvar": training_data.Yvar,
"task_feature": task_features[0],
}
10 changes: 5 additions & 5 deletions botorch/utils/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
Containers to standardize inputs into models and acquisition functions.
"""

from typing import List, NamedTuple, Optional
from typing import NamedTuple, Optional

from torch import Tensor


class TrainingData(NamedTuple):
r"""Standardized struct of model training data."""
r"""Standardized struct of model training data for a single outcome."""

Xs: List[Tensor]
Ys: List[Tensor]
Yvars: Optional[List[Tensor]] = None
X: Tensor
Y: Tensor
Yvar: Optional[Tensor] = None
106 changes: 9 additions & 97 deletions test/models/test_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,47 +280,12 @@ def test_construct_inputs(self):
model, model_kwargs = self._get_model_and_data(
batch_shape=batch_shape, m=2, **tkwargs
)
# len(Xs) == len(Ys) == 1
training_data = TrainingData(
Xs=[model_kwargs["train_X"][0]], Ys=[model_kwargs["train_Y"][0]]
)
data_dict = model.construct_inputs(training_data)
self.assertTrue(
torch.equal(data_dict["train_X"], model_kwargs["train_X"][0])
)
self.assertTrue(
torch.equal(data_dict["train_Y"], model_kwargs["train_Y"][0])
)
# all X's are equal
training_data = TrainingData(
Xs=[model_kwargs["train_X"], model_kwargs["train_X"]],
Ys=[model_kwargs["train_Y"], model_kwargs["train_Y"]],
X=model_kwargs["train_X"], Y=model_kwargs["train_Y"]
)
data_dict = model.construct_inputs(training_data)
self.assertTrue(torch.equal(data_dict["train_X"], model_kwargs["train_X"]))
self.assertTrue(
torch.equal(
data_dict["train_Y"],
torch.cat(
[model_kwargs["train_Y"], model_kwargs["train_Y"]], dim=-1
),
)
)
# unexpected data format
training_data = TrainingData(
Xs=[model_kwargs["train_X"], torch.add(model_kwargs["train_X"], 1)],
Ys=[model_kwargs["train_Y"], model_kwargs["train_Y"]],
)
with self.assertRaises(ValueError):
model.construct_inputs(training_data)
# make sure Yvar is not added to dict
training_data = TrainingData(
Xs=[model_kwargs["train_X"]],
Ys=[model_kwargs["train_Y"]],
Yvars=[torch.full_like(model_kwargs["train_Y"], 0.01)],
)
data_dict = model.construct_inputs(training_data)
self.assertTrue("train_Yvar" not in data_dict)
self.assertTrue(torch.equal(data_dict["train_Y"], model_kwargs["train_Y"]))


class TestFixedNoiseGP(TestSingleTaskGP):
Expand Down Expand Up @@ -363,73 +328,20 @@ def test_construct_inputs(self):
batch_shape=batch_shape, m=2, **tkwargs
)
training_data = TrainingData(
Xs=[model_kwargs["train_X"][0]],
Ys=[model_kwargs["train_Y"][0]],
Yvars=[model_kwargs["train_Yvar"][0]],
X=model_kwargs["train_X"],
Y=model_kwargs["train_Y"],
Yvar=model_kwargs["train_Yvar"],
)
data_dict = model.construct_inputs(training_data)
self.assertTrue("train_Yvar" in data_dict)
self.assertTrue(
torch.equal(data_dict["train_X"], model_kwargs["train_X"][0])
)
self.assertTrue(
torch.equal(data_dict["train_Y"], model_kwargs["train_Y"][0])
)
self.assertTrue(
torch.equal(data_dict["train_Yvar"], model_kwargs["train_Yvar"][0])
)
# if Yvars is missing, then raise error
training_data = TrainingData(
Xs=model_kwargs["train_X"], Ys=model_kwargs["train_Y"]
)
with self.assertRaises(ValueError):
model.construct_inputs(training_data)

# len(Xs) == len(Ys) == 1
training_data = TrainingData(
Xs=[model_kwargs["train_X"][0]],
Ys=[model_kwargs["train_Y"][0]],
Yvars=[model_kwargs["train_Yvar"][0]],
)
data_dict = model.construct_inputs(training_data)
self.assertTrue(
torch.equal(data_dict["train_X"], model_kwargs["train_X"][0])
)
self.assertTrue(
torch.equal(data_dict["train_Y"], model_kwargs["train_Y"][0])
)
self.assertTrue(
torch.equal(data_dict["train_Yvar"], model_kwargs["train_Yvar"][0])
)
# all X's are equal
training_data = TrainingData(
Xs=[model_kwargs["train_X"], model_kwargs["train_X"]],
Ys=[model_kwargs["train_Y"], model_kwargs["train_Y"]],
Yvars=[model_kwargs["train_Yvar"], model_kwargs["train_Yvar"]],
)
data_dict = model.construct_inputs(training_data)
self.assertTrue(torch.equal(data_dict["train_X"], model_kwargs["train_X"]))
self.assertTrue(torch.equal(data_dict["train_Y"], model_kwargs["train_Y"]))
self.assertTrue(
torch.equal(
data_dict["train_Y"],
torch.cat(
[model_kwargs["train_Y"], model_kwargs["train_Y"]], dim=-1
),
)
)
self.assertTrue(
torch.equal(
data_dict["train_Yvar"],
torch.cat(
[model_kwargs["train_Yvar"], model_kwargs["train_Yvar"]], dim=-1
),
)
torch.equal(data_dict["train_Yvar"], model_kwargs["train_Yvar"])
)
# unexpected data format
# if Yvars is missing, then raise error
training_data = TrainingData(
Xs=[model_kwargs["train_X"], torch.add(model_kwargs["train_X"], 1)],
Ys=[model_kwargs["train_Y"], model_kwargs["train_Y"]],
Yvars=[model_kwargs["train_Yvar"]],
X=model_kwargs["train_X"], Y=model_kwargs["train_Y"]
)
with self.assertRaises(ValueError):
model.construct_inputs(training_data)
Expand Down
Loading

0 comments on commit 3621198

Please sign in to comment.