From 29d2b103b826c0394390f44449c4938d95378987 Mon Sep 17 00:00:00 2001 From: James Wilson Date: Mon, 26 Sep 2022 12:26:35 -0700 Subject: [PATCH] bvn, MVNXPB, TruncatedMultivariateNormal, and UnifiedSkewNormal (#1394) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1394 Introduces `utils/probability` submodule with the following offers: - `bvn`: Methods for computing bivariate normal probabilities and moments. - `MVNXPB`: Approximate solver for Multivariate Normal CDF. - `LinearEllipticalSliceSampler`: Class for sampling trMVN random variables. - `TruncatedMultivariateNormal`: Truncated multivariate normal Distribution class - `UnifiedSkewNormal`: Unified skew normal Distribution class Reviewed By: Balandat Differential Revision: D39326106 fbshipit-source-id: e605ff7e3b484f128fa0a86cb6088a85a96e5969 --- botorch/utils/constants.py | 36 ++ botorch/utils/probability/__init__.py | 25 + botorch/utils/probability/bvn.py | 300 ++++++++++++ botorch/utils/probability/lin_ess.py | 299 ++++++++++++ botorch/utils/probability/linalg.py | 208 +++++++++ botorch/utils/probability/mvnxpb.py | 432 ++++++++++++++++++ .../truncated_multivariate_normal.py | 148 ++++++ .../utils/probability/unified_skew_normal.py | 238 ++++++++++ botorch/utils/probability/utils.py | 164 +++++++ botorch/utils/safe_math.py | 50 ++ test/utils/probability/__init__.py | 5 + test/utils/probability/test_bvn.py | 247 ++++++++++ test/utils/probability/test_lin_ess.py | 162 +++++++ test/utils/probability/test_linalg.py | 121 +++++ test/utils/probability/test_mvnxpb.py | 316 +++++++++++++ .../test_truncated_multivariate_normal.py | 149 ++++++ .../probability/test_unified_skew_normal.py | 182 ++++++++ test/utils/probability/test_utils.py | 106 +++++ test/utils/test_constants.py | 46 ++ test/utils/test_safe_math.py | 178 ++++++++ 20 files changed, 3412 insertions(+) create mode 100644 botorch/utils/constants.py create mode 100644 botorch/utils/probability/__init__.py create mode 100644 botorch/utils/probability/bvn.py create mode 100644 botorch/utils/probability/lin_ess.py create mode 100644 botorch/utils/probability/linalg.py create mode 100644 botorch/utils/probability/mvnxpb.py create mode 100644 botorch/utils/probability/truncated_multivariate_normal.py create mode 100644 botorch/utils/probability/unified_skew_normal.py create mode 100644 botorch/utils/probability/utils.py create mode 100644 botorch/utils/safe_math.py create mode 100644 test/utils/probability/__init__.py create mode 100644 test/utils/probability/test_bvn.py create mode 100644 test/utils/probability/test_lin_ess.py create mode 100644 test/utils/probability/test_linalg.py create mode 100644 test/utils/probability/test_mvnxpb.py create mode 100644 test/utils/probability/test_truncated_multivariate_normal.py create mode 100644 test/utils/probability/test_unified_skew_normal.py create mode 100644 test/utils/probability/test_utils.py create mode 100644 test/utils/test_constants.py create mode 100644 test/utils/test_safe_math.py diff --git a/botorch/utils/constants.py b/botorch/utils/constants.py new file mode 100644 index 0000000000..367e376d58 --- /dev/null +++ b/botorch/utils/constants.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from functools import lru_cache +from numbers import Number +from typing import Iterator, Optional, Tuple, Union + +import torch +from torch import Tensor + + +@lru_cache(maxsize=None) +def get_constants( + values: Union[Number, Iterator[Number]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> Union[Tensor, Tuple[Tensor, ...]]: + r"""Returns scalar-valued Tensors containing each of the given constants. + Used to expedite tensor operations involving scalar arithmetic. Note that + the returned Tensors should not be modified in-place.""" + if isinstance(values, Number): + return torch.full((), values, dtype=dtype, device=device) + + return tuple(torch.full((), val, dtype=dtype, device=device) for val in values) + + +def get_constants_like( + values: Union[Number, Iterator[Number]], + ref: Tensor, +) -> Union[Tensor, Iterator[Tensor]]: + return get_constants(values, device=ref.device, dtype=ref.dtype) diff --git a/botorch/utils/probability/__init__.py b/botorch/utils/probability/__init__.py new file mode 100644 index 0000000000..8d0c6981cd --- /dev/null +++ b/botorch/utils/probability/__init__.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from botorch.utils.probability.bvn import bvn, bvnmom +from botorch.utils.probability.lin_ess import LinearEllipticalSliceSampler +from botorch.utils.probability.mvnxpb import MVNXPB +from botorch.utils.probability.truncated_multivariate_normal import ( + TruncatedMultivariateNormal, +) +from botorch.utils.probability.unified_skew_normal import UnifiedSkewNormal +from botorch.utils.probability.utils import ndtr + + +__all__ = [ + "bvn", + "bvnmom", + "LinearEllipticalSliceSampler", + "MVNXPB", + "ndtr", + "TruncatedMultivariateNormal", + "UnifiedSkewNormal", +] diff --git a/botorch/utils/probability/bvn.py b/botorch/utils/probability/bvn.py new file mode 100644 index 0000000000..288e78015f --- /dev/null +++ b/botorch/utils/probability/bvn.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Methods for computing bivariate normal probabilities and statistics. + +.. [Drezner1990computation] + Z. Drezner and G. O. Wesolowsky. On the computation of the bivariate normal + integral. Journal of Statistical Computation and Simulation, 1990. + +.. [Genz2004bvnt] + A. Genz. Numerical computation of rectangular bivariate and trivariate normal and + t probabilities. Statistics and Computing, 2004. + +.. [Rosenbaum1961moments] + S. Rosenbaum. Moments of a Truncated Bivariate Normal Distribution. Journal of the + Royal Statistical Society (Series B), 1961. + +.. [Muthen1990moments] + B. Muthen. Moments of the censored and truncated bivariate normal distribution. + British Journal of Mathematical and Statistical Psychology, 1990. +""" + +from __future__ import annotations + +from math import pi as _pi +from typing import Optional, Tuple + +import torch +from botorch.exceptions import UnsupportedError +from botorch.utils.probability.utils import ( + case_dispatcher, + get_constants_like, + leggauss, + ndtr as Phi, + phi, + STANDARDIZED_RANGE, +) +from botorch.utils.safe_math import ( + div as safe_div, + exp as safe_exp, + mul as safe_mul, + sub as safe_sub, +) +from torch import Tensor + +# Some useful constants +_inf = float("inf") +_2pi = 2 * _pi +_sqrt_2pi = _2pi**0.5 +_inv_2pi = 1 / _2pi + + +def bvn(r: Tensor, xl: Tensor, yl: Tensor, xu: Tensor, yu: Tensor) -> Tensor: + r"""A function for computing bivariate normal probabilities. + + Calculates `P(xl < x < xu, yl < y < yu)` where `x` and `y` are bivariate normal with + unit variance and correlation coefficient `r`. See Section 2.4 of [Genz2004bvnt]_. + + This method uses a sign flip trick to improving numerical performance. Many of + `bvnu`s internal branches rely on evaluations `Phi(-bound)`. For `a < b < 0`, + the term `Phi(-a) - Phi(-b)` goes to zero faster than `Phi(b) - Phi(a)` because + `finfo(dtype).epsneg` is typically much larger than `finfo(dtype).tiny`. In these + cases, flipping the sign can prevent situations where `bvnu(...) - bvnu(...)` would + otherwise be zero due to round-off error. + + Args: + r: Tensor of correlation coefficients. + xl: Tensor of lower bounds for `x`, same shape as `r`. + yl: Tensor of lower bounds for `y`, same shape as `r`. + xu: Tensor of upper bounds for `x`, same shape as `r`. + yu: Tensor of upper bounds for `y`, same shape as `r`. + + Returns: + Tensor of probabilities `P(xl < x < xu, yl < y < yu)`. + + """ + if not (r.shape == xl.shape == xu.shape == yl.shape == yu.shape): + raise UnsupportedError("Arguments to `bvn` must have the same shape.") + + # Sign flip trick + _0, _1, _2 = get_constants_like(values=(0, 1, 2), ref=r) + flip_x = xl.abs() > xu # is xl more negative than xu is positive? + flip_y = yl.abs() > yu + flip = (flip_x & (~flip_y | yu.isinf())) | (flip_y & (~flip_x | xu.isinf())) + if flip.any(): # symmetric calls to `bvnu` below makes swapping bounds unnecessary + sign = _1 - _2 * flip.to(dtype=r.dtype) + xl = sign * xl # becomes `-xu` if flipped + xu = sign * xu # becomes `-xl` + yl = sign * yl # becomes `-yu` + yu = sign * yu # becomes `-yl` + + p = bvnu(r, xl, yl) - bvnu(r, xu, yl) - bvnu(r, xl, yu) + bvnu(r, xu, yu) + return p.clip(_0, _1) + + +def bvnu(r: Tensor, h: Tensor, k: Tensor) -> Tensor: + r"""Solves for `P(x > h, y > k)` where `x` and `y` are standard bivariate normal + random variables with correlation coefficient `r`. In [Genz2004bvnt]_, this is (1) + ``` + L(h, k, r) = P(x < -h, y < -k) + = 1/(a 2\pi) \int_{h}^{\infty} \int_{k}^{\infty} f(x, y, r) dy dx, + ``` + where `f(x, y, r) = e^{-1/(2a^2) (x^2 - 2rxy + y^2)}` and `a = (1 - r^2)^{1/2}`. + + [Genz2004bvnt]_ report the following integation scheme incurs a maximum of 5e-16 + error when run in double precision. For strongly correlated variables |r| >= 0.925, + use a 20-point quadrature rule on a 5th order Taylor expansion. Elsewhere, + numerically integrate in polar coordinates using no more than 20 quadrature points. + + Args: + r: Tensor of correlation coefficients. + h: Tensor of negative upper bounds for `x`, same shape as `r`. + k: Tensor of negative upper bounds for `y`, same shape as `r`. + + Returns: + A tensor of probabilities `P(x > h, y > k)`. + """ + if not (r.shape == h.shape == k.shape): + raise UnsupportedError("Arguments to `bvnu` must have the same shape.") + _0, _1, lower, upper = get_constants_like((0, 1) + STANDARDIZED_RANGE, r) + x_free = h < lower + y_free = k < lower + return case_dispatcher( + out=torch.empty_like(r), + cases=( # Special cases admitting closed-form solutions + (lambda: (h > upper) | (k > upper), lambda mask: _0), + (lambda: x_free & y_free, lambda mask: _1), + (lambda: x_free, lambda mask: Phi(-k[mask])), + (lambda: y_free, lambda mask: Phi(-h[mask])), + (lambda: r == _0, lambda mask: Phi(-h[mask]) * Phi(-k[mask])), + ( # For |r| >= 0.925, use a Taylor approximation + lambda: r.abs() >= get_constants_like(0.925, r), + lambda m: _bvnu_taylor(r[m], h[m], k[m]), + ), + ), # For |r| < 0.925, integrate in polar coordinates. + default=lambda mask: _bvnu_polar(r[mask], h[mask], k[mask]), + ) + + +def _bvnu_polar( + r: Tensor, h: Tensor, k: Tensor, num_points: Optional[int] = None +) -> Tensor: + r"""Solves for `P(x > h, y > k)` by integrating in polar coordinates as + ``` + L(h, k, r) = \Phi(-h)\Phi(-k) + 1/(2\pi) \int_{0}^{sin^{-1}(r)} f(t) dt + f(t) = e^{-0.5 cos(t)^{-2} (h^2 + k^2 - 2hk sin(t))} + ``` + For details, see Section 2.2 of [Genz2004bvnt]_. + """ + if num_points is None: + mar = r.abs().max() + num_points = 6 if mar < 0.3 else 12 if mar < 0.75 else 20 + + _0, _1, _i2, _i2pi = get_constants_like(values=(0, 1, 0.5, _inv_2pi), ref=r) + + x, w = leggauss(num_points, dtype=r.dtype, device=r.device) + x = x + _1 + asin_r = _i2 * torch.asin(r) + sin_asrx = (asin_r.unsqueeze(-1) * x).sin() + + _h = h.unsqueeze(-1) + _k = k.unsqueeze(-1) + vals = safe_exp( + safe_sub(safe_mul(sin_asrx, _h * _k), _i2 * (_h.square() + _k.square())) + / (_1 - sin_asrx.square()) + ) + probs = Phi(-h) * Phi(-k) + _i2pi * asin_r * (vals @ w) + return probs.clip(min=_0, max=_1) # necessary due to "safe" handling of inf + + +def _bvnu_taylor(r: Tensor, h: Tensor, k: Tensor, num_points: int = 20) -> Tensor: + r"""Solves for `P(x > h, y > k)` via Taylor expansion. + + Following [Drezner1990computation], the standard BVN problem may be rewritten as + ``` + L(h, k, r) = L(h, k, s) - s/(2\pi) \int_{0}^{a} f(x) dx + f(x) = (1 - x^2){-1/2} e^{-0.5 ((h - sk)/ x)^2} e^{-shk/(1 + (1 - x^2)^{1/2})}, + ``` + where `s = sign(r)` and `a = sqrt(1 - r^{2})`. The term `L(h, k, s)` is analytic. + The second integral is approximated via Taylor expansion. See Sections 2.3 and + 2.4 of [Genz2004bvnt]_. + """ + _0, _1, _ni2, _i2pi, _sq2pi = get_constants_like( + values=(0, 1, -0.5, _inv_2pi, _sqrt_2pi), ref=r + ) + + x, w = leggauss(num_points, dtype=r.dtype, device=r.device) + x = x + _1 + + s = get_constants_like(2, r) * (r > _0).to(r) - _1 # sign of `r` where sign(0) := 1 + sk = s * k + skh = sk * h + comp_r2 = _1 - r.square() + + a = comp_r2.clip(min=0).sqrt() + b = safe_sub(h, sk) + b2 = b.square() + c = get_constants_like(1 / 8, r) * (get_constants_like(4, r) - skh) + d = get_constants_like(1 / 80, r) * (get_constants_like(12, r) - skh) + + # ---- Solve for `L(h, k, s)` + int_from_0_to_s = case_dispatcher( + out=torch.empty_like(r), + cases=[(lambda: r > _0, lambda mask: Phi(-torch.maximum(h[mask], k[mask])))], + default=lambda mask: (Phi(sk[mask]) - Phi(h[mask])).clip(min=_0), + ) + + # ---- Solve for `s/(2\pi) \int_{0}^{a} f(x) dx` + # Analytic part + _a0 = _ni2 * (safe_div(b2, comp_r2) + skh) + _a1 = c * get_constants_like(1 / 3, r) * (_1 - d * b2) + _a2 = _1 - b2 * _a1 + abs_b = b.abs() + analytic_part = torch.subtract( # analytic part of solution + a * (_a2 + comp_r2 * _a1 + c * d * comp_r2.square()) * safe_exp(_a0), + _sq2pi * Phi(safe_div(-abs_b, a)) * abs_b * _a2 * safe_exp(_ni2 * skh), + ) + + # Quadrature part + _b2 = b2.unsqueeze(-1) + _skh = skh.unsqueeze(-1) + _q0 = get_constants_like(0.25, r) * comp_r2.unsqueeze(-1) * x.square() + _q1 = (_1 - _q0).sqrt() + _q2 = _ni2 * (_b2 / _q0 + _skh) + + _b2 = b2.unsqueeze(-1) + _c = c.unsqueeze(-1) + _d = d.unsqueeze(-1) + vals = (_ni2 * (_b2 / _q0 + _skh)).exp() * torch.subtract( + _1 + _c * _q0 * (_1 + get_constants_like(5, r) * _d * _q0), + safe_exp(_ni2 * _q0 / (_1 + _q1).square() * _skh) / _q1, + ) + mask = _q2 > get_constants_like(-100, r) + if not mask.all(): + vals[~mask] = _0 + quadrature_part = _ni2 * a * (vals @ w) + + # Return `P(x > h, y > k)` + int_from_0_to_a = _i2pi * s * (analytic_part + quadrature_part) + return (int_from_0_to_s - int_from_0_to_a).clip(min=_0, max=_1) + + +def bvnmom( + r: Tensor, + xl: Tensor, + yl: Tensor, + xu: Tensor, + yu: Tensor, + p: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor]: + r"""Computes the expected values of truncated, bivariate normal random variables. + + Let `x` and `y` be a pair of standard bivariate normal random variables having + correlation `r`. This function computes `E([x,y] | [xl,yl] < [x,y] < [xu,yu])`. + + Following [Muthen1990moments]_ equations (4) and (5), we have + ``` + E(x | [xl, yl] < [x, y] < [xu, yu]) + = Z^{-1} \phi(xl) P(yl < y < yu | x=xl) - \phi(xu) P(yl < y < yu | x=xu) + ``` + where `Z = P([xl, yl] < [x, y] < [xu, yu])` and `\phi` is the standard normal PDF. + + Args: + r: Tensor of correlation coefficients. + xl: Tensor of lower bounds for `x`, same shape as `r`. + xu: Tensor of upper bounds for `x`, same shape as `r`. + yl: Tensor of lower bounds for `y`, same shape as `r`. + yu: Tensor of upper bounds for `y`, same shape as `r`. + p: Tensor of probabilities `P(xl < x < xu, yl < y < yu)`, same shape as `r`. + + Returns: + `E(x | [xl, yl] < [x, y] < [xu, yu])` and `E(y | [xl, yl] < [x, y] < [xu, yu])`. + """ + if not (r.shape == xl.shape == xu.shape == yl.shape == yu.shape): + raise UnsupportedError("Arguments to `bvn` must have the same shape.") + + if p is None: + p = bvn(r=r, xl=xl, xu=xu, yl=yl, yu=yu) + + corr = r[..., None, None] + istd = (1 - corr.square()).rsqrt() + lower = torch.stack([xl, yl], -1) + upper = torch.stack([xu, yu], -1) + bounds = torch.stack([lower, upper], -1) + deltas = safe_mul(corr, bounds) + + # Compute densities and conditional probabilities + density_at_bounds = phi(bounds) + prob_given_bounds = Phi( + safe_mul(istd, safe_sub(upper.flip(-1).unsqueeze(-1), deltas)) + ) - Phi(safe_mul(istd, safe_sub(lower.flip(-1).unsqueeze(-1), deltas))) + + # Evaluate Muthen's formula + p_diffs = -(density_at_bounds * prob_given_bounds).diff().squeeze(-1) + moments = (1 / p).unsqueeze(-1) * (p_diffs + r.unsqueeze(-1) * p_diffs.flip(-1)) + return moments.unbind(-1) diff --git a/botorch/utils/probability/lin_ess.py b/botorch/utils/probability/lin_ess.py new file mode 100644 index 0000000000..1fb6d13cb6 --- /dev/null +++ b/botorch/utils/probability/lin_ess.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r"""Linear Elliptical Slice Sampler. + +References + +.. [Gessner2020] + A. Gessner, O. Kanjilal, and P. Hennig. Integrals over gaussians under + linear domain constraints. AISTATS 2020. + + +This implementation is based (with multiple changes / optimiations) on +the following implementations based on the algorithm in [Gessner2020]_: +https://github.com/alpiges/LinConGauss +https://github.com/wjmaddox/pytorch_ess +""" + +from __future__ import annotations + +import math +from typing import Optional, Tuple + +import torch +from botorch.utils.sampling import PolytopeSampler +from torch import Tensor + +_twopi = 2.0 * math.pi +_delta_theta = 1.0e-6 * _twopi + + +class LinearEllipticalSliceSampler(PolytopeSampler): + r"""Linear Elliptical Slice Sampler. + + TODOs: + - clean up docstrings + - optimize computations (if possible) + + Maybe TODOs: + - Support degenerate domains (with zero volume)? + - Add batch support ? + """ + + def __init__( + self, + inequality_constraints: Optional[Tuple[Tensor, Tensor]] = None, + bounds: Optional[Tensor] = None, + interior_point: Optional[Tensor] = None, + mean: Optional[Tensor] = None, + covariance_matrix: Optional[Tensor] = None, + covariance_root: Optional[Tensor] = None, + ) -> None: + r"""Initialize LinearEllipticalSliceSampler. + + Args: + inequality_constraints: Tensors `(A, b)` describing inequality constraints + `A @ x <= b`, where `A` is an `n_ineq_con x d`-dim Tensor and `b` is + an `n_ineq_con x 1`-dim Tensor, with `n_ineq_con` the number of + inequalities and `d` the dimension of the sample space. If omitted, + must provide `bounds` instead. + bounds: A `2 x d`-dim tensor of box bounds. If omitted, must provide + `inequality_constraints` instead. + interior_point: A `d x 1`-dim Tensor presenting a point in the (relative) + interior of the polytope. If omitted, an interior point is determined + automatically by solving a Linear Program. Note: It is crucial that + the point lie in the interior of the feasible set (rather than on the + boundary), otherwise the sampler will produce invalid samples. + mean: The `d x 1`-dim mean of the MVN distribution (if omitted, use zero). + covariance_matrix: The `d x d`-dim covariance matrix of the MVN + distribution (if omitted, use the identity). + covariance_root: A `d x k`-dim root of the covariance matrix such that + covariance_root @ covariance_root.T = covariance_matrix. + + This sampler samples from a multivariante Normal `N(mean, covariance_matrix)` + subject to linear domain constraints `A x <= b` (intersected with box bounds, + if provided). + """ + super().__init__( + inequality_constraints=inequality_constraints, + # TODO: Support equality constraints? + interior_point=interior_point, + bounds=bounds, + ) + tkwargs = {"device": self.x0.device, "dtype": self.x0.dtype} + self._mean = mean + if covariance_matrix is not None: + if covariance_root is not None: + raise ValueError( + "Provide either covariance_matrix or covariance_root, not both." + ) + try: + covariance_root = torch.linalg.cholesky(covariance_matrix) + except RuntimeError as e: + if "positive-definite" in str(e): + raise ValueError( + "Covariance matrix is not positive definite. " + "Currently only non-degenerate distributions are supported." + ) + else: + raise e + self._covariance_root = covariance_root + self._x = self.x0.clone() # state of the sampler ("current point") + # We will need the following repeatedly, let's allocate them once + self._zero = torch.zeros(1, **tkwargs) + self._nan = torch.tensor(float("nan"), **tkwargs) + self._full_angular_range = torch.tensor([0.0, _twopi], **tkwargs) + + def draw(self, n: int = 1) -> Tuple[Tensor, Tensor]: + r"""Draw samples. + + Args: + n: The number of samples. + + Returns: + A `n x d`-dim tensor of `n` samples. + """ + # TODO: Do we need to do any thinnning or warm-up here? + samples = torch.cat([self.step() for _ in range(n)], dim=-1) + return samples.transpose(-1, -2) + + def step(self) -> Tensor: + r"""Take a step, return the new sample, update the internal state. + + Returns: + A `d x 1`-dim sample from the domain. + """ + nu = self._sample_base_rv() + theta = self._draw_angle(nu=nu) + self._x = self._get_cart_coords(nu=nu, theta=theta) + return self._x + + def _sample_base_rv(self) -> Tensor: + r"""Sample a base random variable from N(mean, covariance_matrix). + + Returns: + A `d x 1`-dim sample from the domain + """ + nu = torch.randn_like(self._x) + if self._covariance_root is not None: + nu = self._covariance_root @ nu + if self._mean is not None: + nu = self._mean + nu + return nu + + def _draw_angle(self, nu: Tensor) -> Tensor: + r"""Draw the rotation angle. + + Args: + nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)). + + Returns: + A + """ + rot_angle, rot_slices = self._find_rotated_intersections(nu) + rot_lengths = rot_slices[:, 1] - rot_slices[:, 0] + cum_lengths = torch.cumsum(rot_lengths, dim=0) + cum_lengths = torch.cat((self._zero, cum_lengths), dim=0) + rnd_angle = torch.rand(1) * cum_lengths[-1] + idx = torch.searchsorted(cum_lengths, rnd_angle) - 1 + return rot_slices[idx, 0] + rnd_angle - cum_lengths[idx] + rot_angle + + def _get_cart_coords(self, nu: Tensor, theta: Tensor) -> Tensor: + r"""Determine location on ellipsoid in cartesian coordinates. + + Args: + nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)). + theta: A `k`-dim tensor of angles. + + Returns: + A `d x k`-dim tensor of samples from the domain in cartesian coordinates. + """ + return self._x * torch.cos(theta) + nu * torch.sin(theta) + + def _find_rotated_intersections(self, nu: Tensor) -> Tuple[Tensor, Tensor]: + r"""Finds rotated intersections. + + Rotates the intersections by the rotation angle and makes sure that all + angles lie in [0, 2*pi]. + + Args: + nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)). + + Returns: + A two-tuple containing rotation angle (scalar) and a + `num_active / 2 x 2`-dim tensor of shifted angles. + """ + slices = self._find_active_intersections(nu) + rot_angle = slices[0] + slices = slices - rot_angle + slices = torch.where(slices < 0, slices + _twopi, slices) + return rot_angle, slices.reshape(-1, 2) + + def _find_active_intersections(self, nu: Tensor) -> Tensor: + """ + Find angles of those intersections that are at the boundary of the integration + domain by adding and subtracting a small angle and evaluating on the ellipse + to see if we are on the boundary of the integration domain. + + Args: + nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)). + + Returns: + A `num_active`-dim tensor containing the angles of active intersection in + increasing order so that activation happens in positive direction. If a + slice crosses `theta=0`, the first angle is appended at the end of the + tensor. Every element of the returned tensor defines a slice for elliptical + slice sampling. + """ + theta = self._find_intersection_angles(nu) + active_directions = self._index_active( + nu=nu, theta=theta, delta_theta=_delta_theta + ) + theta_active = theta[active_directions.nonzero()] + + while theta_active.numel() % 2 == 1: + # Almost tangential ellipses, reduce delta_theta + active_directions = self._index_active( + theta=theta, nu=nu, delta_theta=0.1 * _delta_theta + ) + theta_active = theta[active_directions.nonzero()] + + if theta_active.numel() == 0: + theta_active = self._full_angular_range + # TODO: What about `self.ellipse_in_domain = False` in the original code ?? + elif active_directions[active_directions.nonzero()][0] == -1: + theta_active = torch.cat((theta_active[1:], theta_active[:1])) + + return theta_active.view(-1) + + def _find_intersection_angles(self, nu: Tensor) -> Tensor: + """Compute all of the up to 2*n_ineq_con intersections of the ellipse + and the linear constraints. + + Args: + nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)). + + Returns: + An `M`-dim tensor, where `M <= 2 * n_ineq_con` (with `M = n_ineq_con` + if all intermediate computations yield finite numbers). + """ + # Compared to the implementation in https://github.com/alpiges/LinConGauss + # we need to flip the sign of A b/c the original algorithm considers + # A @ x + b >= 0 feasible, whereas we consider A @ x - b <= 0 feasible. + g1 = -self.A @ self._x + g2 = -self.A @ nu + r = torch.sqrt(g1**2 + g2**2) + phi = 2 * torch.atan(g2 / (r + g1)).squeeze() + + arg = -(self.b / r).squeeze() + # Write NaNs if there is no intersection + arg = torch.where(torch.absolute(arg) <= 1, arg, self._nan) + + # Two solutions per linear constraint, shape of theta: (n_ineq_con, 2) + acos_arg = torch.arccos(arg) + theta = torch.stack((phi + acos_arg, phi - acos_arg), dim=-1) + theta = theta[torch.isfinite(theta)] # shape: `n_ineq_con - num_not_finite` + theta = torch.where(theta < 0, theta + _twopi, theta) # [0, 2*pi] + + return torch.sort(theta).values + + def _index_active( + self, nu: Tensor, theta: Tensor, delta_theta: float = 1e-4 + ) -> Tensor: + r"""Determine active indices. + + Args: + nu: A `d x 1`-dim tensor (the "new" direction, drawn from N(0, I)). + theta: An `M`-dim tensor of intersection angles. + delta_theta: A small perturbation to be used for determining whether + intersections are at the boundary of the integration domain. + + Returns: + An `M`-dim tensor with elements taking on values in {-1, 0, 1}. + A non-zero value indicates whether the associated intersection angle + is an active constraint. For active constraints, the sign indicates + the direction of the relevant domain (i.e. +1 (-1) means that + increasing (decreasing) the angle renders the sample feasible). + """ + samples_pos = self._get_cart_coords(nu=nu, theta=theta + delta_theta) + samples_neg = self._get_cart_coords(nu=nu, theta=theta - delta_theta) + pos_diffs = self._is_feasible(samples_pos) + neg_diffs = self._is_feasible(samples_neg) + # We don't use bit-wise XOR here since we need the signs of the directions + return pos_diffs.to(nu) - neg_diffs.to(nu) + + def _is_feasible(self, points: Tensor) -> Tensor: + r""" + + Args: + points: A `M x d`-dim tensor of points. + + Returns: + An `M`-dim binary tensor where `True` indicates that the associated + point is feasible. + """ + return (self.A @ points <= self.b).all(dim=0) diff --git a/botorch/utils/probability/linalg.py b/botorch/utils/probability/linalg.py new file mode 100644 index 0000000000..7a0f7b9cab --- /dev/null +++ b/botorch/utils/probability/linalg.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass, InitVar +from typing import Any, Optional + +import torch +from botorch.utils.probability.utils import swap_along_dim_ +from linear_operator.utils.errors import NotPSDError +from torch import LongTensor, Tensor +from torch.nn.functional import pad + + +def augment_cholesky( + Laa: Tensor, + Kbb: Tensor, + Kba: Optional[Tensor] = None, + Lba: Optional[Tensor] = None, + jitter: Optional[float] = None, +) -> Tensor: + r"""Computes the Cholesky factor of a block matrix `K = [[Kaa, Kab], [Kba, Kbb]]` + based on a precomputed Cholesky factor `Kaa = Laa Laa^T`. + + Args: + Laa: Cholesky factor of K's upper left block. + Kbb: Lower-right block of K. + Kba: Lower-left block of K. + Lba: Precomputed solve `Kba Laa^{-T}`. + jitter: Optional nugget to be added to the diagonal of Kbb. + """ + if not (Kba is None) ^ (Lba is None): + raise ValueError("One and only one of `Kba` or `Lba` must be provided.") + + if jitter is not None: + diag = Kbb.diagonal(dim1=-2, dim2=-1) + Kbb = Kbb.clone() + Kbb.fill_diagonal_(diag + jitter) + + if Lba is None: + Lba = torch.linalg.solve_triangular( + Laa.transpose(-2, -1), Kba, left=False, upper=True + ) + + Lbb, info = torch.linalg.cholesky_ex(Kbb - Lba @ Lba.transpose(-2, -1)) + if info.any(): + raise NotPSDError( + "Schur complement of `K` with respect to `Kaa` not PSD for the given" + "Cholesky factor `Laa`" + f"{'.' if jitter is None else f' and nugget jitter={jitter}.'}" + ) + + m = Laa.shape[-1] + n = Lbb.shape[-1] + batch_shape = torch.broadcast_shapes(Laa.shape[:-2], Lbb.shape[:-2]) + _Laa = Laa.expand(*batch_shape, m, m) + _Lba = Lba.expand(*batch_shape, n, m) + _Lbb = Lbb.expand(*batch_shape, n, n) + return torch.concat([pad(_Laa, (0, n)), torch.concat([_Lba, _Lbb], -1)], -2) + + +@dataclass +class PivotedCholesky: + step: int + tril: Tensor + perm: LongTensor + diag: Optional[Tensor] = None + validate_init: InitVar[bool] = True + + def __post_init__(self, validate_init: bool = True): + if not validate_init: + return + + if self.tril.shape[-2] != self.tril.shape[-1]: + raise ValueError( + f"Expected square matrices but `matrix` has shape {self.tril.shape}." + ) + + if self.perm.shape != self.tril.shape[:-1]: + raise ValueError( + f"`perm` of shape `{self.perm.shape}` incompatible with " + f"`matrix` of shape `{self.tril.shape}." + ) + + if self.diag is not None and self.diag.shape != self.tril.shape[:-1]: + raise ValueError( + f"`diag` of shape `{self.diag.shape}` incompatible with " + f"`matrix` of shape `{self.tril.shape}." + ) + + def __getitem__(self, key: Any) -> PivotedCholesky: + return PivotedCholesky( + step=self.step, + tril=self.tril[key], + perm=self.perm[key], + diag=None if self.diag is None else self.diag[key], + ) + + def update_(self, eps: float = 1e-10) -> None: + r"""Performs a single matrix decomposition step.""" + i = self.step + L = self.tril + Lii = self.tril[..., i, i].clone().clip(min=0).sqrt() + + # Finalize `i-th` row and column of Cholesky factor + L[..., i, i] = Lii + L[..., i, i + 1 :] = 0 + L[..., i + 1 :, i] = L[..., i + 1 :, i].clone() / Lii.unsqueeze(-1) + + # Update `tril(L[i + 1:, i + 1:])` to be the lower triangular part + # of the Schur complement of `cov` with respect to `cov[:i, :i]`. + rank1 = L[..., i + 1 :, i : i + 1].clone() + rank1 = (rank1 * rank1.transpose(-1, -2)).tril() + L[..., i + 1 :, i + 1 :] = L[..., i + 1 :, i + 1 :].clone() - rank1 + L[Lii <= i * eps, i:, i] = 0 # numerical stability clause + self.step += 1 + + def pivot_(self, pivot: LongTensor) -> None: + *batch_shape, _, size = self.tril.shape + if pivot.shape != tuple(batch_shape): + raise ValueError("Argument `pivot` does to match with batch shape`.") + + # Perform basis swaps + for key in ("perm", "diag"): + tnsr = getattr(self, key, None) + if tnsr is None: + continue + swap_along_dim_(tnsr, i=self.step, j=pivot, dim=tnsr.ndim - 1) + + # Perform matrix swaps; prealloacte buffers for row/column linear indices + size2 = size**2 + min_pivot = pivot.min() + tkwargs = {"device": pivot.device, "dtype": pivot.dtype} + buffer_col = torch.arange(size * (1 + min_pivot), size2, size, **tkwargs) + buffer_row = torch.arange(0, max(self.step, pivot.max()), **tkwargs) + head = buffer_row[: self.step] + + indices_v1 = [] + indices_v2 = [] + for i, piv in enumerate(pivot.view(-1, 1)): + v1 = pad(piv, (1, 0), value=self.step).unsqueeze(-1) + v2 = pad(piv, (0, 1), value=self.step).unsqueeze(-1) + start = i * size2 + + indices_v1.extend((start + v1 + size * v1).ravel()) + indices_v2.extend((start + v2 + size * v2).ravel()) + + indices_v1.extend((start + size * v1 + head).ravel()) + indices_v2.extend((start + size * v2 + head).ravel()) + + tail = buffer_col[piv - min_pivot :] + indices_v1.extend((start + v1 + tail).ravel()) + indices_v2.extend((start + v2 + tail).ravel()) + + interior = buffer_row[min(piv, self.step + 1) : piv] + indices_v1.extend(start + size * interior + self.step) + indices_v2.extend(start + size * piv + interior) + + swap_along_dim_( + self.tril.view(-1), + i=torch.as_tensor(indices_v1, **tkwargs), + j=torch.as_tensor(indices_v2, **tkwargs), + dim=0, + ) + + def expand(self, *sizes: int) -> PivotedCholesky: + fields = {} + for name, ndim in {"perm": 1, "diag": 1, "tril": 2}.items(): + src = getattr(self, name) + if src is not None: + fields[name] = src.expand(sizes + src.shape[-ndim:]) + return type(self)(step=self.step, **fields) + + def concat(self, other: PivotedCholesky, dim: int = 0) -> PivotedCholesky: + if self.step != other.step: + raise ValueError("Cannot conncatenate decompositions at different steps.") + + fields = {} + for name in ("tril", "perm", "diag"): + a = getattr(self, name) + b = getattr(other, name) + if type(a) != type(b): + raise NotImplementedError(f"Types of field {name} do not match.") + + if a is not None: + fields[name] = torch.concat((a, b), dim=dim) + + return type(self)(step=self.step, **fields) + + def detach(self) -> PivotedCholesky: + fields = {} + for name in ("tril", "perm", "diag"): + obj = getattr(self, name) + if obj is not None: + fields[name] = obj.detach() + return type(self)(step=self.step, **fields) + + def clone(self) -> PivotedCholesky: + fields = {} + for name in ("tril", "perm", "diag"): + obj = getattr(self, name) + if obj is not None: + fields[name] = obj.clone() + return type(self)(step=self.step, **fields) diff --git a/botorch/utils/probability/mvnxpb.py b/botorch/utils/probability/mvnxpb.py new file mode 100644 index 0000000000..e22ea326fb --- /dev/null +++ b/botorch/utils/probability/mvnxpb.py @@ -0,0 +1,432 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Bivariate conditioning algorithm for approximating Gaussian probabilities, +see [Genz2016numerical]_ and [Trinh2015bivariate]_. + +.. [Trinh2015bivariate] + G. Trinh and A. Genz. Bivariate conditioning approximations for + multivariate normal probabilities. Statistics and Computing, 2015. + +.. [Genz2016numerical] + A. Genz and G. Tring. Numerical Computation of Multivariate Normal Probabilities + using Bivariate Conditioning. Monte Carlo and Quasi-Monte Carlo Methods, 2016. +""" + +from __future__ import annotations + +from typing import Any, Optional, TypedDict +from warnings import warn + +import torch +from botorch.exceptions import UnsupportedError +from botorch.utils.probability.bvn import bvn, bvnmom +from botorch.utils.probability.linalg import augment_cholesky, PivotedCholesky +from botorch.utils.probability.utils import ( + case_dispatcher, + get_constants_like, + ndtr as Phi, + phi, + STANDARDIZED_RANGE, + swap_along_dim_, +) +from botorch.utils.safe_math import log as safe_log, mul as safe_mul +from linear_operator.utils.cholesky import psd_safe_cholesky +from linear_operator.utils.errors import NotPSDError +from torch import LongTensor, Tensor +from torch.nn.functional import pad + + +class mvnxpbState(TypedDict): + step: int + perm: LongTensor + bounds: Tensor + piv_chol: PivotedCholesky + plug_ins: Tensor + log_prob: Tensor + log_prob_iter: Optional[Tensor] + + +class MVNXPB: + def __init__( + self, + covariance_matrix: Optional[Tensor], + bounds: Tensor, + step: int = 0, + perm: Optional[Tensor] = None, + piv_chol: Optional[PivotedCholesky] = None, + plug_ins: Optional[Tensor] = None, + log_prob: Optional[Tensor] = None, + log_prob_iter: Optional[Tensor] = None, + ) -> None: + r"""Initializes an MVNXPB instance. + + Args: + covariance_matrix: Covariance matrices of shape `batch_shape x [n, n]`. If + `piv_chol` is passed as an argument, then `covariance_matrix` may be + None and the system is assumed to be . + bounds: Tensor of lower and upper bounds, `batch_shape x [n, 2]`. + step: Integer used to track the solver's progress. + piv_chol: `PivotedCholesky` instance for the system. + plug_ins: Tensor of plug-in estimators used to update lower and upper bounds + on random variables that have yet to be integrated out. + log_prob: Tensor of log probabilities. + log_prob_iter: Tensor of conditional log probabilities for the next random + variable. Used when integrating over an odd number of random variables. + """ + tkwargs = {} + if piv_chol is None: # full initialization scheme + if covariance_matrix is None: + raise ValueError( + "`piv_chol` must be passed when `covariance_matrix` is None." + ) + *batch_shape, _, ndim = covariance_matrix.shape + tkwargs["dtype"] = covariance_matrix.dtype + device = tkwargs["device"] = covariance_matrix.device + + # Standardize covariance matrices and bounds + var = covariance_matrix.diagonal(dim1=-2, dim2=-1).unsqueeze(-1) + std = var.sqrt() + istd = var.rsqrt() + matrix = istd * covariance_matrix * istd.transpose(-1, -2) + + # Clip first to avoid differentiating through `istd * inf` + bounds = istd * bounds.clip(*(std * lim for lim in STANDARDIZED_RANGE)) + + # Initialize partial pivoted Cholesky + perm = ( + torch.arange(0, ndim, device=device).expand(*batch_shape, ndim) + if perm is None + else perm.clone() + ) + piv_chol = PivotedCholesky( + step=0, + perm=perm.contiguous(), + diag=std.squeeze(-1).clone(), + tril=matrix.tril(), + ) + else: + *batch_shape, _, ndim = piv_chol.tril.shape + tkwargs["dtype"] = piv_chol.tril.dtype + tkwargs["device"] = piv_chol.tril.device + perm = piv_chol.perm if perm is None else perm.contiguous() + + if plug_ins is None: + plug_ins = torch.full(batch_shape + [ndim], float("nan"), **tkwargs) + if log_prob is None: + log_prob = torch.zeros(batch_shape, **tkwargs) + + self.step: int = step + self.perm: Tensor = perm.clone() + self.bounds: Tensor = bounds.clone() + self.piv_chol: PivotedCholesky = piv_chol + self.plug_ins: Tensor = plug_ins.clone() + self.log_prob: Tensor = log_prob.clone() + self.log_prob_iter: Optional[Tensor] = log_prob_iter + + def solve(self, num_steps: Optional[int] = None, eps: float = 1e-10) -> Tensor: + r"""Runs the MVNXPB solver instance for a fixed number of steps. + + Calculates a bivariate conditional approximation to P(X \in bounds), where + X ~ N(0, Σ). For details, see [Genz2016numerical] or [Trinh2015bivariate]_. + """ + if self.step > self.piv_chol.step: + raise ValueError("Invalid state: solver ran ahead of matrix decomposition.") + + # Unpack some terms + start = self.step + bounds = self.bounds + piv_chol = self.piv_chol + L = piv_chol.tril + y = self.plug_ins + + # Subtract marginal log probability of final term from previous result if + # it did not fit in a block. + ndim = y.shape[-1] + if ndim > start and start % 2: + self.log_prob = self.log_prob - self.log_prob_iter + self.log_prob_iter = None + + # Iteratively compute bivariate conditional approximation + zero = get_constants_like(0, L) # needed when calling `torch.where` below + num_steps = num_steps or ndim - start + for i in range(start, start + num_steps): + should_update_chol = self.step == piv_chol.step + + # Determine next pivot element + if should_update_chol: + pivot = self.select_pivot() + else: # pivot using order specified by precomputed pivoted Cholesky step + mask = self.perm[..., i:] == piv_chol.perm[..., i : i + 1] + pivot = i + torch.nonzero(mask, as_tuple=True)[-1] + + if pivot is not None and torch.any(pivot > i): + self.pivot_(pivot=pivot) + + # Initialize `i`-th plug-in value as univariate conditional expectation + Lii = L[..., i, i].clone() + if should_update_chol: + Lii = Lii.clip(min=0).sqrt() + inv_Lii = Lii.reciprocal() + if i == 0: + lb, ub = bounds[..., i, :].clone().unbind(dim=-1) + else: + db = (L[..., i, :i].clone() * y[..., :i].clone()).sum(-1, keepdim=True) + lb, ub = (bounds[..., i, :].clone() - db).unbind(dim=-1) + + Phi_i = Phi(inv_Lii * ub) - Phi(inv_Lii * lb) + small = Phi_i <= i * eps + y[..., i] = case_dispatcher( # used to select next pivot + out=(phi(lb) - phi(ub)) / Phi_i, + cases=( # fallback cases for enhanced numerical stability + (lambda: small & (lb < -9), lambda m: ub[m]), + (lambda: small & (lb > 9), lambda m: lb[m]), + (lambda: small, lambda m: 0.5 * (lb[m] + ub[m])), + ), + ) + + # Maybe finalize the current block + if i and i % 2: + h = i - 1 + blk = slice(h, i + 1) + Lhh = L[..., h, h].clone() + Lih = L[..., i, h].clone() + + std_i = (Lii.square() + Lih.square()).sqrt() + istds = 1 / torch.stack([Lhh, std_i], -1) + blk_bounds = bounds[..., blk, :].clone() + if i > 1: + blk_bounds = blk_bounds - ( + L[..., blk, : i - 1].clone() @ y[..., : i - 1, None].clone() + ) + + blk_lower, blk_upper = ( + pair.unbind(-1) # pair of bounds for `yh` and `yi` + for pair in safe_mul(istds.unsqueeze(-1), blk_bounds).unbind(-1) + ) + blk_corr = Lhh * Lih * istds.prod(-1) + blk_prob = bvn(blk_corr, *blk_lower, *blk_upper) + zh, zi = bvnmom(blk_corr, *blk_lower, *blk_upper, p=blk_prob) + + # Replace 1D expectations with 2D ones `L[blk, blk]^{-1} y[..., blk]` + mask = blk_prob > zero + y[..., h] = torch.where(mask, zh, zero) + y[..., i] = torch.where(mask, (std_i * zi - Lih * zh) / Lii, zero) + + # Update running approximation to log probability + self.log_prob = self.log_prob + safe_log(blk_prob) + + self.step += 1 + if should_update_chol: + piv_chol.update_(eps=eps) + + # Factor in univariate probability if final term fell outside of a block. + if self.step % 2: + self.log_prob_iter = safe_log(Phi_i) + self.log_prob = self.log_prob + self.log_prob_iter + + return self.log_prob + + def select_pivot(self) -> Optional[LongTensor]: + r"""Returns the index of the variable with the smallest marginal probability + when conditioning on `X_{1:t-1} = y_{1:t-1}` where `t` is the current step and + `y_{1:t-1}` are the plug-in values for the random variables `X_{1:t-1}` that we + previously integrated over.""" + i = self.piv_chol.step + L = self.piv_chol.tril + bounds = self.bounds[..., i:, :] + if i: + bounds = bounds - L[..., i:, :i] @ self.plug_ins[..., :i, None] + + inv_stddev = torch.diagonal(L, dim1=-2, dim2=-1)[..., i:].clip(min=0).rsqrt() + probs_1d = Phi(inv_stddev.unsqueeze(-1) * bounds).diff(dim=-1).squeeze(-1) + return i + torch.argmin(probs_1d, dim=-1) + + def pivot_(self, pivot: LongTensor) -> None: + r"""Swap random variables at `pivot` and `step` positions.""" + step = self.step + if self.piv_chol.step == step: + self.piv_chol.pivot_(pivot) + elif self.step > self.piv_chol.step: + raise ValueError + + for tnsr in (self.perm, self.bounds): + swap_along_dim_(tnsr, i=self.step, j=pivot, dim=pivot.ndim) + + def __getitem__(self, key: Any) -> MVNXPB: + new = MVNXPB( + covariance_matrix=None, + piv_chol=self.piv_chol[key], + bounds=self.bounds[key], + perm=self.perm[key], + ) + new.step = self.step + new.plug_ins = self.plug_ins[key] + new.log_prob = self.log_prob[key] + if self.log_prob_iter is not None: + new.log_prob_iter = self.log_prob_iter[key] + return new + + def concat(self, other: MVNXPB, dim: int) -> MVNXPB: + if not isinstance(other, MVNXPB): + raise TypeError( + f"Expected `other` to be {type(self)} typed but was {type(other)}." + ) + + if self.step != other.step or self.piv_chol.step != other.piv_chol.step: + raise UnsupportedError("Cannot combine solvers at different steps.") + + batch_ndim = self.log_prob.ndim + if dim > batch_ndim or dim < -batch_ndim: + raise ValueError(f"`dim={dim}` is not a valid batch dimension.") + + state_dict = self.asdict() + for key, obj in other.asdict().items(): + if obj is None: + if state_dict.get(key) is not None: + raise RuntimeError + + elif isinstance(obj, PivotedCholesky): + state_dict[key] = state_dict[key].concat(obj, dim=dim) + elif isinstance(obj, Tensor): + state_dict[key] = torch.concat((state_dict[key], obj), dim=dim) + elif obj != state_dict.get(key): + raise RuntimeError + + return type(self).from_dict(state_dict) + + def expand(self, *sizes: int) -> MVNXPB: + state_dict = self.asdict() + state_dict["piv_chol"] = state_dict["piv_chol"].expand(*sizes) + for name, ndim in { + "bounds": 2, + "perm": 1, + "plug_ins": 1, + "log_prob": 0, + "log_prob_iter": 0, + }.items(): + src = state_dict[name] + if isinstance(src, Tensor): + state_dict[name] = src.expand( + sizes + src.shape[-ndim:] if ndim else sizes + ) + return self.from_dict(state_dict) + + def augment( + self, + covariance_matrix: Tensor, + bounds: Tensor, + cross_covariance_matrix: Optional[Tensor] = None, + unpermuted_cross_covariance_matrix: Optional[Tensor] = None, + disable_pivoting: bool = False, + jitter: Optional[float] = None, + max_tries: Optional[int] = None, + **kwargs: Any, + ) -> MVNXPB: + r"""Augment an `n`-dimensional MVNXPB instance to include `m` additional random + variables. + """ + n = self.perm.shape[-1] + m = covariance_matrix.shape[-1] + if n != self.piv_chol.step: + raise NotImplementedError( + "Augmentation of incomplete solutions not implemented yet." + ) + + if cross_covariance_matrix is None: + if unpermuted_cross_covariance_matrix is None: + raise ValueError( + "Missing required argument `cross_covariance_matrix` " + "xor `unpermuted_cross_covariance_matrix`." + ) + idx = self.perm.unsqueeze(-2) + idx = idx.expand(*idx.shape[:-2], m, n) + cross_covariance_matrix = unpermuted_cross_covariance_matrix.gather(-1, idx) + elif unpermuted_cross_covariance_matrix is not None: + raise ValueError( + "Arguments `cross_covariance_matrix` and " + "`unpermuted_cross_covariance_matrix` are mutually exclusive." + ) + + var = covariance_matrix.diagonal(dim1=-2, dim2=-1).unsqueeze(-1) + std = var.sqrt() + istd = var.rsqrt() + + Kmn = istd * cross_covariance_matrix + if self.piv_chol.diag is None: + diag = pad(std.squeeze(-1), (cross_covariance_matrix.shape[-1], 0), value=1) + else: + Kmn = Kmn * (1 / self.piv_chol.diag).unsqueeze(-2) + diag = torch.concat([self.piv_chol.diag, std.squeeze(-1)], -1) + + # Augment partial pivoted Cholesky factor + Kmm = istd * covariance_matrix * istd.transpose(-1, -2) + Lnn = self.piv_chol.tril + try: + L = augment_cholesky(Laa=Lnn, Kba=Kmn, Kbb=Kmm, jitter=jitter) + except NotPSDError: + warn("Joint covariance matrix not positive definite, attempting recovery.") + Knn = Lnn @ Lnn.transpose(-1, -2) + Knm = Kmn.transpose(-1, -2) + K = torch.cat([torch.cat((Knn, Knm), -1), torch.cat((Kmn, Kmm), -1)], -2) + L = psd_safe_cholesky(K, jitter=jitter, max_tries=max_tries) + + if not disable_pivoting: + Lmm = L[..., n:, n:].clone() + L[..., n:, n:] = (Lmm @ Lmm.transpose(-2, -1)).tril() + + _bounds = istd * bounds.clip(*(std * lim for lim in STANDARDIZED_RANGE)) + _perm = torch.arange(n, n + m, dtype=self.perm.dtype, device=self.perm.device) + _perm = _perm.expand(covariance_matrix.shape[:-2] + (m,)) + + piv_chol = PivotedCholesky( + step=n + m if disable_pivoting else n, + tril=L.contiguous(), + perm=torch.cat([self.piv_chol.perm, _perm], dim=-1).contiguous(), + diag=diag, + ) + + return MVNXPB( + covariance_matrix=None, + bounds=torch.cat([self.bounds, _bounds], dim=-2), + perm=torch.cat([self.perm, _perm], dim=-1), + step=self.step, + piv_chol=piv_chol, + plug_ins=pad(self.plug_ins, (0, m), value=float("nan")), + log_prob=self.log_prob, + log_prob_iter=self.log_prob_iter, + **kwargs, + ) + + def detach(self) -> MVNXPB: + state_dict = self.asdict() + for key, obj in state_dict.items(): + if isinstance(obj, (PivotedCholesky, Tensor)): + state_dict[key] = obj.detach() + return self.from_dict(state_dict) + + def clone(self) -> MVNXPB: + state_dict = self.asdict() + for key, obj in state_dict.items(): + if isinstance(obj, (PivotedCholesky, Tensor)): + state_dict[key] = obj.clone() + return self.from_dict(state_dict) + + def asdict(self) -> mvnxpbState: + return mvnxpbState( + step=self.step, + perm=self.perm, + bounds=self.bounds, + piv_chol=self.piv_chol, + plug_ins=self.plug_ins, + log_prob=self.log_prob, + log_prob_iter=self.log_prob_iter, + ) + + @classmethod + def from_dict(cls, state: mvnxpbState): + return cls(covariance_matrix=None, **state) diff --git a/botorch/utils/probability/truncated_multivariate_normal.py b/botorch/utils/probability/truncated_multivariate_normal.py new file mode 100644 index 0000000000..c7ee81be5c --- /dev/null +++ b/botorch/utils/probability/truncated_multivariate_normal.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Optional, Sequence + +import torch +from botorch.utils.probability.lin_ess import LinearEllipticalSliceSampler +from botorch.utils.probability.mvnxpb import MVNXPB +from botorch.utils.probability.utils import get_constants_like +from torch import Tensor +from torch.distributions.multivariate_normal import MultivariateNormal + + +class TruncatedMultivariateNormal(MultivariateNormal): + def __init__( + self, + loc: Tensor, + covariance_matrix: Optional[Tensor] = None, + precision_matrix: Optional[Tensor] = None, + scale_tril: Optional[Tensor] = None, + bounds: Tensor = None, + solver: Optional[MVNXPB] = None, + sampler: Optional[LinearEllipticalSliceSampler] = None, + validate_args: Optional[bool] = None, + ): + r"""Initializes an instance of a TruncatedMultivariateNormal distribution. + + Let `x ~ N(0, K)` be an `n`-dimensional Gaussian random vector. This class + represents the distribution of the truncated Multivariate normal random vector + `x | a <= x <= b`. + + Args: + loc: A mean vector for the distribution, `batch_shape x event_shape`. + covariance_matrix: Covariance matrix distribution parameter. + precision_matrix: Inverse covariance matrix distribution parameter. + scale_tril: Lower triangular, square-root covariance matrix distribution + parameter. + bounds: A `batch_shape x event_shape x 2` tensor of strictly increasing + bounds for `x` so that `bounds[..., 0] < bounds[..., 1]` everywhere. + solver: A pre-solved MVNXPB instance used to approximate the log partition. + sampler: A LinearEllipticalSliceSampler instance used for sample generation. + validate_args: Optional argument to super().__init__. + """ + if bounds is None: + raise SyntaxError("Missing required argument `bounds`.") + elif bounds.shape[-1] != 2: + raise ValueError( + f"Expected bounds.shape[-1] to be 2 but bounds shape is {bounds.shape}" + ) + elif torch.gt(*bounds.unbind(dim=-1)).any(): + raise ValueError("`bounds` must be strictly increasing along dim=-1.") + + super().__init__( + loc=loc, + covariance_matrix=covariance_matrix, + precision_matrix=precision_matrix, + scale_tril=scale_tril, + validate_args=validate_args, + ) + self.bounds = bounds + self._solver = solver + self._sampler = sampler + + def log_prob(self, value: Tensor) -> Tensor: + r"""Approximates the true log probability.""" + neg_inf = get_constants_like(-float("inf"), value) + inbounds = torch.logical_and( + (self.bounds[..., 0] < value).all(-1), + (self.bounds[..., 1] > value).all(-1), + ) + if inbounds.any(): + return torch.where( + inbounds, + super().log_prob(value) - self.log_partition, + neg_inf, + ) + return torch.full(value.shape[: -len(self.event_shape)], neg_inf) + + def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: # noqa: B008 + r"""Draw samples from the Truncated Multivariate Normal. + + Args: + sample_shape: The shape of the samples. + + Returns: + The (sample_shape x batch_shape x event_shape) tensor of samples. + """ + num_samples = sample_shape.numel() if sample_shape else 1 + return self.loc + self.sampler.draw(n=num_samples).view(*sample_shape, -1) + + @property + def log_partition(self) -> Tensor: + return self.solver.log_prob + + @property + def solver(self) -> MVNXPB: + if self._solver is None: + self._solver = MVNXPB( + covariance_matrix=self.covariance_matrix, + bounds=self.bounds - self.loc.unsqueeze(-1), + ) + self._solver.solve() + return self._solver + + @property + def sampler(self) -> LinearEllipticalSliceSampler: + if self._sampler is None: + eye = torch.eye( + self.scale_tril.shape[-1], + dtype=self.scale_tril.dtype, + device=self.scale_tril.device, + ) + + A = torch.concat([-eye, eye]) + b = torch.concat( + [ + self.loc - self.bounds[..., 0], + self.bounds[..., 1] - self.loc, + ], + dim=-1, + ).unsqueeze(-1) + + self._sampler = LinearEllipticalSliceSampler( + inequality_constraints=(A, b), + covariance_root=self.scale_tril, + ) + return self._sampler + + def expand( + self, batch_shape: Sequence[int], _instance: TruncatedMultivariateNormal = None + ) -> TruncatedMultivariateNormal: + new = self._get_checked_instance(TruncatedMultivariateNormal, _instance) + super().expand(batch_shape=batch_shape, _instance=new) + + new.bounds = self.bounds.expand(*new.batch_shape, *self.event_shape, 2) + new._sampler = None # does not implement `expand` + new._solver = ( + None if self._solver is None else self._solver.expand(*batch_shape) + ) + return new + + def __repr__(self) -> str: + return super().__repr__()[:-1] + f"bounds: {self.bounds.shape})" diff --git a/botorch/utils/probability/unified_skew_normal.py b/botorch/utils/probability/unified_skew_normal.py new file mode 100644 index 0000000000..97f9f5688d --- /dev/null +++ b/botorch/utils/probability/unified_skew_normal.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from inspect import getmembers +from typing import Optional, Sequence + +import torch +from botorch.utils.probability.linalg import augment_cholesky +from botorch.utils.probability.mvnxpb import MVNXPB +from botorch.utils.probability.truncated_multivariate_normal import ( + TruncatedMultivariateNormal, +) +from torch import Tensor +from torch.distributions.multivariate_normal import Distribution, MultivariateNormal +from torch.distributions.utils import lazy_property +from torch.nn.functional import pad + + +class UnifiedSkewNormal(Distribution): + arg_constraints = {} + + def __init__( + self, + trunc: TruncatedMultivariateNormal, + gauss: MultivariateNormal, + cross_covariance_matrix: Tensor, + validate_args: Optional[bool] = None, + ): + r"""Unified Skew Normal distribution of `Y | a < X < b` for jointly Gaussian + random vectors `X ∈ R^m` and `Y ∈ R^n`. + + Args: + trunc: Distribution of `Z = (X | a < X < b) ∈ R^m`. + gauss: Distribution of `Y ∈ R^n`. + cross_covariance_matrix: Cross-covariance `Cov(X, Y) ∈ R^{m x n}`. + validate_args: Optional argument to super().__init__. + """ + batch_t = trunc.batch_shape + batch_g = gauss.batch_shape + assert all(a == b for a, b in zip(batch_t, batch_g)) + assert len(gauss.event_shape) == len(trunc.event_shape) + super().__init__( + batch_shape=batch_t if len(batch_g) < len(batch_t) else batch_g, + event_shape=gauss.event_shape, + validate_args=validate_args, + ) + self.trunc = trunc + self.gauss = gauss + self.cross_covariance_matrix = cross_covariance_matrix + if validate_args: + try: + self._orthogonalized_gauss.scale_tril + except RuntimeError as e: + if "positive-definite" in str(e): + raise ValueError( + "UnifiedSkewNormal is only well-defined for positive definite" + " joint covariance matrices." + ) + raise e + + def log_prob(self, value: Tensor) -> Tensor: + r"""Computes the log probability `ln p(Y = value | a < X < b)`.""" + event_ndim = len(self.event_shape) + if value.ndim < event_ndim or value.shape[-event_ndim:] != self.event_shape: + raise ValueError( + f"`value` with shape {value.shape} does not comply with the instance's" + f"`event_shape` of {self.event_shape}." + ) + + # Iterate with a fixed batch size to keep memory overhead in check + i = 0 + pre_shape = value.shape[: -len(self.event_shape) - len(self.batch_shape)] + batch_size = self.batch_shape.numel() + log_probs = torch.empty( + pre_shape.numel() * batch_size, device=value.device, dtype=value.dtype + ) + for batch in value.view(-1, *value.shape[len(pre_shape) :]): + log_probs[i : i + batch_size] = self._log_prob(batch).view(-1) + i += batch_size + + return log_probs.view(pre_shape + self.batch_shape) + + def _log_prob(self, value: Tensor) -> Tensor: + r"""Computes the log probability `ln p(Y = value | a < X < b)`.""" + # Center by subtracting E[X | Y = value] from `bounds`. + bounds = ( + self.trunc.bounds + - self.trunc.loc.unsqueeze(-1) + - self._iKyy_Kyx.transpose(-2, -1) @ (value - self.gauss.loc).unsqueeze(-1) + ) + + # Approximately solve for MVN CDF + solver = MVNXPB(covariance_matrix=self._K_schur_Kyy, bounds=bounds) + + # p(Y = value | a < X < b) = P(a < X < b | Y = value)p(Y = value)/P(a < X < b) + return solver.solve() + self.gauss.log_prob(value) - self.trunc.log_partition + + def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: # noqa: B008 + r"""Draw samples from the Unified Skew Normal. + + Args: + sample_shape: The shape of the samples. + + Returns: + The (sample_shape x batch_shape x event_shape) tensor of samples. + """ + residuals = self._orthogonalized_gauss.rsample(sample_shape=sample_shape) + trunc_rvs = self.trunc.rsample(sample_shape=sample_shape) - self.trunc.loc + cond_expectations = self.gauss.loc + trunc_rvs @ self._iKxx_Kxy + return cond_expectations + residuals + + def expand( + self, batch_shape: Sequence[int], _instance: UnifiedSkewNormal = None + ) -> UnifiedSkewNormal: + new = self._get_checked_instance(UnifiedSkewNormal, _instance) + super(UnifiedSkewNormal, new).__init__( + batch_shape=batch_shape, event_shape=self.event_shape, validate_args=False + ) + + new._validate_args = self._validate_args + new.gauss = self.gauss.expand(batch_shape=batch_shape) + new.trunc = self.trunc.expand(batch_shape=batch_shape) + new.cross_covariance_matrix = self.cross_covariance_matrix.expand( + batch_shape + self.cross_covariance_matrix.shape[-2:] + ) + + # Expand cached properties + for name, _ in getmembers( + UnifiedSkewNormal, lambda x: isinstance(x, lazy_property) + ): + if name not in self.__dict__: + continue + + obj = getattr(self, name) + if isinstance(obj, Tensor): + base = obj if (obj._base is None) else obj._base + new_obj = obj.expand(batch_shape + base.shape) + elif isinstance(obj, Distribution): + new_obj = obj.expand(batch_shape=batch_shape) + else: + raise TypeError + + setattr(new, name, new_obj) + return new + + def __repr__(self) -> str: + args_string = ", ".join( + ( + f"trunc: {self.trunc}", + f"gauss: {self.gauss}", + f"cross_covariance_matrix: {self.cross_covariance_matrix.shape}", + ) + ) + return self.__class__.__name__ + "(" + args_string + ")" + + @lazy_property + def covariance_matrix(self) -> Tensor: + A = self.trunc.covariance_matrix + B = self.cross_covariance_matrix + C = self.cross_covariance_matrix.transpose(-1, -2) + D = self.gauss.covariance_matrix + return torch.cat([torch.cat([A, B], -1), torch.cat([C, D], -1)], -2) + + @lazy_property + def scale_tril(self) -> Tensor: + Lxx = self.trunc.scale_tril + Lyx = self._iLxx_Kxy.transpose(-2, -1) + if "_orthogonalized_gauss" in self.__dict__: + n = Lyx.shape[-2] + Lyy = self._orthogonalized_gauss.scale_tril + return torch.concat([pad(Lxx, (0, n)), torch.concat([Lyx, Lyy], -1)], -2) + return augment_cholesky(Laa=Lxx, Lba=Lyx, Kbb=self.gauss.covariance_matrix) + + @lazy_property + def _orthogonalized_gauss(self) -> MultivariateNormal: + r"""Distribution of `Y ⊥ X = Y - E[Y | X]`, where `Y ~ gauss` and `X ~ untrunc` + is the untruncated version of `Z ~ trunc`.""" + if "scale_tril" in self.__dict__: + n = self.gauss.loc.shape[-1] + return MultivariateNormal( + loc=torch.zeros_like(self.gauss.loc), + scale_tril=self.scale_tril[..., -n:, -n:], + validate_args=False, + ) + + beta = self._iLxx_Kxy + gauss = self.gauss + covariance_matrix = gauss.covariance_matrix - beta.transpose(-1, -2) @ beta + return MultivariateNormal( + loc=torch.zeros_like(gauss.loc), + covariance_matrix=covariance_matrix, + validate_args=False, + ) + + @lazy_property + def _iLyy_Kyx(self) -> Tensor: + r"""Cov(Y, Y)^{-1/2}Cov(Y, X)`.""" + return torch.linalg.solve_triangular( + self.gauss.scale_tril, + self.cross_covariance_matrix.transpose(-1, -2), + upper=False, + ) + + @lazy_property + def _iKyy_Kyx(self) -> Tensor: + r"""Cov(Y, Y)^{-1}Cov(Y, X)`.""" + return torch.linalg.solve_triangular( + self.gauss.scale_tril.transpose(-1, -2), + self._iLyy_Kyx, + upper=True, + ) + + @lazy_property + def _iLxx_Kxy(self) -> Tensor: + r"""Cov(X, X)^{-1/2}Cov(X, Y)`.""" + return torch.linalg.solve_triangular( + self.trunc.scale_tril, self.cross_covariance_matrix, upper=False + ) + + @lazy_property + def _iKxx_Kxy(self) -> Tensor: + r"""Cov(X, X)^{-1}Cov(X, Y)`.""" + return torch.linalg.solve_triangular( + self.trunc.scale_tril.transpose(-1, -2), + self._iLxx_Kxy, + upper=True, + ) + + @lazy_property + def _K_schur_Kyy(self) -> Tensor: + r"""Cov(X, X) - Cov(X, Y)Cov(Y, Y)^{-1} Cov(Y, X)`.""" + beta = self._iLyy_Kyx + return self.trunc.covariance_matrix - beta.transpose(-1, -2) @ beta diff --git a/botorch/utils/probability/utils.py b/botorch/utils/probability/utils.py new file mode 100644 index 0000000000..a80000afdf --- /dev/null +++ b/botorch/utils/probability/utils.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from functools import lru_cache +from math import pi +from numbers import Number +from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, Union + +import torch +from numpy.polynomial.legendre import leggauss as numpy_leggauss +from torch import BoolTensor, LongTensor, Tensor + +CaseNd = Tuple[Callable[[], BoolTensor], Callable[[BoolTensor], Tensor]] + +_inv_sqrt_2pi = (2 * pi) ** -0.5 +_neg_inv_sqrt2 = -(2**-0.5) +STANDARDIZED_RANGE: Tuple[float, float] = (-1e6, 1e6) + + +def case_dispatcher( + out: Tensor, + cases: Iterable[CaseNd] = (), + default: Callable[[BoolTensor], Tensor] = None, +) -> Tensor: + active = None + for closure, func in cases: + pred = closure() + if not pred.any(): + continue + + mask = pred if (active is None) else pred & active + if not mask.any(): + continue + + if mask.all(): # where possible, use Ellipsis to avoid indexing + out[...] = func(...) + return out + + out[mask] = func(mask) + if active is None: + active = ~mask + else: + active[mask] = False + + if not active.any(): + break + + if default is not None: + if active is None: + out[...] = default(...) + elif active.any(): + out[active] = default(active) + + return out + + +@lru_cache(maxsize=None) +def get_constants( + values: Union[Number, Iterator[Number]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> Union[Tensor, Tuple[Tensor, ...]]: + r"""Returns scalar-valued Tensors containing each of the given constants. + Used to expedite tensor operations involving scalar arithmetic. Note that + the returned Tensors should not be modified in-place.""" + if isinstance(values, Number): + return torch.full((), values, dtype=dtype, device=device) + + return tuple(torch.full((), val, dtype=dtype, device=device) for val in values) + + +def get_constants_like( + values: Union[Number, Iterator[Number]], + ref: Tensor, +) -> Union[Tensor, Iterator[Tensor]]: + return get_constants(values, device=ref.device, dtype=ref.dtype) + + +def gen_positional_indices( + shape: torch.Size, + dim: int, + device: Optional[torch.device] = None, +) -> Iterator[torch.LongTensor]: + ndim = len(shape) + _dim = ndim + dim if dim < 0 else dim + if _dim >= ndim or _dim < 0: + raise ValueError(f"dim={dim} invalid for shape {shape}.") + + cumsize = shape[_dim + 1 :].numel() + for i, s in enumerate(reversed(shape[: _dim + 1])): + yield torch.arange(0, s * cumsize, cumsize, device=device)[(...,) + i * (None,)] + cumsize *= s + + +def build_positional_indices( + shape: torch.Size, + dim: int, + device: Optional[torch.device] = None, +) -> LongTensor: + return sum(gen_positional_indices(shape=shape, dim=dim, device=device)) + + +@lru_cache(maxsize=None) +def leggauss(deg: int, **tkwargs: Any) -> Tuple[Tensor, Tensor]: + x, w = numpy_leggauss(deg) + return torch.as_tensor(x, **tkwargs), torch.as_tensor(w, **tkwargs) + + +def ndtr(x: Tensor) -> Tensor: + r"""Standard normal CDF.""" + half, ninv_sqrt2 = get_constants_like((0.5, _neg_inv_sqrt2), x) + return half * torch.erfc(ninv_sqrt2 * x) + + +def phi(x: Tensor) -> Tensor: + r"""Standard normal PDF.""" + inv_sqrt_2pi, neg_half = get_constants_like((_inv_sqrt_2pi, -0.5), x) + return inv_sqrt_2pi * (neg_half * x.square()).exp() + + +def swap_along_dim_( + values: Tensor, + i: Union[int, LongTensor], + j: Union[int, LongTensor], + dim: int, + buffer: Optional[Tensor] = None, +) -> Tensor: + dim = values.ndim + dim if dim < 0 else dim + if dim and (isinstance(i, Tensor) or isinstance(j, Tensor)): + # Handle n-dimensional batches of heterogeneous swaps via linear indexing + if isinstance(i, Tensor) and i.shape != values.shape[:dim]: + raise ValueError("Batch shapes of `i` and `values` do not match.") + + if isinstance(j, Tensor) and j.shape != values.shape[:dim]: + raise ValueError("Batch shapes of `j` and `values` do not match.") + + start = build_positional_indices(values.shape[: dim + 1], -2, values.device) + swap_along_dim_( + values.view(-1, *values.shape[dim + 1 :]), + i=(start + i).view(-1), + j=(start + j).view(-1), + dim=0, + buffer=buffer, + ) + else: + # Base cases: homogeneous swaps and 1-dimenensional heterogeneous swaps + ctx = tuple(slice(None) for _ in range(dim)) + i = ctx + (i,) + j = ctx + (j,) + + if buffer is None: + buffer = values[i].clone() + else: + buffer.copy_(values[i]) + + values[i] = values[j] + values[j] = buffer + + return values diff --git a/botorch/utils/safe_math.py b/botorch/utils/safe_math.py new file mode 100644 index 0000000000..fa8e617358 --- /dev/null +++ b/botorch/utils/safe_math.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import math + +import torch +from botorch.utils.constants import get_constants_like +from torch import finfo, Tensor + + +# Unary ops +def exp(x: Tensor, **kwargs) -> Tensor: + info = finfo(x.dtype) + maxexp = get_constants_like(math.log(info.max) - 1e-4, x) + return torch.exp(x.clip(max=maxexp), **kwargs) + + +def log(x: Tensor, **kwargs) -> Tensor: + info = finfo(x.dtype) + return torch.log(x.clip(min=info.tiny), **kwargs) + + +# Binary ops +def add(a: Tensor, b: Tensor, **kwargs) -> Tensor: + _0 = get_constants_like(0, a) + case = a.isinf() & b.isinf() & (a != b) + return torch.where(case, _0, a + b) + + +def sub(a: Tensor, b: Tensor) -> Tensor: + _0 = get_constants_like(0, a) + case = (a.isinf() & b.isinf()) & (a == b) + return torch.where(case, _0, a - b) + + +def div(a: Tensor, b: Tensor) -> Tensor: + _0, _1 = get_constants_like(values=(0, 1), ref=a) + case = ((a == _0) & (b == _0)) | (a.isinf() & a.isinf()) + return torch.where(case, torch.where(a != b, -_1, _1), a / torch.where(case, _1, b)) + + +def mul(a: Tensor, b: Tensor) -> Tensor: + _0 = get_constants_like(values=0, ref=a) + case = (a.isinf() & (b == _0)) | (b.isinf() & (a == _0)) + return torch.where(case, _0, a * torch.where(case, _0, b)) diff --git a/test/utils/probability/__init__.py b/test/utils/probability/__init__.py new file mode 100644 index 0000000000..4b87eb9e4d --- /dev/null +++ b/test/utils/probability/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/utils/probability/test_bvn.py b/test/utils/probability/test_bvn.py new file mode 100644 index 0000000000..4d4ba90d6e --- /dev/null +++ b/test/utils/probability/test_bvn.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Any, Callable, Dict, Optional, Tuple, Union +from unittest import TestCase + +import torch +from botorch.exceptions import UnsupportedError +from botorch.utils.probability.bvn import ( + _bvnu_polar, + _bvnu_taylor, + bvn, + bvnu, + bvnmom, + Phi +) +from torch import Tensor + + +def run_gaussian_estimator( + estimator: Callable[[Tensor], Tuple[Tensor, Union[Tensor, float, int]]], + sqrt_cov: Tensor, + num_samples: int, + batch_limit: Optional[int] = None, + seed: Optional[int] = None, +) -> Tensor: + + if batch_limit is None: + batch_limit = num_samples + + ndim = sqrt_cov.shape[-1] + tkwargs = {"dtype": sqrt_cov.dtype, "device": sqrt_cov.device} + counter = 0 + numerator = 0 + denominator = 0 + + with torch.random.fork_rng(): + if seed: + torch.random.manual_seed(seed) + + while counter < num_samples: + batch_size = min(batch_limit, num_samples - counter) + samples = torch.tensordot( + torch.randn(batch_size, ndim, **tkwargs), + sqrt_cov, + dims=([1], [-1]), + ) + + batch_numerator, batch_denominator = estimator(samples) + counter = counter + batch_size + numerator = numerator + batch_numerator + denominator = denominator + batch_denominator + + return numerator / denominator, denominator + + +class TestBVN(TestCase): + def setUp( + self, + nprobs_per_coeff: int = 3, + bound_range: Tuple[float, float] = (-3.0, 3.0), + mc_num_samples: int = 10000, + mc_batch_limit: int = 1000, + mc_atol_multiplier: float = 4.0, + seed: int = 1, + dtype: torch.dtype = torch.float64, + device: Optional[torch.device] = None, + ): + self.seed = seed + self.dtype = dtype + self.device = device + self.nprobs_per_coeff = nprobs_per_coeff + self.mc_num_samples = mc_num_samples + self.mc_batch_limit = mc_batch_limit + self.mc_atol_multiplier = mc_atol_multiplier + + pos_coeffs = torch.cat( + [ + torch.linspace(0, 1, 5, **self.tkwargs), + torch.tensor([0.01, 0.05, 0.924, 0.925, 0.99], **self.tkwargs), + ] + ) + self.correlations = torch.cat([pos_coeffs, -pos_coeffs[1:]]) + + with torch.random.fork_rng(): + torch.manual_seed(0) + _lower = torch.rand( + nprobs_per_coeff, len(self.correlations), 2, **self.tkwargs + ) + _upper = _lower + (1 - _lower) * torch.rand_like(_lower) + + self.lower_bounds = bound_range[0] + (bound_range[1] - bound_range[0]) * _lower + self.upper_bounds = bound_range[0] + (bound_range[1] - bound_range[0]) * _upper + + self.sqrt_covariances = torch.zeros( + len(self.correlations), 2, 2, **self.tkwargs + ) + self.sqrt_covariances[:, 0, 0] = 1 + self.sqrt_covariances[:, 1, 0] = self.correlations + self.sqrt_covariances[:, 1, 1] = (1 - self.correlations**2) ** 0.5 + + def gen_seed(self, low: int = 0, high: int = 2**30) -> int: + with torch.random.fork_rng(): + torch.random.manual_seed(self.seed) + seed = torch.randint(low=low, high=high, size=()) + self.seed += 1 + return int(seed) + + @property + def tkwargs(self) -> Dict[str, Any]: + return {"dtype": self.dtype, "device": self.device} + + @property + def xl(self): + return self.lower_bounds[..., 0] + + @property + def xu(self): + return self.upper_bounds[..., 0] + + @property + def yl(self): + return self.lower_bounds[..., 1] + + @property + def yu(self): + return self.upper_bounds[..., 1] + + def test_bvnu_polar(self) -> None: + r"""Test special cases where bvnu admits closed-form solutions. + + Note: inf should not be passed to _bvnu as bounds, use big numbers instead. + """ + use_polar = self.correlations.abs() < 0.925 + r = self.correlations[use_polar] + xl = self.xl[..., use_polar] + yl = self.yl[..., use_polar] + with self.subTest(msg="exact_unconstrained"): + prob = _bvnu_polar(r, torch.full_like(r, -1e16), torch.full_like(r, -1e16)) + self.assertTrue(torch.allclose(prob, torch.ones_like(prob))) + + with self.subTest(msg="exact_marginal"): + prob = _bvnu_polar( + r.expand_as(yl), + torch.full_like(xl, -1e16), + yl, + ) + test = Phi(-yl) # same as: 1 - P(y < yl) + self.assertTrue(torch.allclose(prob, test)) + + with self.subTest(msg="exact_independent"): + prob = _bvnu_polar(torch.zeros_like(xl), xl, yl) + test = Phi(-xl) * Phi(-yl) + self.assertTrue(torch.allclose(prob, test)) + + def test_bvnu_taylor(self) -> None: + r"""Test special cases where bvnu admits closed-form solutions. + + Note: inf should not be passed to _bvnu as bounds, use big numbers instead. + """ + use_taylor = self.correlations.abs() >= 0.925 + r = self.correlations[use_taylor] + xl = self.xl[..., use_taylor] + yl = self.yl[..., use_taylor] + with self.subTest(msg="exact_unconstrained"): + prob = _bvnu_taylor(r, torch.full_like(r, -1e16), torch.full_like(r, -1e16)) + self.assertTrue(torch.allclose(prob, torch.ones_like(prob))) + + with self.subTest(msg="exact_marginal"): + prob = _bvnu_taylor( + r.expand_as(yl), + torch.full_like(xl, -1e16), + yl, + ) + test = Phi(-yl) # same as: 1 - P(y < yl) + self.assertTrue(torch.allclose(prob, test)) + + with self.subTest(msg="exact_independent"): + prob = _bvnu_polar(torch.zeros_like(xl), xl, yl) + test = Phi(-xl) * Phi(-yl) + self.assertTrue(torch.allclose(prob, test)) + + def test_bvn(self): + r"""Monte Carlo unit test for `bvn`.""" + r = self.correlations.repeat(self.nprobs_per_coeff, 1) + solves = bvn(r, self.xl, self.yl, self.xu, self.yu) + with self.assertRaisesRegex(UnsupportedError, "same shape"): + bvn(r[..., :1], self.xl, self.yl, self.xu, self.yu) + + with self.assertRaisesRegex(UnsupportedError, "same shape"): + bvnu(r[..., :1], r, r) + + def _estimator(samples): + accept = torch.logical_and( + (samples > self.lower_bounds.unsqueeze(1)).all(-1), + (samples < self.upper_bounds.unsqueeze(1)).all(-1), + ) + numerator = torch.count_nonzero(accept, dim=1).double() + denominator = len(samples) + return numerator, denominator + + estimates, _ = run_gaussian_estimator( + estimator=_estimator, + sqrt_cov=self.sqrt_covariances, + num_samples=self.mc_num_samples, + batch_limit=self.mc_batch_limit, + seed=self.gen_seed(), + ) + + atol = self.mc_atol_multiplier * (self.mc_num_samples**-0.5) + self.assertTrue(torch.allclose(estimates, solves, rtol=0, atol=atol)) + + def test_bvnmom(self): + r"""Monte Carlo unit test for `bvn`.""" + r = self.correlations.repeat(self.nprobs_per_coeff, 1) + Ex, Ey = bvnmom(r, self.xl, self.yl, self.xu, self.yu) + with self.assertRaisesRegex(UnsupportedError, "same shape"): + bvnmom(r[..., :1], self.xl, self.yl, self.xu, self.yu) + + def _estimator(samples): + accept = torch.logical_and( + (samples > self.lower_bounds.unsqueeze(1)).all(-1), + (samples < self.upper_bounds.unsqueeze(1)).all(-1), + ) + numerator = torch.einsum("snd,psn->pnd", samples, accept.to(samples.dtype)) + denominator = torch.count_nonzero(accept, dim=1).to(samples.dtype) + return numerator, denominator.unsqueeze(-1) + + estimates, num_samples = run_gaussian_estimator( + estimator=_estimator, + sqrt_cov=self.sqrt_covariances, + num_samples=self.mc_num_samples, + batch_limit=self.mc_batch_limit, + seed=self.gen_seed(), + ) + for n, ex, ey, _ex, _ey in zip( + *map(torch.ravel, (num_samples.squeeze(-1), Ex, Ey, *estimates.unbind(-1))) + ): + if n: + atol = self.mc_atol_multiplier * (n**-0.5) + self.assertTrue(torch.allclose(ex, _ex, rtol=0, atol=atol)) + self.assertTrue(torch.allclose(ey, _ey, rtol=0, atol=atol)) diff --git a/test/utils/probability/test_lin_ess.py b/test/utils/probability/test_lin_ess.py new file mode 100644 index 0000000000..3d725cd39b --- /dev/null +++ b/test/utils/probability/test_lin_ess.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import torch +from botorch.exceptions.errors import BotorchError +from botorch.utils.probability.lin_ess import LinearEllipticalSliceSampler +from botorch.utils.testing import BotorchTestCase + + +class TestLinearEllipticalSliceSampler(BotorchTestCase): + def test_univariate(self): + for dtype in (torch.float, torch.double): + tkwargs = {"device": self.device, "dtype": dtype} + # test input validation + with self.assertRaises(BotorchError) as e: + LinearEllipticalSliceSampler() + self.assertTrue( + "requires either inequality constraints or bounds" in str(e) + ) + # special case: N(0, 1) truncated to negative numbers + A = torch.ones(1, 1, **tkwargs) + b = torch.zeros(1, 1, **tkwargs) + x0 = -torch.rand(1, 1, **tkwargs) + sampler = LinearEllipticalSliceSampler( + inequality_constraints=(A, b), interior_point=x0 + ) + self.assertIsNone(sampler._mean) + self.assertIsNone(sampler._covariance_root) + self.assertTrue(torch.equal(sampler._x, x0)) + self.assertTrue(torch.equal(sampler.x0, x0)) + samples = sampler.draw(n=3) + self.assertEqual(samples.shape, torch.Size([3, 1])) + self.assertLessEqual(samples.max().item(), 0.0) + self.assertFalse(torch.equal(sampler._x, x0)) + # same case as above, but instantiated with bounds + sampler = LinearEllipticalSliceSampler( + bounds=torch.tensor([[-float("inf")], [0.0]], **tkwargs), + interior_point=x0, + ) + self.assertIsNone(sampler._mean) + self.assertIsNone(sampler._covariance_root) + self.assertTrue(torch.equal(sampler._x, x0)) + self.assertTrue(torch.equal(sampler.x0, x0)) + samples = sampler.draw(n=3) + self.assertEqual(samples.shape, torch.Size([3, 1])) + self.assertLessEqual(samples.max().item(), 0.0) + self.assertFalse(torch.equal(sampler._x, x0)) + # same case as above, but with redundant constraints + sampler = LinearEllipticalSliceSampler( + inequality_constraints=(A, b), + bounds=torch.tensor([[-float("inf")], [1.0]], **tkwargs), + interior_point=x0, + ) + self.assertIsNone(sampler._mean) + self.assertIsNone(sampler._covariance_root) + self.assertTrue(torch.equal(sampler._x, x0)) + self.assertTrue(torch.equal(sampler.x0, x0)) + samples = sampler.draw(n=3) + self.assertEqual(samples.shape, torch.Size([3, 1])) + self.assertLessEqual(samples.max().item(), 0.0) + self.assertFalse(torch.equal(sampler._x, x0)) + # narrow feasible region, automatically find interior point + sampler = LinearEllipticalSliceSampler( + inequality_constraints=(A, b), + bounds=torch.tensor([[-0.25], [float("inf")]], **tkwargs), + ) + self.assertIsNone(sampler._mean) + self.assertIsNone(sampler._covariance_root) + self.assertTrue(torch.all(sampler._is_feasible(sampler.x0))) + samples = sampler.draw(n=3) + self.assertEqual(samples.shape, torch.Size([3, 1])) + self.assertLessEqual(samples.max().item(), 0.0) + self.assertGreaterEqual(samples.min().item(), -0.25) + self.assertFalse(torch.equal(sampler._x, x0)) + # non-standard mean / variance + mean = torch.tensor([[0.25]], **tkwargs) + covariance_matrix = torch.tensor([[4.0]], **tkwargs) + with self.assertRaises(ValueError) as e: + LinearEllipticalSliceSampler( + bounds=torch.tensor([[0.0], [float("inf")]], **tkwargs), + covariance_matrix=covariance_matrix, + covariance_root=covariance_matrix.sqrt(), + ) + self.assertTrue( + "either covariance_matrix or covariance_root, not both" in str(e) + ) + with self.assertRaises(ValueError) as e: + LinearEllipticalSliceSampler( + bounds=torch.tensor([[0.0], [float("inf")]], **tkwargs), + covariance_matrix=-covariance_matrix, + ) + self.assertTrue( + "Covariance matrix is not positive definite" in str(e) + ) + sampler = LinearEllipticalSliceSampler( + bounds=torch.tensor([[0.0], [float("inf")]], **tkwargs), + mean=mean, + covariance_matrix=covariance_matrix, + ) + self.assertTrue(torch.equal(sampler._mean, mean)) + self.assertTrue( + torch.equal(sampler._covariance_root, covariance_matrix.sqrt()) + ) + self.assertTrue(torch.all(sampler._is_feasible(sampler.x0))) + samples = sampler.draw(n=4) + self.assertEqual(samples.shape, torch.Size([4, 1])) + self.assertGreaterEqual(samples.min().item(), 0.0) + self.assertFalse(torch.equal(sampler._x, x0)) + + def test_bivariate(self): + for dtype in (torch.float, torch.double): + tkwargs = {"device": self.device, "dtype": dtype} + # special case: N(0, I) truncated to positive numbers + A = -torch.eye(2, **tkwargs) + b = torch.zeros(2, 1, **tkwargs) + sampler = LinearEllipticalSliceSampler(inequality_constraints=(A, b)) + self.assertIsNone(sampler._mean) + self.assertIsNone(sampler._covariance_root) + self.assertTrue(torch.all(sampler._is_feasible(sampler.x0))) + samples = sampler.draw(n=3) + self.assertEqual(samples.shape, torch.Size([3, 2])) + self.assertGreaterEqual(samples.min().item(), 0.0) + self.assertFalse(torch.equal(sampler._x, sampler.x0)) + # same case as above, but instantiated with bounds + sampler = LinearEllipticalSliceSampler( + bounds=torch.tensor( + [[0.0, 0.0], [float("inf"), float("inf")]], **tkwargs + ), + ) + self.assertIsNone(sampler._mean) + self.assertIsNone(sampler._covariance_root) + self.assertTrue(torch.all(sampler._is_feasible(sampler.x0))) + samples = sampler.draw(n=3) + self.assertEqual(samples.shape, torch.Size([3, 2])) + self.assertGreaterEqual(samples.min().item(), 0.0) + self.assertFalse(torch.equal(sampler._x, sampler.x0)) + # A case with bounded domain and non-standard mean and covariance + mean = -3.0 * torch.ones(2, 1, **tkwargs) + covariance_matrix = torch.tensor([[4.0, 2.0], [2.0, 2.0]], **tkwargs) + bounds = torch.tensor( + [[-float("inf"), -float("inf")], [0.0, 0.0]], **tkwargs + ) + A = torch.ones(1, 2, **tkwargs) + b = torch.tensor([[-2.0]], **tkwargs) + sampler = LinearEllipticalSliceSampler( + inequality_constraints=(A, b), + bounds=bounds, + mean=mean, + covariance_matrix=covariance_matrix, + ) + self.assertTrue(torch.equal(sampler._mean, mean)) + covar_root_xpct = torch.tensor([[2.0, 0.0], [1.0, 1.0]], **tkwargs) + self.assertTrue(torch.equal(sampler._covariance_root, covar_root_xpct)) + samples = sampler.draw(n=3) + self.assertEqual(samples.shape, torch.Size([3, 2])) + self.assertTrue(sampler._is_feasible(samples.t()).all()) + self.assertFalse(torch.equal(sampler._x, sampler.x0)) diff --git a/test/utils/probability/test_linalg.py b/test/utils/probability/test_linalg.py new file mode 100644 index 0000000000..d320972141 --- /dev/null +++ b/test/utils/probability/test_linalg.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from copy import deepcopy +from unittest import TestCase + +import torch +from botorch.utils.probability.linalg import PivotedCholesky + + +class TestPivotedCholesky(TestCase): + def setUp(self): + n = 5 + with torch.random.fork_rng(): + torch.random.manual_seed(0) + + matrix = torch.randn(2, n, n) + matrix = matrix @ matrix.transpose(-1, -2) + + diag = matrix.diagonal(dim1=-2, dim2=-1).sqrt() + idiag = diag.reciprocal().unsqueeze(-1) + + piv_chol = PivotedCholesky( + step=0, + tril=(idiag * matrix * idiag.transpose(-2, -1)).tril(), + perm=torch.arange(n)[None].expand(len(matrix), n).contiguous(), + diag=diag.clone(), + ) + + self.diag = diag + self.matrix = matrix + self.piv_chol = piv_chol + + self.piv_chol.update_() + self.piv_chol.pivot_(torch.tensor([2, 3])) + self.piv_chol.update_() + + def test_update_(self): + # Construct permuted matrices A + n = self.matrix.shape[-1] + A = (1 / self.diag).unsqueeze(-1) * self.matrix * (1 / self.diag).unsqueeze(-2) + A = A.gather(-1, self.piv_chol.perm.unsqueeze(-2).repeat(1, n, 1)) + A = A.gather(-2, self.piv_chol.perm.unsqueeze(-1).repeat(1, 1, n)) + + # Test upper left block + L = torch.linalg.cholesky(A[..., :2, :2]) + self.assertTrue(L.allclose(self.piv_chol.tril[..., :2, :2])) + + # Test lower left block + beta = torch.linalg.solve_triangular(L, A[..., :2:, 2:], upper=False) + self.assertTrue( + beta.transpose(-1, -2).allclose(self.piv_chol.tril[..., 2:, :2]) + ) + + # Test lower right block + schur = A[..., 2:, 2:] - beta.transpose(-1, -2) @ beta + self.assertTrue(schur.tril().allclose(self.piv_chol.tril[..., 2:, 2:])) + + def test_pivot_(self): + piv_chol = deepcopy(self.piv_chol) + self.assertEqual(piv_chol.perm.tolist(), [[0, 2, 1, 3, 4], [0, 3, 2, 1, 4]]) + + piv_chol.pivot_(torch.tensor([2, 3])) + self.assertEqual(piv_chol.perm.tolist(), [[0, 2, 1, 3, 4], [0, 3, 1, 2, 4]]) + self.assertTrue(piv_chol.tril[0].equal(self.piv_chol.tril[0])) + + A = self.piv_chol.tril[1] + B = piv_chol.tril[1] + self.assertTrue(A[2:4, :2].equal(B[2:4, :2].roll(1, 0))) + self.assertTrue(A[4:, 2:4].equal(B[4:, 2:4].roll(1, 1))) + + def test_concat(self): + A = self.piv_chol.expand(2, 2) + B = self.piv_chol.expand(1, 2) + B = B.concat(B, dim=0) + for key in ("tril", "perm", "diag"): + self.assertTrue(getattr(A, key).equal(getattr(B, key))) + + def test_clone(self): + self.piv_chol.diag.requires_grad_(True) + try: + other = self.piv_chol.clone() + for key in ("tril", "perm", "diag"): + a = getattr(self.piv_chol, key) + b = getattr(other, key) + self.assertTrue(a.equal(b)) + self.assertFalse(a is b) + + other.diag.sum().backward() + self.assertTrue(self.piv_chol.diag.grad.eq(1).all()) + finally: + self.piv_chol.diag.requires_grad_(False) + + def test_detach(self): + self.piv_chol.diag.requires_grad_(True) + try: + other = self.piv_chol.detach() + for key in ("tril", "perm", "diag"): + a = getattr(self.piv_chol, key) + b = getattr(other, key) + self.assertTrue(a.equal(b)) + self.assertFalse(a is b) + + with self.assertRaisesRegex(RuntimeError, "does not have a grad_fn"): + other.diag.sum().backward() + + finally: + self.piv_chol.diag.requires_grad_(False) + + def test_expand(self): + other = self.piv_chol.expand(3, 2) + for key in ("tril", "perm", "diag"): + a = getattr(self.piv_chol, key) + b = getattr(other, key) + self.assertEqual(b.shape[: -a.ndim], (3,)) + self.assertTrue(b._base is a) diff --git a/test/utils/probability/test_mvnxpb.py b/test/utils/probability/test_mvnxpb.py new file mode 100644 index 0000000000..d25bc9d088 --- /dev/null +++ b/test/utils/probability/test_mvnxpb.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from functools import partial +from itertools import count +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union +from unittest import TestCase + +import torch +from botorch.utils.probability.linalg import PivotedCholesky +from botorch.utils.probability.mvnxpb import MVNXPB +from torch import Tensor + + +def run_gaussian_estimator( + estimator: Callable[[Tensor], Tuple[Tensor, Union[Tensor, float, int]]], + sqrt_cov: Tensor, + num_samples: int, + batch_limit: Optional[int] = None, + seed: Optional[int] = None, +) -> Tensor: + + if batch_limit is None: + batch_limit = num_samples + + ndim = sqrt_cov.shape[-1] + tkwargs = {"dtype": sqrt_cov.dtype, "device": sqrt_cov.device} + counter = 0 + numerator = 0 + denominator = 0 + with torch.random.fork_rng(): + if seed: + torch.random.manual_seed(seed) + + while counter < num_samples: + batch_size = min(batch_limit, num_samples - counter) + samples = torch.tensordot( + torch.randn(batch_size, ndim, **tkwargs), + sqrt_cov, + dims=([1], [-1]), + ) + + batch_numerator, batch_denominator = estimator(samples) + counter = counter + batch_size + numerator = numerator + batch_numerator + denominator = denominator + batch_denominator + + return numerator / denominator, denominator + + +class TestMVNXPB(TestCase): + def setUp( + self, + ndims: Sequence[int] = (4, 8, 16), + batch_shape: Sequence[int] = (4,), + bound_range: Tuple[float, float] = (-5.0, 5.0), + mc_num_samples: int = 100000, + mc_batch_limit: int = 10000, + mc_atol_multiplier: float = 4.0, + seed: int = 1, + dtype: torch.dtype = torch.float64, + device: Optional[torch.device] = None, + ): + self.seed = seed + self.dtype = dtype + self.device = device + self.mc_num_samples = mc_num_samples + self.mc_batch_limit = mc_batch_limit + self.mc_atol_multiplier = mc_atol_multiplier + + self.bounds = [] + self.sqrt_covariances = [] + with torch.random.fork_rng(): + torch.random.manual_seed(self.gen_seed()) + for n in ndims: + self.bounds.append(self.gen_bounds(n, batch_shape, bound_range)) + self.sqrt_covariances.append( + self.gen_covariances(n, batch_shape, as_sqrt=True) + ) + + # Create a toy MVNXPB instance for API testing + tril = torch.rand([4, 2, 3, 3]) + diag = torch.rand([4, 2, 3]) + perm = torch.stack([torch.randperm(3) for _ in range(8)]).reshape(4, 2, 3) + + self.toy_solver = MVNXPB( + covariance_matrix=None, + perm=perm, + bounds=torch.rand(4, 2, 3, 2).cumsum(dim=-1), + piv_chol=PivotedCholesky(tril=tril, perm=perm, diag=diag, step=0), + plug_ins=torch.randn(4, 2, 3), + log_prob=torch.rand(4, 2), + ) + + def gen_seed(self, low: int = 0, high: int = 2**30) -> int: + with torch.random.fork_rng(): + torch.random.manual_seed(self.seed) + seed = torch.randint(low=low, high=high, size=()) + self.seed += 1 + return int(seed) + + def gen_covariances( + self, + ndim: int, + batch_shape: Sequence[int] = (), + as_sqrt: bool = False, + ) -> Tensor: + shape = tuple(batch_shape) + (ndim, ndim) + eigvals = -torch.rand(shape[:-1], **self.tkwargs).log() # exponential rvs + orthmat = torch.linalg.svd(torch.randn(shape, **self.tkwargs)).U + sqrt_covar = orthmat * torch.sqrt(eigvals).unsqueeze(-2) + return sqrt_covar if as_sqrt else sqrt_covar @ sqrt_covar.transpose(-2, -1) + + def gen_bounds( + self, + ndim: int, + batch_shape: Sequence[int] = (), + bound_range: Optional[Tuple[float, float]] = None, + ) -> Tuple[Tensor, Tensor]: + shape = tuple(batch_shape) + (ndim,) + lower = torch.rand(shape, **self.tkwargs) + upper = lower + (1 - lower) * torch.rand_like(lower) + if bound_range is not None: + lower = bound_range[0] + (bound_range[1] - bound_range[0]) * lower + upper = bound_range[0] + (bound_range[1] - bound_range[0]) * upper + + return torch.stack([lower, upper], dim=-1) + + @property + def tkwargs(self) -> Dict[str, Any]: + return {"dtype": self.dtype, "device": self.device} + + def assertEqualMXNBPB(self, A: MVNXPB, B: MVNXPB): + for key, a in A.asdict().items(): + b = getattr(B, key) + if isinstance(a, PivotedCholesky): + continue + elif isinstance(a, torch.Tensor): + self.assertTrue(a.equal(b)) + else: + self.assertEqual(a, b) + + for key in ("perm", "tril", "diag"): + a = getattr(A.piv_chol, key) + b = getattr(B.piv_chol, key) + self.assertTrue(a.equal(b)) + + def test_solve(self): + r"""Monte Carlo unit test for `solve`.""" + + def _estimator(samples, bounds): + accept = torch.logical_and( + (samples > bounds[..., 0]).all(-1), + (samples < bounds[..., 1]).all(-1), + ) + numerator = torch.count_nonzero(accept, dim=0).double() + denominator = len(samples) + return numerator, denominator + + base_seed = self.gen_seed() + for i, sqrt_cov, bounds in zip(count(), self.sqrt_covariances, self.bounds): + estimates, _ = run_gaussian_estimator( + estimator=partial(_estimator, bounds=bounds), + sqrt_cov=sqrt_cov, + num_samples=self.mc_num_samples, + batch_limit=self.mc_batch_limit, + seed=base_seed + i, + ) + + cov = sqrt_cov @ sqrt_cov.transpose(-2, -1) + solver = MVNXPB(cov, bounds) + solver.solve() + + atol = self.mc_atol_multiplier * (self.mc_num_samples**-0.5) + for est, prob in zip(estimates, solver.log_prob.exp()): + if est == 0.0: + continue + + self.assertTrue(torch.allclose(est, prob, rtol=0, atol=atol)) + + def test_augment(self): + r"""Test `augment`.""" + with torch.random.fork_rng(): + torch.random.manual_seed(self.gen_seed()) + + # Pick a set of subproblems at random + index = torch.randint(low=0, high=len(self.sqrt_covariances), size=()) + sqrt_cov = self.sqrt_covariances[index] + cov = sqrt_cov @ sqrt_cov.transpose(-2, -1) + bounds = self.bounds[index] + + # Partially solve for `N`-dimensional integral + N = cov.shape[-1] + n = torch.randint(low=1, high=N - 1, size=()) + full = MVNXPB(cov, bounds=bounds) # , should_solve=False) + full.solve(num_steps=n) + + # Reorder terms according according to `full.perm` + perm = full.perm.detach().clone() + _cov = cov.gather(-2, perm.unsqueeze(-1).repeat(1, 1, N)) + _cov = _cov.gather(-1, perm.unsqueeze(-2).repeat(1, N, 1)) + _bounds = bounds.gather(-2, perm.unsqueeze(-1).repeat(1, 1, 2)) + + # Solve for same `n`-dimensional integral as `full.solve(num_steps=n)` + init = MVNXPB(_cov[..., :n, :n], _bounds[..., :n, :]) + init.solve() + + # Augment solver to include remaining `m = N - n` random variables + augm = init.augment( + covariance_matrix=_cov[..., n:, n:], + cross_covariance_matrix=_cov[..., n:, :n], + bounds=_bounds[..., n:, :], + ) + + def _compare(a: MVNXPB, b: MVNXPB, tril_only: bool = False) -> None: + self.assertTrue(a.step == b.step) + self.assertTrue(a.perm.equal(perm.gather(-1, b.perm))) + self.assertTrue(a.bounds.allclose(b.bounds)) + self.assertTrue(a.plug_ins.allclose(b.plug_ins, equal_nan=True)) + self.assertTrue(a.log_prob.allclose(b.log_prob, equal_nan=True)) + if a.log_prob_iter is not None or b.log_prob_iter is not None: + self.assertTrue( + a.log_prob_iter.allclose(b.log_prob_iter, equal_nan=True) + ) + + _a = a.piv_chol + _b = b.piv_chol + self.assertTrue(_a.step == _b.step) + self.assertTrue(_a.perm.equal(perm.gather(-1, _b.perm))) + self.assertTrue(_a.tril.allclose(_b.tril)) + + with self.subTest(msg="initialization"): + _compare(full, augm, tril_only=True) + + # TODO: This may fail for higher-dim integrals, where the downstream + # effects of numerical differences will accumulate. + augm.piv_chol.tril = full.piv_chol.tril.detach().clone() + with self.subTest(msg="solve"): + full.solve() + augm.solve() + _compare(full, augm, tril_only=False) + + def test_getitem(self): + with torch.random.fork_rng(): + torch.random.manual_seed(1) + mask = torch.rand(self.toy_solver.log_prob.shape) > 0.5 + + other = self.toy_solver[mask] + for key, b in other.asdict().items(): + a = getattr(self.toy_solver, key) + if isinstance(b, PivotedCholesky): + continue + elif isinstance(b, torch.Tensor): + self.assertTrue(a[mask].equal(b)) + else: + self.assertEqual(a, b) + + for key in ("perm", "tril", "diag"): + a = getattr(self.toy_solver.piv_chol, key)[mask] + b = getattr(other.piv_chol, key) + self.assertTrue(a.equal(b)) + + def test_concat(self): + split = len(self.toy_solver.log_prob) // 2 + other = self.toy_solver[:split].concat(self.toy_solver[split:], dim=0) + self.assertEqualMXNBPB(self.toy_solver, other) + + def test_clone(self): + self.toy_solver.bounds.requires_grad_(True) + try: + other = self.toy_solver.clone() + self.assertEqualMXNBPB(self.toy_solver, other) + for key, a in self.toy_solver.asdict().items(): + if a is None or isinstance(a, int): + continue + b = getattr(other, key) + self.assertFalse(a is b) + + other.bounds.sum().backward() + self.assertTrue(self.toy_solver.bounds.grad.eq(1).all()) + finally: + self.toy_solver.bounds.requires_grad_(False) + + def test_detach(self): + self.toy_solver.bounds.requires_grad_(True) + try: + other = self.toy_solver.detach() + self.assertEqualMXNBPB(self.toy_solver, other) + for key, a in self.toy_solver.asdict().items(): + if a is None or isinstance(a, int): + continue + b = getattr(other, key) + self.assertFalse(a is b) + + with self.assertRaisesRegex(RuntimeError, "does not have a grad_fn"): + other.bounds.sum().backward() + finally: + self.toy_solver.bounds.requires_grad_(False) + + def test_expand(self): + other = self.toy_solver.expand(2, 4, 2) + self.assertEqualMXNBPB(self.toy_solver, other[0]) + self.assertEqualMXNBPB(self.toy_solver, other[1]) + + def test_asdict(self): + for key, val in self.toy_solver.asdict().items(): + self.assertTrue(val is getattr(self.toy_solver, key)) + + def test_from_dict(self): + other = MVNXPB.from_dict(self.toy_solver.asdict()) + self.assertEqualMXNBPB(self.toy_solver, other) diff --git a/test/utils/probability/test_truncated_multivariate_normal.py b/test/utils/probability/test_truncated_multivariate_normal.py new file mode 100644 index 0000000000..7519bc329c --- /dev/null +++ b/test/utils/probability/test_truncated_multivariate_normal.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Sequence, Tuple + +import torch +from botorch.utils.probability.mvnxpb import MVNXPB +from botorch.utils.probability.truncated_multivariate_normal import ( + TruncatedMultivariateNormal, +) +from botorch.utils.testing import BotorchTestCase +from torch import Tensor +from torch.distributions import MultivariateNormal +from torch.special import ndtri + + +class TestTruncatedMultivariateNormal(BotorchTestCase): + def setUp( + self, + ndims: Sequence[Tuple[int, int]] = (2, 4), + lower_quantile_max: float = 0.9, # if these get too far into the tail, naive + upper_quantile_min: float = 0.1, # MC methods will not produce any samples. + num_log_probs: int = 4, + seed: int = 1, + ): + self.seed = seed + self.num_log_probs = num_log_probs + + tkwargs = {"dtype": torch.float64} + self.distributions = [] + self.sqrt_covariances = [] + with torch.random.fork_rng(): + torch.random.manual_seed(self.gen_seed()) + for ndim in ndims: + loc = torch.randn(ndim, **tkwargs) + sqrt_covariance = self.gen_covariances(ndim, as_sqrt=True).to(**tkwargs) + covariance_matrix = sqrt_covariance @ sqrt_covariance.transpose(-1, -2) + std = covariance_matrix.diag().sqrt() + + lb = lower_quantile_max * torch.rand(ndim, **tkwargs) + ub = lb.clip(min=upper_quantile_min) # scratch variable + ub = ub + (1 - ub) * torch.rand(ndim, **tkwargs) + bounds = loc.unsqueeze(-1) + std.unsqueeze(-1) * ndtri( + torch.stack([lb, ub], dim=-1) + ) + + self.distributions.append( + TruncatedMultivariateNormal( + loc=loc, + covariance_matrix=covariance_matrix, + bounds=bounds, + validate_args=True, + ) + ) + self.sqrt_covariances.append(sqrt_covariance) + + def gen_seed(self, low: int = 0, high: int = 2**30) -> int: + with torch.random.fork_rng(): + torch.random.manual_seed(self.seed) + seed = torch.randint(low=low, high=high, size=()) + self.seed += 1 + return int(seed) + + def gen_covariances( + self, + ndim: int, + batch_shape: Sequence[int] = (), + as_sqrt: bool = False, + ) -> Tensor: + shape = tuple(batch_shape) + (ndim, ndim) + eigvals = -torch.rand(shape[:-1]).log() # exponential rvs + orthmat = torch.linalg.svd(torch.randn(shape)).U + sqrt_covar = orthmat * torch.sqrt(eigvals).unsqueeze(-2) + return sqrt_covar if as_sqrt else sqrt_covar @ sqrt_covar.transpose(-2, -1) + + def test_init(self): + trunc = next(iter(self.distributions)) + with self.assertRaisesRegex(SyntaxError, "Missing required argument `bounds`"): + TruncatedMultivariateNormal( + loc=trunc.loc, covariance_matrix=trunc.covariance_matrix + ) + + with self.assertRaisesRegex(ValueError, r"Expected bounds.shape\[-1\] to be 2"): + TruncatedMultivariateNormal( + loc=trunc.loc, + covariance_matrix=trunc.covariance_matrix, + bounds=torch.empty(trunc.covariance_matrix.shape[:-1] + (1,)), + ) + + with self.assertRaisesRegex(ValueError, "`bounds` must be strictly increasing"): + TruncatedMultivariateNormal( + loc=trunc.loc, + covariance_matrix=trunc.covariance_matrix, + bounds=trunc.bounds.roll(shifts=1, dims=-1), + ) + + def test_solver(self): + for trunc in self.distributions: + # Test that solver was setup properly + solver = trunc.solver + self.assertIsInstance(solver, MVNXPB) + self.assertTrue(solver.perm.equal(solver.piv_chol.perm)) + self.assertEqual(solver.step, trunc.covariance_matrix.shape[-1]) + + bounds = torch.gather( + trunc.covariance_matrix.diag().rsqrt().unsqueeze(-1) + * (trunc.bounds - trunc.loc.unsqueeze(-1)), + dim=-2, + index=solver.perm.unsqueeze(-1).expand(*trunc.bounds.shape), + ) + self.assertTrue(solver.bounds.allclose(bounds)) + + # Test that (permuted) covariance matrices match + A = solver.piv_chol.diag.unsqueeze(-1) * solver.piv_chol.tril + A = A @ A.transpose(-2, -1) + + n = A.shape[-1] + B = trunc.covariance_matrix + B = B.gather(-1, solver.perm.unsqueeze(-2).repeat(n, 1)) + B = B.gather(-2, solver.perm.unsqueeze(-1).repeat(1, n)) + self.assertTrue(A.allclose(B)) + + def test_log_prob(self): + with torch.random.fork_rng(): + torch.random.manual_seed(self.gen_seed()) + for trunc in self.distributions: + # Test generic values + vals = trunc.rsample(sample_shape=torch.Size([self.num_log_probs])) + test = MultivariateNormal.log_prob(trunc, vals) - trunc.log_partition + self.assertTrue(test.equal(trunc.log_prob(vals))) + + # Test out of bounds + m = trunc.bounds.shape[-2] // 2 + oob = torch.concat( + [trunc.bounds[..., :m, 0] - 1, trunc.bounds[..., m:, 1] + 1], dim=-1 + ) + self.assertTrue(trunc.log_prob(oob).eq(-float("inf")).all()) + + def test_expand(self): + trunc = next(iter(self.distributions)) + other = trunc.expand(torch.Size([2])) + for key in ("loc", "covariance_matrix", "bounds", "log_partition"): + a = getattr(trunc, key) + self.assertTrue(all(a.equal(b) for b in getattr(other, key).unbind())) diff --git a/test/utils/probability/test_unified_skew_normal.py b/test/utils/probability/test_unified_skew_normal.py new file mode 100644 index 0000000000..d6ae873027 --- /dev/null +++ b/test/utils/probability/test_unified_skew_normal.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Any, Dict, Optional, Sequence, Tuple + +import torch +from botorch.utils.probability.mvnxpb import MVNXPB +from botorch.utils.probability.truncated_multivariate_normal import ( + TruncatedMultivariateNormal, +) +from botorch.utils.probability.unified_skew_normal import UnifiedSkewNormal +from botorch.utils.testing import BotorchTestCase +from torch import Tensor +from torch.distributions import MultivariateNormal +from torch.special import ndtri + + +class TestUnifiedSkewNormal(BotorchTestCase): + def setUp( + self, + ndims: Sequence[Tuple[int, int]] = ((1, 1), (2, 3), (3, 2), (3, 3)), + lower_quantile_max: float = 0.9, # if these get too far into the tail, naive + upper_quantile_min: float = 0.1, # MC methods will not produce any samples. + num_log_probs: int = 4, + mc_num_samples: int = 100000, + mc_num_rsamples: int = 1000, + mc_atol_multiplier: float = 4.0, + seed: int = 1, + dtype: torch.dtype = torch.float64, + device: Optional[torch.device] = None, + ): + self.seed = seed + self.dtype = dtype + self.device = device + self.num_log_probs = num_log_probs + self.mc_num_samples = mc_num_samples + self.mc_num_rsamples = mc_num_rsamples + self.mc_atol_multiplier = mc_atol_multiplier + + self.distributions = [] + self.sqrt_covariances = [] + with torch.random.fork_rng(): + torch.random.manual_seed(self.gen_seed()) + for ndim_x, ndim_y in ndims: + ndim_xy = ndim_x + ndim_y + sqrt_covariance = self.gen_covariances(ndim_xy, as_sqrt=True) + covariance = sqrt_covariance @ sqrt_covariance.transpose(-1, -2) + + loc_x = torch.randn(ndim_x, **self.tkwargs) + cov_x = covariance[:ndim_x, :ndim_x] + std_x = cov_x.diag().sqrt() + lb = lower_quantile_max * torch.rand(ndim_x, **self.tkwargs) + ub = lb.clip(min=upper_quantile_min) # scratch variable + ub = ub + (1 - ub) * torch.rand(ndim_x, **self.tkwargs) + bounds_x = loc_x.unsqueeze(-1) + std_x.unsqueeze(-1) * ndtri( + torch.stack([lb, ub], dim=-1) + ) + + xcov = covariance[:ndim_x, ndim_x:] + trunc = TruncatedMultivariateNormal( + loc=loc_x, + covariance_matrix=cov_x, + bounds=bounds_x, + validate_args=True, + ) + + gauss = MultivariateNormal( + loc=torch.randn(ndim_y, **self.tkwargs), + covariance_matrix=covariance[ndim_x:, ndim_x:], + ) + + self.sqrt_covariances.append(sqrt_covariance) + self.distributions.append( + UnifiedSkewNormal( + trunc=trunc, gauss=gauss, cross_covariance_matrix=xcov + ) + ) + + @property + def tkwargs(self) -> Dict[str, Any]: + return {"dtype": self.dtype, "device": self.device} + + def gen_seed(self, low: int = 0, high: int = 2**30) -> int: + with torch.random.fork_rng(): + torch.random.manual_seed(self.seed) + seed = torch.randint(low=low, high=high, size=()) + self.seed += 1 + return int(seed) + + def gen_covariances( + self, + ndim: int, + batch_shape: Sequence[int] = (), + as_sqrt: bool = False, + ) -> Tensor: + shape = tuple(batch_shape) + (ndim, ndim) + eigvals = -torch.rand(shape[:-1], **self.tkwargs).log() # exponential rvs + orthmat = torch.linalg.svd(torch.randn(shape, **self.tkwargs)).U + sqrt_covar = orthmat * torch.sqrt(eigvals).unsqueeze(-2) + return sqrt_covar if as_sqrt else sqrt_covar @ sqrt_covar.transpose(-2, -1) + + def test_log_prob(self): + with torch.random.fork_rng(): + torch.random.manual_seed(self.gen_seed()) + for usn in self.distributions: + shape = torch.Size([self.num_log_probs]) + vals = usn.gauss.rsample(sample_shape=shape) + + # Manually compute log probabilities + alpha = torch.cholesky_solve( + usn.cross_covariance_matrix.T, usn.gauss.scale_tril + ) + loc_condx = usn.trunc.loc + (vals - usn.gauss.loc) @ alpha + cov_condx = ( + usn.trunc.covariance_matrix - usn.cross_covariance_matrix @ alpha + ) + solver = MVNXPB( + covariance_matrix=cov_condx.repeat(self.num_log_probs, 1, 1), + bounds=usn.trunc.bounds - loc_condx.unsqueeze(-1), + ) + log_probs = ( + solver.solve() + usn.gauss.log_prob(vals) - usn.trunc.log_partition + ) + + # Compare with log probabilities returned by class + self.assertTrue(log_probs.allclose(usn.log_prob(vals))) + + def test_rsample(self): + # TODO: Replace with e.g. two-sample test. + with torch.random.fork_rng(): + torch.random.manual_seed(self.gen_seed()) + + # Pick a USN distribution at random + index = torch.randint(low=0, high=len(self.distributions), size=()) + usn = self.distributions[index] + sqrt_covariance = self.sqrt_covariances[index] + + # Generate draws using `rsample` + samples_y = usn.rsample(sample_shape=torch.Size([self.mc_num_rsamples])) + means = samples_y.mean(0) + covar = samples_y.T.cov() + + # Generate draws using rejection sampling + ndim = sqrt_covariance.shape[-1] + base_rvs = torch.randn(self.mc_num_samples, ndim, **self.tkwargs) + _samples_x, _samples_y = (base_rvs @ sqrt_covariance.T).split( + usn.trunc.event_shape + usn.gauss.event_shape, dim=-1 + ) + + _accept = torch.logical_and( + (_samples_x > usn.trunc.bounds[..., 0] - usn.trunc.loc).all(-1), + (_samples_x < usn.trunc.bounds[..., 1] - usn.trunc.loc).all(-1), + ) + + _means = usn.gauss.loc + _samples_y[_accept].mean(0) + _covar = _samples_y[_accept].T.cov() + + atol = self.mc_atol_multiplier * ( + _accept.count_nonzero() ** -0.5 + self.mc_num_rsamples**-0.5 + ) + + self.assertTrue(torch.allclose(_means, means, rtol=0, atol=atol)) + self.assertTrue(torch.allclose(_covar, covar, rtol=0, atol=atol)) + + def test_expand(self): + usn = next(iter(self.distributions)) + other = usn.expand(torch.Size([2])) + for key in ("loc", "covariance_matrix"): + a = getattr(usn.gauss, key) + self.assertTrue(all(a.equal(b) for b in getattr(other.gauss, key).unbind())) + + for key in ("loc", "covariance_matrix", "bounds", "log_partition"): + a = getattr(usn.trunc, key) + self.assertTrue(all(a.equal(b) for b in getattr(other.trunc, key).unbind())) + + for b in other.cross_covariance_matrix.unbind(): + self.assertTrue(usn.cross_covariance_matrix.equal(b)) diff --git a/test/utils/probability/test_utils.py b/test/utils/probability/test_utils.py new file mode 100644 index 0000000000..c9004930cb --- /dev/null +++ b/test/utils/probability/test_utils.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from unittest import TestCase + +import torch +from botorch.utils.probability import utils +from numpy.polynomial.legendre import leggauss as numpy_leggauss + + +class TestProbabilityUtils(TestCase): + def test_case_dispatcher(self): + with torch.random.fork_rng(): + torch.random.manual_seed(0) + values = torch.rand([32]) + + # Test default case + output = utils.case_dispatcher( + out=torch.full_like(values, float("nan")), + default=lambda mask: 0, + ) + self.assertTrue(output.eq(0).all()) + + # Test randomized value assignments + levels = 0.25, 0.5, 0.75 + cases = [ # switching cases + (lambda level=level: values < level, lambda mask, i=i: i) + for i, level in enumerate(levels) + ] + + cases.append( # dummy case whose predicate is always False + (lambda: torch.full(values.shape, False), lambda mask: float("nan")) + ) + + output = utils.case_dispatcher( + out=torch.full_like(values, float("nan")), + cases=cases, + default=lambda mask: len(levels), + ) + + self.assertTrue(output.isfinite().all()) + active = torch.full(values.shape, True) + for i, level in enumerate(levels): + mask = active & (values < level) + self.assertTrue(output[mask].eq(i).all()) + active[mask] = False + self.assertTrue(~active.any() or output[active].eq(len(levels)).all()) + + def test_build_positional_indices(self): + with torch.random.fork_rng(): + torch.random.manual_seed(0) + values = torch.rand(3, 2, 5) + + for dim in (values.ndim, -values.ndim - 1): + with self.assertRaisesRegex(ValueError, r"dim=(-?\d+) invalid for shape"): + utils.build_positional_indices(shape=values.shape, dim=dim) + + start = utils.build_positional_indices(shape=values.shape, dim=-2) + self.assertEqual(start.shape, values.shape[:-1]) + self.assertTrue(start.remainder(values.shape[-1]).eq(0).all()) + + max_values, max_indices = values.max(dim=-1) + self.assertTrue(values.view(-1)[start + max_indices].equal(max_values)) + + def test_leggaus(self): + for a, b in zip(utils.leggauss(20, dtype=torch.float64), numpy_leggauss(20)): + self.assertEqual(a.dtype, torch.float64) + self.assertTrue((a.numpy() == b).all()) + + def test_swap_along_dim_(self): + with torch.random.fork_rng(): + torch.random.manual_seed(0) + values = torch.rand(3, 2, 5) + + start = utils.build_positional_indices(shape=values.shape, dim=-2) + min_values, min_indices = values.min(dim=-1) + max_values, max_indices = values.max(dim=-1) + + i = (start + min_indices).ravel() + j = (start + max_indices).ravel() + out = utils.swap_along_dim_(values.view(-1).clone(), i=i, j=j, dim=0) + + # Verify that positions of minimum and maximum values were swapped + for vec, min_val, min_idx, max_val, max_idx in zip( + out.view(-1, values.shape[-1]), + min_values.ravel(), + min_indices.ravel(), + max_values.ravel(), + max_indices.ravel(), + ): + self.assertEqual(vec[min_idx], max_val) + self.assertEqual(vec[max_idx], min_val) + + # Test passing in a pre-allocated copy buffer + temp = values.view(-1).clone()[i] + buff = torch.empty_like(temp) + out2 = utils.swap_along_dim_( + values.view(-1).clone(), i=i, j=j, dim=0, buffer=buff + ) + self.assertTrue(out.equal(out2)) + self.assertTrue(temp.equal(buff)) diff --git a/test/utils/test_constants.py b/test/utils/test_constants.py new file mode 100644 index 0000000000..bdb16b03de --- /dev/null +++ b/test/utils/test_constants.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from unittest import TestCase +from unittest.mock import patch + +import torch +from botorch.utils import constants + + +class TestConstants(TestCase): + def test_get_constants(self): + const = constants.get_constants(0.123, dtype=torch.float16) + self.assertEqual(const, 0.123) + self.assertEqual(const.dtype, torch.float16) + + try: # test in-place modification + const.add_(1) + const2 = constants.get_constants(0.123, dtype=torch.float16) + self.assertEqual(const2, 1.123) + self.assertEqual(const2.dtype, torch.float16) + finally: + const.sub_(1) + + # Test fetching of multiple constants + const_tuple = constants.get_constants(values=(0, 1, 2), dtype=torch.float16) + self.assertIsInstance(const_tuple, tuple) + self.assertEqual(len(const_tuple), 3) + for i, const in enumerate(const_tuple): + self.assertEqual(const, i) + + def test_get_constants_like(self): + def mock_get_constants(values: torch.Tensor, **kwargs): + return kwargs + + with patch.object(constants, "get_constants", new=mock_get_constants): + ref = torch.tensor([123], dtype=torch.float16) + self.assertEqual( + constants.get_constants_like(0.123, ref=ref), + {"device": ref.device, "dtype": ref.dtype}, + ) diff --git a/test/utils/test_safe_math.py b/test/utils/test_safe_math.py new file mode 100644 index 0000000000..98b00c005c --- /dev/null +++ b/test/utils/test_safe_math.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import math +from abc import abstractmethod +from itertools import combinations, product +from typing import Callable +from unittest import TestCase + +import torch +from botorch.utils import safe_math +from botorch.utils.constants import get_constants_like +from torch import finfo, Tensor + +INF = float("inf") + + +class UnaryOpTestMixin: + op: Callable[[Tensor], Tensor] + safe_op: Callable[[Tensor], Tensor] + + def __init_subclass__(cls, op: Callable, safe_op: Callable): + cls.op = staticmethod(op) + cls.safe_op = staticmethod(safe_op) + + def test_generic(self, m: int = 3, n: int = 4): + for dtype in (torch.float32, torch.float64): + # Test forward + x = torch.rand(n, m, dtype=dtype, requires_grad=True) + y = self.safe_op(x) + + _x = x.detach().clone().requires_grad_(True) + _y = self.op(_x) + self.assertTrue(y.equal(_y)) + + # Test backward + y.sum().backward() + _y.sum().backward() + self.assertTrue(x.grad.equal(_x.grad)) + + # Test passing in pre-allocated `out` + with torch.no_grad(): + y.zero_() + self.safe_op(x, out=y) + self.assertTrue(y.equal(_y)) + + @abstractmethod + def test_special(self): + pass # pragma: no cover + + +class BinaryOpTestMixin: + op: Callable[[Tensor, Tensor], Tensor] + safe_op: Callable[[Tensor, Tensor], Tensor] + + def __init_subclass__(cls, op: Callable, safe_op: Callable): + cls.op = staticmethod(op) + cls.safe_op = staticmethod(safe_op) + + def test_generic(self, m: int = 3, n: int = 4): + for dtype in (torch.float32, torch.float64): + # Test equality for generic cases + a = torch.rand(n, m, dtype=dtype, requires_grad=True) + b = torch.rand(n, m, dtype=dtype, requires_grad=True) + y = self.safe_op(a, b) + + _a = a.detach().clone().requires_grad_(True) + _b = b.detach().clone().requires_grad_(True) + _y = self.op(_a, _b) + self.assertTrue(y.equal(_y)) + + # Test backward + y.sum().backward() + _y.sum().backward() + self.assertTrue(a.grad.equal(_a.grad)) + self.assertTrue(b.grad.equal(_b.grad)) + + @abstractmethod + def test_special(self): + pass # pragma: no cover + + +class TestSafeExp(TestCase, UnaryOpTestMixin, op=torch.exp, safe_op=safe_math.exp): + def test_special(self): + for dtype in (torch.float32, torch.float64): + x = torch.full([], INF, dtype=dtype, requires_grad=True) + y = self.safe_op(x) + self.assertEqual( + y, get_constants_like(math.log(finfo(dtype).max) - 1e-4, x).exp() + ) + + y.backward() + self.assertEqual(x.grad, 0) + + +class TestSafeLog(TestCase, UnaryOpTestMixin, op=torch.log, safe_op=safe_math.log): + def test_special(self): + for dtype in (torch.float32, torch.float64): + x = torch.zeros([], dtype=dtype, requires_grad=True) + y = self.safe_op(x) + self.assertEqual(y, math.log(finfo(dtype).tiny)) + + y.backward() + self.assertEqual(x.grad, 0) + + +class TestSafeAdd(TestCase, BinaryOpTestMixin, op=torch.add, safe_op=safe_math.add): + def test_special(self): + for dtype in (torch.float32, torch.float64): + for _a in (INF, -INF): + a = torch.tensor(_a, dtype=dtype, requires_grad=True) + b = torch.tensor(INF, dtype=dtype, requires_grad=True) + + out = self.safe_op(a, b) + self.assertEqual(out, 0 if a != b else b) + + out.backward() + self.assertEqual(a.grad, 0 if a != b else 1) + self.assertEqual(b.grad, 0 if a != b else 1) + + +class TestSafeSub(TestCase, BinaryOpTestMixin, op=torch.sub, safe_op=safe_math.sub): + def test_special(self): + for dtype in (torch.float32, torch.float64): + for _a in (INF, -INF): + a = torch.tensor(_a, dtype=dtype, requires_grad=True) + b = torch.tensor(INF, dtype=dtype, requires_grad=True) + + out = self.safe_op(a, b) + self.assertEqual(out, 0 if a == b else -b) + + out.backward() + self.assertEqual(a.grad, 0 if a == b else 1) + self.assertEqual(b.grad, 0 if a == b else -1) + + +class TestSafeMul(TestCase, BinaryOpTestMixin, op=torch.mul, safe_op=safe_math.mul): + def test_special(self): + for dtype in (torch.float32, torch.float64): + for _a, _b in product([0, 2], [INF, -INF]): + a = torch.tensor(_a, dtype=dtype, requires_grad=True) + b = torch.tensor(_b, dtype=dtype, requires_grad=True) + + out = self.safe_op(a, b) + self.assertEqual(out, a if a == 0 else b) + + out.backward() + self.assertEqual(a.grad, 0 if a == 0 else b) + self.assertEqual(b.grad, 0 if a == 0 else a) + + +class TestSafeDiv(TestCase, BinaryOpTestMixin, op=torch.div, safe_op=safe_math.div): + def test_special(self): + for dtype in (torch.float32, torch.float64): + for _a, _b in combinations([0, INF, -INF], 2): + a = torch.tensor(_a, dtype=dtype, requires_grad=True) + b = torch.tensor(_b, dtype=dtype, requires_grad=True) + + out = self.safe_op(a, b) + if a == b: + self.assertEqual(out, 1) + elif a == -b: + self.assertEqual(out, -1) + else: + self.assertEqual(out, a / b) + + out.backward() + if ((a == 0) & (b == 0)) | (a.isinf() & b.isinf()): + self.assertEqual(a.grad, 0) + self.assertEqual(b.grad, 0) + else: + self.assertEqual(a.grad, 1 / b) + self.assertEqual(b.grad, -a * b**-2)