diff --git a/carabiner/mpl/utils.py b/carabiner/mpl/utils.py index dc931f3..e80e1c0 100644 --- a/carabiner/mpl/utils.py +++ b/carabiner/mpl/utils.py @@ -12,7 +12,7 @@ "\nor reinstall carabiner with matplotlib:\n" "\n\t$ pip install carabiner[mpl]\n") else: - from matplotlib import axes, cycler, figure, rcParams + from matplotlib import axes, cycler, figure, rcParams, legend import numpy as np from pandas import DataFrame from tqdm.auto import tqdm @@ -87,6 +87,36 @@ def grid( ax.yaxis.set_tick_params(labelleft=True) return fig, axes + +def add_legend( + ax: axes.Axes, + **kwargs +) -> legend.Legend: + + """Add a legend to the right of a Matplotlib plotting axis. + + Uses a sensible default for putting the legend out of the way. Keyword arguments + override `loc` and `bbox_to_anchor`, and additional arguments are passed to + `matplotlib.axes.Axes.legend()`. + + Parameters + ---------- + ax : matplotlib.axes.Axes + Axes to add a legend to. + + Returns + ------- + matplotlib.legend.Legend + + """ + default_opts = { + "loc": 'center left', + "bbox_to_anchor": (1, .5) + } + default_opts.update(kwargs) + return ax.legend(**default_opts) + + def scattergrid( df: DataFrame, grid_columns: Union[str, Iterable[str]], @@ -95,6 +125,7 @@ def scattergrid( log: Optional[Union[str, Iterable[str]]] = None, n_bins: int = 40, scatter_opts: Optional[Mapping[str, Any]] = None, + hist_opts: Optional[Mapping[str, Any]] = None, legend_opts: Optional[Mapping[str, Any]] = None, *args, **kwargs ) -> TFigAx: @@ -127,6 +158,8 @@ def scattergrid( scatter grid. Default: 40. scatter_opts : dict, optional Extra keyword arguments to pass to the Matplotlib scatter plots. + hist_opts : dict, optional + Extra keyword arguments to pass to the Matplotlib histogram plots. legend_opts : dict, optional Extra keyword arguments to pass to the Matplotlib legend. @@ -150,11 +183,6 @@ def scattergrid( log = log or [] log = [name for name in cast(log, to=list) if name in all_names] - _scatter_opts = {"s": 3.} - _scatter_opts.update(scatter_opts or {}) - _legend_opts = {"loc": "center left", "bbox_to_anchor": (1., .5)} - _legend_opts.update(legend_opts or {}) - if grouping is None: grouping = "__group__" df = df.assign(__group__=grouping).groupby(grouping) @@ -168,18 +196,27 @@ def scattergrid( df = df.groupby(grouping) dummy_group = False + _scatter_opts = {"s": 3.} + _scatter_opts.update(scatter_opts or {}) + _hist_opts = {"alpha": .7} if not dummy_group else {} + _hist_opts.update(hist_opts or {}) + _legend_opts = {} + _legend_opts.update(legend_opts or {}) + fig, axes = grid( nrow=len(grid_rows), ncol=len(grid_columns), *args, **kwargs ) for axrow, grid_row_name in zip(tqdm(axes), grid_rows): - yscale = "log" if grid_row_name in log else "linear" for ax, grid_col_name in zip(axrow, grid_columns): + make_histogram = grid_row_name == grid_col_name xscale = "log" if grid_col_name in log else "linear" + yscale = "log" if (grid_row_name in log and not make_histogram) else "linear" + ylabel = grid_row_name if not make_histogram else "Frequency" for group_name, group_df in df: - extras = {"label": group_name} if not dummy_group else {} - if grid_row_name == grid_col_name: + labels = {"label": ":".join(map(str, group_name))} if not dummy_group else {} + if make_histogram: if xscale == "log": values = group_df[grid_col_name].values bins = np.geomspace( @@ -193,26 +230,28 @@ def scattergrid( grid_col_name, data=group_df, bins=bins, - **extras + **_hist_opts, + **labels, ) - ylabel = "Frequency" else: ax.scatter( grid_col_name, grid_row_name, data=group_df, **_scatter_opts, - **extras + **labels, ) - ylabel = grid_row_name ax.set( xlabel=grid_col_name, ylabel=ylabel, + xscale=xscale, + yscale=yscale, ) if not dummy_group: - ax.legend(**_legend_opts) + add_legend(ax, **_legend_opts) return fig, axes + def figsaver( output_dir: str = ".", prefix: Optional[str] = None, @@ -242,6 +281,7 @@ def figsaver( """ + prefix = prefix or "" if not os.path.exists(output_dir): os.mkdir(output_dir)