Skip to content

Commit

Permalink
Allow passing in task features as part of X in MTGP.posterior (#1868)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1868

Prior to this change, `MultiTaskGP` could be evaluated using inputs of shape `batch x d` when using the model directly, and using inputs of shape `batch x d + 1` when the model was wrapped in a `ModelListGP`. This diff (combined with #1854) legitimizes both input shapes regardless of what API is used to evaluate the model. This will unify the APIs and increase the flexibility in how the model is used.

Reviewed By: Balandat

Differential Revision: D46496775

fbshipit-source-id: f44f5e529d9105d203923013101c37a8c1697e52
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jun 6, 2023
1 parent 79a1162 commit 5224872
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 30 deletions.
68 changes: 47 additions & 21 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,34 +706,59 @@ def posterior(
r"""Computes the posterior over model outputs at the provided points.
Args:
X: A `q x d` or `batch_shape x q x d` (batch mode) tensor, where `d` is the
dimension of the feature space (not including task indices) and
`q` is the number of points considered jointly.
output_indices: A list of indices, corresponding to the outputs over
which to compute the posterior (if the model is multi-output).
Can be used to speed up computation if only a subset of the
model's outputs are required for optimization. If omitted,
computes the posterior over all model outputs.
X: A tensor of shape `batch_shape x q x d` or `batch_shape x q x (d + 1)`,
where `d` is the dimension of the feature space (not including task
indices) and `q` is the number of points considered jointly. The `+ 1`
dimension is the optional task feature / index. If given, the model
produces the outputs for the given task indices. If omitted, the
model produces outputs for tasks in in `self._output_tasks` (specified
as `output_tasks` while constructing the model), which can overwritten
using `output_indices`.
output_indices: A list of indices, corresponding to the tasks over
which to compute the posterior. Only used if `X` does not include the
task feature. If omitted, defaults to `self._output_tasks`.
observation_noise: If True, add observation noise from the respective
likelihoods. If a Tensor, specifies the observation noise levels
to add.
posterior_transform: An optional PosteriorTransform.
Returns:
A `GPyTorchPosterior` object, representing `batch_shape` joint
distributions over `q` points and the outputs selected by
`output_indices`. Includes measurement noise if
`observation_noise` is specified.
distributions over `q` points. If the task features are included in `X`,
the posterior will be single output. Otherwise, the posterior will be
single or multi output corresponding to the tasks included in
either the `output_indices` or `self._output_tasks`.
"""
if output_indices is None:
output_indices = self._output_tasks
num_outputs = len(output_indices)
if any(i not in self._output_tasks for i in output_indices):
raise ValueError("Too many output indices")
cls_name = self.__class__.__name__

# construct evaluation X
X_full = _make_X_full(X=X, output_indices=output_indices, tf=self._task_feature)
includes_task_feature = X.shape[-1] == self.num_non_task_features + 1
if includes_task_feature:
# Make sure all task feature values are valid.
task_features = X[..., self._task_feature].unique()
if not (
(task_features >= 0).all() and (task_features < self.num_tasks).all()
):
raise ValueError(
"Expected all task features in `X` to be between 0 and "
f"self.num_tasks - 1. Got {task_features}."
)
if output_indices is not None:
raise ValueError(
"`output_indices` must be None when `X` includes task features."
)
num_outputs = 1
X_full = X
else:
# Add the task features to construct the full X for evaluation.
if output_indices is None:
output_indices = self._output_tasks
num_outputs = len(output_indices)
if not all(0 <= i < self.num_tasks for i in output_indices):
raise ValueError(
"Expected `output_indices` to be between 0 and self.num_tasks - 1. "
f"Got {output_indices}."
)
X_full = _make_X_full(
X=X, output_indices=output_indices, tf=self._task_feature
)

self.eval() # make sure model is in eval mode
# input transforms are applied at `posterior` in `eval` mode, and at
Expand All @@ -743,7 +768,8 @@ def posterior(
mvn = self(X_full)
if observation_noise is not False:
raise NotImplementedError(
f"Specifying observation noise is not yet supported by {cls_name}"
"Specifying observation noise is not yet supported by "
f"{self.__class__.__name__}."
)
# If single-output, return the posterior of a single-output model
if num_outputs == 1:
Expand Down
15 changes: 8 additions & 7 deletions botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ def __init__(
X=train_X, input_transform=input_transform
)
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
all_tasks, task_feature, d = self.get_all_tasks(
all_tasks, task_feature, self.num_non_task_features = self.get_all_tasks(
transformed_X, task_feature, output_tasks
)
self.num_tasks = len(all_tasks)
if outcome_transform is not None:
train_Y, train_Yvar = outcome_transform(Y=train_Y, Yvar=train_Yvar)

Expand All @@ -174,7 +175,7 @@ def __init__(

# construct indexer to be used in forward
self._task_feature = task_feature
self._base_idxr = torch.arange(d)
self._base_idxr = torch.arange(self.num_non_task_features)
self._base_idxr[task_feature:] += 1 # exclude task feature

super().__init__(
Expand All @@ -184,18 +185,18 @@ def __init__(
if covar_module is None:
self.covar_module = ScaleKernel(
base_kernel=MaternKernel(
nu=2.5, ard_num_dims=d, lengthscale_prior=GammaPrior(3.0, 6.0)
nu=2.5,
ard_num_dims=self.num_non_task_features,
lengthscale_prior=GammaPrior(3.0, 6.0),
),
outputscale_prior=GammaPrior(2.0, 0.15),
)
else:
self.covar_module = covar_module

num_tasks = len(all_tasks)
self._rank = rank if rank is not None else num_tasks

self._rank = rank if rank is not None else self.num_tasks
self.task_covar_module = IndexKernel(
num_tasks=num_tasks, rank=self._rank, prior=task_covar_prior
num_tasks=self.num_tasks, rank=self._rank, prior=task_covar_prior
)
if input_transform is not None:
self.input_transform = input_transform
Expand Down
30 changes: 28 additions & 2 deletions test/models/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import itertools
import math
import warnings
from typing import Optional
from typing import List, Optional

import torch
from botorch.acquisition.objective import ScalarizedPosteriorTransform
Expand Down Expand Up @@ -63,13 +63,18 @@ def _gen_datasets(yvar: Optional[float] = None, **tkwargs):


def _gen_model_and_data(
task_feature: int = 0, input_transform=None, outcome_transform=None, **tkwargs
task_feature: int = 0,
output_tasks: Optional[List[int]] = None,
input_transform=None,
outcome_transform=None,
**tkwargs
):
datasets, (train_X, train_Y) = _gen_datasets(**tkwargs)
model = MultiTaskGP(
train_X,
train_Y,
task_feature=task_feature,
output_tasks=output_tasks,
input_transform=input_transform,
outcome_transform=outcome_transform,
)
Expand Down Expand Up @@ -264,6 +269,27 @@ def test_MultiTaskGP(self):
self.assertIsInstance(posterior_f, GPyTorchPosterior)
self.assertIsInstance(posterior_f.distribution, MultitaskMultivariateNormal)

# test posterior with X including the task features
posterior_expected = model.posterior(test_x, output_indices=[0])
test_x = torch.cat([torch.zeros_like(test_x), test_x], dim=-1)
posterior_f = model.posterior(test_x)
self.assertIsInstance(posterior_f, GPyTorchPosterior)
self.assertIsInstance(posterior_f.distribution, MultivariateNormal)
self.assertAllClose(posterior_f.mean, posterior_expected.mean)
self.assertAllClose(
posterior_f.covariance_matrix, posterior_expected.covariance_matrix
)

# test task features in X and output_indices is not None.
with self.assertRaisesRegex(ValueError, "`output_indices` must be None"):
model.posterior(test_x, output_indices=[0, 1])

# test invalid task feature in X.
invalid_x = test_x.clone()
invalid_x[0, 0, 0] = 3
with self.assertRaisesRegex(ValueError, "task features in `X`"):
model.posterior(invalid_x)

# test that unsupported batch shape MTGPs throw correct error
with self.assertRaises(ValueError):
MultiTaskGP(torch.rand(2, 2, 2), torch.rand(2, 2, 1), 0)
Expand Down

0 comments on commit 5224872

Please sign in to comment.