Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add plot_fitted_traces function #39

Merged
merged 18 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
6122a77
✨ Added 'plot_fit_overview' function to plot data and fit per wavelength
s-weigand Oct 2, 2021
fe4a186
✨ Added 'wavelength_range' parameter so the user can select data to plot
s-weigand Oct 7, 2021
2018103
🧹 Resticted DatasetConvertible objects to always be xr.Dataset
s-weigand Oct 9, 2021
fe6cfea
♻️ Factored out extraction of irf location to a seperate function
s-weigand Oct 9, 2021
af3a0e9
👌 Plotted fits and data are now shifted by the irf location
s-weigand Oct 9, 2021
6c6c995
🧹 Unified xarray typing style
s-weigand Oct 18, 2021
c4503c9
👌♻️ Added 'dataset_name' argument, use 'load_dataset' to load result …
s-weigand Oct 18, 2021
a55327f
♻️ Refactored load_data and result_dataset_mapping to work with DataA…
s-weigand Oct 20, 2021
79a8c35
♻️✨ Made 'plot_fit_overview' work properly with unevenly spaced wavel…
s-weigand Oct 21, 2021
0cbffb2
♻️ Moved 'plot_fit_overview' helper functions to plotting.utils
s-weigand Oct 21, 2021
b336c7c
✨ Added figsize argument to plot_data_overview
s-weigand Oct 21, 2021
c46a4c4
👌 Reexport 'select_plot_wavelengths' from plotting.data for convenience
s-weigand Oct 21, 2021
e91a371
✨ Added divide_by_scale parameter to divide data by dataset_scale
s-weigand Oct 22, 2021
a0cdc16
👌 Made function name in 'select_plot_wavelengths' warning dynamic
s-weigand Oct 22, 2021
d46bc2f
📚 Added comments and docstrings to plot style code
s-weigand Oct 22, 2021
58ad974
👌 Made the ylabel of plots an argument and default to 'a.u.'
s-weigand Oct 22, 2021
7ec2491
♻️ Addressed requested renaming and moving functions suited for this PR
s-weigand Oct 22, 2021
d4f3183
♻️ Moved plot_concentrations to its own module with the same name
s-weigand Oct 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions pyglotaran_extras/io/load_data.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,44 @@
from __future__ import annotations

from pathlib import Path
from typing import Union

import xarray as xr
from glotaran.io import load_dataset
from glotaran.project.result import Result

from pyglotaran_extras.types import DatasetConvertible


def load_data(result: DatasetConvertible, dataset_name: str | None = None) -> xr.Dataset:
"""Extract a single dataset from a :class:`DatasetConvertible` object.

Parameters
----------
result : DatasetConvertible
Result class instance, xarray Dataset or path to a dataset file.
dataset_name : str, optional
Name of a specific dataset contained in ``result``, if not provided
the first dataset will be extracted., by default None

Returns
-------
xr.Dataset
Extracted dataset.

def load_data(result: Union[xr.Dataset, Path, Result]) -> xr.Dataset:
Raises
------
TypeError
If ``result`` isn't a :class:`DatasetConvertible` object.
"""
if isinstance(result, xr.Dataset):
return result
elif isinstance(result, Result):
if isinstance(result, xr.DataArray):
return result.to_dataset(name="data")
if isinstance(result, Result):
if dataset_name is not None:
return result.data[dataset_name]
keys = list(result.data)
return result.data[keys[0]]
else:
return xr.open_dataset(result)
if isinstance(result, (str, Path)):
return load_data(load_dataset(result))
raise TypeError(f"Result needs to be of type {DatasetConvertible!r}, but was {result!r}.")
48 changes: 48 additions & 0 deletions pyglotaran_extras/io/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

from collections.abc import Mapping
from collections.abc import Sequence
from pathlib import Path

import xarray as xr
from glotaran.project.result import Result

from pyglotaran_extras.io.load_data import load_data
from pyglotaran_extras.types import ResultLike


def result_dataset_mapping(result: ResultLike) -> Mapping[str, xr.Dataset]:
"""Convert a ``ResultLike`` object to a per dataset mapping of result like data.

Parameters
----------
result : ResultLike
Data structure which can be converted to a mapping.

Returns
-------
Mapping[str, Dataset]
Per dataset mapping of result like data.

Raises
------
TypeError
If any value of a ``result`` isn't of :class:`DatasetConvertible`.
TypeError
If ``result`` isn't a :class:`ResultLike` object.
"""

result_mapping = {}
if isinstance(result, Result):
return result.data
if isinstance(result, (xr.Dataset, xr.DataArray, Path, str)):
return {"dataset": load_data(result)}
if isinstance(result, Sequence):
for index, value in enumerate(result):
result_mapping[f"dataset{index}"] = load_data(value)
return result_mapping
if isinstance(result, Mapping):
for key, value in result.items():
result_mapping[key] = load_data(value)
return result_mapping
raise TypeError(f"Result needs to be of type {ResultLike!r}, but was {result!r}.")
42 changes: 38 additions & 4 deletions pyglotaran_extras/plotting/data.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,49 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import xarray as xr

from pyglotaran_extras.plotting.plot_svd import plot_lsv_data
from pyglotaran_extras.plotting.plot_svd import plot_rsv_data
from pyglotaran_extras.plotting.plot_svd import plot_sv_data
from pyglotaran_extras.plotting.utils import select_plot_wavelengths

__all__ = ["select_plot_wavelengths", "plot_data_overview"]

if TYPE_CHECKING:
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes
from xarray import Dataset


def plot_data_overview(
dataset: xr.Dataset, title="Data overview", linlog: bool = False, linthresh: float = 1
):
fig = plt.figure()
dataset: Dataset,
title="Data overview",
linlog: bool = False,
linthresh: float = 1,
figsize: tuple[int, int] = (30, 15),
) -> tuple[Figure, Axes]:
"""Plot data as filled contour plot and SVD components.

Parameters
----------
dataset : Dataset
Dataset containing data and SVD of the data.
title : str, optional
Title to add to the figure., by default "Data overview"
linlog : bool, optional
Whether to use 'symlog' scale or not, by default False
linthresh : float, optional
A single float which defines the range (-x, x), within which the plot is linear.
This avoids having the plot go to infinity around zero., by default 1

Returns
-------
tuple[Figure, Axes]
Figure and axes which can then be refined by the user.
"""
fig = plt.figure(figsize=figsize)
data_ax = plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig)
lsv_ax = plt.subplot2grid((4, 3), (3, 0), fig=fig)
sv_ax = plt.subplot2grid((4, 3), (3, 1), fig=fig)
Expand Down
64 changes: 64 additions & 0 deletions pyglotaran_extras/plotting/plot_concentrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt

from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import get_shifted_traces

if TYPE_CHECKING:
import xarray as xr
from matplotlib.pyplot import Axes


def plot_concentrations(
res: xr.Dataset,
ax: Axes,
center_λ: float | None,
linlog: bool = False,
linthresh: float = 1,
linscale: float = 1,
main_irf_nr: int = 0,
) -> None:
"""Plot traces on the given axis ``ax``

Parameters
----------
res: xr.Dataset
Result dataset from a pyglotaran optimization.
ax: Axes
Axes to plot the traces on
center_λ: float | None
Center wavelength (λ in nm)
linlog: bool
Whether to use 'symlog' scale or not, by default False
linthresh: int
A single float which defines the range (-x, x), within which the plot is linear.
This avoids having the plot go to infinity around zero., by default 1
linscale: int
This allows the linear range (-linthresh to linthresh) to be stretched
relative to the logarithmic range.
Its value is the number of decades to use for each half of the linear range.
For example, when linscale == 1.0 (the default), the space used for the
positive and negative halves of the linear range will be equal to one
decade in the logarithmic range., by default 1
main_irf_nr: int
Index of the main ``irf`` component when using an ``irf``
parametrized with multiple peaks , by default 0

See Also
--------
get_shifted_traces
"""
traces = get_shifted_traces(res, center_λ, main_irf_nr)
plot_style = PlotStyle()
plt.rc("axes", prop_cycle=plot_style.cycler)

if "spectral" in traces.coords:
traces.sel(spectral=center_λ, method="nearest").plot.line(x="time", ax=ax)
else:
traces.plot.line(x="time", ax=ax)

if linlog:
ax.set_xscale("symlog", linthresh=linthresh, linscale=linscale)
4 changes: 2 additions & 2 deletions pyglotaran_extras/plotting/plot_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import xarray as xr

from pyglotaran_extras.io.load_data import load_data
from pyglotaran_extras.plotting.plot_concentrations import plot_concentrations
from pyglotaran_extras.plotting.plot_residual import plot_residual
from pyglotaran_extras.plotting.plot_spectra import plot_spectra
from pyglotaran_extras.plotting.plot_svd import plot_svd
from pyglotaran_extras.plotting.plot_traces import plot_traces
from pyglotaran_extras.plotting.style import PlotStyle

if TYPE_CHECKING:
Expand Down Expand Up @@ -73,7 +73,7 @@ def plot_overview(
center_λ = min(res.dims["spectral"], round(res.dims["spectral"] / 2))

# First and second row: concentrations - SAS/EAS - DAS
plot_traces(
plot_concentrations(
res,
ax[0, 0],
center_λ,
Expand Down
Loading