Skip to content

Commit

Permalink
fix: linlog plot scale (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve authored Jul 29, 2024
1 parent ef3c575 commit d10c33f
Showing 1 changed file with 40 additions and 24 deletions.
64 changes: 40 additions & 24 deletions src/elisa/plot/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(self, base: float, lin_thresh: float, lin_scale: float):
self.base = base
self.lin_thresh = lin_thresh
self.lin_scale = lin_scale
self._lin_scale_adj = lin_scale / (1.0 - self.base**-1)
self._lin_scale_adj = lin_scale / (1.0 - 1.0 / self.base)
self.inv_lin_thresh = lin_thresh * self._lin_scale_adj

def transform_non_affine(self, values: np.ndarray):
Expand All @@ -163,7 +163,13 @@ def inverted(self):
class _LinLogFormatter(LogFormatterSciNotation):
"""Formatter for LinLogScale axes ticks."""

def __init__(self, base: float, lin_thresh: float, lin_scale: float):
def __init__(
self,
base: float,
lin_thresh: float,
lin_scale: float,
label_only_base: bool = False,
):
base = float(base)
lin_thresh = float(lin_thresh)
lin_scale = float(lin_scale)
Expand All @@ -179,20 +185,35 @@ def __init__(self, base: float, lin_thresh: float, lin_scale: float):
self.__lin_thresh = lin_thresh
self.__lin_scale = lin_scale
self._formatter_lin = ScalarFormatter()
super().__init__(base, linthresh=lin_thresh)
super().__init__(
base=base,
labelOnlyBase=label_only_base,
linthresh=lin_thresh,
)

def __call__(self, x: float, pos: int | None = None):
if x >= self.__lin_thresh:
return super().__call__(x, pos)
else:
return self._formatter_lin(x, pos)
s = self._formatter_lin(x, pos)
try:
if float(s) == 0.0:
s = '0'
except ValueError:
pass
return s

def set_axis(self, axis: Axis):
dummy = _DummyAxis(
axis, self.__base, self.__lin_thresh, self.__lin_scale, False
self._formatter_lin.set_axis(
_DummyAxis(
axis, self.__base, self.__lin_thresh, self.__lin_scale, False
)
)
super().set_axis(
_DummyAxis(
axis, self.__base, self.__lin_thresh, self.__lin_scale, True
)
)
self._formatter_lin.set_axis(dummy)
super().set_axis(axis)

def create_dummy_axis(self, **kwargs):
self._formatter_lin.create_dummy_axis(**kwargs)
Expand All @@ -201,7 +222,8 @@ def create_dummy_axis(self, **kwargs):
def set_locs(self, locs=None):
"""Set the locations of the ticks."""
super().set_locs(locs)
self._formatter_lin.set_locs(locs)
lin_locs = [i for i in locs if i <= self.__lin_thresh]
self._formatter_lin.set_locs(lin_locs)


class LinLogLocator(Locator):
Expand Down Expand Up @@ -294,21 +316,11 @@ def tick_values(self, vmin: float, vmax: float):
ticks_log = self._locator_log()
mask = np.greater(ticks_log, log_lower_lim)
if mask.any():
ticks_log = ticks_log[mask]

# ignore the first major tick of log range if too close to 0
if not self._is_minor:
t0 = ticks_log[0]
fx = np.log(t0) / self._log_base
if np.isclose(fx, round(fx)):
t0_ = self.transform_non_affine(np.asarray([t0]))
if t0_ < self._lin_scale_adj:
ticks_log = ticks_log[1:]

ticks.append(ticks_log)
ticks.append(ticks_log[mask])

if ticks:
ticks = np.unique(np.hstack(ticks))

return self.raise_if_exceeds(ticks)

def view_limits(self, vmin, vmax):
Expand Down Expand Up @@ -397,10 +409,14 @@ def set_default_locators_and_formatters(self, axis):
lin_thresh, lin_scale, base, 'auto', is_minor=True
)
axis.set_minor_locator(minor_locator)
minor_formatter = LogFormatterSciNotation(
base=self.base, labelOnlyBase=(self.subs is not None)
axis.set_minor_formatter(
_LinLogFormatter(
base,
lin_thresh,
lin_scale,
label_only_base=(self.subs is not None),
)
)
axis.set_minor_formatter(minor_formatter)

def get_transform(self):
"""Return the `.LinLogTransform` associated with this scale."""
Expand Down

0 comments on commit d10c33f

Please sign in to comment.