Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bernoulli Target and Bernoulli Multi-Armed Bandit Surrogate #231

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6c2e550
Draft BernoulliMultiArmedBanditSurrogate class
julianStreibel May 3, 2024
c843fd7
Remove wrong/undesired default values
AdrianSosic May 6, 2024
93ba114
Use built-in validators for number of arms
AdrianSosic May 6, 2024
e876940
Do not expose counts via constructor
AdrianSosic May 6, 2024
1a9fe30
Fix default, conversion and validation of beta prior attribute
AdrianSosic May 6, 2024
5e45f52
Add factory for counts attribute
AdrianSosic May 6, 2024
b3e068a
Add missing type hints
AdrianSosic May 6, 2024
c20d286
Avoid importing torch eagerly
AdrianSosic May 6, 2024
30f49d7
Use built-in mean and var methods
AdrianSosic May 6, 2024
b9d2f0d
Rename variance method to variances
AdrianSosic May 6, 2024
ea5a146
Fix return value of _posterior method
AdrianSosic May 6, 2024
587c362
Fix input validation of _fit method
AdrianSosic May 6, 2024
685c27d
Automatically infer number of arms
AdrianSosic May 6, 2024
9d1d6f5
Rename attributes and methods
AdrianSosic May 6, 2024
d9b7df2
Adjust comment and docstring style
AdrianSosic May 6, 2024
cc40415
Ignore trained model for equality operator
AdrianSosic May 6, 2024
c721015
Store prior parameters as tuple
AdrianSosic May 6, 2024
1389b1d
Add TODO notes
AdrianSosic May 6, 2024
1b9ef19
Implemented the beta prior
julianStreibel May 7, 2024
0193e51
Renamed bernoulli target to BinaryTarget and checked target values in…
julianStreibel May 7, 2024
f64296f
reworked multi armed bandit example
julianStreibel May 7, 2024
328aa7f
Added pytests for the bernoulli multi arm bandit
julianStreibel May 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions baybe/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
validate_searchspace_from_config,
)
from baybe.serialization import SerialMixin, converter
from baybe.targets import BinaryTarget
from baybe.targets.base import Target
from baybe.telemetry import (
TELEM_LABELS,
Expand Down Expand Up @@ -217,6 +218,17 @@ def add_measurements(
f"The target '{target.name}' has non-numeric entries in the "
f"provided dataframe. Non-numeric target values are not supported."
)
if (
isinstance(target, BinaryTarget)
and not data[target.name].isin(BinaryTarget.accepted_values).all()
):
raise ValueError(
f"'{BinaryTarget.__name__}' only accepts "
f"{BinaryTarget.accepted_values} as target values."
)
# TODO: check targets falling into bounds for other targets.
# This should most likely be done in the recommender for
# standalone use.

# Check if all targets have valid values
for param in self.parameters:
Expand Down
6 changes: 5 additions & 1 deletion baybe/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class NoMCAcquisitionFunctionError(Exception):

class IncompatibleSearchSpaceError(Exception):
"""
A recommender is used with a search space that contains incompatible parts,
A BayBE component is used with a search space that contains incompatible parts,
e.g. a discrete recommender is used with a hybrid or continuous search space.
"""

Expand Down Expand Up @@ -50,3 +50,7 @@ class DeprecationError(Exception):

class UnidentifiedSubclassError(Exception):
"""A specified subclass cannot be found in the given class hierarchy."""


class NotFitError(Exception):
"""A surrogate is not fit but accessed in an unsafe way."""
2 changes: 2 additions & 0 deletions baybe/kernels/priors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Available priors."""

from baybe.kernels.priors.basic import (
BetaPrior,
GammaPrior,
HalfCauchyPrior,
HalfNormalPrior,
Expand All @@ -16,4 +17,5 @@
"LogNormalPrior",
"NormalPrior",
"SmoothedBoxPrior",
"BetaPrior",
]
29 changes: 29 additions & 0 deletions baybe/kernels/priors/basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Priors that can be used for kernels."""
from typing import Any

import numpy as np
from attrs import define, field
from attrs.validators import gt

Expand Down Expand Up @@ -84,3 +85,31 @@ def _validate_order(self, _: Any, b: float) -> None: # noqa: DOC101, DOC103
f"For {self.__class__.__name__}, the upper bound `b` (provided: {b}) "
f"must be larger than the lower bound `a` (provided: {self.a})."
)


@define(frozen=True)
class BetaPrior(Prior):
"""A beta prior parameterized by alpha and beta."""

alpha: float = field(converter=float)
"""Alpha of the beta distribution."""

beta: float = field(converter=float)
"""Beta of the beta distribution."""

@alpha.validator
@beta.validator
def _validate_parameter(self, attributte, value) -> None:
if value <= 0.0:
raise ValueError(
f"The value of '{attributte.name} must be strictly positive.'"
)

def to_gpytorch(self, *args, **kwargs): # noqa: D102
raise NotImplementedError(
f"The '{self.__class__.__name__}' does not have an analog in gpytorch"
)

def numpy(self) -> np.ndarray:
"""Return alpha and beta as a numpy ndarray."""
return np.array([self.alpha, self.beta]).reshape(-1, 1)
2 changes: 2 additions & 0 deletions baybe/surrogates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from baybe.surrogates.custom import _ONNX_INSTALLED, register_custom_architecture
from baybe.surrogates.gaussian_process import GaussianProcessSurrogate
from baybe.surrogates.linear import BayesianLinearSurrogate
from baybe.surrogates.multi_armed_bandit import BernoulliMultiArmedBanditSurrogate
from baybe.surrogates.naive import MeanPredictionSurrogate
from baybe.surrogates.ngboost import NGBoostSurrogate
from baybe.surrogates.random_forest import RandomForestSurrogate
Expand All @@ -14,6 +15,7 @@
"MeanPredictionSurrogate",
"NGBoostSurrogate",
"RandomForestSurrogate",
"BernoulliMultiArmedBanditSurrogate",
]

if _ONNX_INSTALLED:
Expand Down
90 changes: 90 additions & 0 deletions baybe/surrogates/multi_armed_bandit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Multi-armed bandit surrogate."""

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar, Optional

import numpy as np
from attrs import define, field
from scipy.stats import beta

from baybe.exceptions import IncompatibleSearchSpaceError, NotFitError
from baybe.kernels.priors import BetaPrior
from baybe.parameters import CategoricalParameter
from baybe.parameters.enum import CategoricalEncoding
from baybe.searchspace.core import SearchSpace
from baybe.surrogates.base import Surrogate

if TYPE_CHECKING:
from torch import Tensor


@define
class BernoulliMultiArmedBanditSurrogate(Surrogate):
"""A multi-armed bandit model with Bernoulli likelihood and beta prior."""

joint_posterior: ClassVar[bool] = False
# See base class.

supports_transfer_learning: ClassVar[bool] = False
# See base class.

prior: BetaPrior = field(default=BetaPrior(1, 1))
"""Beta prior parameters. By default, configured to produce a uniform prior."""

_win_lose_counts: Optional[np.ndarray[int]] = field(
init=False, default=None, eq=False
)
"""Sufficient statistics of the trained model (i.e., win and lose counts)."""

@property
def _posterior_beta_parameters(self) -> np.ndarray[float]:
"""The parameters of the posterior beta distribution."""
if self._win_lose_counts is None:
raise NotFitError(
f"'{self.__class__.__name__}' must be "
"fitted to access likelihood information"
)
# TODO: this could be removed when the number of arms could be inferred
return self._win_lose_counts + self.prior.numpy()

@property
def means(self) -> np.ndarray[float]:
"""Posterior means of the bandit arms."""
return beta(*self._posterior_beta_parameters).mean()

@property
def variances(self) -> np.ndarray[float]:
"""Posterior variances of the bandit arms."""
return beta(*self._posterior_beta_parameters).var()

def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
# See base class.

import torch

candidate_arms = candidates.argmax(dim=-1)
posterior_mean = self.means[candidate_arms]
posterior_variance = self.variances[candidate_arms]
return torch.tensor(posterior_mean), torch.tensor(posterior_variance)

def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> None:
# See base class.

# TODO: Fix requirement of OHE encoding
# TODO: Generalize to arbitrary number of categorical parameters
if not (
(len(searchspace.parameters) == 1)
and isinstance(p := searchspace.parameters[0], CategoricalParameter)
and p.encoding is CategoricalEncoding.OHE
):
raise IncompatibleSearchSpaceError(
f"'{self.__class__.__name__}' currently only supports search spaces "
f"spanned by exactly one categorical parameter using one-hot encoding."
)

# TODO: Incorporate training target validation at the appropriate place in
# the BayBE ecosystem.
wins = (train_x * train_y).sum(axis=0)
losses = (train_x * (1 - train_y)).sum(axis=0)
self._win_lose_counts = np.vstack([wins, losses])
2 changes: 2 additions & 0 deletions baybe/targets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""BayBE targets."""

from baybe.targets.binary import BinaryTarget
from baybe.targets.deprecation import Objective
from baybe.targets.enum import TargetMode, TargetTransformation
from baybe.targets.numerical import NumericalTarget
Expand All @@ -9,4 +10,5 @@
"Objective",
"TargetMode",
"TargetTransformation",
"BinaryTarget",
]
27 changes: 27 additions & 0 deletions baybe/targets/binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Binary target."""

import pandas as pd
from attrs import define

from baybe.serialization import SerialMixin
from baybe.targets.base import Target


@define(frozen=True)
class BinaryTarget(Target, SerialMixin):
"""Class for bernoulli targets."""

accepted_values = [0, 1]

def transform(self, data: pd.DataFrame) -> pd.DataFrame: # noqa: D102
# see base class
assert data.shape[1] == 1
# TODO: negation (1 - data) for min mode?!
return data

def summary(self) -> dict: # noqa: D102
# see base class
return dict(
Type=self.__class__.__name__,
Name=self.name,
)
8 changes: 8 additions & 0 deletions baybe/utils/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pandas as pd

from baybe.parameters.base import ContinuousParameter, DiscreteParameter
from baybe.targets import BinaryTarget
from baybe.targets.enum import TargetMode
from baybe.utils.numerical import DTypeFloatNumpy

Expand Down Expand Up @@ -129,6 +130,8 @@ def add_fake_results(
if good_intervals is None:
good_intervals = {}
for target in campaign.targets:
if isinstance(target, BinaryTarget):
continue
if target.mode is TargetMode.MAX:
lbound = target.bounds.lower if np.isfinite(target.bounds.lower) else 66
ubound = (
Expand Down Expand Up @@ -158,6 +161,8 @@ def add_fake_results(
if bad_intervals is None:
bad_intervals = {}
for target in campaign.targets:
if isinstance(target, BinaryTarget):
continue
if target.mode is TargetMode.MAX:
lbound = target.bounds.lower if np.isfinite(target.bounds.lower) else 0
ubound = target.bounds.upper if np.isfinite(target.bounds.upper) else 33
Expand Down Expand Up @@ -186,6 +191,9 @@ def add_fake_results(

# Add the fake data for each target
for target in campaign.targets:
if isinstance(target, BinaryTarget):
data[target.name] = np.random.choice([0, 1])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use target.accepted_values

continue
# Add bad values
data[target.name] = np.random.uniform(
bad_intervals[target.name][0], bad_intervals[target.name][1], len(data)
Expand Down
4 changes: 4 additions & 0 deletions examples/Multi_Armed_Bandit/Multi_Armed_Bandit_Header.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Multi-Armed Bandit

These examples demonstrate BayBE's
{doc}`Multi-Armed Bandit Capabilities </userguide/multi_armed_bandit>`.
1 change: 1 addition & 0 deletions examples/Multi_Armed_Bandit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Multi-armed bandit examples."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@julianStreibel @AdrianSosic just sharing the following ICML paper here with you in case this is interesting for this or a future PR in this direction: https://proceedings.mlr.press/v235/jun24a.html

108 changes: 108 additions & 0 deletions examples/Multi_Armed_Bandit/find_maximizing_arm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
## Example for a Multi Armed Bandit

# This example shows how to use the bernoulli multi armed bandit surrogate.

from collections.abc import Iterable
from typing import Union

import numpy as np
from attrs import define
from scipy.stats import bernoulli, rv_continuous, rv_discrete

from baybe import Campaign
from baybe.acquisition import ProbabilityOfImprovement
from baybe.objectives import SingleTargetObjective
from baybe.parameters import CategoricalParameter
from baybe.recommenders import (
FPSRecommender,
SequentialGreedyRecommender,
TwoPhaseMetaRecommender,
)
from baybe.searchspace import SearchSpace
from baybe.surrogates import BernoulliMultiArmedBanditSurrogate
from baybe.targets import BinaryTarget

### Setup

# We are using a 5-armed bandit in this example. The bandit has a random win rate for now.

N_ARMS = 5
N_ITERATIONS = 300
np.random.seed(0)


@define
class MultiArmedBanditModel:
"""Representation of a multi armed bandit."""

real_distributions: list[Union[rv_discrete, rv_continuous]]
"""List of the reward distribution per arm."""

def sample(self, arm_idxs: Iterable[int]):
"""Draw reward samples from the arms indexed in arm_idxs."""
return [self.real_distributions[arm_idx].rvs() for arm_idx in arm_idxs]

@property
def means(self):
"""Return the real means of the reward distributions."""
return [dist.stats(moments="m") for dist in self.real_distributions]


mab = MultiArmedBanditModel(
real_distributions=[bernoulli(np.random.rand()) for _ in range(N_ARMS)]
)
print("real means", mab.means)


### Campaign

# We are using the BinaryTarget as we are modeling a bernoulli reward.
# The searchspace has one categorical parameter to model the arms of the bandit.
# The probability of improvement acquisition function is not perfect in this setting
# as it assumes a normal distribution of the win rate.

target = BinaryTarget(name="win_rate")
objective = SingleTargetObjective(target=target)
parameters = [
CategoricalParameter(
name="arm",
values=[str(i) for i in range(N_ARMS)],
)
]
searchspace = SearchSpace.from_product(parameters)
mabs = BernoulliMultiArmedBanditSurrogate()
recommender = TwoPhaseMetaRecommender(
initial_recommender=FPSRecommender(
allow_repeated_recommendations=True,
allow_recommending_already_measured=True,
),
recommender=SequentialGreedyRecommender(
surrogate_model=mabs,
allow_repeated_recommendations=True,
allow_recommending_already_measured=True,
acquisition_function=ProbabilityOfImprovement(),
),
)
campaign = Campaign(searchspace, objective, recommender)


### Optimization Loop

total_reward = 0
for i in range(N_ITERATIONS):
df = campaign.recommend(batch_size=1)
reward = mab.sample(df.index.tolist())
total_reward += sum(reward)
df["win_rate"] = reward
campaign.add_measurements(df)

if (i + 1) % 50 == 0:
print("iter", i + 1)
print("estimated means", mabs.means)
print("-" * 5)

real_means = mab.means
print("real means", real_means)
print("optimal expected reward", max(real_means) * N_ITERATIONS)
print("total_reward", total_reward)
print("mean reward", total_reward / N_ITERATIONS)
Loading