Skip to content

Commit

Permalink
Fix broken docs (#424)
Browse files Browse the repository at this point in the history
* Fix broken docs

* Fix kernel ref
  • Loading branch information
thomaspinder authored Nov 30, 2023
1 parent 07d99db commit 55ecaac
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:

opt_posterior, _ = gpx.fit_scipy(
model=posterior,
objective=gpx.ConjugateMLL(negative=True),
objective=gpx.objectives.ConjugateMLL(negative=True),
train_data=D,
)
latent_dist = opt_posterior.predict(xtest, train_data=D)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
# %%
opt_posterior, training_history = gpx.fit(
model=posterior,
objective=gpx.ConjugateMLL(negative=True),
objective=gpx.objectives.ConjugateMLL(negative=True),
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=1000,
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/yacht.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@

# %%
n_train, n_covariates = scaled_Xtr.shape
kernel = gpx.RBF(
kernel = gpx.kernels.RBF(
active_dims=list(range(n_covariates)),
variance=np.var(scaled_ytr),
lengthscale=0.1 * np.ones((n_covariates,)),
Expand All @@ -188,7 +188,7 @@
# %%
training_data = gpx.Dataset(X=scaled_Xtr, y=scaled_ytr)

negative_mll = jit(gpx.ConjugateMLL(negative=True))
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))

opt_posterior, history = gpx.fit_scipy(
model=posterior,
Expand Down

0 comments on commit 55ecaac

Please sign in to comment.