diff --git a/pyglotaran_extras/plotting/utils.py b/pyglotaran_extras/plotting/utils.py index 69c38711..865400aa 100644 --- a/pyglotaran_extras/plotting/utils.py +++ b/pyglotaran_extras/plotting/utils.py @@ -28,6 +28,7 @@ from matplotlib.pyplot import Axes from pyglotaran_extras.types import BuiltinSubPlotLabelFormatFunctionKey + from pyglotaran_extras.types import CyclerColor from pyglotaran_extras.types import ResultLike from pyglotaran_extras.types import SubPlotLabelCoord @@ -215,6 +216,28 @@ def add_unique_figure_legend(fig: Figure, axes: Axes) -> None: fig.legend(*zip(*unique, strict=True)) +def get_next_cycler_color(ax: Axes) -> CyclerColor: + """Get next color from cycler assigned to ``ax``. + + Note + ---- + This will advance the cycler to the next state. + + Parameters + ---------- + ax : Axes + Axes to get the color from. + + Returns + ------- + CyclerColor + """ + # Matplotlib<3.8 compat + if hasattr(ax._get_lines, "prop_cycler"): + return next(ax._get_lines.prop_cycler) + return {"color": ax._get_lines.get_next_color()} + + def select_plot_wavelengths( result: ResultLike, axes_shape: tuple[int, int] = (4, 4), diff --git a/pyglotaran_extras/types.py b/pyglotaran_extras/types.py index 17b62d40..cbf013d0 100644 --- a/pyglotaran_extras/types.py +++ b/pyglotaran_extras/types.py @@ -5,12 +5,17 @@ from collections.abc import Mapping from collections.abc import Sequence from pathlib import Path +from typing import TYPE_CHECKING from typing import Literal from typing import TypeAlias +from typing import TypedDict import xarray as xr from glotaran.project.result import Result +if TYPE_CHECKING: + from pyglotaran_extras.plotting.style import ColorCode + class UnsetType: """Type for the ``Unset`` singleton.""" @@ -26,6 +31,13 @@ def __repr__(self) -> str: # noqa: DOC This way we can prevent regressions. """ + +class CyclerColor(TypedDict): + """Color value returned by a cycler.""" + + color: str | ColorCode + + DatasetConvertible: TypeAlias = xr.Dataset | xr.DataArray | str | Path """Types of data which can be converted to a dataset.""" ResultLike: TypeAlias = ( diff --git a/tests/plotting/test_utils.py b/tests/plotting/test_utils.py index bc5cb326..41c21e16 100644 --- a/tests/plotting/test_utils.py +++ b/tests/plotting/test_utils.py @@ -17,6 +17,7 @@ from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi from pyglotaran_extras.plotting.utils import ensure_axes_array from pyglotaran_extras.plotting.utils import format_sub_plot_number_upper_case_letter +from pyglotaran_extras.plotting.utils import get_next_cycler_color from pyglotaran_extras.plotting.utils import not_single_element_dims if TYPE_CHECKING: @@ -42,11 +43,9 @@ def test_add_cycler_if_not_none_single_axis(cycler: Cycler | None, expected_cycl ax = plt.subplot() add_cycler_if_not_none(ax, cycler) - ax_cycler = iter(ax._get_lines._cycler_items) - for _ in range(10): expected = next(expected_cycler) - assert next(ax_cycler) == expected + assert get_next_cycler_color(ax) == expected @pytest.mark.parametrize( @@ -58,13 +57,10 @@ def test_add_cycler_if_not_none_multiple_axes(cycler: Cycler | None, expected_cy _, axes = plt.subplots(1, 2) add_cycler_if_not_none(axes, cycler) - ax0_cycler = iter(axes[0]._get_lines._cycler_items) - ax1_cycler = iter(axes[1]._get_lines._cycler_items) - for _ in range(10): expected = next(expected_cycler) - assert next(ax0_cycler) == expected - assert next(ax1_cycler) == expected + assert get_next_cycler_color(axes[0]) == expected + assert get_next_cycler_color(axes[1]) == expected @pytest.mark.parametrize(