diff --git a/botorch/generation/gen.py b/botorch/generation/gen.py index 842d0b3656..48770d19f5 100644 --- a/botorch/generation/gen.py +++ b/botorch/generation/gen.py @@ -19,6 +19,7 @@ from botorch.acquisition import AcquisitionFunction from botorch.exceptions.warnings import OptimizationWarning from botorch.generation.utils import _remove_fixed_features_from_optimization +from botorch.logging import _get_logger from botorch.optim.parameter_constraints import ( _arrayify, make_scipy_bounds, @@ -29,9 +30,12 @@ from botorch.optim.stopping import ExpMAStoppingCriterion from botorch.optim.utils import _filter_kwargs, columnwise_clamp, fix_features from scipy.optimize import minimize +from scipy.optimize.optimize import OptimizeResult from torch import Tensor from torch.optim import Optimizer +logger = _get_logger() + def gen_candidates_scipy( initial_conditions: Tensor, @@ -95,6 +99,7 @@ def gen_candidates_scipy( ) """ options = options or {} + options = {**options, "maxiter": options.get("maxiter", 2000)} # if there are fixed features we may optimize over a domain of lower dimension reduced_domain = False @@ -211,23 +216,8 @@ def f(x): callback=options.get("callback", None), options={k: v for k, v in options.items() if k not in ["method", "callback"]}, ) + _process_scipy_result(res=res, options=options) - if "success" not in res.keys() or "status" not in res.keys(): - with warnings.catch_warnings(): - warnings.simplefilter("always", category=OptimizationWarning) - warnings.warn( - "Optimization failed within `scipy.optimize.minimize` with no " - "status returned to `res.`", - OptimizationWarning, - ) - elif not res.success: - with warnings.catch_warnings(): - warnings.simplefilter("always", category=OptimizationWarning) - warnings.warn( - f"Optimization failed within `scipy.optimize.minimize` with status " - f"{res.status}.", - OptimizationWarning, - ) candidates = fix_features( X=torch.from_numpy(res.x).to(initial_conditions).reshape(shapeX), fixed_features=fixed_features, @@ -399,3 +389,37 @@ def get_best_candidates(batch_candidates: Tensor, batch_values: Tensor) -> Tenso """ best = torch.argmax(batch_values.view(-1), dim=0) return batch_candidates[best] + + +def _process_scipy_result(res: OptimizeResult, options: Dict[str, Any]) -> None: + r"""Process scipy optimization result to produce relevant logs and warnings.""" + if "success" not in res.keys() or "status" not in res.keys(): + with warnings.catch_warnings(): + warnings.simplefilter("always", category=OptimizationWarning) + warnings.warn( + "Optimization failed within `scipy.optimize.minimize` with no " + "status returned to `res.`", + OptimizationWarning, + ) + elif not res.success: + if ( + "ITERATIONS REACHED LIMIT" in res.message + or "Iteration limit reached" in res.message + ): + logger.info( + "`scipy.minimize` exited by reaching the iteration limit of " + f"`maxiter: {options.get('maxiter')}`." + ) + elif "EVALUATIONS EXCEEDS LIMIT" in res.message: + logger.info( + "`scipy.minimize` exited by reaching the function evaluation limit of " + f"`maxfun: {options.get('maxfun')}`." + ) + else: + with warnings.catch_warnings(): + warnings.simplefilter("always", category=OptimizationWarning) + warnings.warn( + f"Optimization failed within `scipy.optimize.minimize` with status " + f"{res.status}.", + OptimizationWarning, + ) diff --git a/test/generation/test_gen.py b/test/generation/test_gen.py index caeed47086..36b317cdac 100644 --- a/test/generation/test_gen.py +++ b/test/generation/test_gen.py @@ -69,7 +69,7 @@ def _setUp(self, double=False, expand=False): class TestGenCandidates(TestBaseCandidateGeneration): def test_gen_candidates(self, gen_candidates=gen_candidates_scipy, options=None): options = options or {} - options = {**options, "maxiter": 5} + options = {**options, "maxiter": options.get("maxiter", 5)} for double in (True, False): self._setUp(double=double) acqfs = [ @@ -125,19 +125,14 @@ def test_gen_candidates_with_none_fixed_features( ics = self.initial_conditions if isinstance(acqf, qKnowledgeGradient): ics = ics.repeat(5, 1) - # we are getting a warning that this fails with status 1: - # 'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT' - # This is expected since we have set "maxiter" low, so suppress - # the warning - with warnings.catch_warnings(record=True): - candidates, _ = gen_candidates( - initial_conditions=ics, - acquisition_function=acqf, - lower_bounds=0, - upper_bounds=1, - fixed_features={1: None}, - options=options or {}, - ) + candidates, _ = gen_candidates( + initial_conditions=ics, + acquisition_function=acqf, + lower_bounds=0, + upper_bounds=1, + fixed_features={1: None}, + options=options or {}, + ) if isinstance(acqf, qKnowledgeGradient): candidates = acqf.extract_candidates(candidates) candidates = candidates.squeeze(0) @@ -166,19 +161,14 @@ def test_gen_candidates_with_fixed_features( ics = self.initial_conditions if isinstance(acqf, qKnowledgeGradient): ics = ics.repeat(5, 1) - # we are getting a warning that this fails with status 1: - # 'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT' - # This is expected since we have set "maxiter" low, so suppress - # the warning - with warnings.catch_warnings(record=True): - candidates, _ = gen_candidates( - initial_conditions=ics, - acquisition_function=acqf, - lower_bounds=0, - upper_bounds=1, - fixed_features={1: 0.25}, - options=options, - ) + candidates, _ = gen_candidates( + initial_conditions=ics, + acquisition_function=acqf, + lower_bounds=0, + upper_bounds=1, + fixed_features={1: 0.25}, + options=options, + ) if isinstance(acqf, qKnowledgeGradient): candidates = acqf.extract_candidates(candidates) @@ -192,20 +182,16 @@ def test_gen_candidates_scipy_with_fixed_features_inequality_constraints(self): for double in (True, False): self._setUp(double=double, expand=True) qEI = qExpectedImprovement(self.model, best_f=self.f_best) - # we are getting a warning that this fails with status 9: - # "Iteration limit reached." This is expected since we have set - # "maxiter" low, so suppress the warning. - with warnings.catch_warnings(record=True): - candidates, _ = gen_candidates_scipy( - initial_conditions=self.initial_conditions.reshape(1, 1, -1), - acquisition_function=qEI, - inequality_constraints=[ - (torch.tensor([0]), torch.tensor([1]), 0), - (torch.tensor([1]), torch.tensor([-1]), -1), - ], - fixed_features={1: 0.25}, - options=options, - ) + candidates, _ = gen_candidates_scipy( + initial_conditions=self.initial_conditions.reshape(1, 1, -1), + acquisition_function=qEI, + inequality_constraints=[ + (torch.tensor([0]), torch.tensor([1]), 0), + (torch.tensor([1]), torch.tensor([-1]), -1), + ], + fixed_features={1: 0.25}, + options=options, + ) # candidates is of dimension 1 x 1 x 2 # so we are squeezing all the singleton dimensions candidates = candidates.squeeze() @@ -227,6 +213,27 @@ def test_gen_candidates_scipy_warns_opt_failure(self): ) self.assertTrue(expected_warning_raised) + def test_gen_candidates_scipy_maxiter_behavior(self): + # Check that no warnings are raised & log produced on hitting maxiter. + for method in ("SLSQP", "L-BFGS-B"): + with warnings.catch_warnings(record=True) as ws, self.assertLogs( + "botorch", level="INFO" + ) as logs: + self.test_gen_candidates(options={"maxiter": 1, "method": method}) + self.assertFalse( + any(issubclass(w.category, OptimizationWarning) for w in ws) + ) + self.assertTrue("iteration limit" in logs.output[-1]) + # Check that we handle maxfun as well. + with warnings.catch_warnings(record=True) as ws, self.assertLogs( + "botorch", level="INFO" + ) as logs: + self.test_gen_candidates( + options={"maxiter": 100, "maxfun": 1, "method": "L-BFGS-B"} + ) + self.assertFalse(any(issubclass(w.category, OptimizationWarning) for w in ws)) + self.assertTrue("function evaluation limit" in logs.output[-1]) + def test_gen_candidates_scipy_warns_opt_no_res(self): ckwargs = {"dtype": torch.float, "device": self.device}