Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Namespace cleanup #408

Merged
merged 15 commits into from
Nov 30, 2023
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ D = gpx.Dataset(X=x, y=y)
# Construct the prior
meanf = gpx.mean_functions.Zero()
kernel = gpx.kernels.RBF()
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)

# Define a likelihood
likelihood = gpx.Gaussian(num_datapoints=n)
likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)

# Construct the posterior
posterior = prior * likelihood
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def setup(self, n_datapoints: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.objective = gpx.ConjugateMLL()
self.posterior = self.prior * self.likelihood
Expand All @@ -48,7 +48,7 @@ def setup(self, n_datapoints: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood
Expand All @@ -75,7 +75,7 @@ def setup(self, n_datapoints: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Poisson(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setup(self, n_test: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
Expand All @@ -46,7 +46,7 @@ def setup(self, n_test: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
Expand All @@ -71,7 +71,7 @@ def setup(self, n_test: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setup(self, n_datapoints: int, n_inducing: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(1)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setup(self, n_datapoints: int, n_inducing: int, batch_size: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(1)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Prior(AbstractPrior):
>>>
>>> meanf = gpx.mean_functions.Zero()
>>> kernel = gpx.kernels.RBF()
>>> prior = gpx.Prior(mean_function=meanf, kernel = kernel)
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)

Attributes:
kernel (Kernel): The kernel function used to parameterise the prior.
Expand Down
9 changes: 7 additions & 2 deletions docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,13 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
y = y.reshape(-1, 1)
D = gpx.Dataset(X=x, y=y)

likelihood = gpx.Gaussian(num_datapoints=n)
posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
posterior = (
gpx.gps.Prior(
mean_function=gpx.mean_functions.Constant(), kernel=gpx.kernels.RBF()
)
* likelihood
)

opt_posterior, _ = gpx.fit_scipy(
model=posterior,
Expand Down
17 changes: 9 additions & 8 deletions docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:

# %%
def return_optimised_posterior(
data: gpx.Dataset, prior: gpx.Module, key: Array
) -> gpx.Module:
likelihood = gpx.Gaussian(
data: gpx.Dataset, prior: gpx.base.Module, key: Array
) -> gpx.base.Module:
likelihood = gpx.likelihoods.Gaussian(
num_datapoints=data.n, obs_stddev=jnp.array(1e-3)
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
likelihood = likelihood.replace_trainable(obs_stddev=False)
Expand All @@ -230,7 +230,7 @@ def return_optimised_posterior(

mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
prior = gpx.Prior(mean_function=mean, kernel=kernel)
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
opt_posterior = return_optimised_posterior(D, prior, key)

# %% [markdown]
Expand Down Expand Up @@ -315,7 +315,7 @@ def optimise_sample(

# %%
def plot_bayes_opt(
posterior: gpx.Module,
posterior: gpx.base.Module,
sample: FunctionalSample,
dataset: gpx.Dataset,
queried_x: ScalarFloat,
Expand Down Expand Up @@ -401,7 +401,7 @@ def plot_bayes_opt(
# Generate optimised posterior using previously observed data
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
prior = gpx.Prior(mean_function=mean, kernel=kernel)
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
opt_posterior = return_optimised_posterior(D, prior, subkey)

# Draw a sample from the posterior, and find the minimiser of it
Expand Down Expand Up @@ -543,7 +543,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
kernel = gpx.kernels.Matern52(
active_dims=[0, 1], lengthscale=jnp.array([1.0, 1.0]), variance=2.0
)
prior = gpx.Prior(mean_function=mean, kernel=kernel)
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
opt_posterior = return_optimised_posterior(D, prior, subkey)

# Draw a sample from the posterior, and find the minimiser of it
Expand All @@ -561,7 +561,8 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
# Evaluate the black-box function at the best point observed so far, and add it to the dataset
y_star = six_hump_camel(x_star)
print(
f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value: {y_star}"
f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value:"
f" {y_star}"
)
D = D + gpx.Dataset(X=x_star, y=y_star)
bo_experiment_results.append(D)
Expand Down
12 changes: 6 additions & 6 deletions docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@
# choose a Bernoulli likelihood with a probit link function.

# %%
kernel = gpx.RBF()
meanf = gpx.Constant()
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.Bernoulli(num_datapoints=D.n)
kernel = gpx.kernels.RBF()
meanf = gpx.mean_functions.Constant()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n)

# %% [markdown]
# We construct the posterior through the product of our prior and likelihood.
Expand All @@ -116,7 +116,7 @@
# Optax's optimisers.

# %%
negative_lpd = jax.jit(gpx.LogPosteriorDensity(negative=True))
negative_lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=True))

optimiser = ox.adam(learning_rate=0.01)

Expand Down Expand Up @@ -345,7 +345,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
num_adapt = 500
num_samples = 500

lpd = jax.jit(gpx.LogPosteriorDensity(negative=False))
lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=False))
unconstrained_lpd = jax.jit(lambda tree: lpd(tree.constrain(), D))

adapt = blackjax.window_adaptation(
Expand Down
24 changes: 13 additions & 11 deletions docs/examples/collapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@
# this, it is intractable to evaluate.

# %%
meanf = gpx.Constant()
kernel = gpx.RBF()
likelihood = gpx.Gaussian(num_datapoints=D.n)
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
meanf = gpx.mean_functions.Constant()
kernel = gpx.kernels.RBF()
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
posterior = prior * likelihood

# %% [markdown]
Expand All @@ -119,15 +119,17 @@
# inducing points into the constructor as arguments.

# %%
q = gpx.CollapsedVariationalGaussian(posterior=posterior, inducing_inputs=z)
q = gpx.variational_families.CollapsedVariationalGaussian(
posterior=posterior, inducing_inputs=z
)

# %% [markdown]
# We define our variational inference algorithm through `CollapsedVI`. This defines
# the collapsed variational free energy bound considered in
# <strong data-cite="titsias2009">Titsias (2009)</strong>.

# %%
elbo = gpx.CollapsedELBO(negative=True)
elbo = gpx.objectives.CollapsedELBO(negative=True)

# %% [markdown]
# For researchers, GPJax has the capacity to print the bibtex citation for objects such
Expand Down Expand Up @@ -241,14 +243,14 @@
# full model.

# %%
full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian(
num_datapoints=D.n
)
negative_mll = jit(gpx.ConjugateMLL(negative=True).step)
full_rank_model = gpx.gps.Prior(
mean_function=gpx.mean_functions.Zero(), kernel=gpx.kernels.RBF()
) * gpx.likelihoods.Gaussian(num_datapoints=D.n)
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True).step)
# %timeit negative_mll(full_rank_model, D).block_until_ready()

# %%
negative_elbo = jit(gpx.CollapsedELBO(negative=True).step)
negative_elbo = jit(gpx.objectives.CollapsedELBO(negative=True).step)
# %timeit negative_elbo(q, D).block_until_ready()

# %% [markdown]
Expand Down
8 changes: 4 additions & 4 deletions docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
meanf = gpx.mean_functions.Zero()

for k, ax in zip(kernels, axes.ravel()):
prior = gpx.Prior(mean_function=meanf, kernel=k)
prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))
ax.plot(x, y.T, alpha=0.7)
Expand Down Expand Up @@ -263,13 +263,13 @@ def __call__(
# Define polar Gaussian process
PKern = Polar()
meanf = gpx.mean_functions.Zero()
likelihood = gpx.Gaussian(num_datapoints=n)
circular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
circular_posterior = gpx.gps.Prior(mean_function=meanf, kernel=PKern) * likelihood

# Optimise GP's marginal log-likelihood using BFGS
opt_posterior, history = gpx.fit_scipy(
model=circular_posterior,
objective=jit(gpx.ConjugateMLL(negative=True)),
objective=jit(gpx.objectives.ConjugateMLL(negative=True)),
train_data=D,
)

Expand Down
10 changes: 5 additions & 5 deletions docs/examples/decision_making.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
# mean function and kernel for the job at hand:

# %%
mean = gpx.Zero()
kernel = gpx.Matern52()
prior = gpx.Prior(mean_function=mean, kernel=kernel)
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)

# %% [markdown]
# One difference from GPJax is the way in which we define our likelihood. In GPJax, we
Expand All @@ -153,7 +153,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
# with the correct number of datapoints:

# %%
likelihood_builder = lambda n: gpx.Gaussian(
likelihood_builder = lambda n: gpx.likelihoods.Gaussian(
num_datapoints=n, obs_stddev=jnp.array(1e-3)
)

Expand All @@ -174,7 +174,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
posterior_handler = PosteriorHandler(
prior,
likelihood_builder=likelihood_builder,
optimization_objective=gpx.ConjugateMLL(negative=True),
optimization_objective=gpx.objectives.ConjugateMLL(negative=True),
optimizer=ox.adam(learning_rate=0.01),
num_optimization_iters=1000,
)
Expand Down
10 changes: 5 additions & 5 deletions docs/examples/deep_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,16 @@ def __call__(self, x):
# kernel and assume a Gaussian likelihood.

# %%
base_kernel = gpx.Matern52(
base_kernel = gpx.kernels.Matern52(
active_dims=list(range(feature_space_dim)),
lengthscale=jnp.ones((feature_space_dim,)),
)
kernel = DeepKernelFunction(
network=forward_linear, base_kernel=base_kernel, key=key, dummy_x=x
)
meanf = gpx.Zero()
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.Gaussian(num_datapoints=D.n)
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
posterior = prior * likelihood
# %% [markdown]
# ### Optimisation
Expand Down Expand Up @@ -207,7 +207,7 @@ def __call__(self, x):

opt_posterior, history = gpx.fit(
model=posterior,
objective=jax.jit(gpx.ConjugateMLL(negative=True)),
objective=jax.jit(gpx.objectives.ConjugateMLL(negative=True)),
train_data=D,
optim=optimiser,
num_iters=800,
Expand Down
10 changes: 5 additions & 5 deletions docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@
# %%
x = jnp.arange(G.number_of_nodes()).reshape(-1, 1)

true_kernel = gpx.GraphKernel(
true_kernel = gpx.kernels.GraphKernel(
laplacian=L,
lengthscale=2.3,
variance=3.2,
smoothness=6.1,
)
prior = gpx.Prior(mean_function=gpx.Zero(), kernel=true_kernel)
prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=true_kernel)

fx = prior(x)
y = fx.sample(seed=key, sample_shape=(1,)).reshape(-1, 1)
Expand Down Expand Up @@ -136,9 +136,9 @@
# We do this using the BFGS optimiser provided in `scipy` via 'jaxopt'.

# %%
likelihood = gpx.Gaussian(num_datapoints=D.n)
kernel = gpx.GraphKernel(laplacian=L)
prior = gpx.Prior(mean_function=gpx.Zero(), kernel=kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
kernel = gpx.kernels.GraphKernel(laplacian=L)
prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=kernel)
posterior = prior * likelihood

# %% [markdown]
Expand Down
Loading
Loading