-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bernoulli Target and Bernoulli Multi-Armed Bandit Surrogate #231
Closed
julianStreibel
wants to merge
22
commits into
emdgroup:main
from
julianStreibel:feature/bernoulli_target
Closed
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
6c2e550
Draft BernoulliMultiArmedBanditSurrogate class
julianStreibel c843fd7
Remove wrong/undesired default values
AdrianSosic 93ba114
Use built-in validators for number of arms
AdrianSosic e876940
Do not expose counts via constructor
AdrianSosic 1a9fe30
Fix default, conversion and validation of beta prior attribute
AdrianSosic 5e45f52
Add factory for counts attribute
AdrianSosic b3e068a
Add missing type hints
AdrianSosic c20d286
Avoid importing torch eagerly
AdrianSosic 30f49d7
Use built-in mean and var methods
AdrianSosic b9d2f0d
Rename variance method to variances
AdrianSosic ea5a146
Fix return value of _posterior method
AdrianSosic 587c362
Fix input validation of _fit method
AdrianSosic 685c27d
Automatically infer number of arms
AdrianSosic 9d1d6f5
Rename attributes and methods
AdrianSosic d9b7df2
Adjust comment and docstring style
AdrianSosic cc40415
Ignore trained model for equality operator
AdrianSosic c721015
Store prior parameters as tuple
AdrianSosic 1389b1d
Add TODO notes
AdrianSosic 1b9ef19
Implemented the beta prior
julianStreibel 0193e51
Renamed bernoulli target to BinaryTarget and checked target values in…
julianStreibel f64296f
reworked multi armed bandit example
julianStreibel 328aa7f
Added pytests for the bernoulli multi arm bandit
julianStreibel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
"""Multi-armed bandit surrogate.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, ClassVar, Optional | ||
|
||
import numpy as np | ||
from attrs import define, field | ||
from scipy.stats import beta | ||
|
||
from baybe.exceptions import IncompatibleSearchSpaceError, NotFitError | ||
from baybe.kernels.priors import BetaPrior | ||
from baybe.parameters import CategoricalParameter | ||
from baybe.parameters.enum import CategoricalEncoding | ||
from baybe.searchspace.core import SearchSpace | ||
from baybe.surrogates.base import Surrogate | ||
|
||
if TYPE_CHECKING: | ||
from torch import Tensor | ||
|
||
|
||
@define | ||
class BernoulliMultiArmedBanditSurrogate(Surrogate): | ||
"""A multi-armed bandit model with Bernoulli likelihood and beta prior.""" | ||
|
||
joint_posterior: ClassVar[bool] = False | ||
# See base class. | ||
|
||
supports_transfer_learning: ClassVar[bool] = False | ||
# See base class. | ||
|
||
prior: BetaPrior = field(default=BetaPrior(1, 1)) | ||
"""Beta prior parameters. By default, configured to produce a uniform prior.""" | ||
|
||
_win_lose_counts: Optional[np.ndarray[int]] = field( | ||
init=False, default=None, eq=False | ||
) | ||
"""Sufficient statistics of the trained model (i.e., win and lose counts).""" | ||
|
||
@property | ||
def _posterior_beta_parameters(self) -> np.ndarray[float]: | ||
"""The parameters of the posterior beta distribution.""" | ||
if self._win_lose_counts is None: | ||
raise NotFitError( | ||
f"'{self.__class__.__name__}' must be " | ||
"fitted to access likelihood information" | ||
) | ||
# TODO: this could be removed when the number of arms could be inferred | ||
return self._win_lose_counts + self.prior.numpy() | ||
|
||
@property | ||
def means(self) -> np.ndarray[float]: | ||
"""Posterior means of the bandit arms.""" | ||
return beta(*self._posterior_beta_parameters).mean() | ||
|
||
@property | ||
def variances(self) -> np.ndarray[float]: | ||
"""Posterior variances of the bandit arms.""" | ||
return beta(*self._posterior_beta_parameters).var() | ||
|
||
def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: | ||
# See base class. | ||
|
||
import torch | ||
|
||
candidate_arms = candidates.argmax(dim=-1) | ||
posterior_mean = self.means[candidate_arms] | ||
posterior_variance = self.variances[candidate_arms] | ||
return torch.tensor(posterior_mean), torch.tensor(posterior_variance) | ||
|
||
def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> None: | ||
# See base class. | ||
|
||
# TODO: Fix requirement of OHE encoding | ||
# TODO: Generalize to arbitrary number of categorical parameters | ||
if not ( | ||
(len(searchspace.parameters) == 1) | ||
and isinstance(p := searchspace.parameters[0], CategoricalParameter) | ||
and p.encoding is CategoricalEncoding.OHE | ||
): | ||
raise IncompatibleSearchSpaceError( | ||
f"'{self.__class__.__name__}' currently only supports search spaces " | ||
f"spanned by exactly one categorical parameter using one-hot encoding." | ||
) | ||
|
||
# TODO: Incorporate training target validation at the appropriate place in | ||
# the BayBE ecosystem. | ||
wins = (train_x * train_y).sum(axis=0) | ||
losses = (train_x * (1 - train_y)).sum(axis=0) | ||
self._win_lose_counts = np.vstack([wins, losses]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
"""Binary target.""" | ||
|
||
import pandas as pd | ||
from attrs import define | ||
|
||
from baybe.serialization import SerialMixin | ||
from baybe.targets.base import Target | ||
|
||
|
||
@define(frozen=True) | ||
class BinaryTarget(Target, SerialMixin): | ||
"""Class for bernoulli targets.""" | ||
|
||
accepted_values = [0, 1] | ||
|
||
def transform(self, data: pd.DataFrame) -> pd.DataFrame: # noqa: D102 | ||
# see base class | ||
assert data.shape[1] == 1 | ||
# TODO: negation (1 - data) for min mode?! | ||
return data | ||
|
||
def summary(self) -> dict: # noqa: D102 | ||
# see base class | ||
return dict( | ||
Type=self.__class__.__name__, | ||
Name=self.name, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Multi-Armed Bandit | ||
|
||
These examples demonstrate BayBE's | ||
{doc}`Multi-Armed Bandit Capabilities </userguide/multi_armed_bandit>`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Multi-armed bandit examples.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @julianStreibel @AdrianSosic just sharing the following ICML paper here with you in case this is interesting for this or a future PR in this direction: https://proceedings.mlr.press/v235/jun24a.html |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
## Example for a Multi Armed Bandit | ||
|
||
# This example shows how to use the bernoulli multi armed bandit surrogate. | ||
|
||
from collections.abc import Iterable | ||
from typing import Union | ||
|
||
import numpy as np | ||
from attrs import define | ||
from scipy.stats import bernoulli, rv_continuous, rv_discrete | ||
|
||
from baybe import Campaign | ||
from baybe.acquisition import ProbabilityOfImprovement | ||
from baybe.objectives import SingleTargetObjective | ||
from baybe.parameters import CategoricalParameter | ||
from baybe.recommenders import ( | ||
FPSRecommender, | ||
SequentialGreedyRecommender, | ||
TwoPhaseMetaRecommender, | ||
) | ||
from baybe.searchspace import SearchSpace | ||
from baybe.surrogates import BernoulliMultiArmedBanditSurrogate | ||
from baybe.targets import BinaryTarget | ||
|
||
### Setup | ||
|
||
# We are using a 5-armed bandit in this example. The bandit has a random win rate for now. | ||
|
||
N_ARMS = 5 | ||
N_ITERATIONS = 300 | ||
np.random.seed(0) | ||
|
||
|
||
@define | ||
class MultiArmedBanditModel: | ||
"""Representation of a multi armed bandit.""" | ||
|
||
real_distributions: list[Union[rv_discrete, rv_continuous]] | ||
"""List of the reward distribution per arm.""" | ||
|
||
def sample(self, arm_idxs: Iterable[int]): | ||
"""Draw reward samples from the arms indexed in arm_idxs.""" | ||
return [self.real_distributions[arm_idx].rvs() for arm_idx in arm_idxs] | ||
|
||
@property | ||
def means(self): | ||
"""Return the real means of the reward distributions.""" | ||
return [dist.stats(moments="m") for dist in self.real_distributions] | ||
|
||
|
||
mab = MultiArmedBanditModel( | ||
real_distributions=[bernoulli(np.random.rand()) for _ in range(N_ARMS)] | ||
) | ||
print("real means", mab.means) | ||
|
||
|
||
### Campaign | ||
|
||
# We are using the BinaryTarget as we are modeling a bernoulli reward. | ||
# The searchspace has one categorical parameter to model the arms of the bandit. | ||
# The probability of improvement acquisition function is not perfect in this setting | ||
# as it assumes a normal distribution of the win rate. | ||
|
||
target = BinaryTarget(name="win_rate") | ||
objective = SingleTargetObjective(target=target) | ||
parameters = [ | ||
CategoricalParameter( | ||
name="arm", | ||
values=[str(i) for i in range(N_ARMS)], | ||
) | ||
] | ||
searchspace = SearchSpace.from_product(parameters) | ||
mabs = BernoulliMultiArmedBanditSurrogate() | ||
recommender = TwoPhaseMetaRecommender( | ||
initial_recommender=FPSRecommender( | ||
allow_repeated_recommendations=True, | ||
allow_recommending_already_measured=True, | ||
), | ||
recommender=SequentialGreedyRecommender( | ||
surrogate_model=mabs, | ||
allow_repeated_recommendations=True, | ||
allow_recommending_already_measured=True, | ||
acquisition_function=ProbabilityOfImprovement(), | ||
), | ||
) | ||
campaign = Campaign(searchspace, objective, recommender) | ||
|
||
|
||
### Optimization Loop | ||
|
||
total_reward = 0 | ||
for i in range(N_ITERATIONS): | ||
df = campaign.recommend(batch_size=1) | ||
reward = mab.sample(df.index.tolist()) | ||
total_reward += sum(reward) | ||
df["win_rate"] = reward | ||
campaign.add_measurements(df) | ||
|
||
if (i + 1) % 50 == 0: | ||
print("iter", i + 1) | ||
print("estimated means", mabs.means) | ||
print("-" * 5) | ||
|
||
real_means = mab.means | ||
print("real means", real_means) | ||
print("optimal expected reward", max(real_means) * N_ITERATIONS) | ||
print("total_reward", total_reward) | ||
print("mean reward", total_reward / N_ITERATIONS) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use target.accepted_values