Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👌 Add minor ticks to linlog plots #183

Merged
merged 2 commits into from
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
## 0.8.0 (Unreleased)

- 🩹 Fix crashes of plot_doas and plot_coherent_artifact for non dispersive IRF (#173)
- 👌 Add minor ticks to linlog plots (#183)
- 🚧📦 Remove upper python version limit (#174)

(changes-0_7_0)=
Expand Down
2 changes: 2 additions & 0 deletions pyglotaran_extras/plotting/plot_concentrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING

from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import MinorSymLogLocator
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import get_shifted_traces

Expand Down Expand Up @@ -69,3 +70,4 @@ def plot_concentrations(

if linlog:
ax.set_xscale("symlog", linthresh=linthresh, linscale=linscale)
ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh))
3 changes: 3 additions & 0 deletions pyglotaran_extras/plotting/plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pyglotaran_extras.plotting.plot_svd import plot_lsv_data
from pyglotaran_extras.plotting.plot_svd import plot_rsv_data
from pyglotaran_extras.plotting.plot_svd import plot_sv_data
from pyglotaran_extras.plotting.utils import MinorSymLogLocator
from pyglotaran_extras.plotting.utils import not_single_element_dims
from pyglotaran_extras.plotting.utils import shift_time_axis_by_irf_location

Expand Down Expand Up @@ -109,6 +110,7 @@ def plot_data_overview(

if linlog:
data_ax.set_xscale("symlog", linthresh=linthresh)
data_ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh))
return fig, (data_ax, lsv_ax, sv_ax, rsv_ax)


Expand Down Expand Up @@ -150,4 +152,5 @@ def _plot_single_trace(

if linlog:
ax.set_xscale("symlog", linthresh=linthresh)
ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh))
return fig, ax
2 changes: 2 additions & 0 deletions pyglotaran_extras/plotting/plot_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from pyglotaran_extras.plotting.plot_irf_dispersion_center import _plot_irf_dispersion_center
from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import MinorSymLogLocator
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import shift_time_axis_by_irf_location

Expand Down Expand Up @@ -83,4 +84,5 @@ def plot_residual(
ax.legend()
if linlog:
ax.set_xscale("symlog", linthresh=linthresh)
ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh))
ax.set_title(title)
3 changes: 3 additions & 0 deletions pyglotaran_extras/plotting/plot_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from glotaran.io.prepare_dataset import add_svd_to_dataset

from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import MinorSymLogLocator
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import shift_time_axis_by_irf_location

Expand Down Expand Up @@ -140,6 +141,7 @@ def plot_lsv_data(
ax.set_title("data. LSV")
if linlog:
ax.set_xscale("symlog", linthresh=linthresh)
ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh))


def plot_rsv_data(
Expand Down Expand Up @@ -240,6 +242,7 @@ def plot_lsv_residual(
ax.set_title("res. LSV")
if linlog:
ax.set_xscale("symlog", linthresh=linthresh)
ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh))


def plot_rsv_residual(
Expand Down
2 changes: 2 additions & 0 deletions pyglotaran_extras/plotting/plot_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pyglotaran_extras.io.utils import result_dataset_mapping
from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import MinorSymLogLocator
from pyglotaran_extras.plotting.utils import PlotDuplicationWarning
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import add_unique_figure_legend
Expand Down Expand Up @@ -97,6 +98,7 @@ def plot_data_and_fits(
[next(axis._get_lines.prop_cycler) for _ in range(2)]
if linlog:
axis.set_xscale("symlog", linthresh=linthresh)
axis.xaxis.set_minor_locator(MinorSymLogLocator(linthresh))
if show_zero_line is True:
axis.axhline(0, color="k", linewidth=1)
axis.set_ylabel(y_label)
Expand Down
96 changes: 96 additions & 0 deletions pyglotaran_extras/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import xarray as xr
from matplotlib.ticker import Locator

from pyglotaran_extras.inspect.utils import pretty_format_numerical_iterable
from pyglotaran_extras.io.utils import result_dataset_mapping
Expand Down Expand Up @@ -482,3 +483,98 @@ def not_single_element_dims(data_array: xr.DataArray) -> list[Hashable]:
Names of dimensions in ``data`` which don't have a size equal to one.
"""
return [dim for dim, values in data_array.coords.items() if values.size != 1]


class MinorSymLogLocator(Locator):
"""Dynamically find minor tick positions based on major ticks for a symlog scaling.

Ref.: https://stackoverflow.com/a/45696768
"""

def __init__(self, linthresh: float, nints: int = 10) -> None:
"""Ticks will be placed between the major ticks.

The placement is linear for x between -linthresh and linthresh,
otherwise its logarithmically. nints gives the number of
intervals that will be bounded by the minor ticks.

Parameters
----------
linthresh : float
A single float which defines the range (-x, x), within which the plot is linear.
nints : int
Number of minor tick between major ticks. Defaults to 10
"""
self.linthresh = linthresh
self.nintervals = nints

def __call__(self) -> list[float]:
"""Return the locations of the ticks.

Returns
-------
list[float]
Minor ticks position.
"""
# Return the locations of the ticks
majorlocs = self.axis.get_majorticklocs()

if len(majorlocs) == 1:
return self.raise_if_exceeds(np.array([]))

# add temporary major tick locs at either end of the current range
# to fill in minor tick gaps
dmlower = majorlocs[1] - majorlocs[0] # major tick difference at lower end
dmupper = majorlocs[-1] - majorlocs[-2] # major tick difference at upper end

# add temporary major tick location at the lower end
if majorlocs[0] != 0.0 and (
(majorlocs[0] != self.linthresh and dmlower > self.linthresh)
or (dmlower == self.linthresh and majorlocs[0] < 0)
):
majorlocs = np.insert(majorlocs, 0, majorlocs[0] * 10.0)
else:
majorlocs = np.insert(majorlocs, 0, majorlocs[0] - self.linthresh)

# add temporary major tick location at the upper end
if majorlocs[-1] != 0.0 and (
(np.abs(majorlocs[-1]) != self.linthresh and dmupper > self.linthresh)
or (dmupper == self.linthresh and majorlocs[-1] > 0)
):
majorlocs = np.append(majorlocs, majorlocs[-1] * 10.0)
else:
majorlocs = np.append(majorlocs, majorlocs[-1] + self.linthresh)

# iterate through minor locs
minorlocs: list[float] = []

# handle the lowest part
for i in range(1, len(majorlocs)):
majorstep = majorlocs[i] - majorlocs[i - 1]
if abs(majorlocs[i - 1] + majorstep / 2) < self.linthresh:
ndivs = self.nintervals
else:
ndivs = self.nintervals - 1

minorstep = majorstep / ndivs
locs = np.arange(majorlocs[i - 1], majorlocs[i], minorstep)[1:]
minorlocs.extend(locs)

return self.raise_if_exceeds(np.array(minorlocs))

def tick_values(self, vmin: float, vmax: float) -> None:
"""Return the values of the located ticks given **vmin** and **vmax** (not implemented).

Parameters
----------
vmin : float
Minimum value.
vmax : float
Maximum value.

Raises
------
NotImplementedError
Not used
"""
raise NotImplementedError(f"Cannot get tick locations for a {type(self)} type.")