diff --git a/botorch/optim/closures/core.py b/botorch/optim/closures/core.py index d2363ad1fb..33e45954c9 100644 --- a/botorch/optim/closures/core.py +++ b/botorch/optim/closures/core.py @@ -110,11 +110,14 @@ def __init__( """ if get_state is None: # Note: Numpy supports copying data between ndarrays with different dtypes. - # Hence, our default behavior need not coerce the ndarray represenations of - # tensors in `parameters` to float64 when copying over data. + # Hence, our default behavior need not coerce the ndarray representations + # of tensors in `parameters` to float64 when copying over data. _as_array = as_ndarray if as_array is None else as_array get_state = partial( - get_tensors_as_ndarray_1d, parameters, as_array=_as_array + get_tensors_as_ndarray_1d, + tensors=parameters, + dtype=np_float64, + as_array=_as_array, ) if as_array is None: # per the note, do this after resolving `get_state` @@ -154,7 +157,7 @@ def __call__( grads[index : index + size] = self.as_array(grad.view(-1)) index += size except RuntimeError as e: - value, grads = _handle_numerical_errors(error=e, x=self.state) + value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64) return value, grads @@ -174,9 +177,9 @@ def _get_gradient_ndarray(self, fill_value: Optional[float] = None) -> ndarray: size = sum(param.numel() for param in self.parameters.values()) array = ( - np_zeros(size) + np_zeros(size, dtype=np_float64) if fill_value is None or fill_value == 0.0 - else np_full(size, fill_value) + else np_full(size, fill_value, dtype=np_float64) ) if self.persistent: self._gradient_ndarray = array diff --git a/botorch/optim/core.py b/botorch/optim/core.py index 9110312fb3..49cb92831d 100644 --- a/botorch/optim/core.py +++ b/botorch/optim/core.py @@ -18,7 +18,7 @@ from botorch.optim.closures import NdarrayOptimizationClosure from botorch.optim.utils import get_bounds_as_ndarray -from numpy import asarray, ndarray +from numpy import asarray, float64 as np_float64, ndarray from scipy.optimize import minimize from torch import Tensor from torch.optim.adam import Adam @@ -105,7 +105,7 @@ def wrapped_callback(x: ndarray): raw = minimize( wrapped_closure, - wrapped_closure.state if x0 is None else x0, + wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False), jac=True, bounds=bounds_np, method=method, diff --git a/botorch/optim/utils/common.py b/botorch/optim/utils/common.py index fbb5a20252..886820a0a7 100644 --- a/botorch/optim/utils/common.py +++ b/botorch/optim/utils/common.py @@ -33,7 +33,7 @@ def _filter_kwargs(function: Callable, **kwargs: Any) -> Any: def _handle_numerical_errors( - error: RuntimeError, x: np.ndarray + error: RuntimeError, x: np.ndarray, dtype: Optional[np.dtype] = None ) -> Tuple[np.ndarray, np.ndarray]: if isinstance(error, NotPSDError): raise error @@ -43,7 +43,8 @@ def _handle_numerical_errors( or "singular" in error_message # old pytorch message or "input is not positive-definite" in error_message # since pytorch #63864 ): - return np.full((), "nan", dtype=x.dtype), np.full_like(x, "nan") + _dtype = x.dtype if dtype is None else dtype + return np.full((), "nan", dtype=_dtype), np.full_like(x, "nan", dtype=_dtype) raise error # pragma: nocover diff --git a/botorch/optim/utils/numpy_utils.py b/botorch/optim/utils/numpy_utils.py index 052b58ec69..af68e0eb2a 100644 --- a/botorch/optim/utils/numpy_utils.py +++ b/botorch/optim/utils/numpy_utils.py @@ -66,7 +66,7 @@ def as_ndarray( # Convert to ndarray and maybe cast to `dtype` out = out.numpy() - return out if (dtype is None or dtype == out.dtype) else out.astype(dtype) + return out.astype(dtype, copy=False) def get_tensors_as_ndarray_1d( diff --git a/test/optim/utils/test_common.py b/test/optim/utils/test_common.py index 713a7acf15..85b6c626f8 100644 --- a/test/optim/utils/test_common.py +++ b/test/optim/utils/test_common.py @@ -17,23 +17,27 @@ class TestUtilsCommon(BotorchTestCase): def test_handle_numerical_errors(self): - x = np.zeros(1) + x = np.zeros(1, dtype=np.float64) with self.assertRaisesRegex(NotPSDError, "foo"): - _handle_numerical_errors(error=NotPSDError("foo"), x=x) + _handle_numerical_errors(NotPSDError("foo"), x=x) for error in ( NanError(), RuntimeError("singular"), RuntimeError("input is not positive-definite"), ): - fake_loss, fake_grad = _handle_numerical_errors(error=error, x=x) + fake_loss, fake_grad = _handle_numerical_errors(error, x=x) self.assertTrue(np.isnan(fake_loss)) self.assertEqual(fake_grad.shape, x.shape) self.assertTrue(np.isnan(fake_grad).all()) + fake_loss, fake_grad = _handle_numerical_errors(error, x=x, dtype=np.float32) + self.assertEqual(np.float32, fake_loss.dtype) + self.assertEqual(np.float32, fake_grad.dtype) + with self.assertRaisesRegex(RuntimeError, "foo"): - _handle_numerical_errors(error=RuntimeError("foo"), x=x) + _handle_numerical_errors(RuntimeError("foo"), x=x) def test_warning_handler_template(self): with catch_warnings(record=True) as ws: