From 328aa7f9b2e6f9cad555e66ebe7ecbfb6dacba67 Mon Sep 17 00:00:00 2001 From: Julian Streibel Date: Tue, 7 May 2024 18:21:16 +0200 Subject: [PATCH] Added pytests for the bernoulli multi arm bandit --- baybe/utils/dataframe.py | 8 ++++++++ .../Multi_Armed_Bandit/find_maximizing_arm.py | 6 +++++- tests/conftest.py | 13 +++++++++++-- tests/hypothesis_strategies/priors.py | 9 +++++++++ tests/hypothesis_strategies/targets.py | 5 ++++- tests/test_iterations.py | 18 +++++++++++++++++- 6 files changed, 54 insertions(+), 5 deletions(-) diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 1c492af04..053a14284 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -16,6 +16,7 @@ import pandas as pd from baybe.parameters.base import ContinuousParameter, DiscreteParameter +from baybe.targets import BinaryTarget from baybe.targets.enum import TargetMode from baybe.utils.numerical import DTypeFloatNumpy @@ -129,6 +130,8 @@ def add_fake_results( if good_intervals is None: good_intervals = {} for target in campaign.targets: + if isinstance(target, BinaryTarget): + continue if target.mode is TargetMode.MAX: lbound = target.bounds.lower if np.isfinite(target.bounds.lower) else 66 ubound = ( @@ -158,6 +161,8 @@ def add_fake_results( if bad_intervals is None: bad_intervals = {} for target in campaign.targets: + if isinstance(target, BinaryTarget): + continue if target.mode is TargetMode.MAX: lbound = target.bounds.lower if np.isfinite(target.bounds.lower) else 0 ubound = target.bounds.upper if np.isfinite(target.bounds.upper) else 33 @@ -186,6 +191,9 @@ def add_fake_results( # Add the fake data for each target for target in campaign.targets: + if isinstance(target, BinaryTarget): + data[target.name] = np.random.choice([0, 1]) + continue # Add bad values data[target.name] = np.random.uniform( bad_intervals[target.name][0], bad_intervals[target.name][1], len(data) diff --git a/examples/Multi_Armed_Bandit/find_maximizing_arm.py b/examples/Multi_Armed_Bandit/find_maximizing_arm.py index b5175aace..4b0d384e5 100644 --- a/examples/Multi_Armed_Bandit/find_maximizing_arm.py +++ b/examples/Multi_Armed_Bandit/find_maximizing_arm.py @@ -23,7 +23,9 @@ from baybe.targets import BinaryTarget ### Setup + # We are using a 5-armed bandit in this example. The bandit has a random win rate for now. + N_ARMS = 5 N_ITERATIONS = 300 np.random.seed(0) @@ -53,7 +55,8 @@ def means(self): ### Campaign -# We are using the BinaryTarget as we are modeling a brnoulli reward. + +# We are using the BinaryTarget as we are modeling a bernoulli reward. # The searchspace has one categorical parameter to model the arms of the bandit. # The probability of improvement acquisition function is not perfect in this setting # as it assumes a normal distribution of the win rate. @@ -84,6 +87,7 @@ def means(self): ### Optimization Loop + total_reward = 0 for i in range(N_ITERATIONS): df = campaign.recommend(batch_size=1) diff --git a/tests/conftest.py b/tests/conftest.py index 6216ccf5a..a27574920 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,7 +52,7 @@ from baybe.recommenders.pure.nonpredictive.sampling import RandomRecommender from baybe.searchspace import SearchSpace from baybe.surrogates import _ONNX_INSTALLED, GaussianProcessSurrogate -from baybe.targets import NumericalTarget +from baybe.targets import BinaryTarget, NumericalTarget from baybe.telemetry import ( VARNAME_TELEMETRY_ENABLED, VARNAME_TELEMETRY_HOSTNAME, @@ -390,6 +390,7 @@ def fixture_targets(target_names: list[str]): bounds=(0, 100), transformation="TRIANGULAR", ), + BinaryTarget(name="Target_binary"), ] return [t for t in valid_targets if t.name in target_names] @@ -628,14 +629,22 @@ def fixture_initial_recommender(): return RandomRecommender() +@pytest.fixture(name="allow_repeated_recommendations") +def fixture_allow_repeated_recommendations(): + return False + + @pytest.fixture(name="recommender") -def fixture_recommender(initial_recommender, surrogate_model, acqf): +def fixture_recommender( + initial_recommender, surrogate_model, acqf, allow_repeated_recommendations +): """The default recommender to be used if not specified differently.""" return TwoPhaseMetaRecommender( initial_recommender=initial_recommender, recommender=SequentialGreedyRecommender( surrogate_model=surrogate_model, acquisition_function=acqf, + allow_repeated_recommendations=allow_repeated_recommendations, ), ) diff --git a/tests/hypothesis_strategies/priors.py b/tests/hypothesis_strategies/priors.py index c438f9404..7e0e5bed7 100644 --- a/tests/hypothesis_strategies/priors.py +++ b/tests/hypothesis_strategies/priors.py @@ -3,6 +3,7 @@ import hypothesis.strategies as st from baybe.kernels.priors import ( + BetaPrior, GammaPrior, HalfCauchyPrior, HalfNormalPrior, @@ -47,6 +48,13 @@ ) """A strategy that generates Log-Normal priors.""" +beta_priors = st.builds( + BetaPrior, + st.floats(min_value=0.0, exclude_min=True), + st.floats(min_value=0.0, exclude_min=True), +) +"""A strategy that generates Beta priors""" + @st.composite def _smoothed_box_priors(draw: st.DrawFn): @@ -69,6 +77,7 @@ def _smoothed_box_priors(draw: st.DrawFn): log_normal_priors, normal_priors, smoothed_box_priors, + beta_priors, ] ) """A strategy that generates priors.""" diff --git a/tests/hypothesis_strategies/targets.py b/tests/hypothesis_strategies/targets.py index eebbcb11a..ba4b6f7e7 100644 --- a/tests/hypothesis_strategies/targets.py +++ b/tests/hypothesis_strategies/targets.py @@ -4,6 +4,7 @@ import hypothesis.strategies as st +from baybe.targets import BinaryTarget from baybe.targets.enum import TargetMode from baybe.targets.numerical import _VALID_TRANSFORMATIONS, NumericalTarget from baybe.utils.interval import Interval @@ -41,5 +42,7 @@ def numerical_targets( ) -targets = numerical_targets() +binary_targets = st.builds(BinaryTarget) + +targets = st.one_of([binary_targets, numerical_targets()]) """A strategy that generates targets.""" diff --git a/tests/test_iterations.py b/tests/test_iterations.py index 98b1e50c4..dc4a422ba 100644 --- a/tests/test_iterations.py +++ b/tests/test_iterations.py @@ -23,6 +23,7 @@ ) from baybe.recommenders.pure.nonpredictive.base import NonPredictiveRecommender from baybe.searchspace import SearchSpaceType +from baybe.surrogates import BernoulliMultiArmedBanditSurrogate from baybe.surrogates.base import Surrogate from baybe.utils.basic import get_subclasses @@ -32,7 +33,10 @@ # Settings of the individual components to be tested ######################################################################################## valid_surrogate_models = [ - cls() for cls in get_subclasses(Surrogate) if cls.__name__ != "CustomONNXSurrogate" + cls() + for cls in get_subclasses(Surrogate) + if cls.__name__ + not in ["CustomONNXSurrogate", BernoulliMultiArmedBanditSurrogate.__name__] ] valid_initial_recommenders = [cls() for cls in get_subclasses(NonPredictiveRecommender)] # TODO the TwoPhaseMetaRecommender below can be removed if the SeqGreedy recommender @@ -219,3 +223,15 @@ def test_iter_recommender_hybrid(campaign, n_iterations, batch_size): @pytest.mark.parametrize("recommender", valid_meta_recommenders, indirect=True) def test_meta_recommenders(campaign, n_iterations, batch_size): run_iterations(campaign, n_iterations, batch_size) + + +@pytest.mark.parametrize("surrogate_model", [BernoulliMultiArmedBanditSurrogate()]) +@pytest.mark.parametrize( + "parameter_names", + [["Categorical_1"], ["Categorical_2"], ["Switch_1"], ["Switch_2"]], +) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("target_names", [["Target_binary"]]) +@pytest.mark.parametrize("allow_repeated_recommendations", [True]) +def test_multi_arm_bandit(campaign, n_iterations, batch_size): + run_iterations(campaign, n_iterations, batch_size, add_noise=False)