diff --git a/bambi/backend/model_components.py b/bambi/backend/model_components.py index fadebbf26..da8dd659d 100644 --- a/bambi/backend/model_components.py +++ b/bambi/backend/model_components.py @@ -5,28 +5,40 @@ from bambi.backend.terms import CommonTerm, GroupSpecificTerm, HSGPTerm, InterceptTerm, ResponseTerm from bambi.backend.utils import get_distribution_from_prior from bambi.families.multivariate import MultivariateFamily -from bambi.families.univariate import Categorical +from bambi.families.univariate import Categorical, Cumulative, StoppingRatio from bambi.utils import get_aliased_name +ORDINAL_FAMILIES = (Cumulative, StoppingRatio) + + class ConstantComponent: def __init__(self, component): self.component = component self.output = 0 - def build(self, pymc_backend, bmb_model): # pylint: disable = unused-argument + def build(self, pymc_backend, bmb_model): + extra_args = {} + + if self.component.alias: + label = self.component.alias + else: + label = self.component.name + + if isinstance(bmb_model.family, ORDINAL_FAMILIES): + threshold_dim = label + "_dim" + threshold_values = np.arange(len(bmb_model.response_component.response_term.levels) - 1) + extra_args["dims"] = threshold_dim + pymc_backend.model.add_coords({threshold_dim: threshold_values}) + with pymc_backend.model: - if self.component.alias: - label = self.component.alias - else: - label = self.component.name - # It's set to a constant value + # Set to a constant value if isinstance(self.component.prior, (int, float)): self.output = self.component.prior # Set to a distribution else: dist = get_distribution_from_prior(self.component.prior) - self.output = dist(label, **self.component.prior.args) + self.output = dist(label, **self.component.prior.args, **extra_args) class DistributionalComponent: diff --git a/bambi/backend/terms.py b/bambi/backend/terms.py index 575080fd4..1bb59791d 100644 --- a/bambi/backend/terms.py +++ b/bambi/backend/terms.py @@ -225,13 +225,7 @@ def build(self, pymc_backend, bmb_model): data = np.squeeze(self.term.data) parent = self.family.likelihood.parent - # The linear predictor for the parent parameter (usually the mean) - nu = pymc_backend.distributional_components[self.term.name].output - - if hasattr(self.family, "transform_backend_nu"): - nu = self.family.transform_backend_nu(nu, data) - - # Add auxiliary parameters + # Auxiliary parameters kwargs = {} # Constant parameters. No link function is used. @@ -252,13 +246,21 @@ def build(self, pymc_backend, bmb_model): f"{self.name}_{aliased_name}", linkinv(component.output), dims=dims ) + # Add observed and dims + kwargs["observed"] = data + kwargs["dims"] = dims + + # The linear predictor for the parent parameter (usually the mean) + eta = pymc_backend.distributional_components[self.term.name].output + + if hasattr(self.family, "transform_backend_eta"): + eta = self.family.transform_backend_eta(eta, kwargs) + # Take the inverse link function that maps from linear predictor to the parent of likelihood linkinv = get_linkinv(self.family.link[parent], pymc_backend.INVLINKS) - # Add parent parameter and observed data. We don't need to pass dims. - kwargs[parent] = linkinv(nu) - kwargs["observed"] = data - kwargs["dims"] = dims + # Add parent parameter after the applying the linkinv transformation + kwargs[parent] = linkinv(eta) # Build the response distribution dist = self.build_response_distribution(kwargs, pymc_backend) diff --git a/bambi/backend/utils.py b/bambi/backend/utils.py index b87b97964..dfd6ed1d7 100644 --- a/bambi/backend/utils.py +++ b/bambi/backend/utils.py @@ -1,11 +1,15 @@ import pytensor.tensor as pt import pymc as pm +MAPPING = {"Cumulative": pm.Categorical, "StoppingRatio": pm.Categorical} + def get_distribution(dist): """Return a PyMC distribution.""" if isinstance(dist, str): - if hasattr(pm, dist): + if dist in MAPPING: + dist = MAPPING[dist] + elif hasattr(pm, dist): dist = getattr(pm, dist) else: raise ValueError(f"The Distribution '{dist}' was not found in PyMC") diff --git a/bambi/defaults/families.py b/bambi/defaults/families.py index 63e59dbc4..ce4729e5b 100644 --- a/bambi/defaults/families.py +++ b/bambi/defaults/families.py @@ -6,6 +6,7 @@ BetaBinomial, Binomial, Categorical, + Cumulative, Gamma, Gaussian, HurdleGamma, @@ -15,6 +16,7 @@ NegativeBinomial, Laplace, Poisson, + StoppingRatio, StudentT, VonMises, Wald, @@ -37,7 +39,6 @@ "family": AsymmetricLaplace, "default_priors": {"b": "HalfNormal", "kappa": "HalfNormal"} }, - "bernoulli": { "likelihood": { "name": "Bernoulli", @@ -85,6 +86,16 @@ "link": {"p": "softmax"}, "family": Categorical, }, + "cumulative": { + "likelihood": { + "name": "Cumulative", + "params": ["p", "threshold"], + "parent": "p", + }, + "link": {"p": "logit", "threshold": "identity"}, + "family": Cumulative, + "default_priors": {"threshold": "Normal"}, + }, "dirichlet_multinomial": { "likelihood": { "name": "DirichletMultinomial", @@ -192,6 +203,16 @@ "link": {"mu": "log"}, "family": Poisson, }, + "sratio": { + "likelihood": { + "name": "StoppingRatio", + "params": ["p", "threshold"], + "parent": "p", + }, + "link": {"p": "logit", "threshold": "identity"}, + "family": StoppingRatio, + "default_priors": {"threshold": "Normal"}, + }, "t": { "likelihood": { "name": "StudentT", diff --git a/bambi/families/family.py b/bambi/families/family.py index 714cfbfba..3d594d16a 100644 --- a/bambi/families/family.py +++ b/bambi/families/family.py @@ -170,14 +170,15 @@ def posterior_predictive(self, model, posterior, **kwargs): continue kwargs[key] = expand_array(values, ndims_max) - if hasattr(model.family, "transform_backend_kwargs"): - kwargs = model.family.transform_backend_kwargs(kwargs) + if hasattr(model.family, "transform_kwargs"): + kwargs = model.family.transform_kwargs(kwargs) output_array = pm.draw(response_dist.dist(**kwargs)) output_coords_all = xr.merge(output_dataset_list).coords coord_names = ["chain", "draw", response_aliased_name + "_obs"] - if hasattr(model.family, "KIND") and model.family.KIND == "Multivariate": + is_multivariate = hasattr(model.family, "KIND") and model.family.KIND == "Multivariate" + if is_multivariate: coord_names.append(response_aliased_name + "_dim") output_coords = {} @@ -206,8 +207,12 @@ def get_response_dist(family): pm.Distribution The response distribution """ + mapping = {"Cumulative": pm.Categorical, "StoppingRatio": pm.Categorical} + if family.likelihood.dist: dist = family.likelihood.dist + elif family.likelihood.name in mapping: + dist = mapping[family.likelihood.name] else: dist = getattr(pm, family.likelihood.name) return dist diff --git a/bambi/families/likelihood.py b/bambi/families/likelihood.py index 324b49ed0..99c8af311 100644 --- a/bambi/families/likelihood.py +++ b/bambi/families/likelihood.py @@ -11,6 +11,7 @@ "BetaBinomial": DistSettings(params=("mu", "kappa"), parent="mu"), "Binomial": DistSettings(params=("p",), parent="p"), "Categorical": DistSettings(params=("p",), parent="p"), + "Cumulative": DistSettings(params=("p", "threshold"), parent="p"), "DirichletMultinomial": DistSettings(params=("a",), parent="a"), "Gamma": DistSettings(params=("mu", "alpha"), parent="mu"), "Multinomial": DistSettings(params=("p",), parent="p"), diff --git a/bambi/families/multivariate.py b/bambi/families/multivariate.py index e7076d9a5..14ec7acd6 100644 --- a/bambi/families/multivariate.py +++ b/bambi/families/multivariate.py @@ -1,6 +1,7 @@ # pylint: disable=unused-argument import numpy as np import pytensor.tensor as pt +import xarray as xr from bambi.families.family import Family from bambi.transformations import transformations_namespace @@ -15,7 +16,10 @@ class Multinomial(MultivariateFamily): SUPPORTED_LINKS = {"p": ["softmax"]} INVLINK_KWARGS = {"axis": -1} - def transform_linear_predictor(self, model, linear_predictor): + @staticmethod + def transform_linear_predictor( + model, linear_predictor: xr.DataArray, posterior: xr.DataArray + ) -> xr.DataArray: # pylint: disable = unused-variable response_name = get_aliased_name(model.response_component.response_term) response_levels_dim = response_name + "_reduced_dim" linear_predictor = linear_predictor.pad({response_levels_dim: (1, 0)}, constant_values=0) @@ -50,19 +54,20 @@ def get_levels(self, response): @staticmethod def transform_backend_kwargs(kwargs): - if "observed" in kwargs: - kwargs["n"] = kwargs["observed"].sum(axis=1).astype(int) + kwargs["n"] = kwargs["observed"].sum(axis=1).astype(int) return kwargs @staticmethod - def transform_backend_nu(nu, data): + def transform_backend_eta(eta, kwargs): + data = kwargs["observed"] + # Add column of zeros to the linear predictor for the reference level (the first one) shape = (data.shape[0], 1) # The first line makes sure the intercept-only models work - nu = np.ones(shape) * nu # (response_levels, ) -> (n, response_levels) - nu = pt.concatenate([np.zeros(shape), nu], axis=1) - return nu + eta = np.ones(shape) * eta # (response_levels, ) -> (n, response_levels) + eta = pt.concatenate([np.zeros(shape), eta], axis=1) + return eta class DirichletMultinomial(MultivariateFamily): @@ -86,6 +91,5 @@ def get_levels(self, response): @staticmethod def transform_backend_kwargs(kwargs): - if "observed" in kwargs: - kwargs["n"] = kwargs["observed"].sum(axis=1).astype(int) + kwargs["n"] = kwargs["observed"].sum(axis=1).astype(int) return kwargs diff --git a/bambi/families/univariate.py b/bambi/families/univariate.py index c977becce..c78cdfe79 100644 --- a/bambi/families/univariate.py +++ b/bambi/families/univariate.py @@ -1,5 +1,6 @@ import numpy as np import pytensor.tensor as pt +import xarray as xr from bambi.families.family import Family from bambi.utils import get_aliased_name @@ -22,11 +23,9 @@ def posterior_predictive(self, model, posterior, **kwargs): @staticmethod def transform_backend_kwargs(kwargs): - # Only used when fitting data, not when getting draws from posterior predictive distribution - if "observed" in kwargs: - observed = kwargs.pop("observed") - kwargs["observed"] = observed[:, 0].squeeze() - kwargs["n"] = observed[:, 1].squeeze() + observed = kwargs.pop("observed") + kwargs["observed"] = observed[:, 0].squeeze() + kwargs["n"] = observed[:, 1].squeeze() return kwargs @@ -70,6 +69,14 @@ def transform_backend_kwargs(kwargs): kwargs["beta"] = (1 - mu) * kappa return kwargs + @staticmethod + def transform_kwargs(kwargs): + mu = kwargs.pop("mu") + kappa = kwargs.pop("kappa") + kwargs["alpha"] = mu * kappa + kwargs["beta"] = (1 - mu) * kappa + return kwargs + class BetaBinomial(BinomialBaseFamily): """BetaBinomial family @@ -89,6 +96,16 @@ def transform_backend_kwargs(kwargs): # Then transform the parameters of the binomial component return BinomialBaseFamily.transform_backend_kwargs(kwargs) + @staticmethod + def transform_kwargs(kwargs): + # First, transform the parameters of the beta component + print(kwargs) + mu = kwargs.pop("mu") + kappa = kwargs.pop("kappa") + kwargs["alpha"] = mu * kappa + kwargs["beta"] = (1 - mu) * kappa + return kwargs + class Binomial(BinomialBaseFamily): SUPPORTED_LINKS = {"p": ["identity", "logit", "probit", "cloglog"]} @@ -98,7 +115,11 @@ class Categorical(UnivariateFamily): SUPPORTED_LINKS = {"p": ["softmax"]} INVLINK_KWARGS = {"axis": -1} - def transform_linear_predictor(self, model, linear_predictor): + # pylint: disable = unused-argument + @staticmethod + def transform_linear_predictor( + model, linear_predictor: xr.DataArray, posterior: xr.DataArray + ) -> xr.DataArray: response_name = get_aliased_name(model.response_component.response_term) response_levels_dim = response_name + "_reduced_dim" linear_predictor = linear_predictor.pad({response_levels_dim: (1, 0)}, constant_values=0) @@ -125,14 +146,86 @@ def get_reference(self, response): return get_reference_level(response.term) @staticmethod - def transform_backend_nu(nu, data): + def transform_backend_eta(eta, kwargs): + data = kwargs["observed"] + # Add column of zeros to the linear predictor for the reference level (the first one) shape = (data.shape[0], 1) # The first line makes sure the intercept-only models work - nu = np.ones(shape) * nu # (response_levels, ) -> (n, response_levels) - nu = pt.concatenate([np.zeros(shape), nu], axis=1) - return nu + eta = np.ones(shape) * eta # (response_levels, ) -> (n, response_levels) + eta = pt.concatenate([np.zeros(shape), eta], axis=1) + return eta + + +class Cumulative(UnivariateFamily): + SUPPORTED_LINKS = {"p": ["logit", "probit", "cloglog"], "threshold": ["identity"]} + + def get_data(self, response): + return np.nonzero(response.term.data)[1] + + @staticmethod + def transform_linear_predictor( + model, linear_predictor: xr.DataArray, posterior: xr.DataArray + ) -> xr.DataArray: + """Computes threshold_k - eta""" + threshold_component = model.components["threshold"] + response_name = get_aliased_name(model.response_component.response_term) + threshold_name = threshold_component.alias if threshold_component.alias else "threshold" + threshold_name = f"{response_name}_{threshold_name}" + threshold = posterior[threshold_name] + return threshold - linear_predictor + + @staticmethod + def transform_mean(model, mean: xr.DataArray) -> xr.DataArray: + """Computes P(Y = k) = F(threshold_k - eta) - F(threshold_{k - 1} - eta)""" + threshold_component = model.components["threshold"] + response_name = get_aliased_name(model.response_component.response_term) + threshold_name = threshold_component.alias if threshold_component.alias else "threshold" + threshold_dim = f"{response_name}_{threshold_name}_dim" + response_dim = response_name + "_dim" + mean = xr.concat( + [ + mean.isel({threshold_dim: 0}), + mean.diff(threshold_dim), + 1 - mean.isel({threshold_dim: -1}), + ], + dim=threshold_dim, + ) + mean = mean.rename({threshold_dim: response_dim}) + mean = mean.assign_coords({response_dim: model.response_component.response_term.levels}) + mean = mean.transpose(..., response_dim) # make sure response levels is the last dim + return mean + + @staticmethod + def transform_backend_eta(eta, kwargs): + # shape(threshold) = (K, ) + # shape(eta) = (n, ) + # shape(threshold - shape_padright(eta)) = (n, K) + threshold = kwargs["threshold"] + eta_shifted = threshold - pt.shape_padright(eta) + return eta_shifted + + @staticmethod + def transform_backend_kwargs(kwargs): + # P(Y = k) = F(threshold_k - eta) - F(threshold_{k - 1} - eta) + p = kwargs.pop("p") + p = pt.concatenate( + [ + pt.shape_padright(p[..., 0]), + p[..., 1:] - p[..., :-1], + pt.shape_padright(1 - p[..., -1]), + ], + axis=-1, + ) + kwargs["p"] = p + kwargs.pop("threshold", None) # this is not passed to the likelihood function + return kwargs + + @staticmethod + def transform_kwargs(kwargs): + kwargs.pop("threshold", None) # this is not passed to the likelihood function + return kwargs class Gamma(UnivariateFamily): @@ -146,6 +239,12 @@ def transform_backend_kwargs(kwargs): kwargs["sigma"] = kwargs["mu"] / (alpha**0.5) return kwargs + @staticmethod + def transform_kwargs(kwargs): + alpha = kwargs.pop("alpha") + kwargs["sigma"] = kwargs["mu"] / (alpha**0.5) + return kwargs + class Gaussian(UnivariateFamily): SUPPORTED_LINKS = {"mu": ["identity", "log", "inverse"], "sigma": ["log"]} @@ -164,6 +263,12 @@ def transform_backend_kwargs(kwargs): kwargs["sigma"] = kwargs["mu"] / (alpha**0.5) return kwargs + @staticmethod + def transform_kwargs(kwargs): + alpha = kwargs.pop("alpha") + kwargs["sigma"] = kwargs["mu"] / (alpha**0.5) + return kwargs + class HurdleLogNormal(UnivariateFamily): SUPPORTED_LINKS = { @@ -197,6 +302,89 @@ class Poisson(UnivariateFamily): SUPPORTED_LINKS = {"mu": ["identity", "log"]} +class StoppingRatio(UnivariateFamily): + SUPPORTED_LINKS = {"p": ["logit", "probit", "cloglog"], "threshold": ["identity"]} + + def get_data(self, response): + return np.nonzero(response.term.data)[1] + + @staticmethod + def transform_linear_predictor( + model, linear_predictor: xr.DataArray, posterior: xr.DataArray + ) -> xr.DataArray: + """Computes threshold_k - eta""" + threshold_component = model.components["threshold"] + response_name = get_aliased_name(model.response_component.response_term) + threshold_name = threshold_component.alias if threshold_component.alias else "threshold" + threshold_name = f"{response_name}_{threshold_name}" + threshold = posterior[threshold_name] + return threshold - linear_predictor + + @staticmethod + def transform_mean(model, mean: xr.DataArray) -> xr.DataArray: + """Computes P(Y = k) = F(threshold_k - eta) - F(threshold_{k - 1} - eta)""" + threshold_component = model.components["threshold"] + response_name = get_aliased_name(model.response_component.response_term) + threshold_name = threshold_component.alias if threshold_component.alias else "threshold" + threshold_dim = f"{response_name}_{threshold_name}_dim" + response_dim = response_name + "_dim" + threshold_n = len(mean[threshold_dim]) + + # the `.assign_coords`` is needed for the concat to work + mean = xr.concat( + [ + mean.isel({threshold_dim: 0}), + *[ + ( + mean.isel({threshold_dim: j}) + * (1 - mean).isel({threshold_dim: slice(None, j)}).prod(threshold_dim) + ) + for j in range(1, threshold_n) + ], + (1 - mean).prod(threshold_dim).assign_coords({threshold_dim: threshold_n + 1}), + ], + dim=threshold_dim, + ) + mean = mean.rename({threshold_dim: response_dim}) + mean = mean.assign_coords({response_dim: model.response_component.response_term.levels}) + mean = mean.transpose(..., response_dim) # make sure response levels is the last dim + return mean + + @staticmethod + def transform_backend_eta(eta, kwargs): + # shape(threshold) = (K, ) + # shape(eta) = (n, ) + # shape(threshold - shape_padright(eta)) = (n, K) + threshold = kwargs["threshold"] + eta_shifted = threshold - pt.shape_padright(eta) + return eta_shifted + + @staticmethod + def transform_backend_kwargs(kwargs): + # P(Y = k) = F(threshold_k - eta) * \prod_{j=1}^{k-1}{1 - F(threshold_j - eta)} + p = kwargs.pop("p") + n_columns = p.type.shape[-1] + p = pt.concatenate( + [ + pt.shape_padright(p[..., 0]), + *[ + pt.shape_padright(p[..., j] * pt.prod(1 - p[..., :j], axis=-1)) + for j in range(1, n_columns) + ], + pt.shape_padright(pt.prod(1 - p, axis=-1)), + ], + axis=-1, + ) + kwargs["p"] = p + kwargs.pop("threshold", None) # this is not passed to the likelihood function + return kwargs + + @staticmethod + def transform_kwargs(kwargs): + kwargs.pop("threshold", None) # this is not passed to the likelihood function + return kwargs + + class StudentT(UnivariateFamily): SUPPORTED_LINKS = {"mu": ["identity", "log", "inverse"], "sigma": ["log"], "nu": ["log"]} diff --git a/bambi/formula.py b/bambi/formula.py index 7614f0723..83557e636 100644 --- a/bambi/formula.py +++ b/bambi/formula.py @@ -1,7 +1,8 @@ +import warnings + from typing import Sequence -from formulae import model_description -from formulae.terms.variable import Variable +import formulae as fm class Formula: @@ -56,14 +57,14 @@ def check_additional(self, additional: str): ValueError If the response term is not a plain name """ - response = model_description(additional).response + response = fm.model_description(additional).response # There's a response in the formula if response is None: raise ValueError("Additional formulas must contain a response name.") # The response is a name, not a function call for example - if not isinstance(response.term.components[0], Variable): + if not isinstance(response.term.components[0], fm.terms.variable.Variable): raise ValueError("The response must be a name.") self.additionals_lhs.append(response.term.name) @@ -87,3 +88,16 @@ def __repr__(self): formulas = [self.main] + list(self.additionals) middle = ", ".join([f"'{formula}'" for formula in formulas]) return f"Formula({middle})" + + +def formula_has_intercept(formula: str) -> bool: + description = fm.model_description(formula) + return any(isinstance(term, fm.terms.Intercept) for term in description.terms) + + +def check_ordinal_formula(formula: Formula) -> Formula: + if len(formula.additionals) > 0: + raise ValueError("Ordinal families don't accept multiple formulas") + if formula_has_intercept(formula.main): + warnings.warn("The intercept is omitted in ordinal families") + return formula diff --git a/bambi/model_components.py b/bambi/model_components.py index 978dc24d3..19aa351a7 100644 --- a/bambi/model_components.py +++ b/bambi/model_components.py @@ -286,7 +286,9 @@ def predict(self, idata, data=None, include_group_specific=True, hsgp_dict=None) # Handle more special cases if hasattr(family, "transform_linear_predictor"): - linear_predictor = family.transform_linear_predictor(self.spec, linear_predictor) + linear_predictor = family.transform_linear_predictor( + self.spec, linear_predictor, posterior + ) if self.response_kind == "data": linkinv = family.link[family.likelihood.parent].linkinv @@ -299,6 +301,9 @@ def predict(self, idata, data=None, include_group_specific=True, hsgp_dict=None) if hasattr(family, "transform_coords"): response = family.transform_coords(self.spec, response) + if hasattr(family, "transform_mean"): + response = family.transform_mean(self.spec, response) + return response @property diff --git a/bambi/models.py b/bambi/models.py index 81296fe47..bce64b727 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -5,32 +5,34 @@ from copy import deepcopy +import formulae as fm import pymc as pm import pandas as pd from arviz.plots import plot_posterior -from formulae import design_matrices - from bambi.backend import PyMCModel from bambi.defaults import get_builtin_family from bambi.model_components import ConstantComponent, DistributionalComponent from bambi.families import Family, univariate -from bambi.formula import Formula +from bambi.formula import Formula, check_ordinal_formula from bambi.priors import Prior, PriorScaler from bambi.transformations import transformations_namespace from bambi.utils import ( clean_formula_lhs, get_aliased_name, get_auxiliary_parameters, - listify, indentify, + listify, + remove_common_intercept, wrapify, ) from bambi.version import __version__ _log = logging.getLogger("bambi") +ORDINAL_FAMILIES = (univariate.Cumulative, univariate.StoppingRatio) + class Model: """Specification of model class. @@ -133,9 +135,6 @@ def __init__( # Obtain design matrices and related objects. na_action = "drop" if dropna else "error" - # Create family - self._set_family(family, link) - # Handle additional namespaces additional_namespace = transformations_namespace.copy() if not isinstance(extra_namespace, (type(None), dict)): @@ -144,8 +143,24 @@ def __init__( if isinstance(extra_namespace, dict): additional_namespace.update(extra_namespace) + # Create family + self._set_family(family, link) + ## Main component - design = design_matrices(self.formula.main, self.data, na_action, 1, additional_namespace) + if isinstance(self.family, ORDINAL_FAMILIES): + self.formula = check_ordinal_formula(self.formula) + # Notice the intercept is added so formulae constrains categorical predictors, avoiding + # linear dependencies with the cutpoints. + # Then the intercept is removed from the design matrix because of the cutpoints. + design = fm.design_matrices( + self.formula.main + " + 1", self.data, na_action, 1, additional_namespace + ) + design = remove_common_intercept(design) + else: + design = fm.design_matrices( + self.formula.main, self.data, na_action, 1, additional_namespace + ) + if design.response is None: raise ValueError( "No outcome variable is set! " @@ -177,7 +192,7 @@ def __init__( ) # Create design matrix, only for the response part - design = design_matrices( + design = fm.design_matrices( clean_formula_lhs(extra_formula), self.data, na_action, 1, additional_namespace ) diff --git a/bambi/priors/prior.py b/bambi/priors/prior.py index 217fdc6cc..fd5470d57 100644 --- a/bambi/priors/prior.py +++ b/bambi/priors/prior.py @@ -53,7 +53,7 @@ def __eq__(self, other): def __str__(self): args = ", ".join( [ - f"{k}: {np.round_(v, 4)}" if not isinstance(v, type(self)) else f"{k}: {v}" + f"{k}: {format_arg(v, 4)}" if not isinstance(v, type(self)) else f"{k}: {v}" for k, v in self.args.items() ] ) @@ -61,3 +61,14 @@ def __str__(self): def __repr__(self): return self.__str__() + + +def format_arg(value, decimals): + try: + outcome = np.round_(value, decimals) + except: # pylint: disable = bare-except + try: + outcome = value.name + except: # pylint: disable = bare-except + outcome = value + return outcome diff --git a/bambi/priors/scaler.py b/bambi/priors/scaler.py index 35b88210b..8e1402d5c 100644 --- a/bambi/priors/scaler.py +++ b/bambi/priors/scaler.py @@ -1,6 +1,7 @@ import numpy as np +import pymc as pm -from bambi.families.univariate import Gaussian, StudentT, VonMises +from bambi.families.univariate import Cumulative, Gaussian, StoppingRatio, StudentT, VonMises from bambi.model_components import ConstantComponent from bambi.priors.prior import Prior @@ -97,6 +98,25 @@ def scale_group_specific(self, term): sigma[i] = self.get_slope_sigma(value) term.prior.args["sigma"].update(sigma=np.squeeze(np.atleast_1d(sigma))) + def scale_threshold(self): + if isinstance(self.model.family, Cumulative): + threshold = self.model.components["threshold"] + if isinstance(threshold, ConstantComponent) and threshold.prior.auto_scale: + response_level_n = len(np.unique(self.response_component.response_term.data)) + mu = np.round(np.linspace(-2, 2, num=response_level_n - 1), 2) + threshold.prior = Prior( + "Normal", + mu=mu, + sigma=1, + transform=pm.distributions.transforms.univariate_ordered, + ) + elif isinstance(self.model.family, StoppingRatio): + threshold = self.model.components["threshold"] + if isinstance(threshold, ConstantComponent) and threshold.prior.auto_scale: + response_level_n = len(np.unique(self.response_component.response_term.data)) + mu = np.round(np.linspace(-2, 2, num=response_level_n - 1), 2) + threshold.prior = Prior("Normal", mu=mu, sigma=1) + def scale(self): # Scale response self.scale_response() @@ -116,3 +136,6 @@ def scale(self): for term in self.response_component.group_specific_terms.values(): if term.prior.auto_scale: self.scale_group_specific(term) + + # Scale threshold parameters in ordinal families + self.scale_threshold() diff --git a/bambi/utils.py b/bambi/utils.py index 4e906353a..55dcc4efd 100644 --- a/bambi/utils.py +++ b/bambi/utils.py @@ -3,7 +3,8 @@ import ast import textwrap -from formulae.terms.call import Call +import formulae as fm +import numpy as np from bambi.transformations import HSGP @@ -150,7 +151,7 @@ def is_single_component(term): def is_call_component(component): - return isinstance(component, Call) + return isinstance(component, fm.terms.call.Call) def has_stateful_transform(component): @@ -166,3 +167,10 @@ def is_hsgp_term(term): if not has_stateful_transform(component): return False return isinstance(component.call.stateful_transform, HSGP) + + +def remove_common_intercept(dm: fm.matrices.DesignMatrices) -> fm.matrices.DesignMatrices: + dm.common.terms.pop("Intercept") + intercept_slice = dm.common.slices.pop("Intercept") + dm.common.design_matrix = np.delete(dm.common.design_matrix, intercept_slice, axis=1) + return dm diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index cd3cdaa37..2582c2a87 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -4,6 +4,8 @@ ### New features +* Implement new families `"ordinal"` and `"sratio"` for modeling of ordinal responses (#678) + ### Maintenance and fixes ### Documentation diff --git a/stemcell.csv b/stemcell.csv new file mode 100644 index 000000000..4ec9593fb --- /dev/null +++ b/stemcell.csv @@ -0,0 +1,830 @@ +"belief","rating","gender" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"fundamentalist",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"moderate",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"liberal",1,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"fundamentalist",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"moderate",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"liberal",2,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"fundamentalist",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"moderate",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"liberal",3,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"fundamentalist",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"moderate",4,"female" +"liberal",4,"female" +"liberal",4,"female" +"liberal",4,"female" +"liberal",4,"female" +"liberal",4,"female" +"liberal",4,"female" +"liberal",4,"female" +"liberal",4,"female" +"liberal",4,"female" +"liberal",4,"female" +"liberal",4,"female" +"liberal",4,"female" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"fundamentalist",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"moderate",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"liberal",1,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"fundamentalist",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"moderate",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"liberal",2,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"fundamentalist",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"moderate",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"liberal",3,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"fundamentalist",4,"male" +"moderate",4,"male" +"moderate",4,"male" +"moderate",4,"male" +"moderate",4,"male" +"moderate",4,"male" +"moderate",4,"male" +"moderate",4,"male" +"moderate",4,"male" +"moderate",4,"male" +"moderate",4,"male" +"moderate",4,"male" +"liberal",4,"male" +"liberal",4,"male" +"liberal",4,"male" +"liberal",4,"male" +"liberal",4,"male" +"liberal",4,"male" +"liberal",4,"male" +"liberal",4,"male" +"liberal",4,"male" +"liberal",4,"male" +"liberal",4,"male" diff --git a/tests/test_built_models.py b/tests/test_built_models.py index 5b6c8505a..38bb0b6ea 100644 --- a/tests/test_built_models.py +++ b/tests/test_built_models.py @@ -906,4 +906,39 @@ def test_hurlde_families(): df = pd.DataFrame({"y": pm.draw(pm.HurdleLogNormal.dist(0.7, mu=0, sigma=0.2), 1000)}) model = bmb.Model("y ~ 1", df, family="hurdle_lognormal") idata = model.fit() - model.predict(idata, kind="pps") \ No newline at end of file + model.predict(idata, kind="pps") + +@pytest.mark.parametrize( + "family, link", + [ + ("cumulative", "logit"), + ("cumulative", "probit"), + ("cumulative", "cloglog"), + ("sratio", "logit"), + ("sratio", "probit"), + ("sratio", "cloglog"), + ] +) +def test_ordinal_families(inhaler, family, link): + data = inhaler.copy() + data["carry"] = pd.Categorical(data["carry"]) # To have both numeric and categoric predictors + model = bmb.Model("rating ~ period + carry + treat", data, family=family, link=link) + idata = model.fit(tune=100, draws=100) + model.predict(idata, kind="pps") + assert np.allclose(idata.posterior["rating_mean"].sum("rating_dim").to_numpy(), 1) + assert np.all(np.unique(idata.posterior_predictive["rating"]) == np.array([0, 1, 2, 3])) + + +def test_cumulative_family_priors(inhaler): + priors = { + "threshold": bmb.Prior( + "Normal", + mu=[-0.5, 0, 0.5], + sigma=1.5, + transform=pm.distributions.transforms.univariate_ordered + ) + } + model = bmb.Model( + "rating ~ period + carry + treat", inhaler, family="cumulative", priors=priors + ) + model.fit(tune=100, draws=100) \ No newline at end of file