Skip to content

Commit

Permalink
Add add_legend, general fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
eachanjohnson committed Oct 5, 2024
1 parent fdfa4f3 commit 0ba148b
Showing 1 changed file with 54 additions and 14 deletions.
68 changes: 54 additions & 14 deletions carabiner/mpl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"\nor reinstall carabiner with matplotlib:\n"
"\n\t$ pip install carabiner[mpl]\n")
else:
from matplotlib import axes, cycler, figure, rcParams
from matplotlib import axes, cycler, figure, rcParams, legend
import numpy as np
from pandas import DataFrame
from tqdm.auto import tqdm
Expand Down Expand Up @@ -87,6 +87,36 @@ def grid(
ax.yaxis.set_tick_params(labelleft=True)
return fig, axes


def add_legend(
ax: axes.Axes,
**kwargs
) -> legend.Legend:

"""Add a legend to the right of a Matplotlib plotting axis.
Uses a sensible default for putting the legend out of the way. Keyword arguments
override `loc` and `bbox_to_anchor`, and additional arguments are passed to
`matplotlib.axes.Axes.legend()`.
Parameters
----------
ax : matplotlib.axes.Axes
Axes to add a legend to.
Returns
-------
matplotlib.legend.Legend
"""
default_opts = {
"loc": 'center left',
"bbox_to_anchor": (1, .5)
}
default_opts.update(kwargs)
return ax.legend(**default_opts)


def scattergrid(
df: DataFrame,
grid_columns: Union[str, Iterable[str]],
Expand All @@ -95,6 +125,7 @@ def scattergrid(
log: Optional[Union[str, Iterable[str]]] = None,
n_bins: int = 40,
scatter_opts: Optional[Mapping[str, Any]] = None,
hist_opts: Optional[Mapping[str, Any]] = None,
legend_opts: Optional[Mapping[str, Any]] = None,
*args, **kwargs
) -> TFigAx:
Expand Down Expand Up @@ -127,6 +158,8 @@ def scattergrid(
scatter grid. Default: 40.
scatter_opts : dict, optional
Extra keyword arguments to pass to the Matplotlib scatter plots.
hist_opts : dict, optional
Extra keyword arguments to pass to the Matplotlib histogram plots.
legend_opts : dict, optional
Extra keyword arguments to pass to the Matplotlib legend.
Expand All @@ -150,11 +183,6 @@ def scattergrid(
log = log or []
log = [name for name in cast(log, to=list) if name in all_names]

_scatter_opts = {"s": 3.}
_scatter_opts.update(scatter_opts or {})
_legend_opts = {"loc": "center left", "bbox_to_anchor": (1., .5)}
_legend_opts.update(legend_opts or {})

if grouping is None:
grouping = "__group__"
df = df.assign(__group__=grouping).groupby(grouping)
Expand All @@ -168,18 +196,27 @@ def scattergrid(
df = df.groupby(grouping)
dummy_group = False

_scatter_opts = {"s": 3.}
_scatter_opts.update(scatter_opts or {})
_hist_opts = {"alpha": .7} if not dummy_group else {}
_hist_opts.update(hist_opts or {})
_legend_opts = {}
_legend_opts.update(legend_opts or {})

fig, axes = grid(
nrow=len(grid_rows),
ncol=len(grid_columns),
*args, **kwargs
)
for axrow, grid_row_name in zip(tqdm(axes), grid_rows):
yscale = "log" if grid_row_name in log else "linear"
for ax, grid_col_name in zip(axrow, grid_columns):
make_histogram = grid_row_name == grid_col_name
xscale = "log" if grid_col_name in log else "linear"
yscale = "log" if (grid_row_name in log and not make_histogram) else "linear"
ylabel = grid_row_name if not make_histogram else "Frequency"
for group_name, group_df in df:
extras = {"label": group_name} if not dummy_group else {}
if grid_row_name == grid_col_name:
labels = {"label": ":".join(map(str, group_name))} if not dummy_group else {}
if make_histogram:
if xscale == "log":
values = group_df[grid_col_name].values
bins = np.geomspace(
Expand All @@ -193,26 +230,28 @@ def scattergrid(
grid_col_name,
data=group_df,
bins=bins,
**extras
**_hist_opts,
**labels,
)
ylabel = "Frequency"
else:
ax.scatter(
grid_col_name,
grid_row_name,
data=group_df,
**_scatter_opts,
**extras
**labels,
)
ylabel = grid_row_name
ax.set(
xlabel=grid_col_name,
ylabel=ylabel,
xscale=xscale,
yscale=yscale,
)
if not dummy_group:
ax.legend(**_legend_opts)
add_legend(ax, **_legend_opts)
return fig, axes


def figsaver(
output_dir: str = ".",
prefix: Optional[str] = None,
Expand Down Expand Up @@ -242,6 +281,7 @@ def figsaver(
"""

prefix = prefix or ""
if not os.path.exists(output_dir):
os.mkdir(output_dir)

Expand Down

0 comments on commit 0ba148b

Please sign in to comment.