Skip to content

Commit

Permalink
Merge pull request #21 from DrJonnyT/predict-mps
Browse files Browse the repository at this point in the history
Add dtype_lik option to nps.predict
  • Loading branch information
wesselb authored Apr 20, 2024
2 parents 56b673b + ed4c08c commit b1ef055
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions neuralprocesses/model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def predict(
*,
num_samples=50,
batch_size=16,
dtype_lik=None
):
"""Use a model to predict.
Expand All @@ -29,6 +30,8 @@ def predict(
xt (input): Inputs of the target set.
num_samples (int, optional): Number of samples to produce. Defaults to 50.
batch_size (int, optional): Batch size. Defaults to 16.
dtype_lik (dtype, optional): Data type to use for the likelihood computation.
Defaults to the 64-bit variant of the data type of `xt`.
Returns:
random state, optional: Random state.
Expand All @@ -38,7 +41,11 @@ def predict(
tensor: `num_samples` noisy samples.
"""
float = B.dtype_float(xt)
float64 = B.promote_dtypes(float, np.float64)

# For the likelihood computation, default to using a 64-bit version of the data
# type of `xt`.
if not dtype_lik:
dtype_lik = B.promote_dtypes(float, np.float64)

# Collect noiseless samples, noisy samples, first moments, and second moments.
ft, yt = [], []
Expand All @@ -54,8 +61,7 @@ def predict(
contexts,
xt,
dtype_enc_sample=float,
# Run likelihood with `float64`s to ease the numerics as much as possible.
dtype_lik=float64,
dtype_lik=dtype_lik,
num_samples=this_num_samples,
)

Expand Down

0 comments on commit b1ef055

Please sign in to comment.