diff --git a/pyglotaran_extras/plotting/plot_data.py b/pyglotaran_extras/plotting/plot_data.py index 93b7b8fb..5252b944 100644 --- a/pyglotaran_extras/plotting/plot_data.py +++ b/pyglotaran_extras/plotting/plot_data.py @@ -12,11 +12,15 @@ from pyglotaran_extras.plotting.plot_svd import plot_lsv_data from pyglotaran_extras.plotting.plot_svd import plot_rsv_data from pyglotaran_extras.plotting.plot_svd import plot_sv_data +from pyglotaran_extras.plotting.utils import not_single_element_dims from pyglotaran_extras.plotting.utils import shift_time_axis_by_irf_location __all__ = ["plot_data_overview"] if TYPE_CHECKING: + from typing import Hashable + + import xarray as xr from glotaran.project.result import Result from matplotlib.figure import Figure from matplotlib.pyplot import Axes @@ -33,7 +37,7 @@ def plot_data_overview( nr_of_data_svd_vectors: int = 4, show_data_svd_legend: bool = True, irf_location: float | None = None, -) -> tuple[Figure, Axes]: +) -> tuple[Figure, Axes] | tuple[Figure, Axis]: """Plot data as filled contour plot and SVD components. Parameters @@ -59,10 +63,21 @@ def plot_data_overview( Returns ------- - tuple[Figure, Axes] + tuple[Figure, Axes]|tuple[Figure,Axis] Figure and axes which can then be refined by the user. """ dataset = load_data(dataset, _stacklevel=3) + data = shift_time_axis_by_irf_location(dataset.data, irf_location) + + if len(not_single_element_dims(data)) == 1: + return _plot_single_trace( + data, + not_single_element_dims(data)[0], + title="Single trace data", + linlog=linlog, + linthresh=linthresh, + figsize=figsize, + ) fig = plt.figure(figsize=figsize) data_ax = cast(Axis, plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig)) @@ -71,8 +86,6 @@ def plot_data_overview( sv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 1), fig=fig)) rsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 2), fig=fig)) - data = shift_time_axis_by_irf_location(dataset.data, irf_location) - if len(data.time) > 1: data.plot(x="time", ax=data_ax, center=False) else: @@ -97,3 +110,44 @@ def plot_data_overview( if linlog: data_ax.set_xscale("symlog", linthresh=linthresh) return fig, (data_ax, lsv_ax, sv_ax, rsv_ax) + + +def _plot_single_trace( + data_array: xr.DataArray, + x_dim: Hashable, + *, + title: str = "Single trace data", + linlog: bool = False, + linthresh: float = 1, + figsize: tuple[int, int] = (15, 10), +) -> tuple[Figure, Axis]: + """Plot single trace data in case ``plot_data_overview`` gets passed ingle trace data. + + Parameters + ---------- + data_array: xr.DataArray + DataArray containing only data of a single trace. + x_dim: Hashable + Name of the x dimension. + title: str + Title to add to the figure. Defaults to "Data overview". + linlog: bool + Whether to use 'symlog' scale or not. Defaults to False. + linthresh: float + A single float which defines the range (-x, x), within which the plot is linear. + This avoids having the plot go to infinity around zero. Defaults to 1. + figsize: tuple[int, int] + Size of the figure (N, M) in inches. Defaults to (15, 10). + + Returns + ------- + tuple[Figure, Axis] + Figure and axis which can then be refined by the user. + """ + fig, ax = plt.subplots(1, 1, figsize=figsize) + data_array.plot.line(x=x_dim, ax=ax) + fig.suptitle(title, fontsize=16) + + if linlog: + ax.set_xscale("symlog", linthresh=linthresh) + return fig, ax