Skip to content

Commit

Permalink
Generalize lengend to accept arbitrary strings (#644)
Browse files Browse the repository at this point in the history
* generalize lengend to accept arbitrary strings

* fix lint issues
  • Loading branch information
aloctavodia authored Feb 13, 2025
1 parent 1209225 commit 7caa9bc
Show file tree
Hide file tree
Showing 24 changed files with 84 additions and 77 deletions.
1 change: 1 addition & 0 deletions preliz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Exploring and eliciting probability distributions
"""

from os import path as os_path

from matplotlib import rcParams as mpl_rcParams
Expand Down
1 change: 1 addition & 0 deletions preliz/distributions/betabinomial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""BetaBinomial probability distribution."""

import numba as nb
import numpy as np

Expand Down
1 change: 1 addition & 0 deletions preliz/distributions/continuous_multivariate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Continuous multivariate probability distributions."""

from copy import copy

import numpy as np
Expand Down
7 changes: 5 additions & 2 deletions preliz/distributions/distributions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Parent classes for all families."""

import warnings
from collections import namedtuple
from copy import copy
Expand Down Expand Up @@ -68,7 +69,7 @@ def params_dict(self):

def summary(self, mass=None, interval=None, fmt=".2f"):
"""
Namedtuple with the mean, median, standard deviation, and lower and upper bounds of the equal-tailed interval.
Namedtuple with the mean, median, sd, and lower and upper bounds.
Parameters
----------
Expand Down Expand Up @@ -459,7 +460,9 @@ def _finite_endpoints(self, support):

def xvals(self, support, n_points=None):
"""
Provide x values in the support of the distribution. This is useful for example when plotting.
Provide x values in the support of the distribution.
This is useful for example when plotting.
Parameters
----------
Expand Down
11 changes: 9 additions & 2 deletions preliz/distributions/distributions_multivariate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Parent classes for multivariate families."""

from collections import namedtuple

import numpy as np
Expand Down Expand Up @@ -150,7 +151,9 @@ def __init__(self):

def xvals(self, support):
"""
Provide x values in the support of the distribution. This is useful for example when plotting.
Provide x values in the support of the distribution.
This is useful for example when plotting.
Parameters
----------
Expand Down Expand Up @@ -193,7 +196,11 @@ def __init__(self):
self.kind = "discrete"

def xvals(self, support):
"""Provide x values in the support of the distribution. This is useful for example when plotting."""
"""
Provide x values in the support of the distribution.
This is useful for example when plotting.
"""
lower_ep, upper_ep = self._finite_endpoints(support)
x_vals = np.arange(lower_ep, upper_ep + 1, dtype=int)
return x_vals
Expand Down
1 change: 0 additions & 1 deletion preliz/distributions/hurdle.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def pdf(self, x):
)

def cdf(self, x):

if self.dist == "discrete":
return np.where(
x <= 0, 1 - self.psi, 1 - self.psi * (1 - self.dist.cdf(x)) / (1 - self.dist.cdf(0))
Expand Down
5 changes: 3 additions & 2 deletions preliz/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def _parametrization(self, weights=None):
self.param_names.append("weights")
self.params_support.append((0, 1))

self.support = np.min([dist.support[0] for dist in self.dist]), np.max(
[dist.support[1] for dist in self.dist]
self.support = (
np.min([dist.support[0] for dist in self.dist]),
np.max([dist.support[1] for dist in self.dist]),
)
self.weights = np.asarray(weights)
self.weights = self.weights / np.sum(self.weights)
Expand Down
4 changes: 2 additions & 2 deletions preliz/distributions/rice.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def entropy(self):
return -np.trapz(np.exp(logpdf) * logpdf, x_values)

def mean(self):
return self.sigma * np.sqrt(np.pi / 2) * _l_half(-self.nu**2 / (2 * self.sigma**2))
return self.sigma * np.sqrt(np.pi / 2) * _l_half(-(self.nu**2) / (2 * self.sigma**2))

def median(self):
return self.ppf(0.5)
Expand All @@ -136,7 +136,7 @@ def var(self):
return (
2 * self.sigma**2
+ self.nu**2
- np.pi / 2 * self.sigma**2 * _l_half(-self.nu**2 / (2 * self.sigma**2)) ** 2
- np.pi / 2 * self.sigma**2 * _l_half(-(self.nu**2) / (2 * self.sigma**2)) ** 2
)

def std(self):
Expand Down
1 change: 0 additions & 1 deletion preliz/internal/distribution_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def num_kurtosis(dist):


def get_distributions(dist_names=None):

if dist_names is None:
all_distributions = modules["preliz.distributions"].__all__
else:
Expand Down
1 change: 1 addition & 0 deletions preliz/internal/narviz.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions originally imported from ArviZ."""

import warnings

import numpy as np
Expand Down
6 changes: 4 additions & 2 deletions preliz/internal/optimization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Optimization routines and utilities."""

import warnings
from copy import copy

Expand Down Expand Up @@ -299,9 +300,10 @@ def relative_error(dist, lower, upper, required_mass):

def fit_to_epdf(selected_distributions, x_vals, epdf, mean, std, x_min, x_max, extra_pros):
"""
Minimize the difference between the pdf and the epdf over a grid of values defined by x_min and x_max.
Minimize the difference between the pdf and the epdf.
Note: This function is intended to be used with pz.roulette
Minimization is done over a grid of values defined by x_min and x_max. This function is
intended to be used with pz.roulette
"""
fitted = Loss(len(selected_distributions))
for dist in selected_distributions:
Expand Down
76 changes: 31 additions & 45 deletions preliz/internal/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,7 @@ def plot_pdfpmf(
dist, moments, pointinterval, interval, levels, support, legend, color, alpha, figsize, ax
):
ax = get_ax(ax, figsize)
if legend is not None:
label = repr_to_matplotlib(dist)

if moments is not None:
label += get_moments(dist, moments)

if legend == "title":
ax.set_title(label)
label = None
else:
label = None
label = set_label(dist, legend, moments, ax)

x = dist.xvals(support)
if dist.kind == "continuous":
Expand Down Expand Up @@ -181,8 +171,7 @@ def plot_pdfpmf(
if pointinterval:
plot_pointinterval(dist, interval, levels, ax=ax)

if legend == "legend":
side_legend(ax)
side_legend(legend, ax)

return ax

Expand All @@ -191,17 +180,7 @@ def plot_cdf(
dist, moments, pointinterval, interval, levels, support, legend, color, alpha, figsize, ax
):
ax = get_ax(ax, figsize)
if legend is not None:
label = repr_to_matplotlib(dist)

if moments is not None:
label += get_moments(dist, moments)

if legend == "title":
ax.set_title(label)
label = None
else:
label = None
label = set_label(dist, legend, moments, ax)

ax.set_ylim(-0.05, 1.05)
x = dist.xvals(support)
Expand All @@ -221,24 +200,14 @@ def plot_cdf(
if pointinterval:
plot_pointinterval(dist, interval, levels, ax=ax)

if legend == "legend":
side_legend(ax)
side_legend(legend, ax)

return ax


def plot_ppf(dist, moments, pointinterval, interval, levels, legend, color, alpha, figsize, ax):
ax = get_ax(ax, figsize)
if legend is not None:
label = repr_to_matplotlib(dist)

if moments is not None:
label += get_moments(dist, moments)

if legend == "title":
ax.set_title(label)
label = None
else:
label = None
label = set_label(dist, legend, moments, ax)

x = np.linspace(0, 1, 1000)
ax.plot(x, dist.ppf(x), label=label, color=color, alpha=alpha)
Expand All @@ -248,8 +217,8 @@ def plot_ppf(dist, moments, pointinterval, interval, levels, legend, color, alph
if pointinterval:
plot_pointinterval(dist, interval, levels, rotated=True, ax=ax)

if legend == "legend":
side_legend(ax)
side_legend(legend, ax)

return ax


Expand All @@ -263,10 +232,11 @@ def get_ax(ax, figsize):
return ax


def side_legend(ax):
bbox = ax.get_position()
ax.set_position([bbox.x0, bbox.y0, bbox.width, bbox.height])
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
def side_legend(legend, ax):
if isinstance(legend, str) and legend != "title":
bbox = ax.get_position()
ax.set_position([bbox.x0, bbox.y0, bbox.width, bbox.height])
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))


def repr_to_matplotlib(distribution):
Expand Down Expand Up @@ -297,7 +267,6 @@ def get_moments(dist, moments):


def get_slider(name, value, lower, upper):

min_v, max_v, step = generate_range(value, lower, upper)

if isinstance(value, float):
Expand Down Expand Up @@ -337,7 +306,6 @@ def generate_range(value, lower, upper):


def get_boxes(name, value, lower, upper):

if isinstance(value, float):
text_type = FloatText
step = 0.1
Expand Down Expand Up @@ -654,3 +622,21 @@ def reset_dist_panel(ax, yticks):
ax.set_yticks([])
ax.relim()
ax.autoscale_view()


def set_label(dist, legend, moments, ax):
if legend is not None:
if isinstance(legend, str) and legend not in ["title", "legend"]:
label = legend
else:
label = repr_to_matplotlib(dist)

if moments is not None:
label += get_moments(dist, moments)

if legend == "title":
ax.set_title(label)
label = None
else:
label = None
return label
4 changes: 3 additions & 1 deletion preliz/internal/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,9 @@ def betaincinv(a, b, p):
@nb.njit(cache=True)
def garcia_approximation(mean, sigma):
"""
Approximate method of moments for Weibull distribution, provides good results for values of alpha larger than 0.83.
Approximate method of moments for Weibull distribution.
The approximation is good for values of alpha larger than 0.83.
Oscar Garcia. Simplified method-of-moments estimation for the Weibull distribution. 1981.
New Zealand Journal of Forestry Science 11:304-306
Expand Down
5 changes: 4 additions & 1 deletion preliz/multidimensional/dirichlet_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

def dirichlet_mode(mode, mass=None, bound=0.01, plot=None, plot_kwargs=None, ax=None):
"""
Return a Dirichlet distribution where the marginals have the specified mode and mass and their masses lie within the range mode ± bound.
Elicitate a Dirichlet distribution with a given mode and mass.
Computes a Dirichlet distribution where the marginals have the specified mode
and mass and their masses lie within the range mode ± bound.
Parameters
----------
Expand Down
1 change: 0 additions & 1 deletion preliz/ppls/agnostic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Functions to communicate with PPLs."""


import inspect
import logging
import re
Expand Down
14 changes: 9 additions & 5 deletions preliz/ppls/pymc_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def back_fitting_pymc(prior, preliz_model, var_info):


def compile_mllk(model):
"""Compile the log-likelihood function for the model to be able to condition on both data and parameters."""
"""
Compile the log-likelihood for a pymc model.
The compiled function allow us to condition on both data and parameters.
"""
obs_rvs = model.observed_RVs[0]
old_y_value = model.rvs_to_values[obs_rvs]
new_y_value = obs_rvs.type()
Expand Down Expand Up @@ -108,7 +112,7 @@ def extract_preliz_distributions(model):


def retrieve_variable_info(model):
"""Get the shape, size, transformation and parents of each free random variable in a PyMC model."""
"""Get shape, size, transformation and parents of each free RV in a PyMC model."""
var_info = {}
initial_point = model.initial_point()
for v_var in model.value_vars:
Expand Down Expand Up @@ -174,9 +178,9 @@ def write_pymc_string(new_priors, var_info):
size = var_info[nkey][1]
if size > 1:
dist_params = dist_params.split(")")[0]
variables[
i
] = f' {nkey:} = pm.{dist_name}("{nkey}", {dist_params}, shape={size})\n'
# fmt: off
variables[i] = f' {nkey:} = pm.{dist_name}("{nkey}", {dist_params}, shape={size})\n' # noqa: E501
# fmt: on
else:
variables[i] = f' {nkey:} = pm.{dist_name}("{nkey}", {dist_params}\n'

Expand Down
8 changes: 3 additions & 5 deletions preliz/predictive/ppa.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,11 @@ def __init__(self, fmodel, draws, references, boundaries, target, new_families,
self.axes = None # axes used to display the pp_samples

def __call__(self):

if self.engine == "preliz":
variables, self.model = from_preliz(self.fmodel)
elif self.engine == "bambi":
self.fmodel, variables, self.model = from_bambi(self.fmodel, self.draws)

print(variables, self.model)
self.pp_samples, self.prior_samples = get_prior_pp_samples(
self.fmodel, variables, self.draws, self.engine
)
Expand All @@ -208,9 +206,10 @@ def add_target_dist(self):

def compute_octiles(self):
"""
Compute the octiles for the prior predictive samples. This is used to find similar distributions using a KDTree.
Compute the octiles for the prior predictive samples.
We have empirically found that octiles are a good choice, but this could be the consequence of limited testing.
This is used to find similar distributions using a KDTree. We have empirically found that
octiles are a good choice, but this could be the consequence of limited testing.
"""
pp_octiles = np.quantile(
self.pp_samples, [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875], axis=1
Expand Down Expand Up @@ -386,7 +385,6 @@ def carry_on(self, kind, sharex):
self.fig.canvas.draw()

def on_return_prior(self):

selected = list(self.selected)

if len(selected) > 4:
Expand Down
2 changes: 1 addition & 1 deletion preliz/predictive/predictive_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def predictive_explorer(
fmodel, samples=50, kind_plot="ecdf", references=None, plot_func=None, engine="auto"
):
"""
Create textboxes and plot a set of samples returned by a function relating one or more PreliZ distributions.
Explore how changing parameters in the prior affects the prior predictive distribution.
Use this function to interactively explore how a prior predictive distribution changes when the
priors are changed.
Expand Down
Loading

0 comments on commit 7caa9bc

Please sign in to comment.