Skip to content

Commit

Permalink
Merge: Expose surrogate (#355)
Browse files Browse the repository at this point in the history
This PR enables convenient access to the surrogate model and posterior
predictive distribution via the `Campaign` class.
  • Loading branch information
Scienfitz authored Sep 9, 2024
2 parents db20318 + 46441b9 commit 9bda168
Show file tree
Hide file tree
Showing 14 changed files with 298 additions and 39 deletions.
1 change: 1 addition & 0 deletions .lockfiles/py310-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ jinja2==3.1.4
# torch
joblib==1.4.2
# via
# baybe (pyproject.toml)
# scikit-learn
# xyzpy
json5==0.9.25
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Pure recommenders now have the `allow_recommending_pending_experiments` flag,
controlling whether pending experiments are excluded from candidates in purely
discrete search spaces
- `get_surrogate` and `posterior` methods to `Campaign`

### Changed
- The transition from experimental to computational representation no longer happens
Expand Down Expand Up @@ -62,6 +63,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Deprecations
- The role of `register_custom_architecture` has been taken over by
`baybe.surrogates.base.SurrogateProtocol`
- `BayesianRecommender.surrogate_model` has been replaced with `get_surrogate`

## [0.10.0] - 2024-08-02
### Breaking Changes
Expand Down
84 changes: 84 additions & 0 deletions baybe/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING

import cattrs
import numpy as np
Expand All @@ -11,18 +12,24 @@
from attrs.converters import optional
from attrs.validators import instance_of

from baybe.exceptions import IncompatibilityError
from baybe.objectives.base import Objective, to_objective
from baybe.objectives.single import SingleTargetObjective
from baybe.parameters.base import Parameter
from baybe.recommenders.base import RecommenderProtocol
from baybe.recommenders.meta.base import MetaRecommender
from baybe.recommenders.meta.sequential import TwoPhaseMetaRecommender
from baybe.recommenders.pure.bayesian.base import BayesianRecommender
from baybe.searchspace.core import (
SearchSpace,
SearchSpaceType,
to_searchspace,
validate_searchspace_from_config,
)
from baybe.serialization import SerialMixin, converter
from baybe.surrogates.base import SurrogateProtocol
from baybe.targets.base import Target
from baybe.targets.numerical import NumericalTarget
from baybe.telemetry import (
TELEM_LABELS,
telemetry_record_recommended_measurement_percentage,
Expand All @@ -31,6 +38,9 @@
from baybe.utils.boolean import eq_dataframe
from baybe.utils.plotting import to_string

if TYPE_CHECKING:
from botorch.posteriors import Posterior


@define
class Campaign(SerialMixin):
Expand Down Expand Up @@ -269,6 +279,80 @@ def recommend(

return rec

def posterior(self, candidates: pd.DataFrame) -> Posterior:
"""Get the posterior predictive distribution for the given candidates.
The predictive distribution is based on the surrogate model of the last used
recommender.
Args:
candidates: The candidate points in experimental recommendations.
For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`.
Raises:
IncompatibilityError: If the underlying surrogate model exposes no
method for computing the posterior distribution.
Returns:
Posterior: The corresponding posterior object.
For details, see :meth:`baybe.surrogates.base.Surrogate.posterior`.
"""
surrogate = self.get_surrogate()
if not hasattr(surrogate, method_name := "posterior"):
raise IncompatibilityError(
f"The used surrogate type '{surrogate.__class__.__name__}' does not "
f"provide a '{method_name}' method."
)

import torch

with torch.no_grad():
return surrogate.posterior(candidates)

def get_surrogate(self) -> SurrogateProtocol:
"""Get the current surrogate model.
Raises:
RuntimeError: If the current recommender does not provide a surrogate model.
Returns:
Surrogate: The surrogate of the current recommender.
"""
# TODO: remove temporary restriction when target transformations can be handled
match self.objective:
case SingleTargetObjective(
_target=NumericalTarget(bounds=b)
) if not b.is_bounded:
pass
case _:
raise NotImplementedError(
"Surrogate model access is currently only supported for a single "
"untransformed target."
)

if self.objective is None:
raise IncompatibilityError(
f"No surrogate is available since no '{Objective.__name__}' is defined."
)

pure_recommender: RecommenderProtocol
if isinstance(self.recommender, MetaRecommender):
pure_recommender = self.recommender.get_current_recommender()
else:
pure_recommender = self.recommender

if isinstance(pure_recommender, BayesianRecommender):
return pure_recommender.get_surrogate(
self.searchspace, self.objective, self.measurements
)
else:
raise RuntimeError(
f"The current recommender is of type "
f"'{pure_recommender.__class__.__name__}', which does not provide "
f"a surrogate model. Surrogate models are only available for "
f"recommender subclasses of '{BayesianRecommender.__name__}'."
)


def _add_version(dict_: dict) -> dict:
"""Add the package version to the given dictionary."""
Expand Down
29 changes: 27 additions & 2 deletions baybe/recommenders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import cattrs
import pandas as pd
from cattrs import override

from baybe.objectives.base import Objective
from baybe.searchspace import SearchSpace
Expand All @@ -15,6 +16,10 @@
class RecommenderProtocol(Protocol):
"""Type protocol specifying the interface recommenders need to implement."""

# Use slots so that derived classes also remain slotted
# See also: https://www.attrs.org/en/stable/glossary.html#term-slotted-classes
__slots__ = ()

def recommend(
self,
batch_size: int,
Expand Down Expand Up @@ -47,15 +52,35 @@ def recommend(
...


# TODO: The workarounds below are currently required since the hooks created through
# `unstructure_base` and `get_base_structure_hook` do not reuse the hooks of the
# actual class, hence we cannot control things there. Fix is already planned and also
# needed for other reasons.

# Register (un-)structure hooks
converter.register_unstructure_hook(
RecommenderProtocol,
lambda x: unstructure_base(
x,
# TODO: Remove once deprecation got expired:
overrides=dict(acquisition_function_cls=cattrs.override(omit=True)),
overrides=dict(
acquisition_function_cls=cattrs.override(omit=True),
# Temporary workaround (see TODO note above)
_surrogate_model=override(rename="surrogate_model"),
_current_recommender=override(omit=False),
_used_recommender_ids=override(omit=False),
),
),
)
converter.register_structure_hook(
RecommenderProtocol, get_base_structure_hook(RecommenderProtocol)
RecommenderProtocol,
get_base_structure_hook(
RecommenderProtocol,
# Temporary workaround (see TODO note above)
overrides=dict(
_surrogate_model=override(rename="surrogate_model"),
_current_recommender=override(omit=False),
_used_recommender_ids=override(omit=False),
),
),
)
83 changes: 65 additions & 18 deletions baybe/recommenders/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import cattrs
import pandas as pd
from attrs import define
from attrs import define, field

from baybe.objectives.base import Objective
from baybe.recommenders.base import RecommenderProtocol
Expand All @@ -20,6 +20,12 @@
class MetaRecommender(SerialMixin, RecommenderProtocol, ABC):
"""Abstract base class for all meta recommenders."""

_current_recommender: PureRecommender | None = field(default=None, init=False)
"""The current recommender."""

_used_recommender_ids: set[int] = field(factory=set, init=False)
"""Set of ids from recommenders that were used by this meta recommender."""

@abstractmethod
def select_recommender(
self,
Expand All @@ -31,22 +37,60 @@ def select_recommender(
) -> PureRecommender:
"""Select a pure recommender for the given experimentation context.
Args:
batch_size:
See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`.
searchspace:
See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`.
objective:
See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`.
measurements:
See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`.
pending_experiments:
See :func:`baybe.recommenders.meta.base.MetaRecommender.recommend`.
Returns:
The selected recommender.
See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend` for details
on the method arguments.
"""

def get_current_recommender(self) -> PureRecommender:
"""Get the current recommender, if available."""
if self._current_recommender is None:
raise RuntimeError(
f"No recommendation has been requested from the "
f"'{self.__class__.__name__}' yet. Because the recommender is a "
f"'{MetaRecommender.__name__}', this means no actual recommender has "
f"been selected so far. The recommender will be available after the "
f"next '{self.recommend.__name__}' call."
)
return self._current_recommender

def get_next_recommender(
self,
batch_size: int,
searchspace: SearchSpace,
objective: Objective | None = None,
measurements: pd.DataFrame | None = None,
pending_experiments: pd.DataFrame | None = None,
) -> PureRecommender:
"""Get the recommender for the next recommendation.
Returns the next recommender in row that has not yet been used for generating
recommendations. In case of multiple consecutive calls, this means that
the same recommender instance is returned until its :meth:`recommend` method
is called.
See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend` for details
on the method arguments.
"""
# Check if the stored recommender instance can be returned
if (
self._current_recommender is not None
and id(self._current_recommender) not in self._used_recommender_ids
):
recommender = self._current_recommender

# Otherwise, fetch the next recommender waiting in row
else:
recommender = self.select_recommender(
batch_size=batch_size,
searchspace=searchspace,
objective=objective,
measurements=measurements,
pending_experiments=pending_experiments,
)
self._current_recommender = recommender

return recommender

def recommend(
self,
batch_size: int,
Expand All @@ -55,8 +99,8 @@ def recommend(
measurements: pd.DataFrame | None = None,
pending_experiments: pd.DataFrame | None = None,
) -> pd.DataFrame:
"""See :func:`baybe.recommenders.base.RecommenderProtocol.recommend`."""
recommender = self.select_recommender(
"""See :meth:`baybe.recommenders.base.RecommenderProtocol.recommend`."""
recommender = self.get_next_recommender(
batch_size=batch_size,
searchspace=searchspace,
objective=objective,
Expand All @@ -76,12 +120,15 @@ def recommend(
}
)

return recommender.recommend(
recommendations = recommender.recommend(
batch_size=batch_size,
searchspace=searchspace,
pending_experiments=pending_experiments,
**optional_args,
)
self._used_recommender_ids.add(id(recommender))

return recommendations


# Register (un-)structure hooks
Expand Down
5 changes: 4 additions & 1 deletion baybe/recommenders/pure/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from baybe.searchspace.discrete import SubspaceDiscrete


@define
# TODO: Slots are currently disabled since they also block the monkeypatching necessary
# to use `register_hooks`. Probably, we need to update our documentation and
# explain how to work around that before we re-enable slots.
@define(slots=False)
class PureRecommender(ABC, RecommenderProtocol):
"""Abstract base class for all pure recommenders."""

Expand Down
Loading

0 comments on commit 9bda168

Please sign in to comment.