Skip to content

Commit

Permalink
👌 Make plot_data_overview able to plot single trace data (#137)
Browse files Browse the repository at this point in the history
This change allows plot_data_overview to be able to plot single trace data without crashing on a single axis dataset.
The SVD part of the plot will be skipped since it does not make sense.

* ✨ Implemented not_single_element_dims util function

* 👌 Made plot_data_overview able to plot single trace data
  • Loading branch information
s-weigand authored Feb 23, 2023
1 parent 08d18d3 commit 0ed9a7b
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 4 deletions.
62 changes: 58 additions & 4 deletions pyglotaran_extras/plotting/plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
from pyglotaran_extras.plotting.plot_svd import plot_lsv_data
from pyglotaran_extras.plotting.plot_svd import plot_rsv_data
from pyglotaran_extras.plotting.plot_svd import plot_sv_data
from pyglotaran_extras.plotting.utils import not_single_element_dims
from pyglotaran_extras.plotting.utils import shift_time_axis_by_irf_location

__all__ = ["plot_data_overview"]

if TYPE_CHECKING:
from typing import Hashable

import xarray as xr
from glotaran.project.result import Result
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes
Expand All @@ -33,7 +37,7 @@ def plot_data_overview(
nr_of_data_svd_vectors: int = 4,
show_data_svd_legend: bool = True,
irf_location: float | None = None,
) -> tuple[Figure, Axes]:
) -> tuple[Figure, Axes] | tuple[Figure, Axis]:
"""Plot data as filled contour plot and SVD components.
Parameters
Expand All @@ -59,10 +63,21 @@ def plot_data_overview(
Returns
-------
tuple[Figure, Axes]
tuple[Figure, Axes]|tuple[Figure,Axis]
Figure and axes which can then be refined by the user.
"""
dataset = load_data(dataset, _stacklevel=3)
data = shift_time_axis_by_irf_location(dataset.data, irf_location)

if len(not_single_element_dims(data)) == 1:
return _plot_single_trace(
data,
not_single_element_dims(data)[0],
title="Single trace data",
linlog=linlog,
linthresh=linthresh,
figsize=figsize,
)

fig = plt.figure(figsize=figsize)
data_ax = cast(Axis, plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig))
Expand All @@ -71,8 +86,6 @@ def plot_data_overview(
sv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 1), fig=fig))
rsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 2), fig=fig))

data = shift_time_axis_by_irf_location(dataset.data, irf_location)

if len(data.time) > 1:
data.plot(x="time", ax=data_ax, center=False)
else:
Expand All @@ -97,3 +110,44 @@ def plot_data_overview(
if linlog:
data_ax.set_xscale("symlog", linthresh=linthresh)
return fig, (data_ax, lsv_ax, sv_ax, rsv_ax)


def _plot_single_trace(
data_array: xr.DataArray,
x_dim: Hashable,
*,
title: str = "Single trace data",
linlog: bool = False,
linthresh: float = 1,
figsize: tuple[int, int] = (15, 10),
) -> tuple[Figure, Axis]:
"""Plot single trace data in case ``plot_data_overview`` gets passed ingle trace data.
Parameters
----------
data_array: xr.DataArray
DataArray containing only data of a single trace.
x_dim: Hashable
Name of the x dimension.
title: str
Title to add to the figure. Defaults to "Data overview".
linlog: bool
Whether to use 'symlog' scale or not. Defaults to False.
linthresh: float
A single float which defines the range (-x, x), within which the plot is linear.
This avoids having the plot go to infinity around zero. Defaults to 1.
figsize: tuple[int, int]
Size of the figure (N, M) in inches. Defaults to (15, 10).
Returns
-------
tuple[Figure, Axis]
Figure and axis which can then be refined by the user.
"""
fig, ax = plt.subplots(1, 1, figsize=figsize)
data_array.plot.line(x=x_dim, ax=ax)
fig.suptitle(title, fontsize=16)

if linlog:
ax.set_xscale("symlog", linthresh=linthresh)
return fig, ax
19 changes: 19 additions & 0 deletions pyglotaran_extras/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,3 +463,22 @@ def calculate_ticks_in_units_of_pi(
return tick_labels * np.pi, (
str(val) for val in pretty_format_numerical_iterable(tick_labels, decimal_places=1)
)


def not_single_element_dims(data_array: xr.DataArray) -> list[Hashable]:
"""Names of dimensions in ``data`` which don't have a size equal to one.
This helper function is for example used to determine if a data only have a single trace,
since this requires different plotting code (e.g. ``data_array.plot.line(x="time")``).
Parameters
----------
data_array: xr.DataArray
_description_
Returns
-------
list[Hashable]
Names of dimensions in ``data`` which don't have a size equal to one.
"""
return [dim for dim, values in data_array.coords.items() if values.size != 1]
20 changes: 20 additions & 0 deletions tests/plotting/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pyglotaran_extras.plotting.utils import abs_max
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi
from pyglotaran_extras.plotting.utils import not_single_element_dims

matplotlib.use("Agg")
DEFAULT_CYCLER = plt.rcParams["axes.prop_cycle"]
Expand Down Expand Up @@ -85,3 +86,22 @@ def test_calculate_ticks_in_units_of_pi(

assert np.allclose(list(tick_values), expected_tick_values)
assert list(tick_labels) == expected_tick_labels


@pytest.mark.parametrize(
"data_array, expected",
(
(xr.DataArray([1]), []),
(xr.DataArray([1], coords={"dim1": [1]}), []),
(xr.DataArray([[1], [1]], coords={"dim1": [1, 2], "dim2": [1]}), ["dim1"]),
(
xr.DataArray(
[[[1, 1]], [[1, 1]]], coords={"dim1": [1, 2], "dim2": [1], "dim3": [1, 2]}
),
["dim1", "dim3"],
),
),
)
def test_not_single_element_dims(data_array: xr.DataArray, expected: list[Hashable]):
"""Only get dim with more than one element."""
assert not_single_element_dims(data_array) == expected

0 comments on commit 0ed9a7b

Please sign in to comment.