Skip to content

Commit

Permalink
Pass botorch_model_class to Surrogate._set_formatted_inputs (#2653)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2653

Context:

`Surrogate._set_formatted_inputs` is used only in the following context:
- A model input constructor sets inputs on the basis of a `botorch_model_class` and a `surrogate`. It checks which inputs are valid based on the `botorch_model_class`
- The input constructor calls `_set_formatted_inputs`; if the inputs are not valid (as per the above bullet), it raises an exception.
- However, `_set_formatted_inputs` uses `surrogate.botorch_model_class` rather than `botorch_model_class`, which may not be the same, and can raise a nonsensical error. Hence there was a unit test raising a puzzling exception that SaasFullyBayesianSingleTaskGP does not support `outcome_transform` (it does), when it should have been saying that `SingleTaskGPWithDifferentConstructor`, the model in question, doesn't support `outcome_transform`.

This PR:

- Passes `botorch_model_class` to `_set_formatted_inputs`
- changes some `list` annotations to `Sequence` to fix type errors

Reviewed By: Balandat

Differential Revision: D61212316
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 13, 2024
1 parent 0314b02 commit bbbb9a8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
16 changes: 10 additions & 6 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ def __init__(
model_options: Optional[dict[str, Any]] = None,
mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood,
mll_options: Optional[dict[str, Any]] = None,
outcome_transform_classes: Optional[list[type[OutcomeTransform]]] = None,
outcome_transform_classes: Optional[Sequence[type[OutcomeTransform]]] = None,
outcome_transform_options: Optional[dict[str, dict[str, Any]]] = None,
input_transform_classes: Optional[list[type[InputTransform]]] = None,
input_transform_classes: Optional[Sequence[type[InputTransform]]] = None,
input_transform_options: Optional[dict[str, dict[str, Any]]] = None,
covar_module_class: Optional[type[Kernel]] = None,
covar_module_options: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -355,7 +355,9 @@ def _set_formatted_inputs(
dataset: SupervisedDataset,
botorch_model_class_args: list[str],
search_space_digest: SearchSpaceDigest,
botorch_model_class: type[Model],
) -> None:
"""Modifies `formatted_model_inputs` in place."""
for input_name, input_class, input_options in inputs:
if input_class is None:
# This is a temporary solution until all BoTorch models use
Expand All @@ -376,7 +378,7 @@ def _set_formatted_inputs(
# to be expanded to a ModelFactory, see D22457664, to accommodate
# different models in the future.
raise UserInputError(
f"The BoTorch model class {self.botorch_model_class} does not "
f"The BoTorch model class {botorch_model_class.__name__} does not "
f"support the input {input_name}."
)
input_options = deepcopy(input_options) or {}
Expand All @@ -385,7 +387,7 @@ def _set_formatted_inputs(
covar_module_with_defaults = covar_module_argparse(
input_class,
dataset=dataset,
botorch_model_class=self.botorch_model_class,
botorch_model_class=botorch_model_class,
**input_options,
)

Expand Down Expand Up @@ -736,7 +738,7 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]:

def _extract_construct_input_transform_args(
self, search_space_digest: SearchSpaceDigest
) -> tuple[Optional[list[type[InputTransform]]], dict[str, dict[str, Any]]]:
) -> tuple[Optional[Sequence[type[InputTransform]]], dict[str, dict[str, Any]]]:
"""
Extracts input transform classes and input transform options that will
be used in `self._set_formatted_inputs` and ultimately passed to
Expand Down Expand Up @@ -764,7 +766,7 @@ def _extract_construct_input_transform_args(
)
}

submodel_input_transform_classes: list[type[InputTransform]] = [
submodel_input_transform_classes: Sequence[type[InputTransform]] = [
InputPerturbation
]

Expand Down Expand Up @@ -862,6 +864,8 @@ def _submodel_input_constructor_base(
search_space_digest=search_space_digest,
# This is used to check if the arguments are supported.
botorch_model_class_args=botorch_model_class_args,
# Used to raise the appropriate error if arguments are not supported
botorch_model_class=botorch_model_class,
)
return formatted_model_inputs

Expand Down
5 changes: 4 additions & 1 deletion ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,10 @@ def test_with_botorch_transforms(self) -> None:
outcome_transform_classes=[Standardize],
outcome_transform_options={"Standardize": {"m": 1}},
)
with self.assertRaisesRegex(UserInputError, "The BoTorch model class"):
with self.assertRaisesRegex(
UserInputError,
"The BoTorch model class SingleTaskGPWithDifferentConstructor",
):
surrogate.fit(
datasets=self.supervised_training_data,
search_space_digest=SearchSpaceDigest(
Expand Down

0 comments on commit bbbb9a8

Please sign in to comment.