Skip to content

Commit

Permalink
Unpin scipy & use threadpoolctl in minimize_with_timeout (#2712)
Browse files Browse the repository at this point in the history
Summary:
Scipy>=1.15 was leading to significant slowdowns in CI due to issues with OpenBLAS configuration (see scipy/scipy#22438). Thanks to suggestions by SciPy maintainers, we've found a solution that resolves the slowdown without modifying and global configurations for OpenBLAS.

This PR wraps `minimize` calls in `minimize_with_timeout` in a `with threadpool_limits(limits=1, user_api="blas")` context, which makes OpenBLAS operate single-threaded within the context. Since PyTorch relies on OpenMP & MKL rather than OpenBLAS, this does not affect the evaluations of the acquisition functions or the MLL.

Pull Request resolved: #2712

Test Plan:
Compare runtimes in the following tutorial workflows
Baseline - with scipy 1.14.1, without `threadpool_limits`: https://github.com/pytorch/botorch/actions/runs/13044166839/job/36391816602
With scipy 1.14.1, with `threadpool_limits`: https://github.com/pytorch/botorch/actions/runs/13044167834/job/36391818438
With scipy 1.15.1, with `threadpool_limits`: https://github.com/pytorch/botorch/actions/runs/13044245918/job/36392026548

Reviewed By: Balandat

Differential Revision: D68874700

Pulled By: saitcakmak

fbshipit-source-id: c6273bd43a8312e215204207bf92a765eb5201ce
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jan 30, 2025
1 parent 4a595f2 commit acae12d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
32 changes: 18 additions & 14 deletions botorch/optim/utils/timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy.typing as npt
from botorch.exceptions.errors import OptimizationTimeoutError
from scipy import optimize
from threadpoolctl import threadpool_limits


def minimize_with_timeout(
Expand Down Expand Up @@ -79,20 +80,23 @@ def wrapped_callback(xk: npt.NDArray) -> None:

try:
warnings.filterwarnings("error", message="Method .* cannot handle")
return optimize.minimize(
fun=fun,
x0=x0,
args=args,
method=method,
jac=jac,
hess=hess,
hessp=hessp,
bounds=bounds,
constraints=constraints,
tol=tol,
callback=wrapped_callback,
options=options,
)
# To prevent slowdowns after scipy 1.15.
# See https://github.com/scipy/scipy/issues/22438.
with threadpool_limits(limits=1, user_api="blas"):
return optimize.minimize(
fun=fun,
x0=x0,
args=args,
method=method,
jac=jac,
hess=hess,
hessp=hessp,
bounds=bounds,
constraints=constraints,
tol=tol,
callback=wrapped_callback,
options=options,
)
except OptimizationTimeoutError as e:
msg = f"Optimization timed out after {e.runtime} seconds."
current_fun, *_ = fun(e.current_x, *args)
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ gpytorch==1.14
linear_operator==0.6
torch>=2.0.1
pyro-ppl>=1.8.4
scipy<1.15
scipy
multipledispatch
threadpoolctl

0 comments on commit acae12d

Please sign in to comment.