Skip to content

Commit

Permalink
LinearEllipticalSliceSampler with fixed feature constraints (#1883)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1883

This adds support for fixed feature equality constraints to `LinearEllipticalSliceSampler`.

Reviewed By: Balandat

Differential Revision: D46613288

fbshipit-source-id: 6c292e269723a75f4d4966e257b34c8ad6cafe53
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jun 16, 2023
1 parent 00fa4a8 commit 1bfdfaf
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 70 deletions.
237 changes: 190 additions & 47 deletions botorch/utils/probability/lin_ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,25 @@
This implementation is based (with multiple changes / optimiations) on
the following implementations based on the algorithm in [Gessner2020]_:
https://github.com/alpiges/LinConGauss
https://github.com/wjmaddox/pytorch_ess
- https://github.com/alpiges/LinConGauss
- https://github.com/wjmaddox/pytorch_ess
The implementation here differentiates itself from the original implementations with:
1) Support for fixed feature equality constraints.
2) Support for non-standard Normal distributions.
3) Numerical stability improvements, especially relevant for high-dimensional cases.
Notably, this implementation does not rely on an adaptive `delta_theta` parameter in
order to determine if two neighboring constraint intersection angles `theta` lead to a
change in the feasibility of the sample. This both simplifies the implementation and
makes it more robust to numerical imprecisions when two constraint intersection angles
are close to each other.
"""

from __future__ import annotations

import math
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
from botorch.utils.sampling import PolytopeSampler
Expand All @@ -34,20 +45,18 @@
class LinearEllipticalSliceSampler(PolytopeSampler):
r"""Linear Elliptical Slice Sampler.
TODOs:
- clean up docstrings
Ideas:
- Add batch support, broadcasting over parallel chains.
- optimize computations (if possible)
Maybe TODOs:
- Support degenerate domains (with zero volume)?
- Optimize computations if possible, potentially with torch.compile.
- Extend fixed features constraint to general linear equality constraints.
"""

def __init__(
self,
inequality_constraints: Optional[Tuple[Tensor, Tensor]] = None,
bounds: Optional[Tensor] = None,
interior_point: Optional[Tensor] = None,
fixed_indices: Optional[Union[List[int], Tensor]] = None,
mean: Optional[Tensor] = None,
covariance_matrix: Optional[Tensor] = None,
covariance_root: Optional[Tensor] = None,
Expand All @@ -70,6 +79,10 @@ def __init__(
automatically by solving a Linear Program. Note: It is crucial that
the point lie in the interior of the feasible set (rather than on the
boundary), otherwise the sampler will produce invalid samples.
fixed_indices: Integer list or `d`-dim Tensor representing the indices of
dimensions that are constrained to be fixed to the values specified in
the `interior_point`, which is required to be passed in conjunction with
`fixed_indices`.
mean: The `d x 1`-dim mean of the MVN distribution (if omitted, use zero).
covariance_matrix: The `d x d`-dim covariance matrix of the MVN
distribution (if omitted, use the identity).
Expand All @@ -94,13 +107,29 @@ def __init__(
bounds=bounds,
)
tkwargs = {"device": self.x0.device, "dtype": self.x0.dtype}
self._mean = mean
if covariance_matrix is not None:
if covariance_root is not None:
raise ValueError(
"Provide either covariance_matrix or covariance_root, not both."
)
if covariance_matrix is not None and covariance_root is not None:
raise ValueError(
"Provide either covariance_matrix or covariance_root, not both."
)

# can't unpack inequality constraints directly if bounds are passed
A, b = self.A, self.b
self._Az, self._bz = A, b
self._is_fixed, self._not_fixed = None, None
if fixed_indices is not None:
mean, covariance_matrix = self._fixed_features_initialization(
A=A,
b=b,
interior_point=interior_point,
fixed_indices=fixed_indices,
mean=mean,
covariance_matrix=covariance_matrix,
covariance_root=covariance_root,
)

self._mean = mean
# Have to delay factorization until after fixed features initialization.
if covariance_matrix is not None: # implies root is None
covariance_root, info = torch.linalg.cholesky_ex(covariance_matrix)
not_psd = torch.any(info)
if not_psd:
Expand All @@ -110,19 +139,12 @@ def __init__(
)
self._covariance_root = covariance_root

# For non-standard mean and covariance, we're going to rewrite the problem as
# sampling from a standard normal distribution subject to modified constraints.
# Ax - b = A @ (covar_root @ z + mean) - b
# = (A @ covar_root) @ z - (b - A @ mean)
# = _Az @ z - _bz
self._Az = (
self.A if self._covariance_root is None else self.A @ self._covariance_root
)
self._bz = self.b if self._mean is None else self.b - self.A @ self._mean
# Rewrite the constraints as a system that constrains a standard Normal.
self._standardization_initialization()

# state of the sampler ("current point")
self._x = self.x0.clone()
self._z = self._standardize(self._x)
self._z = self._transform(self._x)

# We will need the following repeatedly, let's allocate them once
self._zero = torch.zeros(1, **tkwargs)
Expand All @@ -135,6 +157,65 @@ def __init__(
self.draw(burnin)
self.thinning = thinning

def _fixed_features_initialization(
self,
A: Tensor,
b: Tensor,
interior_point: Optional[Tensor],
fixed_indices: Union[List[int], Tensor],
mean: Optional[Tensor],
covariance_matrix: Optional[Tensor],
covariance_root: Optional[Tensor],
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
"""Modifies the constraint system (A, b) due to fixed indices and assigns
the modified constraints system to `self._Az`, `self._bz`. NOTE: Needs to be
called prior to `self._standardization_initialization` in the constructor.
Returns:
Tuple of `mean` and `covariance_matrix` tensors of the non-fixed dimensions.
"""
if interior_point is None:
raise ValueError(
"If `fixed_indices` are provided, an interior point must also be "
"provided in order to infer feasible values of the fixed features."
)
if covariance_root is not None:
raise ValueError(
"Provide either covariance_root or fixed_indices, not both."
)
d = interior_point.shape[0]
is_fixed, not_fixed = get_index_tensors(fixed_indices=fixed_indices, d=d)
self._is_fixed = is_fixed
self._not_fixed = not_fixed
# Transforming constraint system to incorporate fixed features:
# A @ x - b = (A[:, fixed] @ x[fixed] + A[:, not fixed] @ x[not fixed]) - b
# = A[:, not fixed] @ x[not fixed] - (b - A[:, fixed] @ x[fixed])
# = Az @ z - bz
self._Az = A[:, not_fixed]
self._bz = b - A[:, is_fixed] @ interior_point[is_fixed]
if mean is not None:
mean = mean[not_fixed]
if covariance_matrix is not None: # subselect active dimensions
covariance_matrix = covariance_matrix[
not_fixed.unsqueeze(-1), not_fixed.unsqueeze(0)
]
return mean, covariance_matrix

def _standardization_initialization(self) -> None:
"""For non-standard mean and covariance, we're going to rewrite the problem as
sampling from a standard normal distribution subject to modified constraints.
A @ x - b = A @ (covar_root @ z + mean) - b
= (A @ covar_root) @ z - (b - A @ mean)
= _Az @ z - _bz
NOTE: We need to standardize bz before Az in the following, because it relies
on the untransformed Az. We can't simply use A instead because Az might have
been subject to the fixed features transformation.
"""
if self._mean is not None:
self._bz = self._bz - self._Az @ self._mean
if self._covariance_root is not None:
self._Az = self._Az @ self._covariance_root

@property
def lifetime_samples(self) -> int:
"""The total number of samples generated by the sampler during its lifetime."""
Expand All @@ -156,22 +237,6 @@ def draw(self, n: int = 1) -> Tuple[Tensor, Tensor]:
samples.append(self.step())
return torch.cat(samples, dim=-1).transpose(-1, -2)

def _unstandardize(self, z: Tensor) -> Tensor:
x = z
if self._covariance_root is not None:
x = self._covariance_root @ x
if self._mean is not None:
x = x + self._mean
return x

def _standardize(self, x: Tensor) -> Tensor:
z = x
if self._mean is not None:
z = z - self._mean
if self._covariance_root is not None:
z = torch.linalg.solve_triangular(self._covariance_root, z, upper=False)
return z

def step(self) -> Tensor:
r"""Take a step, return the new sample, update the internal state.
Expand All @@ -182,7 +247,7 @@ def step(self) -> Tensor:
theta = self._draw_angle(nu=nu)
z = self._get_cart_coords(nu=nu, theta=theta)
self._z[:] = z
x = self._unstandardize(z)
x = self._untransform(z)
self._x[:] = x
self._lifetime_samples += 1
if self.check_feasibility and (not self._is_feasible(self._x)):
Expand Down Expand Up @@ -347,25 +412,103 @@ def _active_theta_and_delta(self, nu: Tensor, theta: Tensor) -> Tensor:
theta_mid = torch.cat((last_mid, theta_mid, last_mid), dim=0)
samples_mid = self._get_cart_coords(nu=nu, theta=theta_mid)
delta_feasibility = (
self._is_feasible(samples_mid, standardized=True).to(dtype=int).diff()
self._is_feasible(samples_mid, transformed=True).to(dtype=int).diff()
)
active_indices = delta_feasibility.nonzero()
return theta[active_indices], delta_feasibility[active_indices]

def _is_feasible(self, points: Tensor, standardized: bool = False) -> Tensor:
def _is_feasible(self, points: Tensor, transformed: bool = False) -> Tensor:
r"""Returns a Boolean tensor indicating whether the `points` are feasible,
i.e. they satisfy `A @ points <= b`, where `(A, b)` are the tensors passed
as the `inequality_constraints` to the constructor of the sampler.
Args:
points: A `d x M`-dim tensor of points.
standardized: Wether points are assumed to be standardized by a change of
transformed: Wether points are assumed to be transformed by a change of
basis, which means feasibility should be computed based on the
standardized constraint system (_A, _b), instead of (A, b).
transformed constraint system (_Az, _bz), instead of (A, b).
Returns:
An `M`-dim binary tensor where `True` indicates that the associated
point is feasible.
"""
A, b = (self._Az, self._bz) if standardized else (self.A, self.b)
A, b = (self._Az, self._bz) if transformed else (self.A, self.b)
return (A @ points <= b).all(dim=0)

def _transform(self, x: Tensor) -> Tensor:
"""Transforms the input so that it is equivalent to a standard Normal variable
constrained with the modified system constraints (self._Az, self._bz).
Args:
x: The input tensor to be transformed, usually `d x 1`-dimensional.
Returns:
A `d x 1`-dimensional tensor of transformed values subject to the modified
system of constraints.
"""
if self._not_fixed is not None:
x = x[self._not_fixed]
return self._standardize(x)

def _untransform(self, z: Tensor) -> Tensor:
"""The inverse transform of the `_transform`, i.e. maps `z` back to the original
space where it is subject to the original constraint system (self.A, self.b).
Args:
z: The transformed tensor to be un-transformed, usually `d x 1`-dimensional.
Returns:
A `d x 1`-dimensional tensor of un-transformed values subject to the
original system of constraints.
"""
if self._is_fixed is None:
return self._unstandardize(z)
else:
x = self._x.clone() # _x already contains the fixed values
x[self._not_fixed] = self._unstandardize(z)
return x

def _standardize(self, x: Tensor) -> Tensor:
"""_transform helper standardizing the input `x`, which is assumed to be a
`d x 1`-dim Tensor, or a `len(self._not_fixed) x 1`-dim if there are fixed
indices.
"""
z = x
if self._mean is not None:
z = z - self._mean
if self._covariance_root is not None:
z = torch.linalg.solve_triangular(self._covariance_root, z, upper=False)
return z

def _unstandardize(self, z: Tensor) -> Tensor:
"""_untransform helper un-standardizing the input `z`, which is assumed to be a
`d x 1`-dim Tensor, or a `len(self._not_fixed) x 1`-dim if there are fixed
indices.
"""
x = z
if self._covariance_root is not None:
x = self._covariance_root @ x
if self._mean is not None:
x = x + self._mean
return x


def get_index_tensors(
fixed_indices: Union[List[int], Tensor], d: int
) -> Tuple[Tensor, Tensor]:
"""Converts `fixed_indices` to a `d`-dim integral Tensor that is True at indices
that are contained in `fixed_indices` and False otherwise.
Args:
fixed_indices: A list or Tensoro of integer indices to fix.
d: The dimensionality of the Tensors to be indexed.
Returns:
A Tuple of integral Tensors partitioning [1, d] into indices that are fixed
(first tensor) and non-fixed (second tensor).
"""
is_fixed = torch.as_tensor(fixed_indices)
dtype, device = is_fixed.dtype, is_fixed.device
dims = torch.arange(d, dtype=dtype, device=device)
not_fixed = torch.tensor([i for i in dims if i not in is_fixed])
return is_fixed, not_fixed
Loading

0 comments on commit 1bfdfaf

Please sign in to comment.