From 67514934b502ee27eaf81cdb58e5ca46cf4d6f1d Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Wed, 7 Jun 2023 09:12:43 -0700 Subject: [PATCH] Orthogonal Additive Kernels (#1869) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1869 OAKs were introduced in [Additive Gaussian Processes Revisited](https://arxiv.org/pdf/2206.09861.pdf) but were limited to Gaussian kernels with Gaussian data densities (which required the application of normalizing flows to the actual input data to make it look Gaussian). This commit introduces a generalization of OAKs that works with arbitrary base kernels by leveraging Gauss-Legendre quadrature rules for the associated one-dimensional integrals. OAKs could be more sample-efficient than canonical kernels in higher dimensions, and allow for more efficient relevance determination, because dimensions or interactions of dimensions can be pruned by setting their assoicated coefficients -- not just their lengthscales -- to zero. Reviewed By: Balandat Differential Revision: D45217852 fbshipit-source-id: c2ef5d82e3dff5c7a2ce34e678686a29cc236351 --- .../kernels/orthogonal_additive_kernel.py | 299 ++++++++++++++++++ sphinx/source/models.rst | 3 + .../test_orthogonal_additive_kernel.py | 120 +++++++ 3 files changed, 422 insertions(+) create mode 100644 botorch/models/kernels/orthogonal_additive_kernel.py create mode 100644 test/models/kernels/test_orthogonal_additive_kernel.py diff --git a/botorch/models/kernels/orthogonal_additive_kernel.py b/botorch/models/kernels/orthogonal_additive_kernel.py new file mode 100644 index 0000000000..642780bd20 --- /dev/null +++ b/botorch/models/kernels/orthogonal_additive_kernel.py @@ -0,0 +1,299 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import List, Optional, Tuple + +import numpy +import torch +from botorch.exceptions.errors import UnsupportedError +from gpytorch.constraints import Interval, Positive +from gpytorch.kernels import Kernel + +from torch import nn, Tensor + +r"""Orthogonal Additive Kernel. + +References: + +.. [Lu2022additive] + X. Lu, A. Boukouvalas, and J. Hensman. Additive Gaussian processes revisited. + Proceedings of the 39th International Conference on Machine Learning. Jul 2022. + +This implementation is based on [Lu2022additive]_ but generalizes to arbitrary base +kernels by using a Gauss-Legendre quadrature approximation to the one-dimensional +integrals that are required for the orthogonalization of the base kernel. +""" + +_positivity_constraint = Positive() + + +class OrthogonalAdditiveKernel(Kernel): + r"""Orthogonal Additive Kernels (OAKs) were introduced in [Lu2022additive]_, though + only for the case of Gaussian base kernels with a Gaussian input data distribution. + + The implementation here generalizes OAKs to arbitrary base kernels by using a + Gauss-Legendre quadrature approximation to the required one-dimensional integrals + involving the base kernels. + """ + + def __init__( + self, + base_kernel: Kernel, + dim: int, + quad_deg: int = 32, + second_order: bool = False, + batch_shape: Optional[torch.Size] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + coeff_constraint: Interval = _positivity_constraint, + ): + """ + Args: + base_kernel: The kernel which to orthogonalize and evaluate in `forward`. + dim: Input dimensionality of the kernel. + quad_deg: Number of integration nodes for orthogonalization. + second_order: Toggles second order interactions. If true, both the time and + space complexity of evaluating the kernel are quadratic in `dim`. + batch_shape: Optional batch shape for the kernel and its parameters. + dtype: Initialization dtype for required Tensors. + device: Initialization device for required Tensors. + coeff_constraint: Constraint on the coefficients of the additive kernel. + """ + super().__init__(batch_shape=batch_shape) + self.base_kernel = base_kernel + # integration nodes, weights for [0, 1] + tkwargs = {"dtype": dtype, "device": device} + z, w = leggauss(deg=quad_deg, a=0, b=1, **tkwargs) + self.z = z.unsqueeze(-1).expand(quad_deg, dim) # deg x dim + self.w = w.unsqueeze(-1) + self.register_parameter( + name="raw_offset", + parameter=nn.Parameter(torch.zeros(self.batch_shape, **tkwargs)), + ) + log_d = math.log(dim) + self.register_parameter( + name="raw_coeffs_1", + parameter=nn.Parameter( + torch.zeros(*self.batch_shape, dim, **tkwargs) - log_d + ), + ) + self.register_parameter( + name="raw_coeffs_2", + parameter=nn.Parameter( + torch.zeros(*self.batch_shape, int(dim * (dim - 1) / 2), **tkwargs) + - 2 * log_d + ) + if second_order + else None, + ) + if second_order: + self._rev_triu_indices = torch.tensor( + _reverse_triu_indices(dim), + device=device, + dtype=int, + ) + # zero tensor for construction of upper-triangular coefficient matrix + self._quad_zero = torch.zeros( + tuple(1 for _ in range(len(batch_shape) + 1)), **tkwargs + ).expand(*batch_shape, 1) + self.coeff_constraint = coeff_constraint + self.dim = dim + + def k(self, x1, x2) -> Tensor: + """Evaluates the kernel matrix base_kernel(x1, x2) on each input dimension + independently. + + Args: + x1: `batch_shape x n1 x d`-dim Tensor in [0, 1]^dim. + x2: `batch_shape x n2 x d`-dim Tensor in [0, 1]^dim. + + Returns: + A `batch_shape x d x n1 x n2`-dim Tensor of kernel matrices. + """ + return self.base_kernel(x1, x2, last_dim_is_batch=True).to_dense() + + @property + def offset(self) -> Tensor: + """Returns the `batch_shape`-dim Tensor of zeroth-order coefficients.""" + return self.coeff_constraint.transform(self.raw_offset) + + @property + def coeffs_1(self) -> Tensor: + """Returns the `batch_shape x d`-dim Tensor of first-order coefficients.""" + return self.coeff_constraint.transform(self.raw_coeffs_1) + + @property + def coeffs_2(self) -> Optional[Tensor]: + """Returns the upper-triangular tensor of second-order coefficients. + + NOTE: We only keep track of the upper triangular part of raw second order + coefficients since the effect of the lower triangular part is identical and + exclude the diagonal, since it is associated with first-order effects only. + While we could further exploit this structure in the forward pass, the + associated indexing and temporary allocations make it significantly less + efficient than the einsum-based implementation below. + + Returns: + `batch_shape x d x d`-dim Tensor of second-order coefficients. + """ + if self.raw_coeffs_2 is not None: + C2 = self.coeff_constraint.transform(self.raw_coeffs_2) + C2 = torch.cat((C2, self._quad_zero), dim=-1) # batch_shape x (d(d-1)/2+1) + C2 = C2.index_select(-1, self._rev_triu_indices) + return C2.reshape(*self.batch_shape, self.dim, self.dim) + else: + return None + + def forward( + self, + x1: Tensor, + x2: Tensor, + diag: bool = False, + last_dim_is_batch: bool = False, + ) -> Tensor: + """Computes the kernel matrix k(x1, x2). + + Args: + x1: `batch_shape x n1 x d`-dim Tensor in [0, 1]^dim. + x2: `batch_shape x n2 x d`-dim Tensor in [0, 1]^dim. + diag: If True, only returns the diagonal of the kernel matrix. + last_dim_is_batch: Not supported by this kernel. + + Returns: + A `batch_shape x n1 x n2`-dim Tensor of kernel matrices. + """ + if last_dim_is_batch: + raise UnsupportedError( + "OrthogonalAdditiveKernel does not support `last_dim_is_batch`." + ) + K_ortho = self._orthogonal_base_kernels(x1, x2) # batch_shape x d x n1 x n2 + + # contracting over d, leading to `batch_shape x n x n`-dim tensor, i.e.: + # K1 = torch.sum(self.coeffs_1[..., None, None] * K_ortho, dim=-3) + K1 = torch.einsum(self.coeffs_1, [..., 0], K_ortho, [..., 0, 1, 2], [..., 1, 2]) + # adding the non-batch dimensions to offset + K = K1 + self.offset[..., None, None] + if self.coeffs_2 is not None: + # Computing the tensor of second order interactions K2. + # NOTE: K2 here is equivalent to: + # K2 = K_ortho.unsqueeze(-4) * K_ortho.unsqueeze(-3) # d x d x n x n + # K2 = (self.coeffs_2[..., None, None] * K2).sum(dim=(-4, -3)) + # but avoids forming the `batch_shape x d x d x n x n`-dim tensor in memory. + # Reducing over the dimensions with the O(d^2) quadratic terms: + K2 = torch.einsum( + K_ortho, + [..., 0, 2, 3], + K_ortho, + [..., 1, 2, 3], + self.coeffs_2, + [..., 0, 1], + [..., 2, 3], # i.e. contracting over the first two non-batch dims + ) + K = K + K2 + + return K if not diag else K.diag() # poor man's diag (TODO) + + def _orthogonal_base_kernels(self, x1: Tensor, x2: Tensor) -> Tensor: + """Evaluates the set of `d` orthogonalized base kernels on (x1, x2). + Note that even if the base kernel is positive, the orthogonalized versions + can - and usually do - take negative values. + + Args: + x1: `batch_shape x n1 x d`-dim inputs to the kernel. + x2: `batch_shape x n2 x d`-dim inputs to the kernel. + + Returns: + A `batch_shape x d x n1 x n2`-dim Tensor. + """ + _check_hypercube(x1, "x1") + if x1 is not x2: + _check_hypercube(x2, "x2") + Kx1x2 = self.k(x1, x2) # d x n x n + # Overwriting allocated quadrature tensors with fitting dtype and device + # self.z, self.w = self.z.to(x1), self.w.to(x1) + # include normalization constant in weights + w = self.w / self.normalizer().sqrt() + Skx1 = self.k(x1, self.z) @ w # batch_shape x d x n + Skx2 = Skx1 if (x1 is x2) else self.k(x2, self.z) @ w # d x n + # this is a tensor of kernel matrices of orthogonal 1d kernels + K_ortho = (Kx1x2 - Skx1 @ Skx2.transpose(-2, -1)).to_dense() # d x n x n + return K_ortho + + def normalizer(self, eps: float = 1e-6) -> Tensor: + """Integrates the `d` orthogonalized base kernels over `[0, 1] x [0, 1]`. + NOTE: If the module is in train mode, this needs to re-compute the normalizer + each time because the underlying parameters might have changed. + + Args: + eps: Minimum value constraint on the normalizers. Avoids division by zero. + + Returns: + A `d`-dim tensor of normalization constants. + """ + if self.train() or getattr(self, "_normalizer", None) is None: + self._normalizer = (self.w.T @ self.k(self.z, self.z) @ self.w).clamp(eps) + return self._normalizer + + +def leggauss( + deg: int, + a: float = -1.0, + b: float = 1.0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, +) -> Tuple[Tensor, Tensor]: + """Computes Gauss-Legendre quadrature nodes and weights. Wraps + `numpy.polynomial.legendre.leggauss` and returns Torch Tensors. + + Args: + deg: Number of sample points and weights. Integrates poynomials of degree + `2 * deg + 1` exactly. + a, b: Lower and upper bound of integration domain. + dtype: Desired floating point type of the return Tensors. + device: Desired device type of the return Tensors. + + Returns: + A tuple of Gauss-Legendre quadrature nodes and weights of length deg. + """ + dtype = dtype if dtype is not None else torch.get_default_dtype() + x, w = numpy.polynomial.legendre.leggauss(deg=deg) + x = torch.as_tensor(x, dtype=dtype, device=device) + w = torch.as_tensor(w, dtype=dtype, device=device) + if not (a == -1 and b == 1): # need to normalize for different domain + x = (b - a) * (x + 1) / 2 + a + w = w * ((b - a) / 2) + return x, w + + +def _check_hypercube(x: Tensor, name: str) -> None: + if (x < 0).any() or (x > 1).any(): + raise ValueError(name + " is not in hypercube [0, 1]^d.") + + +def _reverse_triu_indices(d: int) -> List[int]: + """Computes a list of indices which, upon indexing a `d * (d - 1) / 2 + 1`-dim + Tensor whose last element is zero, will lead to a vectorized representation of + an upper-triangular matrix, whose diagonal is set to zero and whose super-diagonal + elements are set to the `d * (d - 1) / 2` values in the original tensor. + + NOTE: This is a helper function for Orthogonal Additive Kernels, and allows the + implementation to only register `d * (d - 1) / 2` parameters to model the second + order interactions, instead of the full d^2 redundant terms. + + Args: + d: Dimensionality that gives rise to the `d * (d - 1) / 2` quadratic terms. + + Returns: + A list of integer indices in `[0, d * (d - 1) / 2]`. See above for details. + """ + indices = [] + j = 0 + d2 = int(d * (d - 1) / 2) + for i in range(d): + indices.extend(d2 for _ in range(i + 1)) # indexing zero (sub-diagonal) + indices.extend(range(j, j + d - i - 1)) # indexing coeffs (super-diagonal) + j += d - i - 1 + return indices diff --git a/sphinx/source/models.rst b/sphinx/source/models.rst index c7f3c12730..79dddbec37 100644 --- a/sphinx/source/models.rst +++ b/sphinx/source/models.rst @@ -123,6 +123,9 @@ Kernels .. automodule:: botorch.models.kernels.contextual_sac .. autoclass:: SACKernel +.. automodule:: botorch.models.kernels.orthogonal_additive_kernel +.. autoclass:: OrthogonalAdditiveKernel + Likelihoods ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.models.likelihoods.pairwise diff --git a/test/models/kernels/test_orthogonal_additive_kernel.py b/test/models/kernels/test_orthogonal_additive_kernel.py new file mode 100644 index 0000000000..6420b2c218 --- /dev/null +++ b/test/models/kernels/test_orthogonal_additive_kernel.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.exceptions.errors import UnsupportedError +from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel +from botorch.utils.testing import BotorchTestCase +from gpytorch.kernels import MaternKernel +from gpytorch.lazy import LazyEvaluatedKernelTensor +from torch import nn, Tensor + + +class TestOrthogonalAdditiveKernel(BotorchTestCase): + def test_kernel(self): + n, d = 3, 5 + dtypes = [torch.float, torch.double] + batch_shapes = [(), (2,), (7, 2)] + for dtype in dtypes: + tkwargs = {"dtype": dtype, "device": self.device} + for batch_shape in batch_shapes: + X = torch.rand(*batch_shape, n, d, **tkwargs) + base_kernel = MaternKernel().to(device=self.device) + oak = OrthogonalAdditiveKernel( + base_kernel, + dim=d, + second_order=False, + batch_shape=batch_shape, + **tkwargs, + ) + KL = oak(X) + self.assertIsInstance(KL, LazyEvaluatedKernelTensor) + KM = KL.to_dense() + self.assertIsInstance(KM, Tensor) + self.assertEqual(KM.shape, (*batch_shape, n, n)) + self.assertEqual(KM.dtype, dtype) + self.assertEqual(KM.device.type, self.device.type) + # symmetry + self.assertTrue(torch.allclose(KM, KM.transpose(-2, -1))) + # positivity + self.assertTrue(isposdef(KM)) + + # testing differentiability + X.requires_grad = True + oak(X).to_dense().sum().backward() + self.assertFalse(X.grad.isnan().any()) + self.assertFalse(X.grad.isinf().any()) + + X_out_of_hypercube = torch.rand(n, d, **tkwargs) + 1 + with self.assertRaisesRegex(ValueError, r"x1.*hypercube"): + oak(X_out_of_hypercube, X).to_dense() + + with self.assertRaisesRegex(ValueError, r"x2.*hypercube"): + oak(X, X_out_of_hypercube).to_dense() + + with self.assertRaisesRegex(UnsupportedError, "does not support"): + oak.forward(x1=X, x2=X, last_dim_is_batch=True) + + oak_2nd = OrthogonalAdditiveKernel( + base_kernel, + dim=d, + second_order=True, + batch_shape=batch_shape, + **tkwargs, + ) + KL2 = oak_2nd(X) + self.assertIsInstance(KL2, LazyEvaluatedKernelTensor) + KM2 = KL2.to_dense() + self.assertIsInstance(KM2, Tensor) + self.assertEqual(KM2.shape, (*batch_shape, n, n)) + # symmetry + self.assertTrue(torch.allclose(KM2, KM2.transpose(-2, -1))) + # positivity + self.assertTrue(isposdef(KM2)) + self.assertEqual(KM2.dtype, dtype) + self.assertEqual(KM2.device.type, self.device.type) + + # testing second order coefficient matrices are upper-triangular + # and contain the transformed values in oak_2nd.raw_coeffs_2 + oak_2nd.raw_coeffs_2 = nn.Parameter( + torch.randn_like(oak_2nd.raw_coeffs_2) + ) + C2 = oak_2nd.coeffs_2 + self.assertTrue(C2.shape == (*batch_shape, d, d)) + self.assertTrue((C2.tril() == 0).all()) + c2 = oak_2nd.coeff_constraint.transform(oak_2nd.raw_coeffs_2) + i, j = torch.triu_indices(d, d, offset=1) + self.assertTrue(torch.allclose(C2[..., i, j], c2)) + + # second order effects change the correlation structure + self.assertFalse(torch.allclose(KM, KM2)) + + # check orthogonality of base kernels + n_test = 7 + # inputs on which to evaluate orthogonality + X_ortho = torch.rand(n_test, d, **tkwargs) + # d x quad_deg x quad_deg + K_ortho = oak._orthogonal_base_kernels(X_ortho, oak.z) + + # NOTE: at each random test input x_i and for each dimension d, + # sum_j k_d(x_i, z_j) * w_j = 0. + # Note that this implies the GP mean will be orthogonal as well: + # mean(x) = sum_j k(x, x_j) alpha_j + # so + # sum_i mean(z_i) w_i + # = sum_j alpha_j (sum_i k(z_i, x_j) w_i) // exchanging summations order + # = sum_j alpha_j (0) // due to symmetry + # = 0 + tol = 1e-5 + self.assertTrue(((K_ortho @ oak.w).squeeze(-1) < tol).all()) + + +def isposdef(A: Tensor) -> bool: + """Determines whether A is positive definite or not, by attempting a Cholesky + decomposition. Expects batches of square matrices. Throws a RuntimeError otherwise. + """ + _, info = torch.linalg.cholesky_ex(A) + return not torch.any(info)