Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Add de-aliaser function for plots #1073

Merged
merged 22 commits into from
Mar 2, 2020
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* **Experimental Feature**: Added `arviz.wrappers` module to allow ArviZ to
refit the models if necessary
* **Experimental Feature**: Added `reloo` function to ArviZ
* Added new helper function `matplotlib_kwarg_dealiaser` (#1073)
* ArviZ version to InferenceData attributes. (#1086)


Expand Down
5 changes: 2 additions & 3 deletions arviz/plots/backends/matplotlib/distplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from . import backend_show
from ...kdeplot import plot_kde
from ...plot_utils import matplotlib_kwarg_dealiaser


def plot_dist(
Expand Down Expand Up @@ -49,9 +50,7 @@ def plot_dist(
)

elif kind == "kde":
if plot_kwargs is None:
plot_kwargs = {}

plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
plot_kwargs.setdefault("color", color)
legend = label is not None

Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/backends/matplotlib/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...plot_utils import (
make_label,
_create_axes_grid,
matplotlib_kwarg_dealiaser,
)


Expand Down Expand Up @@ -63,8 +64,7 @@ def plot_ess(
ess_tail = ess_tail_dataset[var_name].sel(**selection)
ax_.plot(xdata, ess_tail, **extra_kwargs)
elif rug:
if rug_kwargs is None:
rug_kwargs = {}
rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot")
if not hasattr(idata, "sample_stats"):
raise ValueError("InferenceData object must contain sample_stats for rug plot")
if not hasattr(idata.sample_stats, rug_kind):
Expand Down
21 changes: 7 additions & 14 deletions arviz/plots/backends/matplotlib/kdeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from . import backend_show
from ...plot_utils import _scale_fig_size
from ...plot_utils import _scale_fig_size, matplotlib_kwarg_dealiaser


def plot_kde(
Expand Down Expand Up @@ -53,19 +53,15 @@ def plot_kde(
figsize, *_, xt_labelsize, linewidth, markersize = _scale_fig_size(figsize, textsize, 1, 1)

if values2 is None:
if plot_kwargs is None:
plot_kwargs = {}
plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
plot_kwargs.setdefault("color", "C0")

default_color = plot_kwargs.get("color")

if fill_kwargs is None:
fill_kwargs = {}

fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "hexbin")
fill_kwargs.setdefault("color", default_color)

if rug_kwargs is None:
rug_kwargs = {}
rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot")
rug_kwargs.setdefault("marker", "_" if rotated else "|")
rug_kwargs.setdefault("linestyle", "None")
rug_kwargs.setdefault("color", default_color)
Expand Down Expand Up @@ -122,13 +118,10 @@ def plot_kde(
if legend and label:
ax.legend()
else:
if contour_kwargs is None:
contour_kwargs = {}
contour_kwargs = matplotlib_kwarg_dealiaser(contour_kwargs, "contour")
contour_kwargs.setdefault("colors", "0.5")
if contourf_kwargs is None:
contourf_kwargs = {}
if pcolormesh_kwargs is None:
pcolormesh_kwargs = {}
contourf_kwargs = matplotlib_kwarg_dealiaser(contourf_kwargs, "contour")
pcolormesh_kwargs = matplotlib_kwarg_dealiaser(pcolormesh_kwargs, "pcolormesh")

# gridsize = (128, 128) if contour else (256, 256)

Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/backends/matplotlib/mcseplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...plot_utils import (
make_label,
_create_axes_grid,
matplotlib_kwarg_dealiaser,
)


Expand Down Expand Up @@ -87,8 +88,7 @@ def plot_mcse(
**text_kwargs,
)
if rug:
if rug_kwargs is None:
rug_kwargs = {}
rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot")
if not hasattr(idata, "sample_stats"):
raise ValueError("InferenceData object must contain sample_stats for rug plot")
if not hasattr(idata.sample_stats, rug_kind):
Expand Down
6 changes: 2 additions & 4 deletions arviz/plots/distplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pylint: disable=unexpected-keyword-arg
"""Plot distribution as histogram or kernel density estimates."""
from .plot_utils import get_bins, get_plotting_function
from .plot_utils import get_bins, get_plotting_function, matplotlib_kwarg_dealiaser


def plot_dist(
Expand Down Expand Up @@ -147,9 +147,7 @@ def plot_dist(
kind = "hist" if values.dtype.kind == "i" else "kde"

if kind == "hist":
if hist_kwargs is None:
hist_kwargs = {}

hist_kwargs = matplotlib_kwarg_dealiaser(hist_kwargs, "hist")
hist_kwargs.setdefault("bins", get_bins(values))
hist_kwargs.setdefault("cumulative", cumulative)
hist_kwargs.setdefault("color", color)
Expand Down
12 changes: 9 additions & 3 deletions arviz/plots/elpdplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from matplotlib.lines import Line2D

from ..data import convert_to_inference_data
from .plot_utils import get_coords, format_coords_as_labels, color_from_dim, get_plotting_function
from .plot_utils import (
get_coords,
format_coords_as_labels,
color_from_dim,
get_plotting_function,
matplotlib_kwarg_dealiaser,
)
from ..stats import waic, loo, ELPDData
from ..rcparams import rcParams

Expand Down Expand Up @@ -142,8 +148,8 @@ def plot_elpd(
if coords is None:
coords = {}

if plot_kwargs is None:
plot_kwargs = {}
plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "scatter")

if backend == "bokeh":
plot_kwargs.setdefault("marker", rcParams["plot.bokeh.marker"])

Expand Down
10 changes: 4 additions & 6 deletions arviz/plots/energyplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from ..data import convert_to_dataset
from .plot_utils import _scale_fig_size, get_plotting_function
from .plot_utils import _scale_fig_size, get_plotting_function, matplotlib_kwarg_dealiaser


def plot_energy(
Expand Down Expand Up @@ -92,11 +92,9 @@ def plot_energy(
"""
energy = convert_to_dataset(data, group="sample_stats").energy.values

if fill_kwargs is None:
fill_kwargs = {}

if plot_kwargs is None:
plot_kwargs = {}
fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "hexbin")
types = "hist" if kind in {"hist", "histogram"} else "plot"
plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, types)

figsize, _, _, xt_labelsize, linewidth, _ = _scale_fig_size(figsize, textsize, 1, 1)

Expand Down
35 changes: 18 additions & 17 deletions arviz/plots/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_coords,
filter_plotters_list,
get_plotting_function,
matplotlib_kwarg_dealiaser,
)
from ..utils import _var_names

Expand Down Expand Up @@ -247,14 +248,15 @@ def plot_ess(
(figsize, ax_labelsize, titlesize, xt_labelsize, _linewidth, _markersize) = _scale_fig_size(
figsize, textsize, rows, cols
)
_linestyle = kwargs.pop("ls", "-" if kind == "evolution" else "none")
kwargs = matplotlib_kwarg_dealiaser(kwargs, "plot")
_linestyle = "-" if kind == "evolution" else "none"
kwargs.setdefault("linestyle", _linestyle)
kwargs.setdefault("linewidth", kwargs.pop("lw", _linewidth))
kwargs.setdefault("markersize", kwargs.pop("ms", _markersize))
kwargs.setdefault("linewidth", _linewidth)
kwargs.setdefault("markersize", _markersize)
kwargs.setdefault("marker", "o")
kwargs.setdefault("zorder", 3)
if extra_kwargs is None:
extra_kwargs = {}

extra_kwargs = matplotlib_kwarg_dealiaser(extra_kwargs, "plot")
if kind == "evolution":
extra_kwargs = {
**extra_kwargs,
Expand All @@ -263,28 +265,27 @@ def plot_ess(
kwargs.setdefault("label", "bulk")
extra_kwargs.setdefault("label", "tail")
else:
extra_kwargs.setdefault("linestyle", extra_kwargs.pop("ls", "-"))
extra_kwargs.setdefault("linewidth", extra_kwargs.pop("lw", _linewidth / 2))
extra_kwargs.setdefault("linestyle", "-")
extra_kwargs.setdefault("linewidth", _linewidth / 2)
extra_kwargs.setdefault("color", "k")
extra_kwargs.setdefault("alpha", 0.5)
kwargs.setdefault("label", kind)
if hline_kwargs is None:
hline_kwargs = {}
hline_kwargs.setdefault("linewidth", hline_kwargs.pop("lw", _linewidth))
hline_kwargs.setdefault("linestyle", hline_kwargs.pop("ls", "--"))
hline_kwargs.setdefault("color", hline_kwargs.pop("c", "gray"))

hline_kwargs = matplotlib_kwarg_dealiaser(hline_kwargs, "plot")
hline_kwargs.setdefault("linewidth", _linewidth)
hline_kwargs.setdefault("linestyle", "--")
hline_kwargs.setdefault("color", "gray")
hline_kwargs.setdefault("alpha", 0.7)
if extra_methods:
mean_ess = ess(data, var_names=var_names, method="mean", relative=relative)
sd_ess = ess(data, var_names=var_names, method="sd", relative=relative)
if text_kwargs is None:
text_kwargs = {}
text_kwargs = matplotlib_kwarg_dealiaser(text_kwargs, "text")
text_x = text_kwargs.pop("x", 1)
text_kwargs.setdefault("fontsize", text_kwargs.pop("size", xt_labelsize * 0.7))
text_kwargs.setdefault("fontsize", xt_labelsize * 0.7)
text_kwargs.setdefault("alpha", extra_kwargs["alpha"])
text_kwargs.setdefault("color", extra_kwargs["color"])
text_kwargs.setdefault("horizontalalignment", text_kwargs.pop("ha", "right"))
text_va = text_kwargs.pop("verticalalignment", text_kwargs.pop("va", None))
text_kwargs.setdefault("horizontalalignment", "right")
text_va = text_kwargs.pop("verticalalignment", None)

essplot_kwargs = dict(
ax=ax,
Expand Down
8 changes: 3 additions & 5 deletions arviz/plots/hpdplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from scipy.signal import savgol_filter

from ..stats import hpd
from .plot_utils import get_plotting_function
from .plot_utils import get_plotting_function, matplotlib_kwarg_dealiaser
from ..rcparams import rcParams


Expand Down Expand Up @@ -64,13 +64,11 @@ def plot_hpd(
-------
axes : matplotlib axes or bokeh figures
"""
if plot_kwargs is None:
plot_kwargs = {}
plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
plot_kwargs.setdefault("color", color)
plot_kwargs.setdefault("alpha", 0)

if fill_kwargs is None:
fill_kwargs = {}
fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "hexbin")
fill_kwargs.setdefault("color", color)
fill_kwargs.setdefault("alpha", 0.5)

Expand Down
17 changes: 14 additions & 3 deletions arviz/plots/jointplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Joint scatter plot of two variables."""
from ..data import convert_to_dataset
from .plot_utils import _scale_fig_size, xarray_var_iter, get_coords, get_plotting_function
from .plot_utils import (
_scale_fig_size,
xarray_var_iter,
get_coords,
get_plotting_function,
matplotlib_kwarg_dealiaser,
)
from ..utils import _var_names


Expand Down Expand Up @@ -156,8 +162,13 @@ def plot_joint(

figsize, ax_labelsize, _, xt_labelsize, linewidth, _ = _scale_fig_size(figsize, textsize)

if joint_kwargs is None:
joint_kwargs = {}
if kind == "kde":
types = "plot"
elif kind == "scatter":
types = "scatter"
else:
types = "hexbin"
joint_kwargs = matplotlib_kwarg_dealiaser(joint_kwargs, types)

if marginal_kwargs is None:
marginal_kwargs = {}
Expand Down
6 changes: 4 additions & 2 deletions arviz/plots/khatplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
color_from_dim,
format_coords_as_labels,
get_plotting_function,
matplotlib_kwarg_dealiaser,
)
from ..stats import ELPDData

Expand Down Expand Up @@ -131,8 +132,7 @@ def plot_khat(
>>> az.plot_khat(loo_radon, color=colors)

"""
if hlines_kwargs is None:
hlines_kwargs = {}
hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines")
hlines_kwargs.setdefault("linestyle", [":", "-.", "--", "-"])
hlines_kwargs.setdefault("alpha", 0.7)
hlines_kwargs.setdefault("zorder", -1)
Expand Down Expand Up @@ -172,6 +172,8 @@ def plot_khat(
if markersize is None:
markersize = scaled_markersize ** 2 # s in scatter plot mus be markersize square
# for dots to have the same size

kwargs = matplotlib_kwarg_dealiaser(kwargs, "scatter")
kwargs.setdefault("s", markersize)
kwargs.setdefault("marker", "+")
color_mapping = None
Expand Down
7 changes: 3 additions & 4 deletions arviz/plots/loopitplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_scale_fig_size,
get_plotting_function,
_fast_kde,
matplotlib_kwarg_dealiaser,
)
from ..rcparams import rcParams

Expand Down Expand Up @@ -137,8 +138,7 @@ def plot_loo_pit(
loo_pit = _loo_pit(idata=idata, y=y, y_hat=y_hat, log_weights=log_weights)
loo_pit = loo_pit.flatten() if isinstance(loo_pit, np.ndarray) else loo_pit.values.flatten()

if plot_kwargs is None:
plot_kwargs = {}
plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
plot_kwargs["color"] = to_hex(color)
plot_kwargs.setdefault("linewidth", linewidth * 1.4)
if isinstance(y, str):
Expand All @@ -155,8 +155,7 @@ def plot_loo_pit(
plot_kwargs.setdefault("label", label)
plot_kwargs.setdefault("zorder", 5)

if plot_unif_kwargs is None:
plot_unif_kwargs = {}
plot_unif_kwargs = matplotlib_kwarg_dealiaser(plot_unif_kwargs, "plot")
light_color = rgb_to_hsv(to_rgb(plot_kwargs.get("color")))
light_color[1] /= 2 # pylint: disable=unsupported-assignment-operation
light_color[2] += (1 - light_color[2]) / 2 # pylint: disable=unsupported-assignment-operation
Expand Down
Loading