diff --git a/neuralprocesses/model/predict.py b/neuralprocesses/model/predict.py index d32c15ee..61a8673a 100644 --- a/neuralprocesses/model/predict.py +++ b/neuralprocesses/model/predict.py @@ -18,7 +18,7 @@ def predict( *, num_samples=50, batch_size=16, - dtype_lik=None + dtype_lik=None, ): """Use a model to predict.