diff --git a/pyglotaran_extras/plotting/plot_data.py b/pyglotaran_extras/plotting/plot_data.py index 93b7b8fb..5252b944 100644 --- a/pyglotaran_extras/plotting/plot_data.py +++ b/pyglotaran_extras/plotting/plot_data.py @@ -12,11 +12,15 @@ from pyglotaran_extras.plotting.plot_svd import plot_lsv_data from pyglotaran_extras.plotting.plot_svd import plot_rsv_data from pyglotaran_extras.plotting.plot_svd import plot_sv_data +from pyglotaran_extras.plotting.utils import not_single_element_dims from pyglotaran_extras.plotting.utils import shift_time_axis_by_irf_location __all__ = ["plot_data_overview"] if TYPE_CHECKING: + from typing import Hashable + + import xarray as xr from glotaran.project.result import Result from matplotlib.figure import Figure from matplotlib.pyplot import Axes @@ -33,7 +37,7 @@ def plot_data_overview( nr_of_data_svd_vectors: int = 4, show_data_svd_legend: bool = True, irf_location: float | None = None, -) -> tuple[Figure, Axes]: +) -> tuple[Figure, Axes] | tuple[Figure, Axis]: """Plot data as filled contour plot and SVD components. Parameters @@ -59,10 +63,21 @@ def plot_data_overview( Returns ------- - tuple[Figure, Axes] + tuple[Figure, Axes]|tuple[Figure,Axis] Figure and axes which can then be refined by the user. """ dataset = load_data(dataset, _stacklevel=3) + data = shift_time_axis_by_irf_location(dataset.data, irf_location) + + if len(not_single_element_dims(data)) == 1: + return _plot_single_trace( + data, + not_single_element_dims(data)[0], + title="Single trace data", + linlog=linlog, + linthresh=linthresh, + figsize=figsize, + ) fig = plt.figure(figsize=figsize) data_ax = cast(Axis, plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig)) @@ -71,8 +86,6 @@ def plot_data_overview( sv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 1), fig=fig)) rsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 2), fig=fig)) - data = shift_time_axis_by_irf_location(dataset.data, irf_location) - if len(data.time) > 1: data.plot(x="time", ax=data_ax, center=False) else: @@ -97,3 +110,44 @@ def plot_data_overview( if linlog: data_ax.set_xscale("symlog", linthresh=linthresh) return fig, (data_ax, lsv_ax, sv_ax, rsv_ax) + + +def _plot_single_trace( + data_array: xr.DataArray, + x_dim: Hashable, + *, + title: str = "Single trace data", + linlog: bool = False, + linthresh: float = 1, + figsize: tuple[int, int] = (15, 10), +) -> tuple[Figure, Axis]: + """Plot single trace data in case ``plot_data_overview`` gets passed ingle trace data. + + Parameters + ---------- + data_array: xr.DataArray + DataArray containing only data of a single trace. + x_dim: Hashable + Name of the x dimension. + title: str + Title to add to the figure. Defaults to "Data overview". + linlog: bool + Whether to use 'symlog' scale or not. Defaults to False. + linthresh: float + A single float which defines the range (-x, x), within which the plot is linear. + This avoids having the plot go to infinity around zero. Defaults to 1. + figsize: tuple[int, int] + Size of the figure (N, M) in inches. Defaults to (15, 10). + + Returns + ------- + tuple[Figure, Axis] + Figure and axis which can then be refined by the user. + """ + fig, ax = plt.subplots(1, 1, figsize=figsize) + data_array.plot.line(x=x_dim, ax=ax) + fig.suptitle(title, fontsize=16) + + if linlog: + ax.set_xscale("symlog", linthresh=linthresh) + return fig, ax diff --git a/pyglotaran_extras/plotting/utils.py b/pyglotaran_extras/plotting/utils.py index 455dc7cf..7449f531 100644 --- a/pyglotaran_extras/plotting/utils.py +++ b/pyglotaran_extras/plotting/utils.py @@ -463,3 +463,22 @@ def calculate_ticks_in_units_of_pi( return tick_labels * np.pi, ( str(val) for val in pretty_format_numerical_iterable(tick_labels, decimal_places=1) ) + + +def not_single_element_dims(data_array: xr.DataArray) -> list[Hashable]: + """Names of dimensions in ``data`` which don't have a size equal to one. + + This helper function is for example used to determine if a data only have a single trace, + since this requires different plotting code (e.g. ``data_array.plot.line(x="time")``). + + Parameters + ---------- + data_array: xr.DataArray + _description_ + + Returns + ------- + list[Hashable] + Names of dimensions in ``data`` which don't have a size equal to one. + """ + return [dim for dim, values in data_array.coords.items() if values.size != 1] diff --git a/tests/plotting/test_utils.py b/tests/plotting/test_utils.py index aede5a88..84c49efc 100644 --- a/tests/plotting/test_utils.py +++ b/tests/plotting/test_utils.py @@ -16,6 +16,7 @@ from pyglotaran_extras.plotting.utils import abs_max from pyglotaran_extras.plotting.utils import add_cycler_if_not_none from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi +from pyglotaran_extras.plotting.utils import not_single_element_dims matplotlib.use("Agg") DEFAULT_CYCLER = plt.rcParams["axes.prop_cycle"] @@ -85,3 +86,22 @@ def test_calculate_ticks_in_units_of_pi( assert np.allclose(list(tick_values), expected_tick_values) assert list(tick_labels) == expected_tick_labels + + +@pytest.mark.parametrize( + "data_array, expected", + ( + (xr.DataArray([1]), []), + (xr.DataArray([1], coords={"dim1": [1]}), []), + (xr.DataArray([[1], [1]], coords={"dim1": [1, 2], "dim2": [1]}), ["dim1"]), + ( + xr.DataArray( + [[[1, 1]], [[1, 1]]], coords={"dim1": [1, 2], "dim2": [1], "dim3": [1, 2]} + ), + ["dim1", "dim3"], + ), + ), +) +def test_not_single_element_dims(data_array: xr.DataArray, expected: list[Hashable]): + """Only get dim with more than one element.""" + assert not_single_element_dims(data_array) == expected