Skip to content

Commit

Permalink
✨ Add add_subplot_labels function (#181)
Browse files Browse the repository at this point in the history
This adds `add_subplot_labels ` as a convenience function to add subplot labels.

Instead of duplicating code for each axis possibly creating inconsistency.

```py
axes[0].annotate("A", xy=(0.01, 0.89), xycoords="axes fraction", fontsize=16)
axes[1].annotate("B", xy=(0.01, 0.89), xycoords="axes fraction", fontsize=16)
axes[2].annotate("C", xy=(0.01, 0.89), xycoords="axes fraction", fontsize=16)
```

This function allows adding labels consistently for all axes

```py
add_subplot_labels(axes, label_position=(0.01, 0.89), label_format_function="upper_case_letter")
```

### Change summary

- [✨ Added format_sub_plot_number_upper_case_letter
function](46e826d)
- [🩹📚 Fixed missing arg docstring for
not_single_element_dims](3a13ae4)
- [✨ Added ensure_axes_array
function](f28ceb0)
- [✨ Added add_subplot_labels
function](a6edab8)
  • Loading branch information
s-weigand authored Jul 3, 2023
1 parent 289c3c3 commit a2f14b4
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 4 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- 🩹 Fix crashes of plot_doas and plot_coherent_artifact for non dispersive IRF (#173, #182)
- 👌 Add minor ticks to linlog plots (#183)
- 🚧📦 Remove upper python version limit (#174)
- ✨ Add add_subplot_labels function (#181)

(changes-0_7_0)=

Expand Down
2 changes: 2 additions & 0 deletions pyglotaran_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pyglotaran_extras.plotting.plot_overview import plot_simple_overview
from pyglotaran_extras.plotting.plot_traces import plot_fitted_traces
from pyglotaran_extras.plotting.plot_traces import select_plot_wavelengths
from pyglotaran_extras.plotting.utils import add_subplot_labels

__all__ = [
"load_data",
Expand All @@ -23,6 +24,7 @@
"plot_simple_overview",
"plot_fitted_traces",
"select_plot_wavelengths",
"add_subplot_labels",
]

__version__ = "0.8.0.dev0"
167 changes: 163 additions & 4 deletions pyglotaran_extras/plotting/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Module containing plotting utility functionality."""
from __future__ import annotations

from math import ceil
from math import log
from types import MappingProxyType
from typing import TYPE_CHECKING
from typing import Iterable
from warnings import warn
Expand All @@ -13,14 +16,19 @@
from pyglotaran_extras.io.utils import result_dataset_mapping

if TYPE_CHECKING:
from typing import Callable
from typing import Hashable
from typing import Literal
from typing import Mapping

from cycler import Cycler
from matplotlib.axis import Axis
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes

from pyglotaran_extras.types import BuiltinSubPlotLabelFormatFunctionKey
from pyglotaran_extras.types import ResultLike
from pyglotaran_extras.types import SubPlotLabelCoord


class PlotDuplicationWarning(UserWarning):
Expand Down Expand Up @@ -365,6 +373,25 @@ def get_shifted_traces(
return shift_time_axis_by_irf_location(traces, irf_location)


def ensure_axes_array(axes: Axis | Axes) -> Axes:
"""Ensure that axes have flatten method even if it is a single axis.
Parameters
----------
axes: Axis | Axes
Axis or Axes to convert for API consistency.
Returns
-------
Axes
Numpy ndarray of axes.
"""
# We can't use `Axis` in isinstance so we check for the np.ndarray attribute of `Axes`
if hasattr(axes, "flatten") is False:
axes = np.array([axes])
return axes


def add_cycler_if_not_none(axis: Axis | Axes, cycler: Cycler | None) -> None:
"""Add cycler to and axis if it is not None.
Expand All @@ -381,9 +408,7 @@ def add_cycler_if_not_none(axis: Axis | Axes, cycler: Cycler | None) -> None:
Plot style cycler to use.
"""
if cycler is not None:
# We can't use `Axis` in isinstance so we check for the np.ndarray attribute of `Axes`
if hasattr(axis, "flatten") is False:
axis = np.array([axis])
axis = ensure_axes_array(axis)
for ax in axis.flatten():
ax.set_prop_cycle(cycler)

Expand Down Expand Up @@ -475,7 +500,7 @@ def not_single_element_dims(data_array: xr.DataArray) -> list[Hashable]:
Parameters
----------
data_array: xr.DataArray
_description_
DataArray to check if it has only a single dimension.
Returns
-------
Expand Down Expand Up @@ -578,3 +603,137 @@ def tick_values(self, vmin: float, vmax: float) -> None:
Not used
"""
raise NotImplementedError(f"Cannot get tick locations for a {type(self)} type.")


def format_sub_plot_number_upper_case_letter(sub_plot_number: int, size: None | int = None) -> str:
"""Format ``sub_plot_number`` into an upper case letter, that can be used as label.
Parameters
----------
sub_plot_number : int
Number of the subplot starting at One.
size : None | int
Size of the axes array (number of plots). Defaults to None
Returns
-------
str
Upper case label for a sub plot.
Examples
--------
>>> print(format_sub_plot_number_upper_case_letter(1))
A
>>> print(format_sub_plot_number_upper_case_letter(26))
Z
>>> print(format_sub_plot_number_upper_case_letter(27))
AA
>>> print(format_sub_plot_number_upper_case_letter(1, 26))
AA
>>> print(format_sub_plot_number_upper_case_letter(2, 26))
AB
>>> print(format_sub_plot_number_upper_case_letter(26, 26))
AZ
>>> print(format_sub_plot_number_upper_case_letter(27, 50))
BA
See Also
--------
BuiltinLabelFormatFunctions
add_subplot_labels
"""
sub_plot_number -= 1
if size is not None and size > 26:
return "".join(
format_sub_plot_number_upper_case_letter(((sub_plot_number // (26**i)) % 26) + 1)
for i in reversed(range(1, ceil(log(size, 26))))
) + format_sub_plot_number_upper_case_letter((sub_plot_number % 26) + 1)
if sub_plot_number < 26:
return chr(ord("A") + sub_plot_number)
return format_sub_plot_number_upper_case_letter(
sub_plot_number // 26
) + format_sub_plot_number_upper_case_letter((sub_plot_number % 26) + 1)


BuiltinSubPlotLabelFormatFunctions: Mapping[
str, Callable[[int, int | None], str]
] = MappingProxyType(
{
"number": lambda x, y: f"{x}",
"upper_case_letter": format_sub_plot_number_upper_case_letter,
"lower_case_letter": lambda x, y: format_sub_plot_number_upper_case_letter(x, y).lower(),
}
)


def get_subplot_label_format_function(
format_function: BuiltinSubPlotLabelFormatFunctionKey | Callable[[int, int | None], str]
) -> Callable[[int, int | None], str]:
"""Get subplot label function from ``BuiltinSubPlotLabelFormatFunctions`` if it is a key.
This function is mainly needed for typing reasons.
Parameters
----------
format_function : BuiltinSubPlotLabelFormatFunctionKey | Callable[[int, int | None], str]
Key ``BuiltinSubPlotLabelFormatFunctions`` to retrieve builtin function or user defined
format function.
Returns
-------
Callable[[int, int | None], str]
Function to format subplot label.
"""
if isinstance(format_function, str) and format_function in BuiltinSubPlotLabelFormatFunctions:
return BuiltinSubPlotLabelFormatFunctions[format_function]
return format_function # type:ignore[return-value]


def add_subplot_labels(
axes: Axis | Axes,
*,
label_position: tuple[float, float] = (-0.05, 1.05),
label_coords: SubPlotLabelCoord = "axes fraction",
direction: Literal["row", "column"] = "row",
label_format_template: str = "{}",
label_format_function: BuiltinSubPlotLabelFormatFunctionKey
| Callable[[int, int | None], str] = "number",
fontsize: int = 16,
) -> None:
"""Add labels to all subplots in ``axes`` in a consistent manner.
Parameters
----------
axes : Axis | Axes
Axes (subplots) on which the labels should be added.
label_position : tuple[float, float]
Position of the label in ``label_coords`` coordinates.
label_coords : SubPlotLabelCoord
Coordinate system used for ``label_position``. Defaults to "axes fraction"
direction : Literal["row", "column"]
Direct in which the axes should be iterated in. Defaults to "row"
label_format_template : str
Template string to inject the return value of ``label_format_function`` into.
Defaults to "{}"
label_format_function : BuiltinSubPlotLabelFormatFunctionKey | Callable[[int, int | None], str]
Function to calculate the label for the axis index and ``axes`` size. Defaults to "number"
fontsize : int
Font size used for the label. Defaults to 16
"""
axes = ensure_axes_array(axes)
format_function = get_subplot_label_format_function(label_format_function)
if direction == "column":
axes = axes.T
for i, ax in enumerate(axes.flatten(), start=1):
ax.annotate(
label_format_template.format(format_function(i, axes.size)),
xy=label_position,
xycoords=label_coords,
fontsize=fontsize,
)
26 changes: 26 additions & 0 deletions pyglotaran_extras/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from __future__ import annotations

from pathlib import Path
from typing import Literal
from typing import Mapping
from typing import Sequence
from typing import TypeAlias
from typing import Union

import xarray as xr
Expand All @@ -15,3 +17,27 @@
Result, DatasetConvertible, Mapping[str, DatasetConvertible], Sequence[DatasetConvertible]
]
"""Result like data which can be converted to a per dataset mapping."""


BuiltinSubPlotLabelFormatFunctionKey: TypeAlias = Literal[
"number", "upper_case_letter", "lower_case_letter"
]
"""Key supported by ``BuiltinLabelFormatFunctions``."""

SubPlotLabelCoordStrs: TypeAlias = Literal[
"figure points",
"figure pixels",
"figure fraction",
"subfigure points",
"subfigure pixels",
"subfigure fraction",
"axes points",
"axes pixels",
"axes fraction",
"data",
"polar",
]

SubPlotLabelCoord: TypeAlias = (
SubPlotLabelCoordStrs | tuple[SubPlotLabelCoordStrs, SubPlotLabelCoordStrs]
)
94 changes: 94 additions & 0 deletions tests/plotting/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import Hashable
from typing import Iterable
from typing import Literal

import matplotlib
import matplotlib.pyplot as plt
Expand All @@ -15,8 +16,12 @@
from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import abs_max
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import add_subplot_labels
from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi
from pyglotaran_extras.plotting.utils import ensure_axes_array
from pyglotaran_extras.plotting.utils import format_sub_plot_number_upper_case_letter
from pyglotaran_extras.plotting.utils import not_single_element_dims
from pyglotaran_extras.types import SubPlotLabelCoord

matplotlib.use("Agg")
DEFAULT_CYCLER = plt.rcParams["axes.prop_cycle"]
Expand Down Expand Up @@ -105,3 +110,92 @@ def test_calculate_ticks_in_units_of_pi(
def test_not_single_element_dims(data_array: xr.DataArray, expected: list[Hashable]):
"""Only get dim with more than one element."""
assert not_single_element_dims(data_array) == expected


@pytest.mark.parametrize(
("value", "size", "expected"),
(
(1, None, "A"),
(2, None, "B"),
(26, None, "Z"),
(27, None, "AA"),
(26**2 + 26, None, "ZZ"),
(1, 26**2, "AA"),
(2, 26**2, "AB"),
(26, 26**2, "AZ"),
(26**2, 26**2, "ZZ"),
(1, 26**3, "AAA"),
(26**3, 26**3, "ZZZ"),
),
)
def test_format_sub_plot_number_upper_case_letter(value: int, size: int | None, expected: str):
"""Expected string format."""
assert format_sub_plot_number_upper_case_letter(value, size) == expected


def test_ensure_axes_array():
"""Hasa flatten method."""
_, ax = plt.subplots(1, 1)
assert hasattr(ax, "flatten") is False
assert hasattr(ensure_axes_array(ax), "flatten") is True

_, axes = plt.subplots(1, 2)
assert hasattr(axes, "flatten") is True
assert hasattr(ensure_axes_array(axes), "flatten") is True


def test_add_subplot_labels_defaults():
"""Sanity check that default arguments got passed on to mpl annotate method."""
_, axes = plt.subplots(2, 2)

add_subplot_labels(axes)

assert [ax.texts[0].get_text() for ax in axes.flatten()] == ["1", "2", "3", "4"]
assert [ax.texts[0].get_position() for ax in axes.flatten()] == pytest.approx(
[(-0.05, 1.05)] * 4
)
assert [ax.texts[0].get_anncoords() for ax in axes.flatten()] == ["axes fraction"] * 4
assert [ax.texts[0].get_fontsize() for ax in axes.flatten()] == [16] * 4


@pytest.mark.parametrize(
"direction, expected", (("row", ["1", "2", "3", "4"]), ("column", ["1", "3", "2", "4"]))
)
@pytest.mark.parametrize("label_position", ((0.01, 0.95), (-0.1, 1.0)))
@pytest.mark.parametrize("label_coords", ("data", ("axes fraction", "data")))
@pytest.mark.parametrize("fontsize", (12, 26))
def test_add_subplot_labels_assignment(
direction: Literal["row", "column"],
label_position: tuple[float, float],
label_coords: SubPlotLabelCoord,
fontsize: int,
expected: list[str],
):
"""Test basic label text assignment."""
_, axes = plt.subplots(2, 2)

add_subplot_labels(
axes,
label_position=label_position,
label_coords=label_coords,
direction=direction,
fontsize=fontsize,
)

assert [ax.texts[0].get_text() for ax in axes.flatten()] == expected
assert [ax.texts[0].get_position() for ax in axes.flatten()] == pytest.approx(
[label_position] * 4
)
assert [ax.texts[0].get_anncoords() for ax in axes.flatten()] == [label_coords] * 4
assert [ax.texts[0].get_fontsize() for ax in axes.flatten()] == [fontsize] * 4

plt.close()


@pytest.mark.parametrize("label_format_template, expected", (("{})", "1)"), ("({})", "(1)")))
def test_add_subplot_labels_label_format_template(label_format_template: str, expected: str):
"""Template is used."""
_, ax = plt.subplots(1, 1)
add_subplot_labels(ax, label_format_template=label_format_template)

assert ax.texts[0].get_text() == expected

0 comments on commit a2f14b4

Please sign in to comment.