Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot comparisons notebook for docs #695

Merged
merged 84 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from 82 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
db8e4c4
plot_cap draft outline for docs example
GStechschulte Apr 24, 2023
d0ce8df
intro. to GLMs and Negative Binomial model
GStechschulte May 8, 2023
8e0b942
added logistic regression and other model params. demo
GStechschulte May 9, 2023
4f02cf1
basic linear model demo
GStechschulte May 10, 2023
6b966a4
comparisons learning from marginaleffects
GStechschulte May 14, 2023
652e8a2
comparison contrasts using make_cap_data code
GStechschulte May 16, 2023
a901346
CreateData class added to __init__.py
GStechschulte May 18, 2023
d6ae8d4
CreateData class for all plotting functions
GStechschulte May 18, 2023
4d0582a
functions for computing and plotting comparisons
GStechschulte May 18, 2023
af0b824
plot_comparisons demo on categorical data
GStechschulte May 18, 2023
d1ca048
logic of main, group, panel for building contrasts df
GStechschulte May 19, 2023
d0bb1b3
add make_group_panel_values and enforce_dtypes functions
GStechschulte May 19, 2023
e007876
plot_comparisons demo
GStechschulte May 19, 2023
5c0a242
cleanup demo notebook
GStechschulte May 19, 2023
bc775ea
cleanup demo notebook
GStechschulte May 20, 2023
8a396f8
move util functions to utils.py and renaming of modules and functions
GStechschulte May 22, 2023
7d57d07
re-run demo.
GStechschulte May 22, 2023
f55b403
use dataclass for returning covariates instead of dict
GStechschulte May 23, 2023
5db0910
remove unused variables in plot_comparison
GStechschulte May 23, 2023
0bb018e
type hinting and added dataclass for covariates
GStechschulte May 23, 2023
bc93a06
module for plot kinds based on response
GStechschulte May 23, 2023
2ca68b1
plot_comparisons demo notebook update
GStechschulte May 23, 2023
5bff599
changed and added modules to be treated as packages
GStechschulte May 23, 2023
76ca977
modularize create cap, comparisons, and slopes data functions
GStechschulte May 23, 2023
bd1b1c9
deleted and moved into plotting.py
GStechschulte May 23, 2023
dc47e11
plot cap and comparisons and type of plots in separate modules
GStechschulte May 23, 2023
3e166cb
commonly used functions for create_data.py
GStechschulte May 23, 2023
46c30f2
delete print statement
GStechschulte May 23, 2023
dfb9301
re-run demo notebook
GStechschulte May 23, 2023
ac7004e
comparisons numerical default
GStechschulte May 28, 2023
4619e41
default contrast level for numeric and char. variables
GStechschulte Jun 7, 2023
5131b1c
replace np.repeat with np.tile
GStechschulte Jun 7, 2023
f2961ef
default contrast level for numeric and char. variables
GStechschulte Jun 7, 2023
03c6b46
plot_comparisons default numeric demo
GStechschulte Jun 7, 2023
0080b1c
re-run plot_cap.ipynb
GStechschulte Jun 7, 2023
5beaf28
modularize cap, comparisons, and slopes
GStechschulte Jun 7, 2023
827b59c
add Comparisons class to reduce redundant passing of args.
GStechschulte Jun 7, 2023
9c5fc9e
re-run notebook to ensure everything still works
GStechschulte Jun 7, 2023
39d5c50
added class objects for attribute lookup, more informative contrast d…
GStechschulte Jun 8, 2023
167d8b2
re-run notebook
GStechschulte Jun 8, 2023
e3d2e9e
re-run notebook with default numeric contrast value
GStechschulte Jun 8, 2023
ad5210d
reduce number of args. in plotting functions and use dataclasses inst…
GStechschulte Jun 9, 2023
c8e6e22
new examples with both numeric and categoric variables
GStechschulte Jun 9, 2023
2e21dc7
add comments for review
GStechschulte Jun 9, 2023
dc67b96
Show working version and some ideas of plot_comparisons with xarray
tomicapretto Jun 10, 2023
0c50780
comparisons computed using entire posterior and better error handling
GStechschulte Jun 12, 2023
468f39c
re-run comparisons notebook w/updated comparisons code
GStechschulte Jun 12, 2023
a8ed2d3
delete notebook
GStechschulte Jun 12, 2023
44b6a8c
refactor plot_cap to work and pass Covariates class instead of dict t…
GStechschulte Jun 13, 2023
32970ee
UserWarning if level > 2
GStechschulte Jun 13, 2023
4c93aea
re-run plot_cap and plot_comparisons notebooks
GStechschulte Jun 13, 2023
b50fcf9
assertions, docstrings, and type hints
GStechschulte Jun 15, 2023
3b6b8f0
re-run to make sure ValueErrors work
GStechschulte Jun 15, 2023
69e381c
comparisons for > 1 contrast level
GStechschulte Jun 16, 2023
625f505
raise ValueError if user passes > 1 contrast level when plotting
GStechschulte Jun 16, 2023
b80481f
add predictions as sub-package
GStechschulte Jun 19, 2023
02d7d8c
type hints, doctrings, and run black
GStechschulte Jun 19, 2023
16fdb66
added comparison tests, and move cap and comparisons tests into classes
GStechschulte Jun 19, 2023
b62d0ba
GSoC code review 22.06 and added arg. for
GStechschulte Jun 24, 2023
930d636
added documentation on new arg.
GStechschulte Jun 24, 2023
c7a4b30
added classes for organizing , and added tests for
GStechschulte Jun 24, 2023
f3deccc
delete print statement
GStechschulte Jun 24, 2023
4cc4aa2
add test_hdi_prob in class TestCommon
GStechschulte Jun 24, 2023
bf78ad5
pylint C2801 use str() instead of __str__()
GStechschulte Jun 26, 2023
620a293
resolve pylint error messages
GStechschulte Jun 26, 2023
cfcc56a
remove lambda expression as it is not needed
GStechschulte Jun 26, 2023
f0c4a5b
add support for unit-level contrasts and 'average_by=True'
GStechschulte Jun 27, 2023
309ea3c
improved OOP with dataclasses, error handling, and added unit-level c…
GStechschulte Jun 28, 2023
d20f72f
ran black
GStechschulte Jun 29, 2023
8763a43
move isinstance logic to dataclass, improved error handling, and remo…
GStechschulte Jun 29, 2023
9dc1ca7
resolve pylint message codes
GStechschulte Jun 29, 2023
20998da
remove imports that users should not have access to
GStechschulte Jul 2, 2023
2baa5ac
fix/add docstrings
GStechschulte Jul 2, 2023
86f829c
fix/add docstrings and f-string attributes to ResponseInfo class
GStechschulte Jul 2, 2023
ce4367a
fix/add docstrings
GStechschulte Jul 2, 2023
b4a9e20
plot_comparisons demo for docs
GStechschulte Jul 3, 2023
03a785a
details added on how default and grid values are computed
GStechschulte Jul 4, 2023
24a1df7
bug fix for building contrast_df when len(contrast values) > 3
GStechschulte Jul 5, 2023
0a92294
more explanations on generated data, multiple contrast levels, and wo…
GStechschulte Jul 5, 2023
96b82bd
raise ValueError if user tries to plot with > 2 contrast values
GStechschulte Jul 6, 2023
d25cac5
pylinter error solved, make contrast_df column ordering consistent
GStechschulte Jul 6, 2023
14ffeb7
wording and added dataset descriptions and comparison type section
GStechschulte Jul 6, 2023
8df0628
change wording to posterior samples and deleted unused cell
GStechschulte Jul 6, 2023
6f269e0
Merge branch 'main' into plot-comparisons-docs
GStechschulte Jul 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ def set_alias(self, aliases):
assert component_name in self.distributional_components
component = self.distributional_components[component_name]
for name, alias in component_aliases.items():

is_used = False

if name in component.terms:
Expand Down Expand Up @@ -660,7 +659,7 @@ def plot_priors(
unobserved_rvs_names = []
flat_rvs = []
for unobserved in self.backend.model.unobserved_RVs:
if "Flat" in unobserved.__str__():
if "Flat" in str(unobserved):
flat_rvs.append(unobserved.name)
else:
unobserved_rvs_names.append(unobserved.name)
Expand Down
6 changes: 4 additions & 2 deletions bambi/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .plot_cap import create_cap_data, plot_cap
from bambi.plots.effects import comparisons, predictions
from bambi.plots.plotting import plot_cap, plot_comparison

__all__ = ["create_cap_data", "plot_cap"]

__all__ = ["comparisons", "predictions", "plot_cap", "plot_comparison"]
129 changes: 129 additions & 0 deletions bambi/plots/create_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import itertools

import numpy as np
import pandas as pd

from bambi.models import Model
from bambi.plots.utils import (
ConditionalInfo,
ContrastInfo,
enforce_dtypes,
get_covariates,
get_model_covariates,
make_group_panel_values,
make_main_values,
set_default_values,
)


def create_cap_data(model: Model, covariates: dict) -> pd.DataFrame:
"""Create data for a Conditional Adjusted Predictions

Parameters
----------
model : bambi.Model
An instance of a Bambi model
covariates : dict
A dictionary of length between one and three.
Keys must be taken from ("horizontal", "color", "panel").
The values indicate the names of variables.

Returns
-------
pandas.DataFrame
The data for the Conditional Adjusted Predictions dataframe and or
plotting.
"""
data = model.data
covariates = get_covariates(covariates)
main, group, panel = covariates.main, covariates.group, covariates.panel

# Obtain data for main variable
main_values = make_main_values(data[main])
data_dict = {main: main_values}

# Obtain data for group and panel variables if not None
data_dict = make_group_panel_values(data, data_dict, main, group, panel, kind="predictions")
data_dict = set_default_values(model, data_dict, kind="predictions")
return enforce_dtypes(data, pd.DataFrame(data_dict))


def create_comparisons_data(
condition: ConditionalInfo, contrast: ContrastInfo, user_passed: bool = False
) -> pd.DataFrame:
"""Create data for a Conditional Adjusted Comparisons

Parameters
----------
condition: ConditionalInfo
A dataclass instance containing the model, contrast, and conditional
covariates to be used in the comparisons.
contrast: ContrastInfo
A dataclass instance containing the model, and contrast name and values.
user_passed: bool, optional
Whether the user passed their own 'conditional' data. Defaults to False.

Returns
-------
pd.DataFrame
The data for the Conditional Adjusted Comparisons dataframe and or
plotting.
"""

def _grid_level(condition: ConditionalInfo, contrast: ContrastInfo):
"""
Creates the data for grid-level contrasts by using the covariates passed
into the `conditional` arg. Values for the grid are either: (1) computed
using a equally spaced grid, mean, and or mode (depending on the covariate
dtype), and (2) a user specified value or range of values.
"""
covariates = get_covariates(condition.covariates)

if user_passed:
data_dict = {**condition.conditional}
else:
main_values = make_main_values(condition.model.data[covariates.main])
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
condition.model.data,
data_dict,
covariates.main,
covariates.group,
covariates.panel,
kind="comparison",
)

data_dict[contrast.name] = contrast.values
comparison_data = set_default_values(condition.model, data_dict, kind="comparison")
# use cartesian product (cross join) to create contrasts
keys, values = zip(*comparison_data.items())
contrast_dict = [dict(zip(keys, v)) for v in itertools.product(*values)]

return enforce_dtypes(condition.model.data, pd.DataFrame(contrast_dict))

def _unit_level(contrast: ContrastInfo):
"""
Creates the data for unit-level contrasts by using the observed (empirical)
data. All covariates in the model are included in the data, except for the
contrast predictor. The contrast predictor is replaced with either: (1) the
default contrast value, or (2) the user specified contrast value.
"""
covariates = get_model_covariates(contrast.model)
df = contrast.model.data[covariates].drop(labels=contrast.name, axis=1)

contrast_vals = np.array(contrast.values)[..., None]
contrast_vals = np.repeat(contrast_vals, contrast.model.data.shape[0], axis=1)

contrast_df_dict = {}
for idx, value in enumerate(contrast_vals):
contrast_df_dict[f"contrast_{idx}"] = df.copy()
contrast_df_dict[f"contrast_{idx}"][contrast.name] = value

return pd.concat(contrast_df_dict.values())

if not condition.conditional:
df = _unit_level(contrast)
else:
df = _grid_level(condition, contrast)

return df
Loading