Skip to content

Commit

Permalink
Enforce use of float64 in NdarrayOptimizationClosure
Browse files Browse the repository at this point in the history
Reviewed By: esantorella

Differential Revision: D41355824

fbshipit-source-id: 0ed23de9e2b70fcb7406c10aecdc511a7a5caa19
  • Loading branch information
James Wilson authored and facebook-github-bot committed Nov 19, 2022
1 parent 17b1bb7 commit 938fdbe
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
15 changes: 9 additions & 6 deletions botorch/optim/closures/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,14 @@ def __init__(
"""
if get_state is None:
# Note: Numpy supports copying data between ndarrays with different dtypes.
# Hence, our default behavior need not coerce the ndarray represenations of
# tensors in `parameters` to float64 when copying over data.
# Hence, our default behavior need not coerce the ndarray representations
# of tensors in `parameters` to float64 when copying over data.
_as_array = as_ndarray if as_array is None else as_array
get_state = partial(
get_tensors_as_ndarray_1d, parameters, as_array=_as_array
get_tensors_as_ndarray_1d,
tensors=parameters,
dtype=np_float64,
as_array=_as_array,
)

if as_array is None: # per the note, do this after resolving `get_state`
Expand Down Expand Up @@ -154,7 +157,7 @@ def __call__(
grads[index : index + size] = self.as_array(grad.view(-1))
index += size
except RuntimeError as e:
value, grads = _handle_numerical_errors(error=e, x=self.state)
value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)

return value, grads

Expand All @@ -174,9 +177,9 @@ def _get_gradient_ndarray(self, fill_value: Optional[float] = None) -> ndarray:

size = sum(param.numel() for param in self.parameters.values())
array = (
np_zeros(size)
np_zeros(size, dtype=np_float64)
if fill_value is None or fill_value == 0.0
else np_full(size, fill_value)
else np_full(size, fill_value, dtype=np_float64)
)
if self.persistent:
self._gradient_ndarray = array
Expand Down
4 changes: 2 additions & 2 deletions botorch/optim/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from botorch.optim.closures import NdarrayOptimizationClosure
from botorch.optim.utils import get_bounds_as_ndarray
from numpy import asarray, ndarray
from numpy import asarray, float64 as np_float64, ndarray
from scipy.optimize import minimize
from torch import Tensor
from torch.optim.adam import Adam
Expand Down Expand Up @@ -105,7 +105,7 @@ def wrapped_callback(x: ndarray):

raw = minimize(
wrapped_closure,
wrapped_closure.state if x0 is None else x0,
wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
jac=True,
bounds=bounds_np,
method=method,
Expand Down
5 changes: 3 additions & 2 deletions botorch/optim/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _filter_kwargs(function: Callable, **kwargs: Any) -> Any:


def _handle_numerical_errors(
error: RuntimeError, x: np.ndarray
error: RuntimeError, x: np.ndarray, dtype: Optional[np.dtype] = None
) -> Tuple[np.ndarray, np.ndarray]:
if isinstance(error, NotPSDError):
raise error
Expand All @@ -43,7 +43,8 @@ def _handle_numerical_errors(
or "singular" in error_message # old pytorch message
or "input is not positive-definite" in error_message # since pytorch #63864
):
return np.full((), "nan", dtype=x.dtype), np.full_like(x, "nan")
_dtype = x.dtype if dtype is None else dtype
return np.full((), "nan", dtype=_dtype), np.full_like(x, "nan", dtype=_dtype)
raise error # pragma: nocover


Expand Down
2 changes: 1 addition & 1 deletion botorch/optim/utils/numpy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def as_ndarray(

# Convert to ndarray and maybe cast to `dtype`
out = out.numpy()
return out if (dtype is None or dtype == out.dtype) else out.astype(dtype)
return out.astype(dtype, copy=False)


def get_tensors_as_ndarray_1d(
Expand Down
12 changes: 8 additions & 4 deletions test/optim/utils/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,27 @@

class TestUtilsCommon(BotorchTestCase):
def test_handle_numerical_errors(self):
x = np.zeros(1)
x = np.zeros(1, dtype=np.float64)

with self.assertRaisesRegex(NotPSDError, "foo"):
_handle_numerical_errors(error=NotPSDError("foo"), x=x)
_handle_numerical_errors(NotPSDError("foo"), x=x)

for error in (
NanError(),
RuntimeError("singular"),
RuntimeError("input is not positive-definite"),
):
fake_loss, fake_grad = _handle_numerical_errors(error=error, x=x)
fake_loss, fake_grad = _handle_numerical_errors(error, x=x)
self.assertTrue(np.isnan(fake_loss))
self.assertEqual(fake_grad.shape, x.shape)
self.assertTrue(np.isnan(fake_grad).all())

fake_loss, fake_grad = _handle_numerical_errors(error, x=x, dtype=np.float32)
self.assertEqual(np.float32, fake_loss.dtype)
self.assertEqual(np.float32, fake_grad.dtype)

with self.assertRaisesRegex(RuntimeError, "foo"):
_handle_numerical_errors(error=RuntimeError("foo"), x=x)
_handle_numerical_errors(RuntimeError("foo"), x=x)

def test_warning_handler_template(self):
with catch_warnings(record=True) as ws:
Expand Down

0 comments on commit 938fdbe

Please sign in to comment.