Skip to content

Commit

Permalink
Added color/linestyles aesthetics and simplified min_ess plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
imperorrp committed Aug 2, 2024
1 parent 7820243 commit 28f2167
Showing 1 changed file with 31 additions and 38 deletions.
69 changes: 31 additions & 38 deletions src/arviz_plots/plots/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# imports
# import warnings
from copy import copy
from importlib import import_module

import arviz_stats # pylint: disable=unused-import
import numpy as np
Expand Down Expand Up @@ -145,6 +146,11 @@ def plot_ess(
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
)

# ensuring plot_kwargs['rug'] is not False
rug_kwargs = copy(plot_kwargs.get("rug", {}))
if rug_kwargs is False:
raise ValueError("plot_kwargs['rug'] can't be False, use rug=False to remove the rug")

# set plot collection initialization defaults if it doesnt exist
if plot_collection is None:
if backend is None:
Expand Down Expand Up @@ -184,6 +190,10 @@ def plot_ess(
aes_map = aes_map.copy()
aes_map.setdefault(kind, plot_collection.aes_set.difference({"overlay"}))
aes_map.setdefault("rug", {"overlay"})
if "model" in distribution:
aes_map.setdefault("mean", {"color"})
aes_map.setdefault("sd", {"color"})
aes_map.setdefault("min_ess", {"color"})
if labeller is None:
labeller = BaseLabeller()

Expand Down Expand Up @@ -233,9 +243,6 @@ def plot_ess(
# overlaying divergences(or other 'rug_kind') for each chain
if rug:
sample_stats = get_group(dt, "sample_stats", allow_missing=True)
rug_kwargs = copy(plot_kwargs.get("rug", {}))
if rug_kwargs is False:
raise ValueError("plot_kwargs['rug'] can't be False, use rug=False to remove the rug")
if (
sample_stats is not None
and rug_kind in sample_stats.data_vars
Expand All @@ -248,12 +255,6 @@ def plot_ess(
rug_kwargs.setdefault("color", "black")
if "marker" not in div_aes:
rug_kwargs.setdefault("marker", "|")
# WIP: if using a default linewidth once defined in backend/agnostic defaults
# if "width" not in div_aes:
# # get default linewidth for backends
# plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
# default_linewidth = plot_bknd.get_default_aes("linewidth", 1, {})
# rug_kwargs.setdefault("width", default_linewidth)
if "size" not in div_aes:
rug_kwargs.setdefault("size", 30)
div_reduce_dims = [dim for dim in distribution.dims if dim not in aux_dim_list]
Expand Down Expand Up @@ -290,6 +291,12 @@ def plot_ess(
x_range = [0, 1]
x_range = xr.DataArray(x_range)

# getting backend specific linestyles
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
linestyles = plot_bknd.get_default_aes("linestyle", 4, {})
# and default color
default_color = plot_bknd.get_default_aes("color", 1, {})[0]

# plot mean and sd
if extra_methods is not False:
mean_kwargs = copy(plot_kwargs.get("mean", {}))
Expand All @@ -300,13 +307,12 @@ def plot_ess(
mean_ess = distribution.azstats.ess(
dims=mean_dims, method="mean", relative=relative, **stats_kwargs.get("mean", {})
)
print(f"\n mean_ess = {mean_ess}")

if "linestyle" not in mean_aes:
if backend == "matplotlib":
mean_kwargs.setdefault("linestyle", "--")
elif backend == "bokeh":
mean_kwargs.setdefault("linestyle", "dashed")
# getting 2nd default linestyle for chosen backend and assigning it by default
mean_kwargs.setdefault("linestyle", linestyles[1])

if "color" not in mean_aes:
mean_kwargs.setdefault("color", default_color)

plot_collection.map(
line_xy,
Expand All @@ -323,13 +329,11 @@ def plot_ess(
sd_ess = distribution.azstats.ess(
dims=sd_dims, method="sd", relative=relative, **stats_kwargs.get("sd", {})
)
print(f"\n sd_ess = {sd_ess}")

if "linestyle" not in sd_aes:
if backend == "matplotlib":
sd_kwargs.setdefault("linestyle", "--")
elif backend == "bokeh":
sd_kwargs.setdefault("linestyle", "dashed")
sd_kwargs.setdefault("linestyle", linestyles[2])

if "color" not in sd_aes:
sd_kwargs.setdefault("color", default_color)

plot_collection.map(
line_xy, "sd", data=sd_ess, ignore_aes=sd_ignore, x=x_range, **sd_kwargs
Expand All @@ -339,36 +343,25 @@ def plot_ess(
min_ess_kwargs = copy(plot_kwargs.get("min_ess", {}))

if min_ess_kwargs is not False:
min_ess_dims, min_ess_aes, min_ess_ignore = filter_aes(
_, min_ess_aes, min_ess_ignore = filter_aes(
plot_collection, aes_map, "min_ess", sample_dims
)

if relative:
min_ess = min_ess / n_points

# for each variable of distribution, put min_ess as the value, reducing all min_ess_dims
min_ess_data = {}
for var in distribution.data_vars:
reduced_data = distribution[var].mean(
dim=[dim for dim in distribution[var].dims if dim in min_ess_dims]
)
min_ess_data[var] = xr.full_like(reduced_data, min_ess)

min_ess_dataset = xr.Dataset(min_ess_data)
print(f"\n min_ess = {min_ess_dataset}")
min_ess_kwargs.setdefault("linestyle", linestyles[3])

if "linestyle" not in min_ess_aes:
if backend == "matplotlib":
min_ess_kwargs.setdefault("linestyle", "--")
elif backend == "bokeh":
min_ess_kwargs.setdefault("linestyle", "dashed")
if "color" not in min_ess_aes:
min_ess_kwargs.setdefault("color", "gray")

plot_collection.map(
line_xy,
"min_ess",
data=min_ess_dataset,
data=distribution,
ignore_aes=min_ess_ignore,
x=x_range,
y=min_ess,
**min_ess_kwargs,
)

Expand Down

0 comments on commit 28f2167

Please sign in to comment.