Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not count hitting maxiter as optimization failure & update default maxiter #1478

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions botorch/generation/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from botorch.acquisition import AcquisitionFunction
from botorch.exceptions.warnings import OptimizationWarning
from botorch.generation.utils import _remove_fixed_features_from_optimization
from botorch.logging import _get_logger
from botorch.optim.parameter_constraints import (
_arrayify,
make_scipy_bounds,
Expand All @@ -29,9 +30,12 @@
from botorch.optim.stopping import ExpMAStoppingCriterion
from botorch.optim.utils import _filter_kwargs, columnwise_clamp, fix_features
from scipy.optimize import minimize
from scipy.optimize.optimize import OptimizeResult
from torch import Tensor
from torch.optim import Optimizer

logger = _get_logger()


def gen_candidates_scipy(
initial_conditions: Tensor,
Expand Down Expand Up @@ -95,6 +99,7 @@ def gen_candidates_scipy(
)
"""
options = options or {}
options = {**options, "maxiter": options.get("maxiter", 2000)}

# if there are fixed features we may optimize over a domain of lower dimension
reduced_domain = False
Expand Down Expand Up @@ -211,23 +216,8 @@ def f(x):
callback=options.get("callback", None),
options={k: v for k, v in options.items() if k not in ["method", "callback"]},
)
_process_scipy_result(res=res, options=options)

if "success" not in res.keys() or "status" not in res.keys():
with warnings.catch_warnings():
warnings.simplefilter("always", category=OptimizationWarning)
warnings.warn(
"Optimization failed within `scipy.optimize.minimize` with no "
"status returned to `res.`",
OptimizationWarning,
)
elif not res.success:
with warnings.catch_warnings():
warnings.simplefilter("always", category=OptimizationWarning)
warnings.warn(
f"Optimization failed within `scipy.optimize.minimize` with status "
f"{res.status}.",
OptimizationWarning,
)
candidates = fix_features(
X=torch.from_numpy(res.x).to(initial_conditions).reshape(shapeX),
fixed_features=fixed_features,
Expand Down Expand Up @@ -399,3 +389,37 @@ def get_best_candidates(batch_candidates: Tensor, batch_values: Tensor) -> Tenso
"""
best = torch.argmax(batch_values.view(-1), dim=0)
return batch_candidates[best]


def _process_scipy_result(res: OptimizeResult, options: Dict[str, Any]) -> None:
r"""Process scipy optimization result to produce relevant logs and warnings."""
if "success" not in res.keys() or "status" not in res.keys():
with warnings.catch_warnings():
warnings.simplefilter("always", category=OptimizationWarning)
warnings.warn(
"Optimization failed within `scipy.optimize.minimize` with no "
"status returned to `res.`",
OptimizationWarning,
)
elif not res.success:
if (
"ITERATIONS REACHED LIMIT" in res.message
or "Iteration limit reached" in res.message
):
logger.info(
"`scipy.minimize` exited by reaching the iteration limit of "
f"`maxiter: {options.get('maxiter')}`."
)
elif "EVALUATIONS EXCEEDS LIMIT" in res.message:
logger.info(
"`scipy.minimize` exited by reaching the function evaluation limit of "
f"`maxfun: {options.get('maxfun')}`."
)
else:
with warnings.catch_warnings():
warnings.simplefilter("always", category=OptimizationWarning)
warnings.warn(
f"Optimization failed within `scipy.optimize.minimize` with status "
f"{res.status}.",
OptimizationWarning,
)
89 changes: 48 additions & 41 deletions test/generation/test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _setUp(self, double=False, expand=False):
class TestGenCandidates(TestBaseCandidateGeneration):
def test_gen_candidates(self, gen_candidates=gen_candidates_scipy, options=None):
options = options or {}
options = {**options, "maxiter": 5}
options = {**options, "maxiter": options.get("maxiter", 5)}
for double in (True, False):
self._setUp(double=double)
acqfs = [
Expand Down Expand Up @@ -125,19 +125,14 @@ def test_gen_candidates_with_none_fixed_features(
ics = self.initial_conditions
if isinstance(acqf, qKnowledgeGradient):
ics = ics.repeat(5, 1)
# we are getting a warning that this fails with status 1:
# 'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT'
# This is expected since we have set "maxiter" low, so suppress
# the warning
with warnings.catch_warnings(record=True):
candidates, _ = gen_candidates(
initial_conditions=ics,
acquisition_function=acqf,
lower_bounds=0,
upper_bounds=1,
fixed_features={1: None},
options=options or {},
)
candidates, _ = gen_candidates(
initial_conditions=ics,
acquisition_function=acqf,
lower_bounds=0,
upper_bounds=1,
fixed_features={1: None},
options=options or {},
)
if isinstance(acqf, qKnowledgeGradient):
candidates = acqf.extract_candidates(candidates)
candidates = candidates.squeeze(0)
Expand Down Expand Up @@ -166,19 +161,14 @@ def test_gen_candidates_with_fixed_features(
ics = self.initial_conditions
if isinstance(acqf, qKnowledgeGradient):
ics = ics.repeat(5, 1)
# we are getting a warning that this fails with status 1:
# 'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT'
# This is expected since we have set "maxiter" low, so suppress
# the warning
with warnings.catch_warnings(record=True):
candidates, _ = gen_candidates(
initial_conditions=ics,
acquisition_function=acqf,
lower_bounds=0,
upper_bounds=1,
fixed_features={1: 0.25},
options=options,
)
candidates, _ = gen_candidates(
initial_conditions=ics,
acquisition_function=acqf,
lower_bounds=0,
upper_bounds=1,
fixed_features={1: 0.25},
options=options,
)

if isinstance(acqf, qKnowledgeGradient):
candidates = acqf.extract_candidates(candidates)
Expand All @@ -192,20 +182,16 @@ def test_gen_candidates_scipy_with_fixed_features_inequality_constraints(self):
for double in (True, False):
self._setUp(double=double, expand=True)
qEI = qExpectedImprovement(self.model, best_f=self.f_best)
# we are getting a warning that this fails with status 9:
# "Iteration limit reached." This is expected since we have set
# "maxiter" low, so suppress the warning.
with warnings.catch_warnings(record=True):
candidates, _ = gen_candidates_scipy(
initial_conditions=self.initial_conditions.reshape(1, 1, -1),
acquisition_function=qEI,
inequality_constraints=[
(torch.tensor([0]), torch.tensor([1]), 0),
(torch.tensor([1]), torch.tensor([-1]), -1),
],
fixed_features={1: 0.25},
options=options,
)
candidates, _ = gen_candidates_scipy(
initial_conditions=self.initial_conditions.reshape(1, 1, -1),
acquisition_function=qEI,
inequality_constraints=[
(torch.tensor([0]), torch.tensor([1]), 0),
(torch.tensor([1]), torch.tensor([-1]), -1),
],
fixed_features={1: 0.25},
options=options,
)
# candidates is of dimension 1 x 1 x 2
# so we are squeezing all the singleton dimensions
candidates = candidates.squeeze()
Expand All @@ -227,6 +213,27 @@ def test_gen_candidates_scipy_warns_opt_failure(self):
)
self.assertTrue(expected_warning_raised)

def test_gen_candidates_scipy_maxiter_behavior(self):
# Check that no warnings are raised & log produced on hitting maxiter.
for method in ("SLSQP", "L-BFGS-B"):
with warnings.catch_warnings(record=True) as ws, self.assertLogs(
"botorch", level="INFO"
) as logs:
self.test_gen_candidates(options={"maxiter": 1, "method": method})
self.assertFalse(
any(issubclass(w.category, OptimizationWarning) for w in ws)
)
self.assertTrue("iteration limit" in logs.output[-1])
# Check that we handle maxfun as well.
with warnings.catch_warnings(record=True) as ws, self.assertLogs(
"botorch", level="INFO"
) as logs:
self.test_gen_candidates(
options={"maxiter": 100, "maxfun": 1, "method": "L-BFGS-B"}
)
self.assertFalse(any(issubclass(w.category, OptimizationWarning) for w in ws))
self.assertTrue("function evaluation limit" in logs.output[-1])

def test_gen_candidates_scipy_warns_opt_no_res(self):
ckwargs = {"dtype": torch.float, "device": self.device}

Expand Down