Skip to content

Commit

Permalink
👌 Made plot_data_overview able to plot single trace data
Browse files Browse the repository at this point in the history
  • Loading branch information
s-weigand committed Feb 5, 2023
1 parent 32ce905 commit 1724884
Showing 1 changed file with 58 additions and 4 deletions.
62 changes: 58 additions & 4 deletions pyglotaran_extras/plotting/plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
from pyglotaran_extras.plotting.plot_svd import plot_lsv_data
from pyglotaran_extras.plotting.plot_svd import plot_rsv_data
from pyglotaran_extras.plotting.plot_svd import plot_sv_data
from pyglotaran_extras.plotting.utils import not_single_element_dims
from pyglotaran_extras.plotting.utils import shift_time_axis_by_irf_location

__all__ = ["plot_data_overview"]

if TYPE_CHECKING:
from typing import Hashable

import xarray as xr
from glotaran.project.result import Result
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes
Expand All @@ -33,7 +37,7 @@ def plot_data_overview(
nr_of_data_svd_vectors: int = 4,
show_data_svd_legend: bool = True,
irf_location: float | None = None,
) -> tuple[Figure, Axes]:
) -> tuple[Figure, Axes] | tuple[Figure, Axis]:
"""Plot data as filled contour plot and SVD components.
Parameters
Expand All @@ -59,10 +63,21 @@ def plot_data_overview(
Returns
-------
tuple[Figure, Axes]
tuple[Figure, Axes]|tuple[Figure,Axis]
Figure and axes which can then be refined by the user.
"""
dataset = load_data(dataset, _stacklevel=3)
data = shift_time_axis_by_irf_location(dataset.data, irf_location)

if len(not_single_element_dims(data)) == 1:
return _plot_single_trace(
data,
not_single_element_dims(data)[0],
title="Single trace data",
linlog=linlog,
linthresh=linthresh,
figsize=figsize,
)

fig = plt.figure(figsize=figsize)
data_ax = cast(Axis, plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig))
Expand All @@ -71,8 +86,6 @@ def plot_data_overview(
sv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 1), fig=fig))
rsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 2), fig=fig))

data = shift_time_axis_by_irf_location(dataset.data, irf_location)

if len(data.time) > 1:
data.plot(x="time", ax=data_ax, center=False)
else:
Expand All @@ -97,3 +110,44 @@ def plot_data_overview(
if linlog:
data_ax.set_xscale("symlog", linthresh=linthresh)
return fig, (data_ax, lsv_ax, sv_ax, rsv_ax)


def _plot_single_trace(
data_array: xr.DataArray,
x_dim: Hashable,
*,
title: str = "Single trace data",
linlog: bool = False,
linthresh: float = 1,
figsize: tuple[int, int] = (15, 10),
) -> tuple[Figure, Axis]:
"""Plot single trace data in case ``plot_data_overview`` gets passed ingle trace data.
Parameters
----------
data_array: xr.DataArray
DataArray containing only data of a single trace.
x_dim: Hashable
Name of the x dimension.
title: str
Title to add to the figure. Defaults to "Data overview".
linlog: bool
Whether to use 'symlog' scale or not. Defaults to False.
linthresh: float
A single float which defines the range (-x, x), within which the plot is linear.
This avoids having the plot go to infinity around zero. Defaults to 1.
figsize: tuple[int, int]
Size of the figure (N, M) in inches. Defaults to (15, 10).
Returns
-------
tuple[Figure, Axis]
Figure and axis which can then be refined by the user.
"""
fig, ax = plt.subplots(1, 1, figsize=figsize)
data_array.plot.line(x=x_dim, ax=ax)
fig.suptitle(title, fontsize=16)

if linlog:
ax.set_xscale("symlog", linthresh=linthresh)
return fig, ax

0 comments on commit 1724884

Please sign in to comment.