diff --git a/preliz/internal/logging.py b/preliz/internal/logging.py new file mode 100644 index 00000000..9f248ff4 --- /dev/null +++ b/preliz/internal/logging.py @@ -0,0 +1,13 @@ +import logging + +from contextlib import contextmanager + + +@contextmanager +def disable_pymc_sampling_logs(logger: logging.Logger = logging.getLogger("pymc")): + effective_level = logger.getEffectiveLevel() + logger.setLevel(logging.ERROR) + try: + yield + finally: + logger.setLevel(effective_level) diff --git a/preliz/internal/parser.py b/preliz/internal/parser.py index 715268f3..bf08b629 100644 --- a/preliz/internal/parser.py +++ b/preliz/internal/parser.py @@ -17,7 +17,7 @@ def inspect_source(fmodel): return source, signature -def parse_function_for_pred_textboxes(source, signature): +def parse_function_for_pred_textboxes(source, signature, engine="preliz"): model = {} slidify = list(signature.parameters.keys()) @@ -29,7 +29,10 @@ def parse_function_for_pred_textboxes(source, signature): for match in matches: dist_name_str = match.group(2) arguments = [s.strip() for s in match.group(3).split(",")] - args = parse_arguments(arguments, regex) + if engine == "pymc": + args = pymc_parse_arguments(arguments, regex) + else: + args = parse_arguments(arguments, regex) for arg in args: if arg: func, var, idx = arg @@ -56,6 +59,23 @@ def parse_arguments(lst, regex): return result +def pymc_parse_arguments(lst, regex): + result = [] + for idx, item in enumerate(lst): + match = re.search(regex, item) + if match: + if item.isidentifier(): + result.append((None, match.group(0), idx - 1)) + else: + if "**" in item: + power = item.split("**")[1].strip() + result.append((power, match.group(0), idx - 1)) + else: + func = item.split("(")[0].split(".")[-1] + result.append((func, match.group(0), idx - 1)) + return result + + def get_prior_pp_samples(fmodel, variables, draws, engine=None, values=None): if values is None: values = [] diff --git a/preliz/internal/plot_helper.py b/preliz/internal/plot_helper.py index 616639b5..b88e5ffb 100644 --- a/preliz/internal/plot_helper.py +++ b/preliz/internal/plot_helper.py @@ -5,15 +5,17 @@ try: from IPython import get_ipython from ipywidgets import FloatSlider, IntSlider, FloatText, IntText, Checkbox, ToggleButton + from pymc import sample_prior_predictive except ImportError: pass -from arviz import plot_kde, plot_ecdf, hdi +from arviz import plot_kde, plot_ecdf, hdi, extract from arviz.stats.density_utils import _kde_linear import numpy as np import matplotlib.pyplot as plt from matplotlib import _pylab_helpers, get_backend from matplotlib.ticker import MaxNLocator +from .logging import disable_pymc_sampling_logs def plot_pointinterval(distribution, interval="hdi", levels=None, rotated=False, ax=None): @@ -425,6 +427,33 @@ def looper(*args, **kwargs): return looper +def pymc_plot_decorator(func, iterations, kind_plot, references, plot_func): + def looper(*args, **kwargs): + results = [] + kwargs.pop("__resample__") + x_min = kwargs.pop("__x_min__") + x_max = kwargs.pop("__x_max__") + if not kwargs.pop("__set_xlim__"): + x_min = None + x_max = None + auto = True + else: + auto = False + with func(*args, **kwargs) as model: + obs_name = model.observed_RVs[0].name + with disable_pymc_sampling_logs(): + idata = sample_prior_predictive(samples=iterations) + results = extract(idata, group="prior_predictive")[obs_name].values + _, ax = plt.subplots() + ax.set_xlim(x_min, x_max, auto=auto) + if plot_func is None: + pymc_plot_repr(results, kind_plot, references, iterations, ax) + else: + plot_func(results, ax) + + return looper + + def plot_repr(results, kind_plot, references, iterations, ax): alpha = max(0.01, 1 - iterations * 0.009) @@ -467,6 +496,48 @@ def plot_repr(results, kind_plot, references, iterations, ax): plot_references(references, ax) +def pymc_plot_repr(results, kind_plot, references, iterations, ax): + alpha = max(0.01, 1 - iterations * 0.009) + + if kind_plot == "hist": + if results[0].dtype.kind == "i": + bins = np.arange(np.min(results), np.max(results) + 1.5) - 0.5 + if len(bins) < 30: + ax.set_xticks(bins + 0.5) + else: + bins = "auto" + ax.hist( + results, + alpha=alpha, + density=True, + color=["0.5"] * iterations, + bins=bins, + histtype="step", + ) + ax.hist( + np.concatenate(results), + density=True, + bins=bins, + color="k", + ls="--", + histtype="step", + ) + elif kind_plot == "kde": + for result in results: + ax.plot(*_kde_linear(result, grid_len=100), "0.5", alpha=alpha) + ax.plot(*_kde_linear(np.concatenate(results), grid_len=100), "k--") + elif kind_plot == "ecdf": + ax.plot( + np.sort(results, axis=1).T, + np.linspace(0, 1, len(results[0]), endpoint=False), + color="0.5", + ) + a = np.concatenate(results) + ax.plot(np.sort(a), np.linspace(0, 1, len(a), endpoint=False), "k--") + + plot_references(references, ax) + + def plot_pp_samples(pp_samples, pp_samples_idxs, references, kind="pdf", sharex=True, fig=None): row_colum = int(np.ceil(len(pp_samples_idxs) ** 0.5)) diff --git a/preliz/predictive/predictive_explorer.py b/preliz/predictive/predictive_explorer.py index f53c7340..416b1c2b 100644 --- a/preliz/predictive/predictive_explorer.py +++ b/preliz/predictive/predictive_explorer.py @@ -4,10 +4,12 @@ except ImportError: pass from preliz.internal.parser import inspect_source, parse_function_for_pred_textboxes -from preliz.internal.plot_helper import get_textboxes, plot_decorator +from preliz.internal.plot_helper import get_textboxes, plot_decorator, pymc_plot_decorator -def predictive_explorer(fmodel, samples=50, kind_plot="ecdf", references=None, plot_func=None): +def predictive_explorer( + fmodel, samples=50, kind_plot="ecdf", references=None, plot_func=None, engine="preliz" +): """ Create textboxes and plot a set of samples returned by a function relating one or more PreliZ distributions. @@ -18,7 +20,8 @@ def predictive_explorer(fmodel, samples=50, kind_plot="ecdf", references=None, p Parameters ---------- fmodel : callable - A function with PreliZ distributions. The distributions should call their rvs methods. + A function with PreliZ distributions or PyMC distributions, depending on the selected + engine. The PreliZ distributions should call their rvs method. samples : int, optional The number of samples to draw from the prior predictive distribution (default is 50). kind_plot : str, optional @@ -30,14 +33,16 @@ def predictive_explorer(fmodel, samples=50, kind_plot="ecdf", references=None, p plot_func : function Custom matplotlib code. Defaults to None. ``kind_plot`` and ``references`` are ignored if ``plot_func`` is specified. + engine : str, optional + Library used to define the fmodel. Either `pymc` or `preliz`. Default to `preliz`. """ source, signature = inspect_source(fmodel) - - model = parse_function_for_pred_textboxes(source, signature) + model = parse_function_for_pred_textboxes(source, signature, engine) textboxes = get_textboxes(signature, model) - - new_fmodel = plot_decorator(fmodel, samples, kind_plot, references, plot_func) - + if engine == "pymc": + new_fmodel = pymc_plot_decorator(fmodel, samples, kind_plot, references, plot_func) + else: + new_fmodel = plot_decorator(fmodel, samples, kind_plot, references, plot_func) out = interactive_output(new_fmodel, textboxes) default_names = ["__set_xlim__", "__x_min__", "__x_max__", "__resample__"] default_controls = [textboxes[name] for name in default_names] diff --git a/preliz/tests/predictive_explorer.ipynb b/preliz/tests/predictive_explorer.ipynb index 3fb78283..37af7afc 100644 --- a/preliz/tests/predictive_explorer.ipynb +++ b/preliz/tests/predictive_explorer.ipynb @@ -14,6 +14,7 @@ "\n", "import numpy as np\n", "import arviz as az\n", + "import pymc as pm\n", "from preliz.distributions import Normal, Gamma\n", "from preliz import predictive_explorer" ] @@ -45,25 +46,14 @@ " (10, \"ecdf\"),\n", "])\n", "def test_predictive_explorer(model, iterations, kind_plot):\n", - " result = predictive_explorer(model, iterations, kind_plot)\n", - " result._ipython_display_()\n", - " slider0, slider1, slider2, plot_data = result.children\n", - " slider0.value = -4\n", - " slider1.value = 0.3\n", - " slider2[2].value = 0.1\n", - " assert 'image/png' in plot_data.outputs[0][\"data\"]\n", + " predictive_explorer(model, iterations, kind_plot)\n", "\n", "def lin_reg(predictions, ax):\n", " ax.plot(x, predictions.T, \"k.\")\n", "\n", - "def test_predictive_explorer_custom_plot(model, iterations, lin_reg):\n", - " result = predictive_explorer(model, iterations, plot_func=lin_reg)\n", - " result._ipython_display_()\n", - " slider0, slider1, slider2, plot_data = result.children\n", - " slider0.value = -4\n", - " slider1.value = 0.3\n", - " slider2[2].value = 0.1\n", - " assert 'image/png' in plot_data.outputs[0][\"data\"]" + "def test_predictive_explorer_custom_plot(model):\n", + " predictive_explorer(model, 50, plot_func=lin_reg)\n", + " " ] }, { @@ -72,7 +62,28 @@ "id": "e006886c", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "%%ipytest\n", + "\n", + "@pytest.fixture\n", + "def model():\n", + " def a_pymc_model(a_mu, b_sigma=1):\n", + " with pm.Model() as model:\n", + " a = pm.Normal(\"a\", a_mu, 1)\n", + " b = pm.HalfNormal(\"b\", b_sigma)\n", + " c = pm.Normal(\"c\", a, b, observed=[0]*100)\n", + " return model\n", + " return a_pymc_model\n", + " \n", + " \n", + "@pytest.mark.parametrize(\"iterations, kind_plot\", [\n", + " (50, \"hist\"),\n", + " (10, \"kde\"),\n", + " (10, \"ecdf\"),\n", + "])\n", + "def test_predictive_explorer(model, iterations, kind_plot):\n", + " predictive_explorer(model, iterations, kind_plot, engine=\"pymc\")" + ] } ], "metadata": {