Skip to content

Commit

Permalink
LinearEllipticalSliceSampler Robustness Improvements (#1859)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1859

This commit improves the robustness of the linear elliptical slice sampler, primarily with two modifications:
1) A rewrite of the computation of the angles of the ellipse that lead to intersections with the constraint boundaries that gets rid of the `delta_theta` parameter which cannot be set universally without sacrificing either correctness or causing errors due to floating point imprecisions.
2) Contracting the feasible slices of the ellipse by an amount close to the numerical precision to guarantee that the resulting samples satisfy the constraints numerically.

The commit also introduces a high dimensional test case that enforces the monotonicity of adjacent elements of the sample vectors, which originally led to the discovery of all the issues that are fixed by the above steps.

Reviewed By: Balandat

Differential Revision: D46422349

fbshipit-source-id: 33bd16edf527f3dfa49aeff5c7aaaf889b43099b
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jun 9, 2023
1 parent 30eddaf commit ef52ea9
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 62 deletions.
115 changes: 69 additions & 46 deletions botorch/utils/probability/lin_ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,18 @@
from torch import Tensor

_twopi = 2.0 * math.pi
_delta_theta = 1.0e-6 * _twopi


class LinearEllipticalSliceSampler(PolytopeSampler):
r"""Linear Elliptical Slice Sampler.
TODOs:
- clean up docstrings
- Add batch support, broadcasting over parallel chains.
- optimize computations (if possible)
Maybe TODOs:
- Support degenerate domains (with zero volume)?
- Add batch support ?
"""

def __init__(
Expand All @@ -52,6 +51,7 @@ def __init__(
mean: Optional[Tensor] = None,
covariance_matrix: Optional[Tensor] = None,
covariance_root: Optional[Tensor] = None,
check_feasibility: bool = False,
) -> None:
r"""Initialize LinearEllipticalSliceSampler.
Expand All @@ -73,6 +73,9 @@ def __init__(
distribution (if omitted, use the identity).
covariance_root: A `d x k`-dim root of the covariance matrix such that
covariance_root @ covariance_root.T = covariance_matrix.
check_feasibility: If True, raise an error if the sampling results in an
infeasible sample. This creates some overhead and so is switched off
by default.
This sampler samples from a multivariante Normal `N(mean, covariance_matrix)`
subject to linear domain constraints `A x <= b` (intersected with box bounds,
Expand Down Expand Up @@ -107,6 +110,7 @@ def __init__(
self._zero = torch.zeros(1, **tkwargs)
self._nan = torch.tensor(float("nan"), **tkwargs)
self._full_angular_range = torch.tensor([0.0, _twopi], **tkwargs)
self.check_feasibility = check_feasibility

def draw(self, n: int = 1) -> Tuple[Tensor, Tensor]:
r"""Draw samples.
Expand All @@ -117,7 +121,7 @@ def draw(self, n: int = 1) -> Tuple[Tensor, Tensor]:
Returns:
A `n x d`-dim tensor of `n` samples.
"""
# TODO: Do we need to do any thinnning or warm-up here?
# TODO: Should apply thinning in higher dimensions, can check step size.
samples = torch.cat([self.step() for _ in range(n)], dim=-1)
return samples.transpose(-1, -2)

Expand All @@ -129,8 +133,18 @@ def step(self) -> Tensor:
"""
nu = self._sample_base_rv()
theta = self._draw_angle(nu=nu)
self._x = self._get_cart_coords(nu=nu, theta=theta)
return self._x
x = self._get_cart_coords(nu=nu, theta=theta)
self._x[:] = x
if self.check_feasibility and (not self._is_feasible(self._x)):
Axmb = self.A @ self._x - self.b
violated_indices = Axmb > 0
raise RuntimeError(
"Sampling resulted in infeasible point. \n\t- Number "
f"of violated constraints: {violated_indices.sum()}."
f"\n\t- Magnitude of violations: {Axmb[violated_indices]}"
"\n\t- If the error persists, please report this bug on GitHub."
)
return x

def _sample_base_rv(self) -> Tensor:
r"""Sample a base random variable from N(mean, covariance_matrix).
Expand All @@ -152,7 +166,7 @@ def _draw_angle(self, nu: Tensor) -> Tensor:
nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)).
Returns:
A
A `1`-dim Tensor containing the rotation angle (radians).
"""
rot_angle, rot_slices = self._find_rotated_intersections(nu)
rot_lengths = rot_slices[:, 1] - rot_slices[:, 0]
Expand All @@ -162,7 +176,7 @@ def _draw_angle(self, nu: Tensor) -> Tensor:
1, device=cum_lengths.device, dtype=cum_lengths.dtype
)
idx = torch.searchsorted(cum_lengths, rnd_angle) - 1
return rot_slices[idx, 0] + rnd_angle - cum_lengths[idx] + rot_angle
return (rot_slices[idx, 0] + rnd_angle + rot_angle) - cum_lengths[idx]

def _get_cart_coords(self, nu: Tensor, theta: Tensor) -> Tensor:
r"""Determine location on ellipsoid in cartesian coordinates.
Expand Down Expand Up @@ -191,9 +205,16 @@ def _find_rotated_intersections(self, nu: Tensor) -> Tuple[Tensor, Tensor]:
"""
slices = self._find_active_intersections(nu)
rot_angle = slices[0]
slices = slices - rot_angle
slices = torch.where(slices < 0, slices + _twopi, slices)
return rot_angle, slices.reshape(-1, 2)
slices = (slices - rot_angle).reshape(-1, 2)
# Ensuring that we don't sample within numerical precision of the boundaries
# due to resulting instabilities in the constraint satisfaction.
eps = 1e-6 if slices.dtype == torch.float32 else 1e-12
eps = torch.tensor(eps, dtype=slices.dtype, device=slices.device)
eps = eps.minimum(slices.diff(dim=-1).abs() / 4)
slices = slices + torch.cat((eps, -eps), dim=-1)
# NOTE: The remainder call relies on the epsilon contraction, since the
# remainder of_twopi divided by _twopi is zero, not _twopi.
return rot_angle, slices.remainder(_twopi)

def _find_active_intersections(self, nu: Tensor) -> Tensor:
"""
Expand All @@ -212,23 +233,15 @@ def _find_active_intersections(self, nu: Tensor) -> Tensor:
slice sampling.
"""
theta = self._find_intersection_angles(nu)
active_directions = self._index_active(
nu=nu, theta=theta, delta_theta=_delta_theta
theta_active, delta_active = self._active_theta_and_delta(
nu=nu,
theta=theta,
)
theta_active = theta[active_directions.nonzero()]
delta_theta = _delta_theta
while theta_active.numel() % 2 == 1:
# Almost tangential ellipses, reduce delta_theta
delta_theta /= 10
active_directions = self._index_active(
theta=theta, nu=nu, delta_theta=delta_theta
)
theta_active = theta[active_directions.nonzero()]

if theta_active.numel() == 0:
theta_active = self._full_angular_range
# TODO: What about `self.ellipse_in_domain = False` in the original code ??
elif active_directions[active_directions.nonzero()][0] == -1:
# TODO: What about `self.ellipse_in_domain = False` in the original code?
elif delta_active[0] == -1: # ensuring that the first interval is feasible

theta_active = torch.cat((theta_active[1:], theta_active[:1]))

return theta_active.view(-1)
Expand Down Expand Up @@ -258,45 +271,55 @@ def _find_intersection_angles(self, nu: Tensor) -> Tensor:
arg = -(self.b / r).squeeze()
# Write NaNs if there is no intersection
arg = torch.where(torch.absolute(arg) <= 1, arg, self._nan)

# Two solutions per linear constraint, shape of theta: (n_ineq_con, 2)
acos_arg = torch.arccos(arg)
theta = torch.stack((phi + acos_arg, phi - acos_arg), dim=-1)
theta = theta[torch.isfinite(theta)] # shape: `n_ineq_con - num_not_finite`
theta = torch.where(theta < 0, theta + _twopi, theta) # [0, 2*pi]

theta = torch.where(theta < 0, theta + _twopi, theta) # in [0, 2*pi]
return torch.sort(theta).values

def _index_active(
self, nu: Tensor, theta: Tensor, delta_theta: float = _delta_theta
) -> Tensor:
def _active_theta_and_delta(self, nu: Tensor, theta: Tensor) -> Tensor:
r"""Determine active indices.
Args:
nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)).
theta: An `M`-dim tensor of intersection angles.
delta_theta: A small perturbation to be used for determining whether
intersections are at the boundary of the integration domain.
theta: A sorted `M`-dim tensor of intersection angles in [0, 2pi].
Returns:
An `M`-dim tensor with elements taking on values in {-1, 0, 1}.
A non-zero value indicates whether the associated intersection angle
is an active constraint. For active constraints, the sign indicates
the direction of the relevant domain (i.e. +1 (-1) means that
increasing (decreasing) the angle renders the sample feasible).
A tuple of Tensors of active constraint intersection angles `theta_active`,
and the change in the feasibility of the points on the ellipse on the left
and right of the active intersection angles `delta_active`. `delta_active`
is is negative if decreasing the angle renders the sample feasible, and
positive if increasing the angle renders the sample feasible.
"""
samples_pos = self._get_cart_coords(nu=nu, theta=theta + delta_theta)
samples_neg = self._get_cart_coords(nu=nu, theta=theta - delta_theta)
pos_diffs = self._is_feasible(samples_pos)
neg_diffs = self._is_feasible(samples_neg)
# We don't use bit-wise XOR here since we need the signs of the directions
return pos_diffs.to(nu) - neg_diffs.to(nu)
# In order to determine if an angle that gives rise to an intersection with a
# constraint boundary leads to a change in the feasibility of the solution,
# we evaluate the constraints on the midpoint of the intersection angles.
# This gets rid of the `delta_theta` parameter in the original implementation,
# which cannot be set universally since it can be both 1) too large, when
# the distance in adjacent intersection angles is small, and 2) too small,
# when it approaches the numerical precision limit.
# The implementation below solves both problems and gets rid of the parameter.
if len(theta) < 2: # if we have no or only a tangential intersection
theta_active = torch.tensor([], dtype=theta.dtype, device=theta.device)
delta_active = torch.tensor([], dtype=int, device=theta.device)
return theta_active, delta_active
theta_mid = (theta[:-1] + theta[1:]) / 2 # midpoints of intersection angles
last_mid = (theta[:1] + theta[-1:] + _twopi) / 2
last_mid = last_mid.where(last_mid < _twopi, last_mid - _twopi)
theta_mid = torch.cat((last_mid, theta_mid, last_mid), dim=0)
samples_mid = self._get_cart_coords(nu=nu, theta=theta_mid)
delta_feasibility = self._is_feasible(samples_mid).to(dtype=int).diff()
active_indices = delta_feasibility.nonzero()
return theta[active_indices], delta_feasibility[active_indices]

def _is_feasible(self, points: Tensor) -> Tensor:
r"""
r"""Returns a Boolean tensor indicating whether the `points` are feasible,
i.e. they satisfy `A @ points <= b`, where `(A, b)` are the tensors passed
as the `inequality_constraints` to the constructor of the sampler.
Args:
points: A `M x d`-dim tensor of points.
points: A `d x M`-dim tensor of points.
Returns:
An `M`-dim binary tensor where `True` indicates that the associated
Expand Down
79 changes: 63 additions & 16 deletions test/utils/probability/test_lin_ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from __future__ import annotations

from unittest.mock import patch

import torch
from botorch.exceptions.errors import BotorchError
from botorch.utils.probability.lin_ess import LinearEllipticalSliceSampler
Expand Down Expand Up @@ -158,14 +160,16 @@ def test_bivariate(self):
self.assertFalse(torch.equal(sampler._x, sampler.x0))

def test_multivariate(self):
d = 3
lower_bound = 1
for dtype in (torch.float, torch.double):
d = 3
tkwargs = {"device": self.device, "dtype": dtype}
# special case: N(0, I) truncated to greater than lower_bound
A = -torch.eye(d, **tkwargs)
lower_bound = 1
b = -torch.full((d, 1), lower_bound, **tkwargs)
sampler = LinearEllipticalSliceSampler(inequality_constraints=(A, b))
sampler = LinearEllipticalSliceSampler(
inequality_constraints=(A, b), check_feasibility=True
)
self.assertIsNone(sampler._mean)
self.assertIsNone(sampler._covariance_root)
self.assertTrue(torch.all(sampler._is_feasible(sampler.x0)))
Expand All @@ -189,29 +193,72 @@ def test_multivariate(self):
self.assertFalse(torch.equal(sampler._x, sampler.x0))

# two special cases of _find_intersection_angles below:
# testing _find_intersection_angles with a proposal "nu"
# 1) testing _find_intersection_angles with a proposal "nu"
# that ensures that the full ellipse is feasible
# NOTE: this test passes even though the full ellipse might
# not be feasible, which should be investigated further.
# However, this case is unlikely to be of much practical
# importance, as sampling a bound that is *exactly* on the
# constraint boundary is highly unlikely.
nu = torch.full((d, 1), lower_bound, **tkwargs)
# setting lower bound below the mean to ensure there's no intersection
lower_bound = -2
b = -torch.full((d, 1), lower_bound, **tkwargs)
nu = torch.full((d, 1), lower_bound + 1, **tkwargs)
sampler = LinearEllipticalSliceSampler(
interior_point=nu, inequality_constraints=(A, b)
interior_point=nu,
inequality_constraints=(A, b),
check_feasibility=True,
)
nu = torch.tensor([[-0.9199], [1.3555], [1.3738]], **tkwargs)
nu = torch.full((d, 1), lower_bound + 2, **tkwargs)
theta_active = sampler._find_active_intersections(nu)
self.assertTrue(
torch.equal(theta_active, sampler._full_angular_range.view(-1))
)
rot_angle, slices = sampler._find_rotated_intersections(nu)
self.assertEqual(rot_angle, 0.0)
self.assertAllClose(
slices, torch.tensor([[0.0, 2 * torch.pi]], **tkwargs), atol=1e-6
)

# testing tangential intersection of ellipse with constraint
# 2) testing tangential intersection of ellipse with constraint
nu = torch.full((d, 1), lower_bound, **tkwargs)
sampler = LinearEllipticalSliceSampler(
interior_point=nu, inequality_constraints=(A, b)
interior_point=nu,
inequality_constraints=(A, b),
check_feasibility=True,
)
nu = torch.full((d, 1), lower_bound, **tkwargs)
nu[1] += 1
nu = torch.full((d, 1), lower_bound + 1, **tkwargs)
# nu[1] += 1
theta_active = sampler._find_active_intersections(nu)
self.assertTrue(theta_active.numel() % 2 == 0)

# testing error message for infeasible sample
sampler.check_feasibility = True
infeasible_x = torch.full((d, 1), lower_bound - 1, **tkwargs)
with patch.object(
sampler, "_draw_angle", return_value=torch.tensor(0.0, **tkwargs)
):
with patch.object(
sampler,
"_get_cart_coords",
return_value=infeasible_x,
):
with self.assertRaisesRegex(
RuntimeError, "Sampling resulted in infeasible point"
):
sampler.step()

# high dimensional test case
d = 128
# this encodes order constraints on all d variables: Ax < b
# x[i] < x[i + 1]
A = torch.zeros(d - 1, d, **tkwargs)
for i in range(d - 1):
A[i, i] = 1
A[i, i + 1] = -1
b = torch.zeros(d - 1, 1, **tkwargs)

interior_point = torch.arange(d, **tkwargs).unsqueeze(-1) / d - 1 / 2
sampler = LinearEllipticalSliceSampler(
inequality_constraints=(A, b),
interior_point=interior_point,
check_feasibility=True,
)
X_high_d = sampler.draw(n=16)
self.assertEqual(X_high_d.shape, torch.Size([16, d]))
self.assertTrue(sampler._is_feasible(X_high_d.T).all())

0 comments on commit ef52ea9

Please sign in to comment.