Skip to content

Commit

Permalink
Add pymc support to predictive explorer (arviz-devs#450)
Browse files Browse the repository at this point in the history
* Add pymc support to predictive explorer

* Clean and correct py version

* Add pymc to requirements-docs.txt

* Shift pymc requirement as optional

* Disable logging for pymc sampling, Update plot helper
  • Loading branch information
rohanbabbar04 authored Jun 3, 2024
1 parent 2a44f9d commit e6ea72e
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 27 deletions.
13 changes: 13 additions & 0 deletions preliz/internal/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import logging

from contextlib import contextmanager


@contextmanager
def disable_pymc_sampling_logs(logger: logging.Logger = logging.getLogger("pymc")):
effective_level = logger.getEffectiveLevel()
logger.setLevel(logging.ERROR)
try:
yield
finally:
logger.setLevel(effective_level)
24 changes: 22 additions & 2 deletions preliz/internal/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def inspect_source(fmodel):
return source, signature


def parse_function_for_pred_textboxes(source, signature):
def parse_function_for_pred_textboxes(source, signature, engine="preliz"):
model = {}

slidify = list(signature.parameters.keys())
Expand All @@ -29,7 +29,10 @@ def parse_function_for_pred_textboxes(source, signature):
for match in matches:
dist_name_str = match.group(2)
arguments = [s.strip() for s in match.group(3).split(",")]
args = parse_arguments(arguments, regex)
if engine == "pymc":
args = pymc_parse_arguments(arguments, regex)
else:
args = parse_arguments(arguments, regex)
for arg in args:
if arg:
func, var, idx = arg
Expand All @@ -56,6 +59,23 @@ def parse_arguments(lst, regex):
return result


def pymc_parse_arguments(lst, regex):
result = []
for idx, item in enumerate(lst):
match = re.search(regex, item)
if match:
if item.isidentifier():
result.append((None, match.group(0), idx - 1))
else:
if "**" in item:
power = item.split("**")[1].strip()
result.append((power, match.group(0), idx - 1))
else:
func = item.split("(")[0].split(".")[-1]
result.append((func, match.group(0), idx - 1))
return result


def get_prior_pp_samples(fmodel, variables, draws, engine=None, values=None):
if values is None:
values = []
Expand Down
73 changes: 72 additions & 1 deletion preliz/internal/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
try:
from IPython import get_ipython
from ipywidgets import FloatSlider, IntSlider, FloatText, IntText, Checkbox, ToggleButton
from pymc import sample_prior_predictive
except ImportError:
pass

from arviz import plot_kde, plot_ecdf, hdi
from arviz import plot_kde, plot_ecdf, hdi, extract
from arviz.stats.density_utils import _kde_linear
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import _pylab_helpers, get_backend
from matplotlib.ticker import MaxNLocator
from .logging import disable_pymc_sampling_logs


def plot_pointinterval(distribution, interval="hdi", levels=None, rotated=False, ax=None):
Expand Down Expand Up @@ -425,6 +427,33 @@ def looper(*args, **kwargs):
return looper


def pymc_plot_decorator(func, iterations, kind_plot, references, plot_func):
def looper(*args, **kwargs):
results = []
kwargs.pop("__resample__")
x_min = kwargs.pop("__x_min__")
x_max = kwargs.pop("__x_max__")
if not kwargs.pop("__set_xlim__"):
x_min = None
x_max = None
auto = True
else:
auto = False
with func(*args, **kwargs) as model:
obs_name = model.observed_RVs[0].name
with disable_pymc_sampling_logs():
idata = sample_prior_predictive(samples=iterations)
results = extract(idata, group="prior_predictive")[obs_name].values
_, ax = plt.subplots()
ax.set_xlim(x_min, x_max, auto=auto)
if plot_func is None:
pymc_plot_repr(results, kind_plot, references, iterations, ax)
else:
plot_func(results, ax)

return looper


def plot_repr(results, kind_plot, references, iterations, ax):
alpha = max(0.01, 1 - iterations * 0.009)

Expand Down Expand Up @@ -467,6 +496,48 @@ def plot_repr(results, kind_plot, references, iterations, ax):
plot_references(references, ax)


def pymc_plot_repr(results, kind_plot, references, iterations, ax):
alpha = max(0.01, 1 - iterations * 0.009)

if kind_plot == "hist":
if results[0].dtype.kind == "i":
bins = np.arange(np.min(results), np.max(results) + 1.5) - 0.5
if len(bins) < 30:
ax.set_xticks(bins + 0.5)
else:
bins = "auto"
ax.hist(
results,
alpha=alpha,
density=True,
color=["0.5"] * iterations,
bins=bins,
histtype="step",
)
ax.hist(
np.concatenate(results),
density=True,
bins=bins,
color="k",
ls="--",
histtype="step",
)
elif kind_plot == "kde":
for result in results:
ax.plot(*_kde_linear(result, grid_len=100), "0.5", alpha=alpha)
ax.plot(*_kde_linear(np.concatenate(results), grid_len=100), "k--")
elif kind_plot == "ecdf":
ax.plot(
np.sort(results, axis=1).T,
np.linspace(0, 1, len(results[0]), endpoint=False),
color="0.5",
)
a = np.concatenate(results)
ax.plot(np.sort(a), np.linspace(0, 1, len(a), endpoint=False), "k--")

plot_references(references, ax)


def plot_pp_samples(pp_samples, pp_samples_idxs, references, kind="pdf", sharex=True, fig=None):
row_colum = int(np.ceil(len(pp_samples_idxs) ** 0.5))

Expand Down
21 changes: 13 additions & 8 deletions preliz/predictive/predictive_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
except ImportError:
pass
from preliz.internal.parser import inspect_source, parse_function_for_pred_textboxes
from preliz.internal.plot_helper import get_textboxes, plot_decorator
from preliz.internal.plot_helper import get_textboxes, plot_decorator, pymc_plot_decorator


def predictive_explorer(fmodel, samples=50, kind_plot="ecdf", references=None, plot_func=None):
def predictive_explorer(
fmodel, samples=50, kind_plot="ecdf", references=None, plot_func=None, engine="preliz"
):
"""
Create textboxes and plot a set of samples returned by a function relating one or more
PreliZ distributions.
Expand All @@ -18,7 +20,8 @@ def predictive_explorer(fmodel, samples=50, kind_plot="ecdf", references=None, p
Parameters
----------
fmodel : callable
A function with PreliZ distributions. The distributions should call their rvs methods.
A function with PreliZ distributions or PyMC distributions, depending on the selected
engine. The PreliZ distributions should call their rvs method.
samples : int, optional
The number of samples to draw from the prior predictive distribution (default is 50).
kind_plot : str, optional
Expand All @@ -30,14 +33,16 @@ def predictive_explorer(fmodel, samples=50, kind_plot="ecdf", references=None, p
plot_func : function
Custom matplotlib code. Defaults to None. ``kind_plot`` and ``references`` are ignored
if ``plot_func`` is specified.
engine : str, optional
Library used to define the fmodel. Either `pymc` or `preliz`. Default to `preliz`.
"""
source, signature = inspect_source(fmodel)

model = parse_function_for_pred_textboxes(source, signature)
model = parse_function_for_pred_textboxes(source, signature, engine)
textboxes = get_textboxes(signature, model)

new_fmodel = plot_decorator(fmodel, samples, kind_plot, references, plot_func)

if engine == "pymc":
new_fmodel = pymc_plot_decorator(fmodel, samples, kind_plot, references, plot_func)
else:
new_fmodel = plot_decorator(fmodel, samples, kind_plot, references, plot_func)
out = interactive_output(new_fmodel, textboxes)
default_names = ["__set_xlim__", "__x_min__", "__x_max__", "__resample__"]
default_controls = [textboxes[name] for name in default_names]
Expand Down
43 changes: 27 additions & 16 deletions preliz/tests/predictive_explorer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"\n",
"import numpy as np\n",
"import arviz as az\n",
"import pymc as pm\n",
"from preliz.distributions import Normal, Gamma\n",
"from preliz import predictive_explorer"
]
Expand Down Expand Up @@ -45,25 +46,14 @@
" (10, \"ecdf\"),\n",
"])\n",
"def test_predictive_explorer(model, iterations, kind_plot):\n",
" result = predictive_explorer(model, iterations, kind_plot)\n",
" result._ipython_display_()\n",
" slider0, slider1, slider2, plot_data = result.children\n",
" slider0.value = -4\n",
" slider1.value = 0.3\n",
" slider2[2].value = 0.1\n",
" assert 'image/png' in plot_data.outputs[0][\"data\"]\n",
" predictive_explorer(model, iterations, kind_plot)\n",
"\n",
"def lin_reg(predictions, ax):\n",
" ax.plot(x, predictions.T, \"k.\")\n",
"\n",
"def test_predictive_explorer_custom_plot(model, iterations, lin_reg):\n",
" result = predictive_explorer(model, iterations, plot_func=lin_reg)\n",
" result._ipython_display_()\n",
" slider0, slider1, slider2, plot_data = result.children\n",
" slider0.value = -4\n",
" slider1.value = 0.3\n",
" slider2[2].value = 0.1\n",
" assert 'image/png' in plot_data.outputs[0][\"data\"]"
"def test_predictive_explorer_custom_plot(model):\n",
" predictive_explorer(model, 50, plot_func=lin_reg)\n",
" "
]
},
{
Expand All @@ -72,7 +62,28 @@
"id": "e006886c",
"metadata": {},
"outputs": [],
"source": []
"source": [
"%%ipytest\n",
"\n",
"@pytest.fixture\n",
"def model():\n",
" def a_pymc_model(a_mu, b_sigma=1):\n",
" with pm.Model() as model:\n",
" a = pm.Normal(\"a\", a_mu, 1)\n",
" b = pm.HalfNormal(\"b\", b_sigma)\n",
" c = pm.Normal(\"c\", a, b, observed=[0]*100)\n",
" return model\n",
" return a_pymc_model\n",
" \n",
" \n",
"@pytest.mark.parametrize(\"iterations, kind_plot\", [\n",
" (50, \"hist\"),\n",
" (10, \"kde\"),\n",
" (10, \"ecdf\"),\n",
"])\n",
"def test_predictive_explorer(model, iterations, kind_plot):\n",
" predictive_explorer(model, iterations, kind_plot, engine=\"pymc\")"
]
}
],
"metadata": {
Expand Down

0 comments on commit e6ea72e

Please sign in to comment.