From 758dc1c35eadc74c46861d9019992c2de2dc3d79 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 6 Nov 2023 07:19:48 +0100 Subject: [PATCH 01/13] Cleanup namespace --- README.md | 4 +- docs/examples/README.md | 2 +- docs/examples/barycentres.py | 11 +- docs/examples/bayesian_optimisation.py | 17 +-- docs/examples/classification.py | 12 +- docs/examples/collapsed_vi.py | 24 ++-- docs/examples/constructing_new_kernels.py | 8 +- docs/examples/decision_making.py | 10 +- docs/examples/deep_kernels.py | 10 +- docs/examples/graph_kernels.py | 12 +- docs/examples/intro_to_kernels.py | 21 +-- docs/examples/likelihoods_guide.py | 14 +- docs/examples/oceanmodelling.py | 4 +- docs/examples/poisson.py | 10 +- docs/examples/regression.py | 4 +- docs/examples/regression_mo.py | 4 +- docs/examples/spatial.py | 6 +- docs/examples/uncollapsed_vi.py | 8 +- docs/examples/yacht.py | 8 +- gpjax/__init__.py | 122 +++--------------- .../test_decision_maker.py | 10 +- tests/test_mean_functions.py | 6 +- tests/test_variational_families.py | 14 +- 23 files changed, 138 insertions(+), 203 deletions(-) diff --git a/README.md b/README.md index 64cb794b2..2490c9cf1 100644 --- a/README.md +++ b/README.md @@ -135,10 +135,10 @@ D = gpx.Dataset(X=x, y=y) # Construct the prior meanf = gpx.mean_functions.Zero() kernel = gpx.kernels.RBF() -prior = gpx.Prior(mean_function=meanf, kernel = kernel) +prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) # Define a likelihood -likelihood = gpx.Gaussian(num_datapoints = n) +likelihood = gpx.likelihoods.Gaussian(num_datapoints = n) # Construct the posterior posterior = prior * likelihood diff --git a/docs/examples/README.md b/docs/examples/README.md index a5188c35e..2b7d37c1a 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -67,7 +67,7 @@ class Prior(AbstractPrior): >>> >>> meanf = gpx.mean_functions.Zero() >>> kernel = gpx.kernels.RBF() - >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) Attributes: kernel (Kernel): The kernel function used to parameterise the prior. diff --git a/docs/examples/barycentres.py b/docs/examples/barycentres.py index 9ab7a2d83..0444cd42f 100644 --- a/docs/examples/barycentres.py +++ b/docs/examples/barycentres.py @@ -134,12 +134,17 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance: y = y.reshape(-1, 1) D = gpx.Dataset(X=x, y=y) - likelihood = gpx.Gaussian(num_datapoints=n) - posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood + likelihood = gpx.likelihoods.Gaussian(num_datapoints=n) + posterior = ( + gpx.gps.Prior( + mean_function=gpx.mean_functions.Constant(), kernel=gpx.kernels.RBF() + ) + * likelihood + ) opt_posterior, _ = gpx.fit( model=posterior, - objective=jax.jit(gpx.ConjugateMLL(negative=True)), + objective=jax.jit(gpx.objectives.ConjugateMLL(negative=True)), train_data=D, optim=ox.adamw(learning_rate=0.01), num_iters=500, diff --git a/docs/examples/bayesian_optimisation.py b/docs/examples/bayesian_optimisation.py index 449b8f450..d16318cd7 100644 --- a/docs/examples/bayesian_optimisation.py +++ b/docs/examples/bayesian_optimisation.py @@ -201,9 +201,9 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: # %% def return_optimised_posterior( - data: gpx.Dataset, prior: gpx.Module, key: Array -) -> gpx.Module: - likelihood = gpx.Gaussian( + data: gpx.Dataset, prior: gpx.base.Module, key: Array +) -> gpx.base.Module: + likelihood = gpx.likelihoods.Gaussian( num_datapoints=data.n, obs_stddev=jnp.array(1e-3) ) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value likelihood = likelihood.replace_trainable(obs_stddev=False) @@ -230,7 +230,7 @@ def return_optimised_posterior( mean = gpx.mean_functions.Zero() kernel = gpx.kernels.Matern52() -prior = gpx.Prior(mean_function=mean, kernel=kernel) +prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) opt_posterior = return_optimised_posterior(D, prior, key) # %% [markdown] @@ -315,7 +315,7 @@ def optimise_sample( # %% def plot_bayes_opt( - posterior: gpx.Module, + posterior: gpx.base.Module, sample: FunctionalSample, dataset: gpx.Dataset, queried_x: ScalarFloat, @@ -401,7 +401,7 @@ def plot_bayes_opt( # Generate optimised posterior using previously observed data mean = gpx.mean_functions.Zero() kernel = gpx.kernels.Matern52() - prior = gpx.Prior(mean_function=mean, kernel=kernel) + prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) opt_posterior = return_optimised_posterior(D, prior, subkey) # Draw a sample from the posterior, and find the minimiser of it @@ -543,7 +543,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: kernel = gpx.kernels.Matern52( active_dims=[0, 1], lengthscale=jnp.array([1.0, 1.0]), variance=2.0 ) - prior = gpx.Prior(mean_function=mean, kernel=kernel) + prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) opt_posterior = return_optimised_posterior(D, prior, subkey) # Draw a sample from the posterior, and find the minimiser of it @@ -561,7 +561,8 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: # Evaluate the black-box function at the best point observed so far, and add it to the dataset y_star = six_hump_camel(x_star) print( - f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value: {y_star}" + f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value:" + f" {y_star}" ) D = D + gpx.Dataset(X=x_star, y=y_star) bo_experiment_results.append(D) diff --git a/docs/examples/classification.py b/docs/examples/classification.py index 0f85e8580..55c1dcb80 100644 --- a/docs/examples/classification.py +++ b/docs/examples/classification.py @@ -89,10 +89,10 @@ # choose a Bernoulli likelihood with a probit link function. # %% -kernel = gpx.RBF() -meanf = gpx.Constant() -prior = gpx.Prior(mean_function=meanf, kernel=kernel) -likelihood = gpx.Bernoulli(num_datapoints=D.n) +kernel = gpx.kernels.RBF() +meanf = gpx.mean_functions.Constant() +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) +likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n) # %% [markdown] # We construct the posterior through the product of our prior and likelihood. @@ -116,7 +116,7 @@ # Optax's optimisers. # %% -negative_lpd = jax.jit(gpx.LogPosteriorDensity(negative=True)) +negative_lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=True)) optimiser = ox.adam(learning_rate=0.01) @@ -345,7 +345,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma num_adapt = 500 num_samples = 500 -lpd = jax.jit(gpx.LogPosteriorDensity(negative=False)) +lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=False)) unconstrained_lpd = jax.jit(lambda tree: lpd(tree.constrain(), D)) adapt = blackjax.window_adaptation( diff --git a/docs/examples/collapsed_vi.py b/docs/examples/collapsed_vi.py index 80e0436af..ad2f1bf9f 100644 --- a/docs/examples/collapsed_vi.py +++ b/docs/examples/collapsed_vi.py @@ -106,10 +106,10 @@ # this, it is intractable to evaluate. # %% -meanf = gpx.Constant() -kernel = gpx.RBF() -likelihood = gpx.Gaussian(num_datapoints=D.n) -prior = gpx.Prior(mean_function=meanf, kernel=kernel) +meanf = gpx.mean_functions.Constant() +kernel = gpx.kernels.RBF() +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) posterior = prior * likelihood # %% [markdown] @@ -119,7 +119,9 @@ # inducing points into the constructor as arguments. # %% -q = gpx.CollapsedVariationalGaussian(posterior=posterior, inducing_inputs=z) +q = gpx.variational_families.CollapsedVariationalGaussian( + posterior=posterior, inducing_inputs=z +) # %% [markdown] # We define our variational inference algorithm through `CollapsedVI`. This defines @@ -127,7 +129,7 @@ # Titsias (2009). # %% -elbo = gpx.CollapsedELBO(negative=True) +elbo = gpx.objectives.CollapsedELBO(negative=True) # %% [markdown] # For researchers, GPJax has the capacity to print the bibtex citation for objects such @@ -241,14 +243,14 @@ # full model. # %% -full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian( - num_datapoints=D.n -) -negative_mll = jit(gpx.ConjugateMLL(negative=True).step) +full_rank_model = gpx.gps.Prior( + mean_function=gpx.mean_functions.Zero(), kernel=gpx.kernels.RBF() +) * gpx.likelihoods.Gaussian(num_datapoints=D.n) +negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True).step) # %timeit negative_mll(full_rank_model, D).block_until_ready() # %% -negative_elbo = jit(gpx.CollapsedELBO(negative=True).step) +negative_elbo = jit(gpx.objectives.CollapsedELBO(negative=True).step) # %timeit negative_elbo(q, D).block_until_ready() # %% [markdown] diff --git a/docs/examples/constructing_new_kernels.py b/docs/examples/constructing_new_kernels.py index 9355b614f..f6b60452c 100644 --- a/docs/examples/constructing_new_kernels.py +++ b/docs/examples/constructing_new_kernels.py @@ -90,7 +90,7 @@ meanf = gpx.mean_functions.Zero() for k, ax in zip(kernels, axes.ravel()): - prior = gpx.Prior(mean_function=meanf, kernel=k) + prior = gpx.gps.Prior(mean_function=meanf, kernel=k) rv = prior(x) y = rv.sample(seed=key, sample_shape=(10,)) ax.plot(x, y.T, alpha=0.7) @@ -264,13 +264,13 @@ def __call__( # Define polar Gaussian process PKern = Polar() meanf = gpx.mean_functions.Zero() -likelihood = gpx.Gaussian(num_datapoints=n) -circular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood +likelihood = gpx.likelihoods.Gaussian(num_datapoints=n) +circular_posterior = gpx.gps.Prior(mean_function=meanf, kernel=PKern) * likelihood # Optimise GP's marginal log-likelihood using Adam opt_posterior, history = gpx.fit( model=circular_posterior, - objective=jit(gpx.ConjugateMLL(negative=True)), + objective=jit(gpx.objectives.ConjugateMLL(negative=True)), train_data=D, optim=ox.adamw(learning_rate=0.05), num_iters=500, diff --git a/docs/examples/decision_making.py b/docs/examples/decision_making.py index 0bf08e60c..66dc77a93 100644 --- a/docs/examples/decision_making.py +++ b/docs/examples/decision_making.py @@ -136,9 +136,9 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: # mean function and kernel for the job at hand: # %% -mean = gpx.Zero() -kernel = gpx.Matern52() -prior = gpx.Prior(mean_function=mean, kernel=kernel) +mean = gpx.mean_functions.Zero() +kernel = gpx.kernels.Matern52() +prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) # %% [markdown] # One difference from GPJax is the way in which we define our likelihood. In GPJax, we @@ -153,7 +153,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: # with the correct number of datapoints: # %% -likelihood_builder = lambda n: gpx.Gaussian( +likelihood_builder = lambda n: gpx.likelihoods.Gaussian( num_datapoints=n, obs_stddev=jnp.array(1e-3) ) @@ -174,7 +174,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: posterior_handler = PosteriorHandler( prior, likelihood_builder=likelihood_builder, - optimization_objective=gpx.ConjugateMLL(negative=True), + optimization_objective=gpx.objectives.ConjugateMLL(negative=True), optimizer=ox.adam(learning_rate=0.01), num_optimization_iters=1000, ) diff --git a/docs/examples/deep_kernels.py b/docs/examples/deep_kernels.py index 3346c958c..31ea4e89e 100644 --- a/docs/examples/deep_kernels.py +++ b/docs/examples/deep_kernels.py @@ -163,16 +163,16 @@ def __call__(self, x): # kernel and assume a Gaussian likelihood. # %% -base_kernel = gpx.Matern52( +base_kernel = gpx.kernels.Matern52( active_dims=list(range(feature_space_dim)), lengthscale=jnp.ones((feature_space_dim,)), ) kernel = DeepKernelFunction( network=forward_linear, base_kernel=base_kernel, key=key, dummy_x=x ) -meanf = gpx.Zero() -prior = gpx.Prior(mean_function=meanf, kernel=kernel) -likelihood = gpx.Gaussian(num_datapoints=D.n) +meanf = gpx.mean_functions.Zero() +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) posterior = prior * likelihood # %% [markdown] # ### Optimisation @@ -207,7 +207,7 @@ def __call__(self, x): opt_posterior, history = gpx.fit( model=posterior, - objective=jax.jit(gpx.ConjugateMLL(negative=True)), + objective=jax.jit(gpx.objectives.ConjugateMLL(negative=True)), train_data=D, optim=optimiser, num_iters=800, diff --git a/docs/examples/graph_kernels.py b/docs/examples/graph_kernels.py index 2d77d79e5..bb427c03f 100644 --- a/docs/examples/graph_kernels.py +++ b/docs/examples/graph_kernels.py @@ -95,13 +95,13 @@ # %% x = jnp.arange(G.number_of_nodes()).reshape(-1, 1) -true_kernel = gpx.GraphKernel( +true_kernel = gpx.kernels.GraphKernel( laplacian=L, lengthscale=2.3, variance=3.2, smoothness=6.1, ) -prior = gpx.Prior(mean_function=gpx.Zero(), kernel=true_kernel) +prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=true_kernel) fx = prior(x) y = fx.sample(seed=key, sample_shape=(1,)).reshape(-1, 1) @@ -137,9 +137,9 @@ # We do this using the Adam optimiser provided in `optax`. # %% -likelihood = gpx.Gaussian(num_datapoints=D.n) -kernel = gpx.GraphKernel(laplacian=L) -prior = gpx.Prior(mean_function=gpx.Zero(), kernel=kernel) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) +kernel = gpx.kernels.GraphKernel(laplacian=L) +prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=kernel) posterior = prior * likelihood # %% [markdown] @@ -157,7 +157,7 @@ # %% opt_posterior, training_history = gpx.fit( model=posterior, - objective=jit(gpx.ConjugateMLL(negative=True)), + objective=jit(gpx.objectives.ConjugateMLL(negative=True)), train_data=D, optim=ox.adamw(learning_rate=0.01), num_iters=1000, diff --git a/docs/examples/intro_to_kernels.py b/docs/examples/intro_to_kernels.py index 04e077274..d1c8f2a22 100644 --- a/docs/examples/intro_to_kernels.py +++ b/docs/examples/intro_to_kernels.py @@ -161,7 +161,7 @@ meanf = gpx.mean_functions.Zero() for k, ax in zip(kernels, axes.ravel()): - prior = gpx.Prior(mean_function=meanf, kernel=k) + prior = gpx.gps.Prior(mean_function=meanf, kernel=k) rv = prior(x) y = rv.sample(seed=key, sample_shape=(10,)) ax.plot(x, y.T, alpha=0.7) @@ -220,9 +220,9 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: lengthscale=jnp.array(2.0) ) # Initialise our kernel lengthscale to 2.0 -prior = gpx.Prior(mean_function=mean, kernel=kernel) +prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) -likelihood = gpx.Gaussian( +likelihood = gpx.likelihoods.Gaussian( num_datapoints=D.n, obs_stddev=jnp.array(1e-3) ) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value likelihood = likelihood.replace_trainable(obs_stddev=False) @@ -358,7 +358,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # %% mean = gpx.mean_functions.Zero() kernel = gpx.kernels.Periodic() -prior = gpx.Prior(mean_function=mean, kernel=kernel) +prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) rv = prior(x) @@ -381,7 +381,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # %% mean = gpx.mean_functions.Zero() kernel = gpx.kernels.Linear() -prior = gpx.Prior(mean_function=mean, kernel=kernel) +prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) rv = prior(x) @@ -417,7 +417,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: kernel_two = gpx.kernels.Periodic() sum_kernel = gpx.kernels.SumKernel(kernels=[kernel_one, kernel_two]) mean = gpx.mean_functions.Zero() -prior = gpx.Prior(mean_function=mean, kernel=sum_kernel) +prior = gpx.gps.Prior(mean_function=mean, kernel=sum_kernel) x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) rv = prior(x) @@ -442,7 +442,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: kernel_two = gpx.kernels.Periodic() sum_kernel = gpx.kernels.ProductKernel(kernels=[kernel_one, kernel_two]) mean = gpx.mean_functions.Zero() -prior = gpx.Prior(mean_function=mean, kernel=sum_kernel) +prior = gpx.gps.Prior(mean_function=mean, kernel=sum_kernel) x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) rv = prior(x) @@ -528,8 +528,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: sum_kernel = gpx.kernels.SumKernel(kernels=[linear_kernel, periodic_kernel]) final_kernel = gpx.kernels.SumKernel(kernels=[rbf_kernel, sum_kernel]) -prior = gpx.Prior(mean_function=mean, kernel=final_kernel) -likelihood = gpx.Gaussian(num_datapoints=D.n) +prior = gpx.gps.Prior(mean_function=mean, kernel=final_kernel) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) posterior = prior * likelihood @@ -652,7 +652,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # %% print( - f"Periodic Kernel Period: {[i for i in opt_posterior.prior.kernel.kernels if isinstance(i, gpx.kernels.Periodic)][0].period}" + "Periodic Kernel Period:" + f" {[i for i in opt_posterior.prior.kernel.kernels if isinstance(i, gpx.kernels.Periodic)][0].period}" ) # %% [markdown] diff --git a/docs/examples/likelihoods_guide.py b/docs/examples/likelihoods_guide.py index 6107c5688..6f8487ee5 100644 --- a/docs/examples/likelihoods_guide.py +++ b/docs/examples/likelihoods_guide.py @@ -124,11 +124,11 @@ # $\mathbf{y}^{\star}$. # + -kernel = gpx.Matern32() -meanf = gpx.Zero() -prior = gpx.Prior(kernel=kernel, mean_function=meanf) +kernel = gpx.kernels.Matern32() +meanf = gpx.mean_functions.Zero() +prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) -likelihood = gpx.Gaussian(num_datapoints=D.n, obs_stddev=0.1) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=0.1) posterior = prior * likelihood @@ -158,7 +158,7 @@ # Similarly, for a Bernoulli likelihood function, the samples of $y$ would be binary. # + -likelihood = gpx.Bernoulli(num_datapoints=D.n) +likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n) fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(9, 2)) @@ -231,7 +231,7 @@ # + z = jnp.linspace(-3.0, 3.0, 10).reshape(-1, 1) -q = gpx.VariationalGaussian(posterior=posterior, inducing_inputs=z) +q = gpx.variational_families.VariationalGaussian(posterior=posterior, inducing_inputs=z) def q_moments(x): @@ -251,7 +251,7 @@ def q_moments(x): # However, had we wanted to do this using quadrature, then we would have done the # following: -lquad = gpx.Gaussian( +lquad = gpx.likelihoods.Gaussian( num_datapoints=D.n, obs_stddev=jnp.array([0.1]), integrator=gpx.integrators.GHQuadratureIntegrator(num_points=20), diff --git a/docs/examples/oceanmodelling.py b/docs/examples/oceanmodelling.py index 3488d1640..0fef55dea 100644 --- a/docs/examples/oceanmodelling.py +++ b/docs/examples/oceanmodelling.py @@ -224,8 +224,8 @@ def __call__( # %% def initialise_gp(kernel, mean, dataset): - prior = gpx.Prior(mean_function=mean, kernel=kernel) - likelihood = gpx.Gaussian( + prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) + likelihood = gpx.likelihoods.Gaussian( num_datapoints=dataset.n, obs_stddev=jnp.array([1.0e-3], dtype=jnp.float64) ) posterior = prior * likelihood diff --git a/docs/examples/poisson.py b/docs/examples/poisson.py index da740665f..a671596f6 100644 --- a/docs/examples/poisson.py +++ b/docs/examples/poisson.py @@ -83,10 +83,10 @@ # kernel, chosen for the purpose of exposition. We adopt the Poisson likelihood available in GPJax. # %% -kernel = gpx.RBF() -meanf = gpx.Constant() -prior = gpx.Prior(mean_function=meanf, kernel=kernel) -likelihood = gpx.Poisson(num_datapoints=D.n) +kernel = gpx.kernels.RBF() +meanf = gpx.mean_functions.Constant() +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) +likelihood = gpx.likelihoods.Poisson(num_datapoints=D.n) # %% [markdown] # We construct the posterior through the product of our prior and likelihood. @@ -135,7 +135,7 @@ num_adapt = 100 num_samples = 200 -lpd = jax.jit(gpx.LogPosteriorDensity(negative=False)) +lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=False)) unconstrained_lpd = jax.jit(lambda tree: lpd(tree.constrain(), D)) adapt = blackjax.window_adaptation( diff --git a/docs/examples/regression.py b/docs/examples/regression.py index bccbb8068..384ad09d2 100644 --- a/docs/examples/regression.py +++ b/docs/examples/regression.py @@ -108,7 +108,7 @@ # %% kernel = gpx.kernels.RBF() meanf = gpx.mean_functions.Zero() -prior = gpx.Prior(mean_function=meanf, kernel=kernel) +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) # %% [markdown] # @@ -152,7 +152,7 @@ # This is defined in GPJax through calling a `Gaussian` instance. # %% -likelihood = gpx.Gaussian(num_datapoints=D.n) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) # %% [markdown] # The posterior is proportional to the prior multiplied by the likelihood, written as diff --git a/docs/examples/regression_mo.py b/docs/examples/regression_mo.py index dc50fd7df..164cb90bd 100644 --- a/docs/examples/regression_mo.py +++ b/docs/examples/regression_mo.py @@ -124,7 +124,7 @@ ) # out_kernel = gpx.kernels.White(variance=1.0) meanf = gpx.mean_functions.Constant(jnp.array([0.0, 1.0])) -prior = gpx.Prior(mean_function=meanf, kernel=kernel, out_kernel=out_kernel) +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel, out_kernel=out_kernel) # %% [markdown] # @@ -169,7 +169,7 @@ # This is defined in GPJax through calling a `Gaussian` instance. # %% -likelihood = gpx.Gaussian(num_datapoints=D.n) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) # %% [markdown] # The posterior is proportional to the prior multiplied by the likelihood, written as diff --git a/docs/examples/spatial.py b/docs/examples/spatial.py index 72fe15c4f..0b266d039 100644 --- a/docs/examples/spatial.py +++ b/docs/examples/spatial.py @@ -151,7 +151,7 @@ # %% @dataclass -class MeanFunction(gpx.gps.AbstractMeanFunction): +class MeanFunction(gpx.mean_functions.AbstractMeanFunction): w: Float[Array, "1"] = param_field(jnp.array([0.0])) b: Float[Array, "1"] = param_field(jnp.array([0.0])) @@ -166,8 +166,8 @@ def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: # %% mean_function = MeanFunction() -prior = gpx.Prior(kernel=kernel, mean_function=mean_function) -likelihood = gpx.Gaussian(D.n) +prior = gpx.gps.Prior(kernel=kernel, mean_function=mean_function) +likelihood = gpx.likelihoods.Gaussian(D.n) # %% [markdown] # Finally, we construct the posterior. diff --git a/docs/examples/uncollapsed_vi.py b/docs/examples/uncollapsed_vi.py index 76eae5ec1..79a8c3ce2 100644 --- a/docs/examples/uncollapsed_vi.py +++ b/docs/examples/uncollapsed_vi.py @@ -203,10 +203,10 @@ # %% meanf = gpx.mean_functions.Zero() -likelihood = gpx.Gaussian(num_datapoints=n) -prior = gpx.Prior(mean_function=meanf, kernel=jk.RBF()) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=n) +prior = gpx.gps.Prior(mean_function=meanf, kernel=jk.RBF()) p = prior * likelihood -q = gpx.VariationalGaussian(posterior=p, inducing_inputs=z) +q = gpx.variational_families.VariationalGaussian(posterior=p, inducing_inputs=z) # %% [markdown] # Here, the variational process $q(\cdot)$ depends on the prior through @@ -232,7 +232,7 @@ # its negative. # %% -negative_elbo = gpx.ELBO(negative=True) +negative_elbo = gpx.objectives.ELBO(negative=True) # %% [markdown] # For researchers, GPJax has the capacity to print the bibtex citation for objects such diff --git a/docs/examples/yacht.py b/docs/examples/yacht.py index c4260a986..c4243c621 100644 --- a/docs/examples/yacht.py +++ b/docs/examples/yacht.py @@ -168,13 +168,13 @@ # %% n_train, n_covariates = scaled_Xtr.shape -kernel = gpx.RBF( +kernel = gpx.kernels.RBF( active_dims=list(range(n_covariates)), lengthscale=np.ones((n_covariates,)) ) meanf = gpx.mean_functions.Zero() -prior = gpx.Prior(mean_function=meanf, kernel=kernel) +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) -likelihood = gpx.Gaussian(num_datapoints=n_train) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=n_train) posterior = prior * likelihood @@ -187,7 +187,7 @@ # %% training_data = gpx.Dataset(X=scaled_Xtr, y=scaled_ytr) -negative_mll = jit(gpx.ConjugateMLL(negative=True)) +negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True)) optimiser = ox.adamw(0.05) opt_posterior, history = gpx.fit( diff --git a/gpjax/__init__.py b/gpjax/__init__.py index eca71e399..03d7abaf2 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -13,120 +13,40 @@ # limitations under the License. # ============================================================================== from gpjax import ( + base, decision_making, + gps, integrators, + kernels, + likelihoods, + mean_functions, + objectives, + sklearn, + variational_families, ) -from gpjax.base import ( - Module, - param_field, -) -from gpjax.citation import cite from gpjax.dataset import Dataset +from gpjax.citation import cite from gpjax.fit import fit -from gpjax.gps import ( - Prior, - construct_posterior, -) -from gpjax.kernels import ( - RBF, - RFF, - AbstractKernel, - BasisFunctionComputation, - CatKernel, - ConstantDiagonalKernelComputation, - DenseKernelComputation, - DiagonalKernelComputation, - EigenKernelComputation, - GraphKernel, - Linear, - Matern12, - Matern32, - Matern52, - Periodic, - Polynomial, - PoweredExponential, - ProductKernel, - RationalQuadratic, - SumKernel, - White, -) -from gpjax.likelihoods import ( - Bernoulli, - Gaussian, - Poisson, -) -from gpjax.mean_functions import ( - Constant, - Zero, -) -from gpjax.objectives import ( - ELBO, - CollapsedELBO, - ConjugateMLL, - LogPosteriorDensity, - NonConjugateMLL, -) -from gpjax.variational_families import ( - CollapsedVariationalGaussian, - ExpectationVariationalGaussian, - NaturalVariationalGaussian, - VariationalGaussian, - WhitenedVariationalGaussian, -) + __license__ = "MIT" __description__ = "Didactic Gaussian processes in JAX" __url__ = "https://github.com/JaxGaussianProcesses/GPJax" __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors" -__version__ = "0.7.1" +__version__ = "0.7.2" __all__ = [ - "Module", - "param_field", - "cite", + "base", "decision_making", - "kernels", - "fit", - "Prior", - "construct_posterior", + "gps", "integrators", - "RBF", - "GraphKernel", - "Matern12", - "Matern32", - "Matern52", - "Polynomial", - "ProductKernel", - "SumKernel", - "Bernoulli", - "Gaussian", - "Poisson", - "Constant", - "Zero", + "kernels", + "likelihoods", + "mean_functions", + "objectives", + "sklearn", + "variational_families", "Dataset", - "CollapsedVariationalGaussian", - "ExpectationVariationalGaussian", - "NaturalVariationalGaussian", - "VariationalGaussian", - "WhitenedVariationalGaussian", - "CollapsedVI", - "StochasticVI", - "ConjugateMLL", - "NonConjugateMLL", - "LogPosteriorDensity", - "CollapsedELBO", - "ELBO", - "AbstractKernel", - "CatKernel", - "Linear", - "DenseKernelComputation", - "DiagonalKernelComputation", - "ConstantDiagonalKernelComputation", - "EigenKernelComputation", - "PoweredExponential", - "Periodic", - "RationalQuadratic", - "White", - "BasisFunctionComputation", - "RFF", + "cite", + "fit", ] diff --git a/tests/test_decision_making/test_decision_maker.py b/tests/test_decision_making/test_decision_maker.py index 4d79f8d52..23a87a2f1 100644 --- a/tests/test_decision_making/test_decision_maker.py +++ b/tests/test_decision_making/test_decision_maker.py @@ -61,16 +61,16 @@ def search_space() -> ContinuousSearchSpace: @pytest.fixture def posterior_handler() -> PosteriorHandler: - mean = gpx.Zero() - kernel = gpx.Matern52(lengthscale=jnp.array(1.0), variance=jnp.array(1.0)) - prior = gpx.Prior(mean_function=mean, kernel=kernel) - likelihood_builder = lambda x: gpx.Gaussian( + mean = gpx.mean_functions.Zero() + kernel = gpx.kernels.Matern52(lengthscale=jnp.array(1.0), variance=jnp.array(1.0)) + prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) + likelihood_builder = lambda x: gpx.likelihoods.Gaussian( num_datapoints=x, obs_stddev=jnp.array(1e-3) ) posterior_handler = PosteriorHandler( prior=prior, likelihood_builder=likelihood_builder, - optimization_objective=gpx.ConjugateMLL(negative=True), + optimization_objective=gpx.objectives.ConjugateMLL(negative=True), optimizer=ox.adam(learning_rate=0.01), num_optimization_iters=100, ) diff --git a/tests/test_mean_functions.py b/tests/test_mean_functions.py index fded647c0..2fb74283d 100644 --- a/tests/test_mean_functions.py +++ b/tests/test_mean_functions.py @@ -65,8 +65,10 @@ def test_zero_mean_remains_zero() -> None: constant=False ) # Prevent kernel from modelling non-zero mean meanf = Zero() - prior = gpx.Prior(mean_function=meanf, kernel=kernel) - likelihood = gpx.Gaussian(num_datapoints=D.n, obs_stddev=jnp.array(1e-3)) + prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) + likelihood = gpx.likelihoods.Gaussian( + num_datapoints=D.n, obs_stddev=jnp.array(1e-3) + ) likelihood = likelihood.replace_trainable(obs_stddev=False) posterior = prior * likelihood diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index 43473d5d3..2f6f2191f 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -114,8 +114,10 @@ def test_variational_gaussians( variational_family: AbstractVariationalFamily, ) -> None: # Initialise variational family: - prior = gpx.Prior(kernel=gpx.RBF(), mean_function=gpx.Constant()) - likelihood = gpx.Gaussian(123) + prior = gpx.gps.Prior( + kernel=gpx.kernels.RBF(), mean_function=gpx.mean_functions.Constant() + ) + likelihood = gpx.likelihoods.Gaussian(123) inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) @@ -223,14 +225,16 @@ def test_collapsed_variational_gaussian( x = jnp.hstack([x] * point_dim) D = gpx.Dataset(X=x, y=y) - prior = gpx.Prior(kernel=gpx.RBF(), mean_function=gpx.Constant()) + prior = gpx.gps.Prior( + kernel=gpx.kernels.RBF(), mean_function=gpx.mean_functions.Constant() + ) inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) test_inputs = jnp.hstack([test_inputs] * point_dim) - posterior = prior * gpx.Gaussian(num_datapoints=D.n) + posterior = prior * gpx.likelihoods.Gaussian(num_datapoints=D.n) variational_family = CollapsedVariationalGaussian( posterior=posterior, @@ -240,7 +244,7 @@ def test_collapsed_variational_gaussian( # We should raise an error for non-Gaussian likelihoods: with pytest.raises(TypeError): CollapsedVariationalGaussian( - posterior=prior * gpx.Bernoulli(num_datapoints=D.n), + posterior=prior * gpx.likelihoods.Bernoulli(num_datapoints=D.n), inducing_inputs=inducing_inputs, ) From a634945e08967534e58188aa05954049594cecf6 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 6 Nov 2023 07:20:13 +0100 Subject: [PATCH 02/13] Format --- README.md | 4 ++-- gpjax/__init__.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 2490c9cf1..4d44b2ad7 100644 --- a/README.md +++ b/README.md @@ -40,13 +40,13 @@ Another way you can contribute to GPJax is through [issue triaging](https://www.codetriage.com/what). This can include reproducing bug reports, asking for vital information such as version numbers and reproduction instructions, or identifying stale issues. If you would like to begin triaging issues, an easy way to get -started is to +started is to [subscribe to GPJax on CodeTriage](https://www.codetriage.com/jaxgaussianprocesses/gpjax). As a contributor to GPJax, you are expected to abide by our [code of conduct](docs/CODE_OF_CONDUCT.md). If you are feel that you have either experienced or witnessed behaviour that violates this standard, then we ask that you report any such -behaviours though [this form](https://jaxgaussianprocesses.com/contact/) or reach out to +behaviours though [this form](https://jaxgaussianprocesses.com/contact/) or reach out to one of the project's [_gardeners_](https://docs.jaxgaussianprocesses.com/GOVERNANCE/#roles). Feel free to join our [Slack diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 03d7abaf2..c0a83eaa3 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -24,11 +24,10 @@ sklearn, variational_families, ) -from gpjax.dataset import Dataset from gpjax.citation import cite +from gpjax.dataset import Dataset from gpjax.fit import fit - __license__ = "MIT" __description__ = "Didactic Gaussian processes in JAX" __url__ = "https://github.com/JaxGaussianProcesses/GPJax" From 88d56418088e8f8ec9c6d859a3deb5c8e69b5567 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 6 Nov 2023 07:23:54 +0100 Subject: [PATCH 03/13] Fix mistake --- gpjax/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gpjax/__init__.py b/gpjax/__init__.py index c0a83eaa3..a02dd56d0 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -43,7 +43,6 @@ "likelihoods", "mean_functions", "objectives", - "sklearn", "variational_families", "Dataset", "cite", From 0ebaaf9910fafd71afd734693a9e0afeeff7c395 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 6 Nov 2023 07:24:39 +0100 Subject: [PATCH 04/13] Fix mistake --- gpjax/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gpjax/__init__.py b/gpjax/__init__.py index a02dd56d0..fa04b40b6 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -21,7 +21,6 @@ likelihoods, mean_functions, objectives, - sklearn, variational_families, ) from gpjax.citation import cite From 8d95268cb745cadc1f262bc37aeb1f5964bfd2d0 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 6 Nov 2023 07:37:28 +0100 Subject: [PATCH 05/13] Refactor docstrings --- benchmarks/objectives.py | 6 +++--- benchmarks/predictions.py | 6 +++--- benchmarks/sparse.py | 2 +- benchmarks/stochastic.py | 2 +- gpjax/fit.py | 8 ++++---- gpjax/gps.py | 14 +++++++------- gpjax/objectives.py | 8 ++++---- tests/test_objectives.py | 12 ++++++++---- 8 files changed, 31 insertions(+), 27 deletions(-) diff --git a/benchmarks/objectives.py b/benchmarks/objectives.py index dd217b42b..65b8569c2 100644 --- a/benchmarks/objectives.py +++ b/benchmarks/objectives.py @@ -22,7 +22,7 @@ def setup(self, n_datapoints: int, n_dims: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(n_dims))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n) self.objective = gpx.ConjugateMLL() self.posterior = self.prior * self.likelihood @@ -48,7 +48,7 @@ def setup(self, n_datapoints: int, n_dims: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(n_dims))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n) self.objective = gpx.LogPosteriorDensity() self.posterior = self.prior * self.likelihood @@ -75,7 +75,7 @@ def setup(self, n_datapoints: int, n_dims: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(n_dims))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Poisson(num_datapoints=self.data.n) self.objective = gpx.LogPosteriorDensity() self.posterior = self.prior * self.likelihood diff --git a/benchmarks/predictions.py b/benchmarks/predictions.py index eed35d66a..a3dd4fa8e 100644 --- a/benchmarks/predictions.py +++ b/benchmarks/predictions.py @@ -21,7 +21,7 @@ def setup(self, n_test: int, n_dims: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(n_dims))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n) self.posterior = self.prior * self.likelihood key, subkey = jr.split(key) @@ -46,7 +46,7 @@ def setup(self, n_test: int, n_dims: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(n_dims))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n) self.posterior = self.prior * self.likelihood key, subkey = jr.split(key) @@ -71,7 +71,7 @@ def setup(self, n_test: int, n_dims: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(n_dims))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n) self.posterior = self.prior * self.likelihood key, subkey = jr.split(key) diff --git a/benchmarks/sparse.py b/benchmarks/sparse.py index 759cac9bc..de7d47054 100644 --- a/benchmarks/sparse.py +++ b/benchmarks/sparse.py @@ -19,7 +19,7 @@ def setup(self, n_datapoints: int, n_inducing: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(1))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n) self.posterior = self.prior * self.likelihood diff --git a/benchmarks/stochastic.py b/benchmarks/stochastic.py index 14681535f..1e530c731 100644 --- a/benchmarks/stochastic.py +++ b/benchmarks/stochastic.py @@ -20,7 +20,7 @@ def setup(self, n_datapoints: int, n_inducing: int, batch_size: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(1))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n) self.posterior = self.prior * self.likelihood diff --git a/gpjax/fit.py b/gpjax/fit.py index 69b6a699a..3fcffae5c 100644 --- a/gpjax/fit.py +++ b/gpjax/fit.py @@ -70,9 +70,9 @@ def fit( # noqa: PLR0913 >>> D = gpx.Dataset(X, y) >>> >>> # (2) Define your model: - >>> class LinearModel(gpx.Module): - weight: float = gpx.param_field() - bias: float = gpx.param_field() + >>> class LinearModel(gpx.base.Module): + weight: float = gpx.base.param_field() + bias: float = gpx.base.param_field() def __call__(self, x): return self.weight * x + self.bias @@ -80,7 +80,7 @@ def __call__(self, x): >>> model = LinearModel(weight=1.0, bias=1.0) >>> >>> # (3) Define your loss function: - >>> class MeanSquareError(gpx.AbstractObjective): + >>> class MeanSquareError(gpx.objectives.AbstractObjective): def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float: return jnp.mean((train_data.y - model(train_data.X)) ** 2) >>> diff --git a/gpjax/gps.py b/gpjax/gps.py index 638f408cd..3c20c66b7 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -149,7 +149,7 @@ class Prior(AbstractPrior): >>> kernel = gpx.kernels.RBF() >>> meanf = gpx.mean_functions.Zero() - >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) ``` """ @@ -183,7 +183,7 @@ def __mul__(self, other): >>> >>> meanf = gpx.mean_functions.Zero() >>> kernel = gpx.kernels.RBF() - >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) >>> >>> prior * likelihood @@ -244,7 +244,7 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> ReshapedGaussianDistributio >>> >>> kernel = gpx.kernels.RBF() >>> meanf = gpx.mean_functions.Zero() - >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) >>> >>> prior.predict(jnp.linspace(0, 1, 100)) ``` @@ -310,7 +310,7 @@ def sample_approx( >>> >>> meanf = gpx.mean_functions.Zero() >>> kernel = gpx.kernels.RBF() - >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) >>> >>> sample_fn = prior.sample_approx(10, key) >>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1)) @@ -434,7 +434,7 @@ class ConjugatePosterior(AbstractPosterior): >>> import gpjax as gpx >>> import jax.numpy as jnp - >>> prior = gpx.Prior( + >>> prior = gpx.gps.Prior( mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF() ) @@ -482,8 +482,8 @@ def predict( >>> D = gpx.Dataset(X=xtrain, y=ytrain) >>> xtest = jnp.linspace(0, 1).reshape(-1, 1) >>> - >>> prior = gpx.Prior(mean_function = gpx.Zero(), kernel = gpx.RBF()) - >>> posterior = prior * gpx.Gaussian(num_datapoints = D.n) + >>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF()) + >>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n) >>> predictive_dist = posterior(xtest, D) ``` diff --git a/gpjax/objectives.py b/gpjax/objectives.py index 887772147..861c01ae4 100644 --- a/gpjax/objectives.py +++ b/gpjax/objectives.py @@ -100,10 +100,10 @@ def step( >>> meanf = gpx.mean_functions.Constant() >>> kernel = gpx.kernels.RBF() >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) - >>> prior = gpx.Prior(mean_function = meanf, kernel=kernel) + >>> prior = gpx.gps.Prior(mean_function = meanf, kernel=kernel) >>> posterior = prior * likelihood >>> - >>> mll = gpx.ConjugateMLL(negative=True) + >>> mll = gpx.objectives.ConjugateMLL(negative=True) >>> mll(posterior, train_data = D) ``` @@ -112,13 +112,13 @@ def step( marginal log-likelihood. This can be realised through ```python - mll = gpx.ConjugateMLL(negative=True) + mll = gpx.objectives.ConjugateMLL(negative=True) ``` For optimal performance, the marginal log-likelihood should be ``jax.jit`` compiled. ```python - mll = jit(gpx.ConjugateMLL(negative=True)) + mll = jit(gpx.objectives.ConjugateMLL(negative=True)) ``` Args: diff --git a/tests/test_objectives.py b/tests/test_objectives.py index 4478630c6..d19a13863 100644 --- a/tests/test_objectives.py +++ b/tests/test_objectives.py @@ -64,7 +64,8 @@ def test_conjugate_mll( # Build model p = Prior( - kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant() + kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), + mean_function=gpx.mean_functions.Constant(), ) likelihood = Gaussian(num_datapoints=num_datapoints) post = p * likelihood @@ -93,7 +94,8 @@ def test_non_conjugate_mll( # Build model p = Prior( - kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant() + kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), + mean_function=gpx.mean_functions.Constant(), ) likelihood = Bernoulli(num_datapoints=num_datapoints) post = p * likelihood @@ -129,7 +131,8 @@ def test_collapsed_elbo( ) p = Prior( - kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant() + kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), + mean_function=gpx.mean_functions.Constant(), ) likelihood = Gaussian(num_datapoints=num_datapoints) q = gpx.CollapsedVariationalGaussian(posterior=p * likelihood, inducing_inputs=z) @@ -169,7 +172,8 @@ def test_elbo( ) p = Prior( - kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant() + kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), + mean_function=gpx.mean_functions.Constant(), ) if binary: likelihood = Bernoulli(num_datapoints=num_datapoints) From 2513d93200db94fba13a5d517e49ae1d4af1e438 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 6 Nov 2023 07:44:19 +0100 Subject: [PATCH 06/13] Fix failing test --- tests/test_objectives.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/tests/test_objectives.py b/tests/test_objectives.py index d19a13863..b51757168 100644 --- a/tests/test_objectives.py +++ b/tests/test_objectives.py @@ -5,11 +5,6 @@ import pytest import gpjax as gpx -from gpjax import ( - Bernoulli, - Gaussian, - Prior, -) from gpjax.dataset import Dataset from gpjax.objectives import ( ELBO, @@ -63,11 +58,11 @@ def test_conjugate_mll( D = build_data(num_datapoints, num_dims, key, binary=False) # Build model - p = Prior( + p = gpx.gps.Prior( kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), mean_function=gpx.mean_functions.Constant(), ) - likelihood = Gaussian(num_datapoints=num_datapoints) + likelihood = gpx.likelihoods.Gaussian(num_datapoints=num_datapoints) post = p * likelihood mll = ConjugateMLL(negative=negative) @@ -93,11 +88,11 @@ def test_non_conjugate_mll( D = build_data(num_datapoints, num_dims, key, binary=True) # Build model - p = Prior( + p = gpx.gps.Prior( kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), mean_function=gpx.mean_functions.Constant(), ) - likelihood = Bernoulli(num_datapoints=num_datapoints) + likelihood = gpx.likelihoods.Bernoulli(num_datapoints=num_datapoints) post = p * likelihood mll = NonConjugateMLL(negative=negative) @@ -130,12 +125,14 @@ def test_collapsed_elbo( key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints // 2, num_dims) ) - p = Prior( + p = gpx.gps.Prior( kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), mean_function=gpx.mean_functions.Constant(), ) - likelihood = Gaussian(num_datapoints=num_datapoints) - q = gpx.CollapsedVariationalGaussian(posterior=p * likelihood, inducing_inputs=z) + likelihood = gpx.likelihoods.Gaussian(num_datapoints=num_datapoints) + q = gpx.variational_families.CollapsedVariationalGaussian( + posterior=p * likelihood, inducing_inputs=z + ) negative_elbo = CollapsedELBO(negative=negative) @@ -171,17 +168,17 @@ def test_elbo( key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints // 2, num_dims) ) - p = Prior( + p = gpx.gps.Prior( kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), mean_function=gpx.mean_functions.Constant(), ) if binary: - likelihood = Bernoulli(num_datapoints=num_datapoints) + likelihood = gpx.likelihoods.Bernoulli(num_datapoints=num_datapoints) else: - likelihood = Gaussian(num_datapoints=num_datapoints) + likelihood = gpx.likelihoods.Gaussian(num_datapoints=num_datapoints) post = p * likelihood - q = gpx.VariationalGaussian(posterior=post, inducing_inputs=z) + q = gpx.variational_families.VariationalGaussian(posterior=post, inducing_inputs=z) negative_elbo = ELBO( negative=negative, From bff89ef67f8a422e33098498009c429d1ef93f71 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 6 Nov 2023 07:54:19 +0100 Subject: [PATCH 07/13] Fix failing test --- tests/test_citations.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/test_citations.py b/tests/test_citations.py index 02eaf13ba..c1a1a7470 100644 --- a/tests/test_citations.py +++ b/tests/test_citations.py @@ -6,7 +6,6 @@ import jax.numpy as jnp import pytest -import gpjax as gpx from gpjax.citation import ( AbstractCitation, BookCitation, @@ -30,6 +29,13 @@ Matern32, Matern52, ) +from gpjax.objectives import ( + ELBO, + CollapsedELBO, + ConjugateMLL, + LogPosteriorDensity, + NonConjugateMLL, +) def _check_no_fallback(citation: AbstractCitation): @@ -103,7 +109,7 @@ def test_missing_citation(kernel): @pytest.mark.parametrize( - "mll", [gpx.ConjugateMLL(), gpx.NonConjugateMLL(), gpx.LogPosteriorDensity()] + "mll", [ConjugateMLL(), NonConjugateMLL(), LogPosteriorDensity()] ) def test_mlls(mll): citation = cite(mll) @@ -115,7 +121,7 @@ def test_mlls(mll): def test_uncollapsed_elbo(): - elbo = gpx.ELBO() + elbo = ELBO() citation = cite(elbo) assert isinstance(citation, PaperCitation) @@ -128,7 +134,7 @@ def test_uncollapsed_elbo(): def test_collapsed_elbo(): - elbo = gpx.CollapsedELBO() + elbo = CollapsedELBO() citation = cite(elbo) assert isinstance(citation, PaperCitation) @@ -158,7 +164,8 @@ def test_thompson_sampling(): ) assert ( citation.authors - == "Wilson, James and Borovitskiy, Viacheslav and Terenin, Alexander and Mostowsky, Peter and Deisenroth, Marc" + == "Wilson, James and Borovitskiy, Viacheslav and Terenin, Alexander and" + " Mostowsky, Peter and Deisenroth, Marc" ) assert citation.year == "2020" assert citation.booktitle == "International Conference on Machine Learning" @@ -205,7 +212,7 @@ def test_logarithmic_goldstein_price(): @pytest.mark.parametrize( "objective", - [gpx.ELBO(), gpx.CollapsedELBO(), gpx.LogPosteriorDensity(), gpx.ConjugateMLL()], + [ELBO(), CollapsedELBO(), LogPosteriorDensity(), ConjugateMLL()], ) def test_jitted_fallback(objective): with pytest.raises(RuntimeError): From 118af12dfd162a8e5f63c6711e3caf119fa1f6ef Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 30 Nov 2023 09:06:36 +0100 Subject: [PATCH 08/13] Update gpjax/__init__.py Co-authored-by: st-- Signed-off-by: Thomas Pinder --- gpjax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpjax/__init__.py b/gpjax/__init__.py index fa04b40b6..fbc15583d 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -31,7 +31,7 @@ __description__ = "Didactic Gaussian processes in JAX" __url__ = "https://github.com/JaxGaussianProcesses/GPJax" __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors" -__version__ = "0.7.2" +__version__ = "0.8.0" __all__ = [ "base", From 65c29fb9df2a98575e5e873de0f91e26eee094a9 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 30 Nov 2023 09:11:28 +0100 Subject: [PATCH 09/13] Fix linting --- gpjax/__init__.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 4fadb1602..47ba74a30 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -23,10 +23,13 @@ objectives, variational_families, ) +from gpjax.base import ( + Module, + param_field, +) from gpjax.citation import cite from gpjax.dataset import Dataset from gpjax.fit import fit -from gpjax.base import Module, param_field __license__ = "MIT" __description__ = "Didactic Gaussian processes in JAX" @@ -48,5 +51,5 @@ "cite", "fit", "Module", - "param_field" + "param_field", ] From 25db8293f600c142116b3eefd9b9d2ab900f7b2c Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 30 Nov 2023 09:24:40 +0100 Subject: [PATCH 10/13] Fix broken test --- README.md | 11 ----------- tests/test_objectives.py | 9 +++++++-- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 3dfe2de91..3917f14d7 100644 --- a/README.md +++ b/README.md @@ -46,11 +46,7 @@ started is to As a contributor to GPJax, you are expected to abide by our [code of conduct](docs/CODE_OF_CONDUCT.md). If you feel that you have either experienced or witnessed behaviour that violates this standard, then we ask that you report any such -<<<<<<< HEAD -behaviours though [this form](https://jaxgaussianprocesses.com/contact/) or reach out to -======= behaviours through [this form](https://jaxgaussianprocesses.com/contact/) or reach out to ->>>>>>> ac475762faa0cb9c64a773d5d9d3506d1a3ebdf2 one of the project's [_gardeners_](https://docs.jaxgaussianprocesses.com/GOVERNANCE/#roles). Feel free to join our [Slack @@ -139,17 +135,10 @@ D = gpx.Dataset(X=x, y=y) # Construct the prior meanf = gpx.mean_functions.Zero() kernel = gpx.kernels.RBF() -<<<<<<< HEAD prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) # Define a likelihood likelihood = gpx.likelihoods.Gaussian(num_datapoints = n) -======= -prior = gpx.Prior(mean_function=meanf, kernel=kernel) - -# Define a likelihood -likelihood = gpx.Gaussian(num_datapoints=n) ->>>>>>> ac475762faa0cb9c64a773d5d9d3506d1a3ebdf2 # Construct the posterior posterior = prior * likelihood diff --git a/tests/test_objectives.py b/tests/test_objectives.py index 02aee2694..fbb53f14a 100644 --- a/tests/test_objectives.py +++ b/tests/test_objectives.py @@ -6,6 +6,8 @@ import gpjax as gpx from gpjax.dataset import Dataset +from gpjax.gps import Prior +from gpjax.likelihoods import Gaussian from gpjax.objectives import ( ELBO, AbstractObjective, @@ -90,7 +92,8 @@ def test_conjugate_loocv( # Build model p = Prior( - kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant() + kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), + mean_function=gpx.mean_functions.Constant(), ) likelihood = Gaussian(num_datapoints=num_datapoints) post = p * likelihood @@ -176,7 +179,9 @@ def test_collapsed_elbo( assert evaluation.shape == () # Data on the full dataset should be the same as the marginal likelihood - q = gpx.CollapsedVariationalGaussian(posterior=p * likelihood, inducing_inputs=D.X) + q = gpx.variational_families.CollapsedVariationalGaussian( + posterior=p * likelihood, inducing_inputs=D.X + ) mll = ConjugateMLL(negative=negative) expected_value = mll(p * likelihood, D) actual_value = negative_elbo(q, D) From 3e475d8618a1f53e2051bf9c7dee5c951568275c Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 30 Nov 2023 09:29:13 +0100 Subject: [PATCH 11/13] Fix docstring --- gpjax/objectives.py | 8 ++++---- gpjax/progress_bar.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/gpjax/objectives.py b/gpjax/objectives.py index 1167d75ab..9b684f891 100644 --- a/gpjax/objectives.py +++ b/gpjax/objectives.py @@ -180,10 +180,10 @@ def step( >>> meanf = gpx.mean_functions.Constant() >>> kernel = gpx.kernels.RBF() >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) - >>> prior = gpx.Prior(mean_function = meanf, kernel=kernel) + >>> prior = gpx.gps.Prior(mean_function = meanf, kernel=kernel) >>> posterior = prior * likelihood >>> - >>> loocv = gpx.ConjugateLOOCV(negative=True) + >>> loocv = gpx.objectives.ConjugateLOOCV(negative=True) >>> loocv(posterior, train_data = D) ``` @@ -192,13 +192,13 @@ def step( leave-one-out log predictive probability. This can be realised through ```python - mll = gpx.ConjugateLOOCV(negative=True) + mll = gpx.objectives.ConjugateLOOCV(negative=True) ``` For optimal performance, the objective should be ``jax.jit`` compiled. ```python - mll = jit(gpx.ConjugateLOOCV(negative=True)) + mll = jit(gpx.objectives.ConjugateLOOCV(negative=True)) ``` Args: diff --git a/gpjax/progress_bar.py b/gpjax/progress_bar.py index 090a03ead..3072b71be 100644 --- a/gpjax/progress_bar.py +++ b/gpjax/progress_bar.py @@ -36,10 +36,10 @@ def progress_bar(num_iters: int, log_rate: int) -> Callable: >>> >>> carry = jnp.array(0.0) >>> iteration_numbers = jnp.arange(100) - + >>> >>> @progress_bar(num_iters=iteration_numbers.shape[0], log_rate=10) >>> def body_func(carry, x): - ... return carry, x + >>> return carry, x >>> >>> carry, _ = jax.lax.scan(body_func, carry, iteration_numbers) ``` From 9e35e3ec7c6b0e73390912d87675cd700d1f7c2e Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 30 Nov 2023 09:35:12 +0100 Subject: [PATCH 12/13] Fix scipy --- gpjax/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 47ba74a30..0fa71556d 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -29,7 +29,10 @@ ) from gpjax.citation import cite from gpjax.dataset import Dataset -from gpjax.fit import fit +from gpjax.fit import ( + fit, + fit_scipy, +) __license__ = "MIT" __description__ = "Didactic Gaussian processes in JAX" @@ -52,4 +55,5 @@ "fit", "Module", "param_field", + "fit_scipy", ] From 86490b1a8e0019eb2dbd7f1952a798e7989465da Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 30 Nov 2023 09:50:23 +0100 Subject: [PATCH 13/13] Bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6e2aa4537..198ccac6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gpjax" -version = "0.7.3" +version = "0.8.0" description = "Gaussian processes in JAX." authors = [ "Thomas Pinder ",