Skip to content

Commit

Permalink
Added pytests for the bernoulli multi arm bandit
Browse files Browse the repository at this point in the history
  • Loading branch information
julianStreibel committed May 7, 2024
1 parent f64296f commit 328aa7f
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 5 deletions.
8 changes: 8 additions & 0 deletions baybe/utils/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pandas as pd

from baybe.parameters.base import ContinuousParameter, DiscreteParameter
from baybe.targets import BinaryTarget
from baybe.targets.enum import TargetMode
from baybe.utils.numerical import DTypeFloatNumpy

Expand Down Expand Up @@ -129,6 +130,8 @@ def add_fake_results(
if good_intervals is None:
good_intervals = {}
for target in campaign.targets:
if isinstance(target, BinaryTarget):
continue
if target.mode is TargetMode.MAX:
lbound = target.bounds.lower if np.isfinite(target.bounds.lower) else 66
ubound = (
Expand Down Expand Up @@ -158,6 +161,8 @@ def add_fake_results(
if bad_intervals is None:
bad_intervals = {}
for target in campaign.targets:
if isinstance(target, BinaryTarget):
continue
if target.mode is TargetMode.MAX:
lbound = target.bounds.lower if np.isfinite(target.bounds.lower) else 0
ubound = target.bounds.upper if np.isfinite(target.bounds.upper) else 33
Expand Down Expand Up @@ -186,6 +191,9 @@ def add_fake_results(

# Add the fake data for each target
for target in campaign.targets:
if isinstance(target, BinaryTarget):
data[target.name] = np.random.choice([0, 1])
continue
# Add bad values
data[target.name] = np.random.uniform(
bad_intervals[target.name][0], bad_intervals[target.name][1], len(data)
Expand Down
6 changes: 5 additions & 1 deletion examples/Multi_Armed_Bandit/find_maximizing_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from baybe.targets import BinaryTarget

### Setup

# We are using a 5-armed bandit in this example. The bandit has a random win rate for now.

N_ARMS = 5
N_ITERATIONS = 300
np.random.seed(0)
Expand Down Expand Up @@ -53,7 +55,8 @@ def means(self):


### Campaign
# We are using the BinaryTarget as we are modeling a brnoulli reward.

# We are using the BinaryTarget as we are modeling a bernoulli reward.
# The searchspace has one categorical parameter to model the arms of the bandit.
# The probability of improvement acquisition function is not perfect in this setting
# as it assumes a normal distribution of the win rate.
Expand Down Expand Up @@ -84,6 +87,7 @@ def means(self):


### Optimization Loop

total_reward = 0
for i in range(N_ITERATIONS):
df = campaign.recommend(batch_size=1)
Expand Down
13 changes: 11 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from baybe.recommenders.pure.nonpredictive.sampling import RandomRecommender
from baybe.searchspace import SearchSpace
from baybe.surrogates import _ONNX_INSTALLED, GaussianProcessSurrogate
from baybe.targets import NumericalTarget
from baybe.targets import BinaryTarget, NumericalTarget
from baybe.telemetry import (
VARNAME_TELEMETRY_ENABLED,
VARNAME_TELEMETRY_HOSTNAME,
Expand Down Expand Up @@ -390,6 +390,7 @@ def fixture_targets(target_names: list[str]):
bounds=(0, 100),
transformation="TRIANGULAR",
),
BinaryTarget(name="Target_binary"),
]
return [t for t in valid_targets if t.name in target_names]

Expand Down Expand Up @@ -628,14 +629,22 @@ def fixture_initial_recommender():
return RandomRecommender()


@pytest.fixture(name="allow_repeated_recommendations")
def fixture_allow_repeated_recommendations():
return False


@pytest.fixture(name="recommender")
def fixture_recommender(initial_recommender, surrogate_model, acqf):
def fixture_recommender(
initial_recommender, surrogate_model, acqf, allow_repeated_recommendations
):
"""The default recommender to be used if not specified differently."""
return TwoPhaseMetaRecommender(
initial_recommender=initial_recommender,
recommender=SequentialGreedyRecommender(
surrogate_model=surrogate_model,
acquisition_function=acqf,
allow_repeated_recommendations=allow_repeated_recommendations,
),
)

Expand Down
9 changes: 9 additions & 0 deletions tests/hypothesis_strategies/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hypothesis.strategies as st

from baybe.kernels.priors import (
BetaPrior,
GammaPrior,
HalfCauchyPrior,
HalfNormalPrior,
Expand Down Expand Up @@ -47,6 +48,13 @@
)
"""A strategy that generates Log-Normal priors."""

beta_priors = st.builds(
BetaPrior,
st.floats(min_value=0.0, exclude_min=True),
st.floats(min_value=0.0, exclude_min=True),
)
"""A strategy that generates Beta priors"""


@st.composite
def _smoothed_box_priors(draw: st.DrawFn):
Expand All @@ -69,6 +77,7 @@ def _smoothed_box_priors(draw: st.DrawFn):
log_normal_priors,
normal_priors,
smoothed_box_priors,
beta_priors,
]
)
"""A strategy that generates priors."""
5 changes: 4 additions & 1 deletion tests/hypothesis_strategies/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import hypothesis.strategies as st

from baybe.targets import BinaryTarget
from baybe.targets.enum import TargetMode
from baybe.targets.numerical import _VALID_TRANSFORMATIONS, NumericalTarget
from baybe.utils.interval import Interval
Expand Down Expand Up @@ -41,5 +42,7 @@ def numerical_targets(
)


targets = numerical_targets()
binary_targets = st.builds(BinaryTarget)

targets = st.one_of([binary_targets, numerical_targets()])
"""A strategy that generates targets."""
18 changes: 17 additions & 1 deletion tests/test_iterations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from baybe.recommenders.pure.nonpredictive.base import NonPredictiveRecommender
from baybe.searchspace import SearchSpaceType
from baybe.surrogates import BernoulliMultiArmedBanditSurrogate
from baybe.surrogates.base import Surrogate
from baybe.utils.basic import get_subclasses

Expand All @@ -32,7 +33,10 @@
# Settings of the individual components to be tested
########################################################################################
valid_surrogate_models = [
cls() for cls in get_subclasses(Surrogate) if cls.__name__ != "CustomONNXSurrogate"
cls()
for cls in get_subclasses(Surrogate)
if cls.__name__
not in ["CustomONNXSurrogate", BernoulliMultiArmedBanditSurrogate.__name__]
]
valid_initial_recommenders = [cls() for cls in get_subclasses(NonPredictiveRecommender)]
# TODO the TwoPhaseMetaRecommender below can be removed if the SeqGreedy recommender
Expand Down Expand Up @@ -219,3 +223,15 @@ def test_iter_recommender_hybrid(campaign, n_iterations, batch_size):
@pytest.mark.parametrize("recommender", valid_meta_recommenders, indirect=True)
def test_meta_recommenders(campaign, n_iterations, batch_size):
run_iterations(campaign, n_iterations, batch_size)


@pytest.mark.parametrize("surrogate_model", [BernoulliMultiArmedBanditSurrogate()])
@pytest.mark.parametrize(
"parameter_names",
[["Categorical_1"], ["Categorical_2"], ["Switch_1"], ["Switch_2"]],
)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("target_names", [["Target_binary"]])
@pytest.mark.parametrize("allow_repeated_recommendations", [True])
def test_multi_arm_bandit(campaign, n_iterations, batch_size):
run_iterations(campaign, n_iterations, batch_size, add_noise=False)

0 comments on commit 328aa7f

Please sign in to comment.