From a2f14b43140c7f8971da6fbf9fdbf725b0b28d6a Mon Sep 17 00:00:00 2001 From: Sebastian Weigand Date: Mon, 3 Jul 2023 11:40:47 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20add=5Fsubplot=5Flabels=20func?= =?UTF-8?q?tion=20(#181)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This adds `add_subplot_labels ` as a convenience function to add subplot labels. Instead of duplicating code for each axis possibly creating inconsistency. ```py axes[0].annotate("A", xy=(0.01, 0.89), xycoords="axes fraction", fontsize=16) axes[1].annotate("B", xy=(0.01, 0.89), xycoords="axes fraction", fontsize=16) axes[2].annotate("C", xy=(0.01, 0.89), xycoords="axes fraction", fontsize=16) ``` This function allows adding labels consistently for all axes ```py add_subplot_labels(axes, label_position=(0.01, 0.89), label_format_function="upper_case_letter") ``` ### Change summary - [✨ Added format_sub_plot_number_upper_case_letter function](https://github.com/glotaran/pyglotaran-extras/commit/46e826d9c230b5ccb71004bfaf77797b96968bf1) - [🩹📚 Fixed missing arg docstring for not_single_element_dims](https://github.com/glotaran/pyglotaran-extras/commit/3a13ae4b6be86541b81bf0c15a7e4ab706c19b4d) - [✨ Added ensure_axes_array function](https://github.com/glotaran/pyglotaran-extras/commit/f28ceb0d326ffe443e6cc2acaf5c91fada0198d6) - [✨ Added add_subplot_labels function](https://github.com/glotaran/pyglotaran-extras/commit/a6edab8286c05da2c8cecfd895f2749a4cd6ff6a) --- changelog.md | 1 + pyglotaran_extras/__init__.py | 2 + pyglotaran_extras/plotting/utils.py | 167 +++++++++++++++++++++++++++- pyglotaran_extras/types.py | 26 +++++ tests/plotting/test_utils.py | 94 ++++++++++++++++ 5 files changed, 286 insertions(+), 4 deletions(-) diff --git a/changelog.md b/changelog.md index 0c32d171..fb9d349a 100644 --- a/changelog.md +++ b/changelog.md @@ -7,6 +7,7 @@ - 🩹 Fix crashes of plot_doas and plot_coherent_artifact for non dispersive IRF (#173, #182) - 👌 Add minor ticks to linlog plots (#183) - 🚧📦 Remove upper python version limit (#174) +- ✨ Add add_subplot_labels function (#181) (changes-0_7_0)= diff --git a/pyglotaran_extras/__init__.py b/pyglotaran_extras/__init__.py index f351a273..2654b81a 100644 --- a/pyglotaran_extras/__init__.py +++ b/pyglotaran_extras/__init__.py @@ -10,6 +10,7 @@ from pyglotaran_extras.plotting.plot_overview import plot_simple_overview from pyglotaran_extras.plotting.plot_traces import plot_fitted_traces from pyglotaran_extras.plotting.plot_traces import select_plot_wavelengths +from pyglotaran_extras.plotting.utils import add_subplot_labels __all__ = [ "load_data", @@ -23,6 +24,7 @@ "plot_simple_overview", "plot_fitted_traces", "select_plot_wavelengths", + "add_subplot_labels", ] __version__ = "0.8.0.dev0" diff --git a/pyglotaran_extras/plotting/utils.py b/pyglotaran_extras/plotting/utils.py index 8081ebee..5e286603 100644 --- a/pyglotaran_extras/plotting/utils.py +++ b/pyglotaran_extras/plotting/utils.py @@ -1,6 +1,9 @@ """Module containing plotting utility functionality.""" from __future__ import annotations +from math import ceil +from math import log +from types import MappingProxyType from typing import TYPE_CHECKING from typing import Iterable from warnings import warn @@ -13,14 +16,19 @@ from pyglotaran_extras.io.utils import result_dataset_mapping if TYPE_CHECKING: + from typing import Callable from typing import Hashable + from typing import Literal + from typing import Mapping from cycler import Cycler from matplotlib.axis import Axis from matplotlib.figure import Figure from matplotlib.pyplot import Axes + from pyglotaran_extras.types import BuiltinSubPlotLabelFormatFunctionKey from pyglotaran_extras.types import ResultLike + from pyglotaran_extras.types import SubPlotLabelCoord class PlotDuplicationWarning(UserWarning): @@ -365,6 +373,25 @@ def get_shifted_traces( return shift_time_axis_by_irf_location(traces, irf_location) +def ensure_axes_array(axes: Axis | Axes) -> Axes: + """Ensure that axes have flatten method even if it is a single axis. + + Parameters + ---------- + axes: Axis | Axes + Axis or Axes to convert for API consistency. + + Returns + ------- + Axes + Numpy ndarray of axes. + """ + # We can't use `Axis` in isinstance so we check for the np.ndarray attribute of `Axes` + if hasattr(axes, "flatten") is False: + axes = np.array([axes]) + return axes + + def add_cycler_if_not_none(axis: Axis | Axes, cycler: Cycler | None) -> None: """Add cycler to and axis if it is not None. @@ -381,9 +408,7 @@ def add_cycler_if_not_none(axis: Axis | Axes, cycler: Cycler | None) -> None: Plot style cycler to use. """ if cycler is not None: - # We can't use `Axis` in isinstance so we check for the np.ndarray attribute of `Axes` - if hasattr(axis, "flatten") is False: - axis = np.array([axis]) + axis = ensure_axes_array(axis) for ax in axis.flatten(): ax.set_prop_cycle(cycler) @@ -475,7 +500,7 @@ def not_single_element_dims(data_array: xr.DataArray) -> list[Hashable]: Parameters ---------- data_array: xr.DataArray - _description_ + DataArray to check if it has only a single dimension. Returns ------- @@ -578,3 +603,137 @@ def tick_values(self, vmin: float, vmax: float) -> None: Not used """ raise NotImplementedError(f"Cannot get tick locations for a {type(self)} type.") + + +def format_sub_plot_number_upper_case_letter(sub_plot_number: int, size: None | int = None) -> str: + """Format ``sub_plot_number`` into an upper case letter, that can be used as label. + + Parameters + ---------- + sub_plot_number : int + Number of the subplot starting at One. + size : None | int + Size of the axes array (number of plots). Defaults to None + + Returns + ------- + str + Upper case label for a sub plot. + + Examples + -------- + >>> print(format_sub_plot_number_upper_case_letter(1)) + A + + >>> print(format_sub_plot_number_upper_case_letter(26)) + Z + + >>> print(format_sub_plot_number_upper_case_letter(27)) + AA + + >>> print(format_sub_plot_number_upper_case_letter(1, 26)) + AA + + >>> print(format_sub_plot_number_upper_case_letter(2, 26)) + AB + + >>> print(format_sub_plot_number_upper_case_letter(26, 26)) + AZ + + >>> print(format_sub_plot_number_upper_case_letter(27, 50)) + BA + + See Also + -------- + BuiltinLabelFormatFunctions + add_subplot_labels + """ + sub_plot_number -= 1 + if size is not None and size > 26: + return "".join( + format_sub_plot_number_upper_case_letter(((sub_plot_number // (26**i)) % 26) + 1) + for i in reversed(range(1, ceil(log(size, 26)))) + ) + format_sub_plot_number_upper_case_letter((sub_plot_number % 26) + 1) + if sub_plot_number < 26: + return chr(ord("A") + sub_plot_number) + return format_sub_plot_number_upper_case_letter( + sub_plot_number // 26 + ) + format_sub_plot_number_upper_case_letter((sub_plot_number % 26) + 1) + + +BuiltinSubPlotLabelFormatFunctions: Mapping[ + str, Callable[[int, int | None], str] +] = MappingProxyType( + { + "number": lambda x, y: f"{x}", + "upper_case_letter": format_sub_plot_number_upper_case_letter, + "lower_case_letter": lambda x, y: format_sub_plot_number_upper_case_letter(x, y).lower(), + } +) + + +def get_subplot_label_format_function( + format_function: BuiltinSubPlotLabelFormatFunctionKey | Callable[[int, int | None], str] +) -> Callable[[int, int | None], str]: + """Get subplot label function from ``BuiltinSubPlotLabelFormatFunctions`` if it is a key. + + This function is mainly needed for typing reasons. + + Parameters + ---------- + format_function : BuiltinSubPlotLabelFormatFunctionKey | Callable[[int, int | None], str] + Key ``BuiltinSubPlotLabelFormatFunctions`` to retrieve builtin function or user defined + format function. + + Returns + ------- + Callable[[int, int | None], str] + Function to format subplot label. + """ + if isinstance(format_function, str) and format_function in BuiltinSubPlotLabelFormatFunctions: + return BuiltinSubPlotLabelFormatFunctions[format_function] + return format_function # type:ignore[return-value] + + +def add_subplot_labels( + axes: Axis | Axes, + *, + label_position: tuple[float, float] = (-0.05, 1.05), + label_coords: SubPlotLabelCoord = "axes fraction", + direction: Literal["row", "column"] = "row", + label_format_template: str = "{}", + label_format_function: BuiltinSubPlotLabelFormatFunctionKey + | Callable[[int, int | None], str] = "number", + fontsize: int = 16, +) -> None: + """Add labels to all subplots in ``axes`` in a consistent manner. + + Parameters + ---------- + axes : Axis | Axes + Axes (subplots) on which the labels should be added. + label_position : tuple[float, float] + Position of the label in ``label_coords`` coordinates. + label_coords : SubPlotLabelCoord + Coordinate system used for ``label_position``. Defaults to "axes fraction" + direction : Literal["row", "column"] + Direct in which the axes should be iterated in. Defaults to "row" + label_format_template : str + Template string to inject the return value of ``label_format_function`` into. + Defaults to "{}" + label_format_function : BuiltinSubPlotLabelFormatFunctionKey | Callable[[int, int | None], str] + Function to calculate the label for the axis index and ``axes`` size. Defaults to "number" + fontsize : int + Font size used for the label. Defaults to 16 + """ + axes = ensure_axes_array(axes) + format_function = get_subplot_label_format_function(label_format_function) + if direction == "column": + axes = axes.T + for i, ax in enumerate(axes.flatten(), start=1): + ax.annotate( + label_format_template.format(format_function(i, axes.size)), + xy=label_position, + xycoords=label_coords, + fontsize=fontsize, + ) diff --git a/pyglotaran_extras/types.py b/pyglotaran_extras/types.py index 93bac6eb..2a40fbc7 100644 --- a/pyglotaran_extras/types.py +++ b/pyglotaran_extras/types.py @@ -2,8 +2,10 @@ from __future__ import annotations from pathlib import Path +from typing import Literal from typing import Mapping from typing import Sequence +from typing import TypeAlias from typing import Union import xarray as xr @@ -15,3 +17,27 @@ Result, DatasetConvertible, Mapping[str, DatasetConvertible], Sequence[DatasetConvertible] ] """Result like data which can be converted to a per dataset mapping.""" + + +BuiltinSubPlotLabelFormatFunctionKey: TypeAlias = Literal[ + "number", "upper_case_letter", "lower_case_letter" +] +"""Key supported by ``BuiltinLabelFormatFunctions``.""" + +SubPlotLabelCoordStrs: TypeAlias = Literal[ + "figure points", + "figure pixels", + "figure fraction", + "subfigure points", + "subfigure pixels", + "subfigure fraction", + "axes points", + "axes pixels", + "axes fraction", + "data", + "polar", +] + +SubPlotLabelCoord: TypeAlias = ( + SubPlotLabelCoordStrs | tuple[SubPlotLabelCoordStrs, SubPlotLabelCoordStrs] +) diff --git a/tests/plotting/test_utils.py b/tests/plotting/test_utils.py index 84c49efc..89b5c004 100644 --- a/tests/plotting/test_utils.py +++ b/tests/plotting/test_utils.py @@ -3,6 +3,7 @@ from typing import Hashable from typing import Iterable +from typing import Literal import matplotlib import matplotlib.pyplot as plt @@ -15,8 +16,12 @@ from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import abs_max from pyglotaran_extras.plotting.utils import add_cycler_if_not_none +from pyglotaran_extras.plotting.utils import add_subplot_labels from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi +from pyglotaran_extras.plotting.utils import ensure_axes_array +from pyglotaran_extras.plotting.utils import format_sub_plot_number_upper_case_letter from pyglotaran_extras.plotting.utils import not_single_element_dims +from pyglotaran_extras.types import SubPlotLabelCoord matplotlib.use("Agg") DEFAULT_CYCLER = plt.rcParams["axes.prop_cycle"] @@ -105,3 +110,92 @@ def test_calculate_ticks_in_units_of_pi( def test_not_single_element_dims(data_array: xr.DataArray, expected: list[Hashable]): """Only get dim with more than one element.""" assert not_single_element_dims(data_array) == expected + + +@pytest.mark.parametrize( + ("value", "size", "expected"), + ( + (1, None, "A"), + (2, None, "B"), + (26, None, "Z"), + (27, None, "AA"), + (26**2 + 26, None, "ZZ"), + (1, 26**2, "AA"), + (2, 26**2, "AB"), + (26, 26**2, "AZ"), + (26**2, 26**2, "ZZ"), + (1, 26**3, "AAA"), + (26**3, 26**3, "ZZZ"), + ), +) +def test_format_sub_plot_number_upper_case_letter(value: int, size: int | None, expected: str): + """Expected string format.""" + assert format_sub_plot_number_upper_case_letter(value, size) == expected + + +def test_ensure_axes_array(): + """Hasa flatten method.""" + _, ax = plt.subplots(1, 1) + assert hasattr(ax, "flatten") is False + assert hasattr(ensure_axes_array(ax), "flatten") is True + + _, axes = plt.subplots(1, 2) + assert hasattr(axes, "flatten") is True + assert hasattr(ensure_axes_array(axes), "flatten") is True + + +def test_add_subplot_labels_defaults(): + """Sanity check that default arguments got passed on to mpl annotate method.""" + _, axes = plt.subplots(2, 2) + + add_subplot_labels(axes) + + assert [ax.texts[0].get_text() for ax in axes.flatten()] == ["1", "2", "3", "4"] + assert [ax.texts[0].get_position() for ax in axes.flatten()] == pytest.approx( + [(-0.05, 1.05)] * 4 + ) + assert [ax.texts[0].get_anncoords() for ax in axes.flatten()] == ["axes fraction"] * 4 + assert [ax.texts[0].get_fontsize() for ax in axes.flatten()] == [16] * 4 + + +@pytest.mark.parametrize( + "direction, expected", (("row", ["1", "2", "3", "4"]), ("column", ["1", "3", "2", "4"])) +) +@pytest.mark.parametrize("label_position", ((0.01, 0.95), (-0.1, 1.0))) +@pytest.mark.parametrize("label_coords", ("data", ("axes fraction", "data"))) +@pytest.mark.parametrize("fontsize", (12, 26)) +def test_add_subplot_labels_assignment( + direction: Literal["row", "column"], + label_position: tuple[float, float], + label_coords: SubPlotLabelCoord, + fontsize: int, + expected: list[str], +): + """Test basic label text assignment.""" + _, axes = plt.subplots(2, 2) + + add_subplot_labels( + axes, + label_position=label_position, + label_coords=label_coords, + direction=direction, + fontsize=fontsize, + ) + + assert [ax.texts[0].get_text() for ax in axes.flatten()] == expected + assert [ax.texts[0].get_position() for ax in axes.flatten()] == pytest.approx( + [label_position] * 4 + ) + assert [ax.texts[0].get_anncoords() for ax in axes.flatten()] == [label_coords] * 4 + assert [ax.texts[0].get_fontsize() for ax in axes.flatten()] == [fontsize] * 4 + + plt.close() + + +@pytest.mark.parametrize("label_format_template, expected", (("{})", "1)"), ("({})", "(1)"))) +def test_add_subplot_labels_label_format_template(label_format_template: str, expected: str): + """Template is used.""" + _, ax = plt.subplots(1, 1) + add_subplot_labels(ax, label_format_template=label_format_template) + + assert ax.texts[0].get_text() == expected