diff --git a/gpjax/bayes_opt/__init__.py b/gpjax/bayes_opt/__init__.py new file mode 100644 index 000000000..5937ffd33 --- /dev/null +++ b/gpjax/bayes_opt/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from gpjax.bayes_opt import search_space + +__all__ = [ + "search_space", +] diff --git a/gpjax/bayes_opt/search_space.py b/gpjax/bayes_opt/search_space.py new file mode 100644 index 000000000..18345c207 --- /dev/null +++ b/gpjax/bayes_opt/search_space.py @@ -0,0 +1,96 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from abc import ( + ABC, + abstractmethod, +) +from dataclasses import dataclass + +from jaxtyping import Float +import tensorflow_probability.substrates.jax as tfp + +from gpjax.typing import ( + Array, + KeyArray, +) + + +@dataclass +class AbstractSearchSpace(ABC): + """The `AbstractSearchSpace` class is an abstract base class for + search spaces, which are used to define domains for sampling and optimisation functionality in GPJax. + """ + + @abstractmethod + def sample(self, num_points: int, key: KeyArray) -> Float[Array, "N D"]: + """Sample points from the search space. + Args: + num_points (int): Number of points to be sampled from the search space. + key (KeyArray): JAX PRNG key. + Returns: + Float[Array, "N D"]: `num_points` points sampled from the search space. + """ + raise NotImplementedError + + @property + @abstractmethod + def dimensionality(self) -> int: + """Dimensionality of the search space. + Returns: + int: Dimensionality of the search space. + """ + raise NotImplementedError + + +@dataclass +class ContinuousSearchSpace(AbstractSearchSpace): + """The `ContinuousSearchSpace` class is used to bound the domain of continuous real functions of dimension $`D`$.""" + + lower_bounds: Float[Array, " D"] + upper_bounds: Float[Array, " D"] + + def __post_init__(self): + if not self.lower_bounds.dtype == self.upper_bounds.dtype: + raise ValueError("Lower and upper bounds must have the same dtype.") + if self.lower_bounds.shape != self.upper_bounds.shape: + raise ValueError("Lower and upper bounds must have the same shape.") + if self.lower_bounds.shape[0] == 0: + raise ValueError("Lower and upper bounds cannot be empty") + if not (self.lower_bounds <= self.upper_bounds).all(): + raise ValueError("Lower bounds must be less than upper bounds.") + + @property + def dimensionality(self) -> int: + return self.lower_bounds.shape[0] + + def sample(self, num_points: int, key: KeyArray) -> Float[Array, "N D"]: + """Sample points from the search space using a Halton sequence. + + Args: + num_points (int): Number of points to be sampled from the search space. + key (KeyArray): JAX PRNG key. + Returns: + Float[Array, "N D"]: `num_points` points sampled using the Halton sequence + from the search space. + """ + if num_points <= 0: + raise ValueError("Number of points must be greater than 0.") + + initial_sample = tfp.mcmc.sample_halton_sequence( + dim=self.dimensionality, num_results=num_points, seed=key + ) + return ( + self.lower_bounds + (self.upper_bounds - self.lower_bounds) * initial_sample + ) diff --git a/gpjax/kernels/__init__.py b/gpjax/kernels/__init__.py index 178aeba17..89809adaf 100644 --- a/gpjax/kernels/__init__.py +++ b/gpjax/kernels/__init__.py @@ -27,7 +27,10 @@ DiagonalKernelComputation, EigenKernelComputation, ) -from gpjax.kernels.non_euclidean import GraphKernel, CatKernel +from gpjax.kernels.non_euclidean import ( + CatKernel, + GraphKernel, +) from gpjax.kernels.nonstationary import ( ArcCosine, Linear, diff --git a/gpjax/kernels/non_euclidean/__init__.py b/gpjax/kernels/non_euclidean/__init__.py index 1289f1d60..ee45287b0 100644 --- a/gpjax/kernels/non_euclidean/__init__.py +++ b/gpjax/kernels/non_euclidean/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from gpjax.kernels.non_euclidean.graph import GraphKernel from gpjax.kernels.non_euclidean.categorical import CatKernel +from gpjax.kernels.non_euclidean.graph import GraphKernel __all__ = ["GraphKernel", "CatKernel"] diff --git a/gpjax/kernels/non_euclidean/categorical.py b/gpjax/kernels/non_euclidean/categorical.py index e0f1e610b..1d376956f 100644 --- a/gpjax/kernels/non_euclidean/categorical.py +++ b/gpjax/kernels/non_euclidean/categorical.py @@ -15,9 +15,16 @@ from dataclasses import dataclass -from typing import NamedTuple, Union +from typing import ( + NamedTuple, + Union, +) + import jax.numpy as jnp -from jaxtyping import Float, Int +from jaxtyping import ( + Float, + Int, +) import tensorflow_probability.substrates.jax as tfp from gpjax.base import ( @@ -25,7 +32,6 @@ static_field, ) from gpjax.kernels.base import AbstractKernel - from gpjax.typing import ( Array, ScalarInt, diff --git a/tests/test_bayes_opt/__init__.py b/tests/test_bayes_opt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_bayes_opt/test_search_space.py b/tests/test_bayes_opt/test_search_space.py new file mode 100644 index 000000000..b63d43578 --- /dev/null +++ b/tests/test_bayes_opt/test_search_space.py @@ -0,0 +1,218 @@ +# Copyright 2023 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from jax.config import config +import jax.numpy as jnp +import jax.random as jr +from jaxtyping import ( + Array, + Float, +) +import pytest + +from gpjax.bayes_opt.search_space import ( + AbstractSearchSpace, + ContinuousSearchSpace, +) + +config.update("jax_enable_x64", True) + + +def test_abstract_search_space(): + with pytest.raises(TypeError): + AbstractSearchSpace() + + +def test_continuous_search_space_empty_bounds(): + with pytest.raises(ValueError): + ContinuousSearchSpace(lower_bounds=jnp.array([]), upper_bounds=jnp.array([])) + + +@pytest.mark.parametrize( + "lower_bounds, upper_bounds", + [ + (jnp.array([0.0], dtype=jnp.float64), jnp.array([1.0], jnp.float32)), + (jnp.array([0.0], dtype=jnp.float32), jnp.array([1.0], jnp.float64)), + ], +) +def test_continuous_search_space_dtype_consistency( + lower_bounds: Float[Array, " D"], upper_bounds: Float[Array, " D"] +): + with pytest.raises(ValueError): + ContinuousSearchSpace(lower_bounds=lower_bounds, upper_bounds=upper_bounds) + + +@pytest.mark.parametrize( + "lower_bounds, upper_bounds", + [ + (jnp.array([0.0]), jnp.array([1.0, 1.0])), + (jnp.array([0.0, 0.0]), jnp.array([1.0])), + ], +) +def test_continous_search_space_bounds_shape_consistency( + lower_bounds: Float[Array, " D"], upper_bounds: Float[Array, " D"] +): + with pytest.raises(ValueError): + ContinuousSearchSpace(lower_bounds=lower_bounds, upper_bounds=upper_bounds) + + +@pytest.mark.parametrize( + "lower_bounds, upper_bounds", + [ + (jnp.array([1.0]), jnp.array([0.0])), + (jnp.array([1.0, 1.0]), jnp.array([0.0, 2.0])), + (jnp.array([1.0, 1.0]), jnp.array([2.0, 0.0])), + ], +) +def test_continuous_search_space_bounds_values_consistency( + lower_bounds: Float[Array, " D"], upper_bounds: Float[Array, " D"] +): + with pytest.raises(ValueError): + ContinuousSearchSpace(lower_bounds=lower_bounds, upper_bounds=upper_bounds) + + +@pytest.mark.parametrize( + "continuous_search_space, dimensionality", + [ + (ContinuousSearchSpace(jnp.array([0.0]), jnp.array([1.0])), 1), + (ContinuousSearchSpace(jnp.array([0.0, 0.0]), jnp.array([1.0, 1.0])), 2), + ( + ContinuousSearchSpace( + jnp.array([0.0, 0.0, 0.0]), jnp.array([1.0, 1.0, 1.0]) + ), + 3, + ), + ], +) +def test_continuous_search_space_dimensionality( + continuous_search_space: ContinuousSearchSpace, dimensionality: int +): + assert continuous_search_space.dimensionality == dimensionality + + +@pytest.mark.parametrize( + "continuous_search_space", + [ + ContinuousSearchSpace(jnp.array([0.0]), jnp.array([1.0])), + ContinuousSearchSpace(jnp.array([0.0, 0.0]), jnp.array([1.0, 1.0])), + ContinuousSearchSpace(jnp.array([0.0, 0.0, 0.0]), jnp.array([1.0, 1.0, 1.0])), + ], +) +@pytest.mark.parametrize("num_points", [0, -1]) +def test_continous_search_space_invalid_sample_num_points( + continuous_search_space: ContinuousSearchSpace, num_points: int +): + with pytest.raises(ValueError): + continuous_search_space.sample(num_points=num_points, key=jr.PRNGKey(42)) + + +@pytest.mark.parametrize( + "continuous_search_space, dimensionality", + [ + (ContinuousSearchSpace(jnp.array([0.0]), jnp.array([1.0])), 1), + (ContinuousSearchSpace(jnp.array([0.0, 0.0]), jnp.array([1.0, 1.0])), 2), + ( + ContinuousSearchSpace( + jnp.array([0.0, 0.0, 0.0]), jnp.array([1.0, 1.0, 1.0]) + ), + 3, + ), + ], +) +@pytest.mark.parametrize("num_points", [1, 5, 50]) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_continuous_search_space_sample_shape( + continuous_search_space: ContinuousSearchSpace, dimensionality: int, num_points: int +): + samples = continuous_search_space.sample(num_points=num_points, key=jr.PRNGKey(42)) + assert samples.shape[0] == num_points + assert samples.shape[1] == dimensionality + + +@pytest.mark.parametrize( + "continuous_search_space", + [ + ContinuousSearchSpace(jnp.array([0.0]), jnp.array([1.0])), + ContinuousSearchSpace(jnp.array([0.0, 0.0]), jnp.array([1.0, 1.0])), + ContinuousSearchSpace(jnp.array([0.0, 0.0, 0.0]), jnp.array([1.0, 1.0, 1.0])), + ], +) +@pytest.mark.parametrize("key", [jr.PRNGKey(42), jr.PRNGKey(5)]) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_continous_search_space_sample_same_key_same_samples( + continuous_search_space: ContinuousSearchSpace, key: jr.PRNGKey +): + sample_one = continuous_search_space.sample(num_points=100, key=key) + sample_two = continuous_search_space.sample(num_points=100, key=key) + assert jnp.array_equal(sample_one, sample_two) + + +@pytest.mark.parametrize( + "continuous_search_space", + [ + ContinuousSearchSpace(jnp.array([0.0]), jnp.array([1.0])), + ContinuousSearchSpace(jnp.array([0.0, 0.0]), jnp.array([1.0, 1.0])), + ContinuousSearchSpace(jnp.array([0.0, 0.0, 0.0]), jnp.array([1.0, 1.0, 1.0])), + ], +) +@pytest.mark.parametrize( + "key_one, key_two", + [(jr.PRNGKey(42), jr.PRNGKey(5)), (jr.PRNGKey(1), jr.PRNGKey(2))], +) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_continuous_search_space_different_keys_different_samples( + continuous_search_space: ContinuousSearchSpace, + key_one: jr.PRNGKey, + key_two: jr.PRNGKey, +): + sample_one = continuous_search_space.sample(num_points=100, key=key_one) + sample_two = continuous_search_space.sample(num_points=100, key=key_two) + assert not jnp.array_equal(sample_one, sample_two) + + +@pytest.mark.parametrize( + "continuous_search_space", + [ + ContinuousSearchSpace( + lower_bounds=jnp.array([0.0]), upper_bounds=jnp.array([1.0]) + ), + ContinuousSearchSpace( + lower_bounds=jnp.array([0.0, 0.0]), upper_bounds=jnp.array([1.0, 2.0]) + ), + ContinuousSearchSpace( + lower_bounds=jnp.array([0.0, 1.0]), upper_bounds=jnp.array([2.0, 2.0]) + ), + ContinuousSearchSpace( + lower_bounds=jnp.array([2.4, 1.7, 4.9]), + upper_bounds=jnp.array([5.6, 1.8, 6.0]), + ), + ], +) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_continuous_search_space_valid_sample_ranges( + continuous_search_space: ContinuousSearchSpace, +): + samples = continuous_search_space.sample(num_points=100, key=jr.PRNGKey(42)) + for i in range(continuous_search_space.dimensionality): + assert jnp.all(samples[:, i] >= continuous_search_space.lower_bounds[i]) + assert jnp.all(samples[:, i] <= continuous_search_space.upper_bounds[i]) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 38c277528..af119a5ac 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -185,9 +185,7 @@ def test_precision_warning( if prec_y != jnp.float64: expected_warnings += 1 - with pytest.warns( - UserWarning, match=".* is not of type float64.*" - ) as record: + with pytest.warns(UserWarning, match=".* is not of type float64.*") as record: Dataset(X=x, y=y) assert len(record) == expected_warnings diff --git a/tests/test_kernels/test_non_euclidean.py b/tests/test_kernels/test_non_euclidean.py index bdcb91b98..3ce5722b4 100644 --- a/tests/test_kernels/test_non_euclidean.py +++ b/tests/test_kernels/test_non_euclidean.py @@ -12,11 +12,14 @@ from jax.config import config import jax.numpy as jnp +import jax.random as jr import networkx as nx -from gpjax.kernels.non_euclidean import GraphKernel, CatKernel +from gpjax.kernels.non_euclidean import ( + CatKernel, + GraphKernel, +) from gpjax.linops import identity -import jax.random as jr # # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True)