From feec334b19782ea360bbee7100993e8bac6032fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20P=2E=20D=C3=BCrholt?= Date: Fri, 22 Mar 2024 10:22:21 +0100 Subject: [PATCH] Feature/linear deterministic (#385) * add linear deterministic surrogate --- bofire/data_models/surrogates/api.py | 3 + .../surrogates/botorch_surrogates.py | 2 + .../data_models/surrogates/deterministic.py | 41 ++++++++++++ bofire/surrogates/api.py | 1 + bofire/surrogates/deterministic.py | 25 +++++++ bofire/surrogates/mapper.py | 2 + tests/bofire/data_models/specs/surrogates.py | 67 +++++++++++++++++++ tests/bofire/surrogates/test_deterministic.py | 26 +++++++ 8 files changed, 167 insertions(+) create mode 100644 bofire/data_models/surrogates/deterministic.py create mode 100644 bofire/surrogates/deterministic.py create mode 100644 tests/bofire/surrogates/test_deterministic.py diff --git a/bofire/data_models/surrogates/api.py b/bofire/data_models/surrogates/api.py index 3c464b245..d984cac7a 100644 --- a/bofire/data_models/surrogates/api.py +++ b/bofire/data_models/surrogates/api.py @@ -5,6 +5,7 @@ AnyBotorchSurrogate, BotorchSurrogates, ) +from bofire.data_models.surrogates.deterministic import LinearDeterministicSurrogate from bofire.data_models.surrogates.empirical import EmpiricalSurrogate from bofire.data_models.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate from bofire.data_models.surrogates.linear import LinearSurrogate @@ -46,6 +47,7 @@ LinearSurrogate, PolynomialSurrogate, TanimotoGPSurrogate, + LinearDeterministicSurrogate, ] AnyTrainableSurrogate = Union[ @@ -74,6 +76,7 @@ LinearSurrogate, PolynomialSurrogate, TanimotoGPSurrogate, + LinearDeterministicSurrogate, ] AnyClassificationSurrogate = ClassificationMLPEnsemble diff --git a/bofire/data_models/surrogates/botorch_surrogates.py b/bofire/data_models/surrogates/botorch_surrogates.py index 2a5f1ee19..6314060f4 100644 --- a/bofire/data_models/surrogates/botorch_surrogates.py +++ b/bofire/data_models/surrogates/botorch_surrogates.py @@ -5,6 +5,7 @@ from bofire.data_models.base import BaseModel from bofire.data_models.domain.api import Inputs, Outputs +from bofire.data_models.surrogates.deterministic import LinearDeterministicSurrogate from bofire.data_models.surrogates.empirical import EmpiricalSurrogate from bofire.data_models.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate from bofire.data_models.surrogates.linear import LinearSurrogate @@ -34,6 +35,7 @@ TanimotoGPSurrogate, LinearSurrogate, PolynomialSurrogate, + LinearDeterministicSurrogate, ] diff --git a/bofire/data_models/surrogates/deterministic.py b/bofire/data_models/surrogates/deterministic.py new file mode 100644 index 000000000..a7acdfa2a --- /dev/null +++ b/bofire/data_models/surrogates/deterministic.py @@ -0,0 +1,41 @@ +from typing import Annotated, Dict, Literal, Type + +from pydantic import Field, model_validator + +from bofire.data_models.features.api import ( + AnyOutput, + ContinuousInput, + ContinuousOutput, + DiscreteInput, +) +from bofire.data_models.surrogates.botorch import BotorchSurrogate + + +class LinearDeterministicSurrogate(BotorchSurrogate): + type: Literal["LinearDeterministicSurrogate"] = "LinearDeterministicSurrogate" + coefficients: Annotated[Dict[str, float], Field(min_length=1)] + intercept: float + + @classmethod + def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool: + """Abstract method to check output type for surrogate models + Args: + my_type: continuous or categorical output + Returns: + bool: True if the output type is valid for the surrogate chosen, False otherwise + """ + return isinstance(my_type, type(ContinuousOutput)) + + @model_validator(mode="after") + def validate_input_types(self): + if len(self.inputs.get([ContinuousInput, DiscreteInput])) != len(self.inputs): + raise ValueError( + "Only numerical inputs are suppoerted for the `LinearDeterministicSurrogate`" + ) + return self + + @model_validator(mode="after") + def validate_coefficients(self): + if sorted(self.inputs.get_keys()) != sorted(self.coefficients.keys()): + raise ValueError("coefficient keys do not match input feature keys.") + return self diff --git a/bofire/surrogates/api.py b/bofire/surrogates/api.py index a2a935380..9115f2d19 100644 --- a/bofire/surrogates/api.py +++ b/bofire/surrogates/api.py @@ -1,4 +1,5 @@ from bofire.surrogates.botorch_surrogates import BotorchSurrogates +from bofire.surrogates.deterministic import LinearDeterministicSurrogate from bofire.surrogates.empirical import EmpiricalSurrogate from bofire.surrogates.mapper import map from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate diff --git a/bofire/surrogates/deterministic.py b/bofire/surrogates/deterministic.py new file mode 100644 index 000000000..a8519e959 --- /dev/null +++ b/bofire/surrogates/deterministic.py @@ -0,0 +1,25 @@ +import torch +from botorch.models.deterministic import AffineDeterministicModel + +from bofire.data_models.surrogates.api import LinearDeterministicSurrogate as DataModel +from bofire.surrogates.botorch import BotorchSurrogate +from bofire.utils.torch_tools import tkwargs + + +class LinearDeterministicSurrogate(BotorchSurrogate): + def __init__( + self, + data_model: DataModel, + **kwargs, + ): + self.intercept = data_model.intercept + self.coefficients = data_model.coefficients + super().__init__(data_model=data_model, **kwargs) + self.model = AffineDeterministicModel( + b=data_model.intercept, + a=torch.tensor( + [data_model.coefficients[key] for key in self.inputs.get_keys()] + ) + .to(**tkwargs) + .unsqueeze(-1), + ) diff --git a/bofire/surrogates/mapper.py b/bofire/surrogates/mapper.py index 3c1817bce..8c711e647 100644 --- a/bofire/surrogates/mapper.py +++ b/bofire/surrogates/mapper.py @@ -1,6 +1,7 @@ from typing import Dict, Type from bofire.data_models.surrogates import api as data_models +from bofire.surrogates.deterministic import LinearDeterministicSurrogate from bofire.surrogates.empirical import EmpiricalSurrogate from bofire.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate @@ -24,6 +25,7 @@ data_models.LinearSurrogate: SingleTaskGPSurrogate, data_models.PolynomialSurrogate: SingleTaskGPSurrogate, data_models.TanimotoGPSurrogate: SingleTaskGPSurrogate, + data_models.LinearDeterministicSurrogate: LinearDeterministicSurrogate, } diff --git a/tests/bofire/data_models/specs/surrogates.py b/tests/bofire/data_models/specs/surrogates.py index 5fb4bcbb4..d4bf765ee 100644 --- a/tests/bofire/data_models/specs/surrogates.py +++ b/tests/bofire/data_models/specs/surrogates.py @@ -394,3 +394,70 @@ "hyperconfig": None, }, ) + +specs.add_valid( + models.LinearDeterministicSurrogate, + lambda: { + "inputs": Inputs( + features=[ + ContinuousInput(key="a", bounds=(0, 1)), + ContinuousInput(key="b", bounds=(0, 1)), + ] + ).model_dump(), + "outputs": Outputs( + features=[ + features.valid(ContinuousOutput).obj(), + ] + ).model_dump(), + "intercept": 5.0, + "coefficients": {"a": 2.0, "b": -3.0}, + "input_preprocessing_specs": {}, + "dump": None, + }, +) + +specs.add_invalid( + models.LinearDeterministicSurrogate, + lambda: { + "inputs": Inputs( + features=[ + ContinuousInput(key="a", bounds=(0, 1)), + ContinuousInput(key="b", bounds=(0, 1)), + ] + ).model_dump(), + "outputs": Outputs( + features=[ + features.valid(ContinuousOutput).obj(), + ] + ).model_dump(), + "intercept": 5.0, + "coefficients": {"a": 2.0, "b": -3.0, "c": 5.0}, + "input_preprocessing_specs": {}, + "dump": None, + }, + error=ValueError, + message="coefficient keys do not match input feature keys.", +) + +specs.add_invalid( + models.LinearDeterministicSurrogate, + lambda: { + "inputs": Inputs( + features=[ + ContinuousInput(key="a", bounds=(0, 1)), + CategoricalInput(key="b", categories=["a", "b"]), + ] + ).model_dump(), + "outputs": Outputs( + features=[ + features.valid(ContinuousOutput).obj(), + ] + ).model_dump(), + "intercept": 5.0, + "coefficients": {"a": 2.0, "b": -3.0}, + "input_preprocessing_specs": {}, + "dump": None, + }, + error=ValueError, + message="Only numerical inputs are suppoerted for the `LinearDeterministicSurrogate`", +) diff --git a/tests/bofire/surrogates/test_deterministic.py b/tests/bofire/surrogates/test_deterministic.py new file mode 100644 index 000000000..22f48daf2 --- /dev/null +++ b/tests/bofire/surrogates/test_deterministic.py @@ -0,0 +1,26 @@ +import pandas as pd +from pandas.testing import assert_frame_equal + +import bofire.surrogates.api as surrogates +from bofire.data_models.domain.api import Inputs, Outputs +from bofire.data_models.features.api import ContinuousInput, ContinuousOutput +from bofire.data_models.surrogates.api import LinearDeterministicSurrogate + + +def test_linear_deterministic_surrogate(): + surrogate_data = LinearDeterministicSurrogate( + inputs=Inputs( + features=[ + ContinuousInput(key="a", bounds=(0, 1)), + ContinuousInput(key="b", bounds=(0, 1)), + ] + ), + outputs=Outputs(features=[ContinuousOutput(key="y")]), + intercept=2.0, + coefficients={"b": 3.0, "a": -2.0}, + ) + surrogate = surrogates.map(surrogate_data) + assert surrogate.input_preprocessing_specs == {} + experiments = pd.DataFrame(data={"a": [1.0, 2.0], "b": [0.5, 4.0]}) + preds = surrogate.predict(experiments) + assert_frame_equal(preds, pd.DataFrame(data={"y_pred": [1.5, 10.0], "y_sd": 0.0}))