Skip to content

Commit

Permalink
Merge: Add iteration test retries (emdgroup#374)
Browse files Browse the repository at this point in the history
An attempt to reduce pipelines failures due to numerics
- uses `tenacity` to retry the iteration test when a known numerical
exception is thrown
- adds temporary random seed to iteration tests to provide more
numerical variance
- enforces integer size limits in baybe random seed utilities
  • Loading branch information
Scienfitz authored Sep 9, 2024
2 parents 9bda168 + d6648a4 commit 8e40acf
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 13 deletions.
1 change: 1 addition & 0 deletions .lockfiles/py310-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,7 @@ tbb==2021.13.0 ; platform_system == 'Windows'
# via mkl
tenacity==8.5.0
# via
# baybe (pyproject.toml)
# plotly
# streamlit
terminado==0.18.1
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
controlling whether pending experiments are excluded from candidates in purely
discrete search spaces
- `get_surrogate` and `posterior` methods to `Campaign`
- `tenacity` test dependency

### Changed
- The transition from experimental to computational representation no longer happens
Expand All @@ -42,6 +43,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Search spaces now store their parameters in alphabetical order by name
- Improvement-based acquisition functions now consider the maximum posterior mean
instead of the maximum noisy measurement as reference value
- Iteration tests now attempt up to 5 repeated executions if they fail due to numerical
reasons

### Fixed
- `CategoricalParameter` and `TaskParameter` no longer incorrectly coerce a single
Expand Down
2 changes: 1 addition & 1 deletion baybe/recommenders/pure/bayesian/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def recommend( # noqa: D102
f"cannot be used for batch recommendation."
)

if isinstance(self.surrogate_model, CustomONNXSurrogate):
if isinstance(self._surrogate_model, CustomONNXSurrogate):
CustomONNXSurrogate.validate_compatibility(searchspace)

self._setup_botorch_acqf(
Expand Down
2 changes: 1 addition & 1 deletion baybe/recommenders/pure/bayesian/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _recommend_hybrid(

def __str__(self) -> str:
fields = [
to_string("Surrogate", self.surrogate_model),
to_string("Surrogate", self._surrogate_model),
to_string(
"Acquisition function", self.acquisition_function, single_line=True
),
Expand Down
6 changes: 6 additions & 0 deletions baybe/utils/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def set_random_seed(seed: int):
"""
import torch

# Ensure seed limits
seed %= 2**32

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
Expand All @@ -28,6 +31,9 @@ def temporary_seed(seed: int): # noqa: DOC402, DOC404
"""
import torch

# Ensure seed limits
seed %= 2**32

# Collect the current RNG states
state_builtin = random.getstate()
state_np = np.random.get_state()
Expand Down
2 changes: 1 addition & 1 deletion examples/Custom_Hooks/campaign_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def stop_on_PI(
)
acqf = ProbabilityOfImprovement()
botorch_acqf = acqf.to_botorch(
self.surrogate_model, searchspace, objective, measurements
self._surrogate_model, searchspace, objective, measurements
)
_, candidates_comp_rep = searchspace.discrete.get_candidates(
allow_repeated_recommendations=self.allow_repeated_recommendations,
Expand Down
2 changes: 1 addition & 1 deletion examples/Custom_Hooks/probability_of_improvement.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def extract_pi(
)
acqf = ProbabilityOfImprovement()
botorch_acqf = acqf.to_botorch(
self.surrogate_model, searchspace, objective, measurements
self._surrogate_model, searchspace, objective, measurements
)
comp_rep_tensor = to_tensor(searchspace.discrete.comp_rep).unsqueeze(1)
with torch.no_grad():
Expand Down
7 changes: 5 additions & 2 deletions examples/Multi_Armed_Bandit/bernoulli_multi_armed_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

### Imports

import os
from dataclasses import dataclass
from typing import NamedTuple

Expand Down Expand Up @@ -76,12 +77,14 @@ def draw_arm(self, arm_index: int) -> bool:

# To estimate the corresponding effects, we simulate each campaign for a certain number of steps and repeat this process in multiple Monte Carlo runs:

SMOKE_TEST = "SMOKE_TEST" in os.environ

ACQFS = [
qThompsonSampling(), # Online optimization
PosteriorStandardDeviation(), # Active learning
]
N_MC_RUNS = 10
N_ITERATIONS = 200
N_MC_RUNS = 2 if SMOKE_TEST else 10
N_ITERATIONS = 2 if SMOKE_TEST else 200


### Building the Model
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ simulation = [

test = [
"hypothesis[pandas]>=6.88.4",
"tenacity>=8.5.0",
"pytest>=7.2.0",
"pytest-cov>=4.1.0",
]
Expand Down
42 changes: 35 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,25 @@
from __future__ import annotations

import os
import time
import warnings
from itertools import chain
from unittest.mock import Mock

import numpy as np
import pandas as pd
import pytest
import torch
from botorch.exceptions import ModelFittingError
from hypothesis import settings as hypothesis_settings
from tenacity import (
retry,
retry_any,
retry_if_exception_message,
retry_if_exception_type,
stop_after_attempt,
)
from torch._C import _LinAlgError

from baybe._optional.info import CHEM_INSTALLED
from baybe.acquisition import qExpectedImprovement
Expand Down Expand Up @@ -68,6 +79,7 @@
from baybe.utils.basic import hilberts_factory
from baybe.utils.boolean import strtobool
from baybe.utils.dataframe import add_fake_results, add_parameter_noise
from baybe.utils.random import temporary_seed

# Hypothesis settings
hypothesis_settings.register_profile("ci", deadline=500, max_examples=100)
Expand Down Expand Up @@ -861,26 +873,42 @@ def fixture_default_onnx_surrogate(onnx_str) -> CustomONNXSurrogate:

# TODO consider turning this into a fixture returning a campaign after running some
# fake iterations
@retry(
stop=stop_after_attempt(5),
retry=retry_any(
retry_if_exception_type((ModelFittingError, _LinAlgError)),
retry_if_exception_message(
match=r".*Expected value argument.*to be within the support.*"
),
),
before_sleep=lambda x: warnings.warn(
f"Retrying iteration test due to '{x.outcome.exception()}'"
),
)
def run_iterations(
campaign: Campaign, n_iterations: int, batch_size: int, add_noise: bool = True
) -> None:
"""Run a campaign for some fake iterations.
This function attempts up to five executions if numerical errors were encountered.
Each retry is done with a different seed to ensure numerical variance.
Args:
campaign: The campaign encapsulating the experiments.
n_iterations: Number of iterations run.
batch_size: Number of recommended points per iteration.
add_noise: Flag whether measurement noise should be added every 2nd iteration.
"""
for k in range(n_iterations):
rec = campaign.recommend(batch_size=batch_size)
# dont use parameter noise for these tests
with temporary_seed(int(time.time())):
for k in range(n_iterations):
rec = campaign.recommend(batch_size=batch_size)
# dont use parameter noise for these tests

add_fake_results(rec, campaign.targets)
if add_noise and (k % 2):
add_parameter_noise(rec, campaign.parameters, noise_level=0.1)
add_fake_results(rec, campaign.targets)
if add_noise and (k % 2):
add_parameter_noise(rec, campaign.parameters, noise_level=0.02)

campaign.add_measurements(rec)
campaign.add_measurements(rec)


def select_recommender(
Expand Down

0 comments on commit 8e40acf

Please sign in to comment.