diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 5f3c983b58..91f1c6f691 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -20,8 +20,10 @@ from collections.abc import Callable, Iterable, Sequence from typing import ( Any, + Literal, TypeAlias, cast, + overload, ) import numpy as np @@ -360,6 +362,28 @@ def observed_dependent_deterministics(model: Model, extra_observeds=None): ] +@overload +def sample_prior_predictive( + draws: int = 500, + model: Model | None = None, + var_names: Iterable[str] | None = None, + random_seed: RandomState = None, + return_inferencedata: Literal[True] = True, + idata_kwargs: dict | None = None, + compile_kwargs: dict | None = None, + samples: int | None = None, +) -> InferenceData: ... +@overload +def sample_prior_predictive( + draws: int = 500, + model: Model | None = None, + var_names: Iterable[str] | None = None, + random_seed: RandomState = None, + return_inferencedata: Literal[False] = False, + idata_kwargs: dict | None = None, + compile_kwargs: dict | None = None, + samples: int | None = None, +) -> dict[str, np.ndarray]: ... def sample_prior_predictive( draws: int = 500, model: Model | None = None, diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 7bbcdc42b1..d3b41bf667 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -857,16 +857,24 @@ def test_logging_sampled_basic_rvs_prior(self, caplog): y = pm.Deterministic("y", x + 1) z = pm.Normal("z", y, observed=0) + # all volatile RVs in model with m: pm.sample_prior_predictive(draws=1) assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x, z]")] caplog.clear() + # `x` has no dependencies so will be sampled by itself with m: pm.sample_prior_predictive(draws=1, var_names=["x"]) assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x]")] caplog.clear() + # `z` depends on `x` + with m: + pm.sample_prior_predictive(draws=1, var_names=["z"]) + assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x, z]")] + caplog.clear() + def test_logging_sampled_basic_rvs_posterior(self, caplog): with pm.Model() as m: x = pm.Normal("x")