diff --git a/botorch/utils/probability/lin_ess.py b/botorch/utils/probability/lin_ess.py index 7a664273d4..a02ea170db 100644 --- a/botorch/utils/probability/lin_ess.py +++ b/botorch/utils/probability/lin_ess.py @@ -29,7 +29,6 @@ from torch import Tensor _twopi = 2.0 * math.pi -_delta_theta = 1.0e-6 * _twopi class LinearEllipticalSliceSampler(PolytopeSampler): @@ -37,11 +36,11 @@ class LinearEllipticalSliceSampler(PolytopeSampler): TODOs: - clean up docstrings + - Add batch support, broadcasting over parallel chains. - optimize computations (if possible) Maybe TODOs: - Support degenerate domains (with zero volume)? - - Add batch support ? """ def __init__( @@ -52,6 +51,7 @@ def __init__( mean: Optional[Tensor] = None, covariance_matrix: Optional[Tensor] = None, covariance_root: Optional[Tensor] = None, + check_feasibility: bool = False, ) -> None: r"""Initialize LinearEllipticalSliceSampler. @@ -73,6 +73,9 @@ def __init__( distribution (if omitted, use the identity). covariance_root: A `d x k`-dim root of the covariance matrix such that covariance_root @ covariance_root.T = covariance_matrix. + check_feasibility: If True, raise an error if the sampling results in an + infeasible sample. This creates some overhead and so is switched off + by default. This sampler samples from a multivariante Normal `N(mean, covariance_matrix)` subject to linear domain constraints `A x <= b` (intersected with box bounds, @@ -107,6 +110,7 @@ def __init__( self._zero = torch.zeros(1, **tkwargs) self._nan = torch.tensor(float("nan"), **tkwargs) self._full_angular_range = torch.tensor([0.0, _twopi], **tkwargs) + self.check_feasibility = check_feasibility def draw(self, n: int = 1) -> Tuple[Tensor, Tensor]: r"""Draw samples. @@ -117,7 +121,7 @@ def draw(self, n: int = 1) -> Tuple[Tensor, Tensor]: Returns: A `n x d`-dim tensor of `n` samples. """ - # TODO: Do we need to do any thinnning or warm-up here? + # TODO: Should apply thinning in higher dimensions, can check step size. samples = torch.cat([self.step() for _ in range(n)], dim=-1) return samples.transpose(-1, -2) @@ -129,8 +133,18 @@ def step(self) -> Tensor: """ nu = self._sample_base_rv() theta = self._draw_angle(nu=nu) - self._x = self._get_cart_coords(nu=nu, theta=theta) - return self._x + x = self._get_cart_coords(nu=nu, theta=theta) + self._x[:] = x + if self.check_feasibility and (not self._is_feasible(self._x)): + Axmb = self.A @ self._x - self.b + violated_indices = Axmb > 0 + raise RuntimeError( + "Sampling resulted in infeasible point. \n\t- Number " + f"of violated constraints: {violated_indices.sum()}." + f"\n\t- Magnitude of violations: {Axmb[violated_indices]}" + "\n\t- If the error persists, please report this bug on GitHub." + ) + return x def _sample_base_rv(self) -> Tensor: r"""Sample a base random variable from N(mean, covariance_matrix). @@ -152,7 +166,7 @@ def _draw_angle(self, nu: Tensor) -> Tensor: nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)). Returns: - A + A `1`-dim Tensor containing the rotation angle (radians). """ rot_angle, rot_slices = self._find_rotated_intersections(nu) rot_lengths = rot_slices[:, 1] - rot_slices[:, 0] @@ -162,7 +176,7 @@ def _draw_angle(self, nu: Tensor) -> Tensor: 1, device=cum_lengths.device, dtype=cum_lengths.dtype ) idx = torch.searchsorted(cum_lengths, rnd_angle) - 1 - return rot_slices[idx, 0] + rnd_angle - cum_lengths[idx] + rot_angle + return (rot_slices[idx, 0] + rnd_angle + rot_angle) - cum_lengths[idx] def _get_cart_coords(self, nu: Tensor, theta: Tensor) -> Tensor: r"""Determine location on ellipsoid in cartesian coordinates. @@ -191,9 +205,16 @@ def _find_rotated_intersections(self, nu: Tensor) -> Tuple[Tensor, Tensor]: """ slices = self._find_active_intersections(nu) rot_angle = slices[0] - slices = slices - rot_angle - slices = torch.where(slices < 0, slices + _twopi, slices) - return rot_angle, slices.reshape(-1, 2) + slices = (slices - rot_angle).reshape(-1, 2) + # Ensuring that we don't sample within numerical precision of the boundaries + # due to resulting instabilities in the constraint satisfaction. + eps = 1e-6 if slices.dtype == torch.float32 else 1e-12 + eps = torch.tensor(eps, dtype=slices.dtype, device=slices.device) + eps = eps.minimum(slices.diff(dim=-1).abs() / 4) + slices = slices + torch.cat((eps, -eps), dim=-1) + # NOTE: The remainder call relies on the epsilon contraction, since the + # remainder of_twopi divided by _twopi is zero, not _twopi. + return rot_angle, slices.remainder(_twopi) def _find_active_intersections(self, nu: Tensor) -> Tensor: """ @@ -212,23 +233,15 @@ def _find_active_intersections(self, nu: Tensor) -> Tensor: slice sampling. """ theta = self._find_intersection_angles(nu) - active_directions = self._index_active( - nu=nu, theta=theta, delta_theta=_delta_theta + theta_active, delta_active = self._active_theta_and_delta( + nu=nu, + theta=theta, ) - theta_active = theta[active_directions.nonzero()] - delta_theta = _delta_theta - while theta_active.numel() % 2 == 1: - # Almost tangential ellipses, reduce delta_theta - delta_theta /= 10 - active_directions = self._index_active( - theta=theta, nu=nu, delta_theta=delta_theta - ) - theta_active = theta[active_directions.nonzero()] - if theta_active.numel() == 0: theta_active = self._full_angular_range - # TODO: What about `self.ellipse_in_domain = False` in the original code ?? - elif active_directions[active_directions.nonzero()][0] == -1: + # TODO: What about `self.ellipse_in_domain = False` in the original code? + elif delta_active[0] == -1: # ensuring that the first interval is feasible + theta_active = torch.cat((theta_active[1:], theta_active[:1])) return theta_active.view(-1) @@ -258,45 +271,55 @@ def _find_intersection_angles(self, nu: Tensor) -> Tensor: arg = -(self.b / r).squeeze() # Write NaNs if there is no intersection arg = torch.where(torch.absolute(arg) <= 1, arg, self._nan) - # Two solutions per linear constraint, shape of theta: (n_ineq_con, 2) acos_arg = torch.arccos(arg) theta = torch.stack((phi + acos_arg, phi - acos_arg), dim=-1) theta = theta[torch.isfinite(theta)] # shape: `n_ineq_con - num_not_finite` - theta = torch.where(theta < 0, theta + _twopi, theta) # [0, 2*pi] - + theta = torch.where(theta < 0, theta + _twopi, theta) # in [0, 2*pi] return torch.sort(theta).values - def _index_active( - self, nu: Tensor, theta: Tensor, delta_theta: float = _delta_theta - ) -> Tensor: + def _active_theta_and_delta(self, nu: Tensor, theta: Tensor) -> Tensor: r"""Determine active indices. Args: nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)). - theta: An `M`-dim tensor of intersection angles. - delta_theta: A small perturbation to be used for determining whether - intersections are at the boundary of the integration domain. + theta: A sorted `M`-dim tensor of intersection angles in [0, 2pi]. Returns: - An `M`-dim tensor with elements taking on values in {-1, 0, 1}. - A non-zero value indicates whether the associated intersection angle - is an active constraint. For active constraints, the sign indicates - the direction of the relevant domain (i.e. +1 (-1) means that - increasing (decreasing) the angle renders the sample feasible). + A tuple of Tensors of active constraint intersection angles `theta_active`, + and the change in the feasibility of the points on the ellipse on the left + and right of the active intersection angles `delta_active`. `delta_active` + is is negative if decreasing the angle renders the sample feasible, and + positive if increasing the angle renders the sample feasible. """ - samples_pos = self._get_cart_coords(nu=nu, theta=theta + delta_theta) - samples_neg = self._get_cart_coords(nu=nu, theta=theta - delta_theta) - pos_diffs = self._is_feasible(samples_pos) - neg_diffs = self._is_feasible(samples_neg) - # We don't use bit-wise XOR here since we need the signs of the directions - return pos_diffs.to(nu) - neg_diffs.to(nu) + # In order to determine if an angle that gives rise to an intersection with a + # constraint boundary leads to a change in the feasibility of the solution, + # we evaluate the constraints on the midpoint of the intersection angles. + # This gets rid of the `delta_theta` parameter in the original implementation, + # which cannot be set universally since it can be both 1) too large, when + # the distance in adjacent intersection angles is small, and 2) too small, + # when it approaches the numerical precision limit. + # The implementation below solves both problems and gets rid of the parameter. + if len(theta) < 2: # if we have no or only a tangential intersection + theta_active = torch.tensor([], dtype=theta.dtype, device=theta.device) + delta_active = torch.tensor([], dtype=int, device=theta.device) + return theta_active, delta_active + theta_mid = (theta[:-1] + theta[1:]) / 2 # midpoints of intersection angles + last_mid = (theta[:1] + theta[-1:] + _twopi) / 2 + last_mid = last_mid.where(last_mid < _twopi, last_mid - _twopi) + theta_mid = torch.cat((last_mid, theta_mid, last_mid), dim=0) + samples_mid = self._get_cart_coords(nu=nu, theta=theta_mid) + delta_feasibility = self._is_feasible(samples_mid).to(dtype=int).diff() + active_indices = delta_feasibility.nonzero() + return theta[active_indices], delta_feasibility[active_indices] def _is_feasible(self, points: Tensor) -> Tensor: - r""" + r"""Returns a Boolean tensor indicating whether the `points` are feasible, + i.e. they satisfy `A @ points <= b`, where `(A, b)` are the tensors passed + as the `inequality_constraints` to the constructor of the sampler. Args: - points: A `M x d`-dim tensor of points. + points: A `d x M`-dim tensor of points. Returns: An `M`-dim binary tensor where `True` indicates that the associated diff --git a/test/utils/probability/test_lin_ess.py b/test/utils/probability/test_lin_ess.py index 76b60686f5..a340387993 100644 --- a/test/utils/probability/test_lin_ess.py +++ b/test/utils/probability/test_lin_ess.py @@ -6,6 +6,8 @@ from __future__ import annotations +from unittest.mock import patch + import torch from botorch.exceptions.errors import BotorchError from botorch.utils.probability.lin_ess import LinearEllipticalSliceSampler @@ -158,14 +160,16 @@ def test_bivariate(self): self.assertFalse(torch.equal(sampler._x, sampler.x0)) def test_multivariate(self): - d = 3 - lower_bound = 1 for dtype in (torch.float, torch.double): + d = 3 tkwargs = {"device": self.device, "dtype": dtype} # special case: N(0, I) truncated to greater than lower_bound A = -torch.eye(d, **tkwargs) + lower_bound = 1 b = -torch.full((d, 1), lower_bound, **tkwargs) - sampler = LinearEllipticalSliceSampler(inequality_constraints=(A, b)) + sampler = LinearEllipticalSliceSampler( + inequality_constraints=(A, b), check_feasibility=True + ) self.assertIsNone(sampler._mean) self.assertIsNone(sampler._covariance_root) self.assertTrue(torch.all(sampler._is_feasible(sampler.x0))) @@ -189,29 +193,72 @@ def test_multivariate(self): self.assertFalse(torch.equal(sampler._x, sampler.x0)) # two special cases of _find_intersection_angles below: - # testing _find_intersection_angles with a proposal "nu" + # 1) testing _find_intersection_angles with a proposal "nu" # that ensures that the full ellipse is feasible - # NOTE: this test passes even though the full ellipse might - # not be feasible, which should be investigated further. - # However, this case is unlikely to be of much practical - # importance, as sampling a bound that is *exactly* on the - # constraint boundary is highly unlikely. - nu = torch.full((d, 1), lower_bound, **tkwargs) + # setting lower bound below the mean to ensure there's no intersection + lower_bound = -2 + b = -torch.full((d, 1), lower_bound, **tkwargs) + nu = torch.full((d, 1), lower_bound + 1, **tkwargs) sampler = LinearEllipticalSliceSampler( - interior_point=nu, inequality_constraints=(A, b) + interior_point=nu, + inequality_constraints=(A, b), + check_feasibility=True, ) - nu = torch.tensor([[-0.9199], [1.3555], [1.3738]], **tkwargs) + nu = torch.full((d, 1), lower_bound + 2, **tkwargs) theta_active = sampler._find_active_intersections(nu) self.assertTrue( torch.equal(theta_active, sampler._full_angular_range.view(-1)) ) + rot_angle, slices = sampler._find_rotated_intersections(nu) + self.assertEqual(rot_angle, 0.0) + self.assertAllClose( + slices, torch.tensor([[0.0, 2 * torch.pi]], **tkwargs), atol=1e-6 + ) - # testing tangential intersection of ellipse with constraint + # 2) testing tangential intersection of ellipse with constraint nu = torch.full((d, 1), lower_bound, **tkwargs) sampler = LinearEllipticalSliceSampler( - interior_point=nu, inequality_constraints=(A, b) + interior_point=nu, + inequality_constraints=(A, b), + check_feasibility=True, ) - nu = torch.full((d, 1), lower_bound, **tkwargs) - nu[1] += 1 + nu = torch.full((d, 1), lower_bound + 1, **tkwargs) + # nu[1] += 1 theta_active = sampler._find_active_intersections(nu) self.assertTrue(theta_active.numel() % 2 == 0) + + # testing error message for infeasible sample + sampler.check_feasibility = True + infeasible_x = torch.full((d, 1), lower_bound - 1, **tkwargs) + with patch.object( + sampler, "_draw_angle", return_value=torch.tensor(0.0, **tkwargs) + ): + with patch.object( + sampler, + "_get_cart_coords", + return_value=infeasible_x, + ): + with self.assertRaisesRegex( + RuntimeError, "Sampling resulted in infeasible point" + ): + sampler.step() + + # high dimensional test case + d = 128 + # this encodes order constraints on all d variables: Ax < b + # x[i] < x[i + 1] + A = torch.zeros(d - 1, d, **tkwargs) + for i in range(d - 1): + A[i, i] = 1 + A[i, i + 1] = -1 + b = torch.zeros(d - 1, 1, **tkwargs) + + interior_point = torch.arange(d, **tkwargs).unsqueeze(-1) / d - 1 / 2 + sampler = LinearEllipticalSliceSampler( + inequality_constraints=(A, b), + interior_point=interior_point, + check_feasibility=True, + ) + X_high_d = sampler.draw(n=16) + self.assertEqual(X_high_d.shape, torch.Size([16, d])) + self.assertTrue(sampler._is_feasible(X_high_d.T).all())