Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Add de-aliaser function for plots #1073

Merged
merged 22 commits into from
Mar 2, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions arviz/plots/backends/matplotlib/distplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from . import backend_show
from ...kdeplot import plot_kde
from ...plot_utils import dealiaser


def plot_dist(
Expand Down Expand Up @@ -51,6 +52,8 @@ def plot_dist(
elif kind == "kde":
if plot_kwargs is None:
plot_kwargs = {}
else:
plot_kwargs = dealiaser(plot_kwargs, type="plot")

plot_kwargs.setdefault("color", color)
legend = label is not None
Expand Down
3 changes: 3 additions & 0 deletions arviz/plots/backends/matplotlib/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...plot_utils import (
make_label,
_create_axes_grid,
dealiaser,
)


Expand Down Expand Up @@ -65,6 +66,8 @@ def plot_ess(
elif rug:
if rug_kwargs is None:
rug_kwargs = {}
else:
rug_kwargs = dealiaser(rug_kwargs, type="plot")
if not hasattr(idata, "sample_stats"):
raise ValueError("InferenceData object must contain sample_stats for rug plot")
if not hasattr(idata.sample_stats, rug_kind):
Expand Down
14 changes: 13 additions & 1 deletion arviz/plots/backends/matplotlib/kdeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from . import backend_show
from ...plot_utils import _scale_fig_size
from ...plot_utils import _scale_fig_size, dealiaser


def plot_kde(
Expand Down Expand Up @@ -55,17 +55,23 @@ def plot_kde(
if values2 is None:
if plot_kwargs is None:
plot_kwargs = {}
else:
plot_kwargs = dealiaser(plot_kwargs, type="plot")
plot_kwargs.setdefault("color", "C0")

default_color = plot_kwargs.get("color")

if fill_kwargs is None:
fill_kwargs = {}
else:
fill_kwargs = dealiaser(fill_kwargs, type="hexbin")

fill_kwargs.setdefault("color", default_color)

if rug_kwargs is None:
rug_kwargs = {}
else:
rug_kwargs = dealiaser(rug_kwargs, type="plot")
rug_kwargs.setdefault("marker", "_" if rotated else "|")
rug_kwargs.setdefault("linestyle", "None")
rug_kwargs.setdefault("color", default_color)
Expand Down Expand Up @@ -124,11 +130,17 @@ def plot_kde(
else:
if contour_kwargs is None:
contour_kwargs = {}
else:
contour_kwargs = dealiaser(contour_kwargs, type="contour")
contour_kwargs.setdefault("colors", "0.5")
if contourf_kwargs is None:
contourf_kwargs = {}
else:
contourf_kwargs = dealiaser(contourf_kwargs, type="contour")
if pcolormesh_kwargs is None:
pcolormesh_kwargs = {}
else:
pcolormesh_kwargs = dealiaser(pcolormesh_kwargs, type="pcolormesh")

# gridsize = (128, 128) if contour else (256, 256)

Expand Down
3 changes: 3 additions & 0 deletions arviz/plots/backends/matplotlib/mcseplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...plot_utils import (
make_label,
_create_axes_grid,
dealiaser,
)


Expand Down Expand Up @@ -89,6 +90,8 @@ def plot_mcse(
if rug:
if rug_kwargs is None:
rug_kwargs = {}
else:
rug_kwargs = dealiaser(rug_kwargs, type="plot")
if not hasattr(idata, "sample_stats"):
raise ValueError("InferenceData object must contain sample_stats for rug plot")
if not hasattr(idata.sample_stats, rug_kind):
Expand Down
4 changes: 3 additions & 1 deletion arviz/plots/distplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pylint: disable=unexpected-keyword-arg
"""Plot distribution as histogram or kernel density estimates."""
from .plot_utils import get_bins, get_plotting_function
from .plot_utils import get_bins, get_plotting_function, dealiaser


def plot_dist(
Expand Down Expand Up @@ -149,6 +149,8 @@ def plot_dist(
if kind == "hist":
if hist_kwargs is None:
hist_kwargs = {}
else:
hist_kwargs = dealiaser(hist_kwargs, type="hist")

hist_kwargs.setdefault("bins", get_bins(values))
hist_kwargs.setdefault("cumulative", cumulative)
Expand Down
11 changes: 10 additions & 1 deletion arviz/plots/elpdplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from matplotlib.lines import Line2D

from ..data import convert_to_inference_data
from .plot_utils import get_coords, format_coords_as_labels, color_from_dim, get_plotting_function
from .plot_utils import (
get_coords,
format_coords_as_labels,
color_from_dim,
get_plotting_function,
dealiaser,
)
from ..stats import waic, loo, ELPDData
from ..rcparams import rcParams

Expand Down Expand Up @@ -144,6 +150,9 @@ def plot_elpd(

if plot_kwargs is None:
plot_kwargs = {}
else:
plot_kwargs = dealiaser(plot_kwargs, type="scatter")

if backend == "bokeh":
plot_kwargs.setdefault("marker", rcParams["plot.bokeh.marker"])

Expand Down
5 changes: 4 additions & 1 deletion arviz/plots/energyplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from ..data import convert_to_dataset
from .plot_utils import _scale_fig_size, get_plotting_function
from .plot_utils import _scale_fig_size, get_plotting_function, dealiaser


def plot_energy(
Expand Down Expand Up @@ -97,6 +97,9 @@ def plot_energy(

if plot_kwargs is None:
plot_kwargs = {}
else:
types = "hist" if kind in {"hist", "histogram"} else "plot"
plot_kwargs = dealiaser(plot_kwargs, type=types)

figsize, _, _, xt_labelsize, linewidth, _ = _scale_fig_size(figsize, textsize, 1, 1)

Expand Down
7 changes: 7 additions & 0 deletions arviz/plots/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_coords,
filter_plotters_list,
get_plotting_function,
dealiaser,
)
from ..utils import _var_names

Expand Down Expand Up @@ -255,6 +256,8 @@ def plot_ess(
kwargs.setdefault("zorder", 3)
if extra_kwargs is None:
extra_kwargs = {}
else:
extra_kwargs = dealiaser(extra_kwargs, type="plot")
if kind == "evolution":
extra_kwargs = {
**extra_kwargs,
Expand All @@ -270,6 +273,8 @@ def plot_ess(
kwargs.setdefault("label", kind)
if hline_kwargs is None:
hline_kwargs = {}
else:
hline_kwargs = dealiaser(hline_kwargs, type="plot")
hline_kwargs.setdefault("linewidth", hline_kwargs.pop("lw", _linewidth))
hline_kwargs.setdefault("linestyle", hline_kwargs.pop("ls", "--"))
percygautam marked this conversation as resolved.
Show resolved Hide resolved
hline_kwargs.setdefault("color", hline_kwargs.pop("c", "gray"))
Expand All @@ -279,6 +284,8 @@ def plot_ess(
sd_ess = ess(data, var_names=var_names, method="sd", relative=relative)
if text_kwargs is None:
text_kwargs = {}
else:
text_kwargs = dealiaser(text_kwargs, type="text")
text_x = text_kwargs.pop("x", 1)
text_kwargs.setdefault("fontsize", text_kwargs.pop("size", xt_labelsize * 0.7))
text_kwargs.setdefault("alpha", extra_kwargs["alpha"])
Expand Down
4 changes: 3 additions & 1 deletion arviz/plots/hpdplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from scipy.signal import savgol_filter

from ..stats import hpd
from .plot_utils import get_plotting_function
from .plot_utils import get_plotting_function, dealiaser
from ..rcparams import rcParams


Expand Down Expand Up @@ -66,6 +66,8 @@ def plot_hpd(
"""
if plot_kwargs is None:
plot_kwargs = {}
else:
plot_kwargs = dealiaser(plot_kwargs, type="plot")
plot_kwargs.setdefault("color", color)
plot_kwargs.setdefault("alpha", 0)

Expand Down
16 changes: 15 additions & 1 deletion arviz/plots/jointplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Joint scatter plot of two variables."""
from ..data import convert_to_dataset
from .plot_utils import _scale_fig_size, xarray_var_iter, get_coords, get_plotting_function
from .plot_utils import (
_scale_fig_size,
xarray_var_iter,
get_coords,
get_plotting_function,
dealiaser,
)
from ..utils import _var_names


Expand Down Expand Up @@ -158,6 +164,14 @@ def plot_joint(

if joint_kwargs is None:
joint_kwargs = {}
else:
if kind == "kde":
types = "plot"
elif kind == "scatter":
types = "scatter"
else:
types = "hexbin"
joint_kwargs = dealiaser(joint_kwargs, type=types)

if marginal_kwargs is None:
marginal_kwargs = {}
Expand Down
6 changes: 6 additions & 0 deletions arviz/plots/khatplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
color_from_dim,
format_coords_as_labels,
get_plotting_function,
dealiaser,
)
from ..stats import ELPDData

Expand Down Expand Up @@ -133,6 +134,8 @@ def plot_khat(
"""
if hlines_kwargs is None:
hlines_kwargs = {}
else:
hlines_kwargs = dealiaser(hlines_kwargs, type="hlines")
hlines_kwargs.setdefault("linestyle", [":", "-.", "--", "-"])
hlines_kwargs.setdefault("alpha", 0.7)
hlines_kwargs.setdefault("zorder", -1)
Expand Down Expand Up @@ -172,6 +175,9 @@ def plot_khat(
if markersize is None:
markersize = scaled_markersize ** 2 # s in scatter plot mus be markersize square
# for dots to have the same size
# scatter plot kwargs dealiasing
if kwargs is not None:
kwargs = dealiaser(kwargs, type="scatter")
kwargs.setdefault("s", markersize)
kwargs.setdefault("marker", "+")
color_mapping = None
Expand Down
5 changes: 5 additions & 0 deletions arviz/plots/loopitplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_scale_fig_size,
get_plotting_function,
_fast_kde,
dealiaser,
)
from ..rcparams import rcParams

Expand Down Expand Up @@ -139,6 +140,8 @@ def plot_loo_pit(

if plot_kwargs is None:
plot_kwargs = {}
else:
plot_kwargs = dealiaser(plot_kwargs, type="plot")
plot_kwargs["color"] = to_hex(color)
plot_kwargs.setdefault("linewidth", linewidth * 1.4)
if isinstance(y, str):
Expand All @@ -157,6 +160,8 @@ def plot_loo_pit(

if plot_unif_kwargs is None:
plot_unif_kwargs = {}
else:
plot_unif_kwargs = dealiaser(plot_unif_kwargs, type="plot")
light_color = rgb_to_hsv(to_rgb(plot_kwargs.get("color")))
light_color[1] /= 2 # pylint: disable=unsupported-assignment-operation
light_color[2] += (1 - light_color[2]) / 2 # pylint: disable=unsupported-assignment-operation
Expand Down
5 changes: 5 additions & 0 deletions arviz/plots/mcseplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_coords,
filter_plotters_list,
get_plotting_function,
dealiaser,
)
from ..utils import _var_names

Expand Down Expand Up @@ -140,6 +141,8 @@ def plot_mcse(
kwargs.setdefault("zorder", 3)
if extra_kwargs is None:
extra_kwargs = {}
else:
extra_kwargs = dealiaser(extra_kwargs, type="plot")
extra_kwargs.setdefault("linestyle", extra_kwargs.pop("ls", "-"))
extra_kwargs.setdefault("linewidth", extra_kwargs.pop("lw", _linewidth / 2))
percygautam marked this conversation as resolved.
Show resolved Hide resolved
extra_kwargs.setdefault("color", "k")
Expand All @@ -149,6 +152,8 @@ def plot_mcse(
sd_mcse = mcse(data, var_names=var_names, method="sd")
if text_kwargs is None:
text_kwargs = {}
else:
text_kwargs = dealiaser(text_kwargs, type="text")
text_x = text_kwargs.pop("x", 1)
text_kwargs.setdefault("fontsize", text_kwargs.pop("size", xt_labelsize * 0.7))
percygautam marked this conversation as resolved.
Show resolved Hide resolved
text_kwargs.setdefault("alpha", extra_kwargs["alpha"])
Expand Down
17 changes: 12 additions & 5 deletions arviz/plots/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from ..data import convert_to_dataset, convert_to_inference_data
from .plot_utils import xarray_to_ndarray, get_coords, get_plotting_function
from .plot_utils import xarray_to_ndarray, get_coords, get_plotting_function, dealiaser
from ..utils import _var_names


Expand Down Expand Up @@ -131,13 +131,20 @@ def plot_pair(

if plot_kwargs is None:
plot_kwargs = {}

if kind == "scatter":
plot_kwargs.setdefault("marker", ".")
plot_kwargs.setdefault("lw", 0)
else:
if kind == "scatter":
plot_kwargs = dealiaser(plot_kwargs, type="scatter")
plot_kwargs.setdefault("marker", ".")
plot_kwargs.setdefault("lw", 0)
percygautam marked this conversation as resolved.
Show resolved Hide resolved
elif kind == "kde":
plot_kwargs = dealiaser(plot_kwargs, type="plot")
else:
plot_kwargs = dealiaser(plot_kwargs, type="hexbin")

if divergences_kwargs is None:
divergences_kwargs = {}
else:
divergences_kwargs = dealiaser(divergences_kwargs, type="plot")

divergences_kwargs.setdefault("marker", "o")
divergences_kwargs.setdefault("markeredgecolor", "k")
Expand Down
16 changes: 16 additions & 0 deletions arviz/plots/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.cbook as cbook
import xarray as xr


Expand Down Expand Up @@ -897,3 +898,18 @@ def _fast_kde_2d(x, y, gridsize=(128, 128), circular=False):
grid /= norm_factor

return grid, xmin, xmax, ymin, ymax


def dealiaser(args, type):
"""De-aliase the kwargs passed to plots."""
dealiaser_dict = {
"scatter": mpl.collections.PathCollection,
"plot": mpl.lines.Line2D,
"hist": mpl.patches.Patch,
"hexbin": mpl.collections.PolyCollection,
"hlines": mpl.collections.LineCollection,
"text": mpl.text.Text,
"contour": mpl.contour.ContourSet,
"pcolormesh": mpl.collections.QuadMesh,
}
return cbook.normalize_kwargs(args, getattr(dealiaser_dict[type], "_alias_map", {}))
Loading