diff --git a/src/elisa/plot/scale.py b/src/elisa/plot/scale.py index 616d0d62..883d62cb 100644 --- a/src/elisa/plot/scale.py +++ b/src/elisa/plot/scale.py @@ -140,7 +140,7 @@ def __init__(self, base: float, lin_thresh: float, lin_scale: float): self.base = base self.lin_thresh = lin_thresh self.lin_scale = lin_scale - self._lin_scale_adj = lin_scale / (1.0 - self.base**-1) + self._lin_scale_adj = lin_scale / (1.0 - 1.0 / self.base) self.inv_lin_thresh = lin_thresh * self._lin_scale_adj def transform_non_affine(self, values: np.ndarray): @@ -163,7 +163,13 @@ def inverted(self): class _LinLogFormatter(LogFormatterSciNotation): """Formatter for LinLogScale axes ticks.""" - def __init__(self, base: float, lin_thresh: float, lin_scale: float): + def __init__( + self, + base: float, + lin_thresh: float, + lin_scale: float, + label_only_base: bool = False, + ): base = float(base) lin_thresh = float(lin_thresh) lin_scale = float(lin_scale) @@ -179,20 +185,35 @@ def __init__(self, base: float, lin_thresh: float, lin_scale: float): self.__lin_thresh = lin_thresh self.__lin_scale = lin_scale self._formatter_lin = ScalarFormatter() - super().__init__(base, linthresh=lin_thresh) + super().__init__( + base=base, + labelOnlyBase=label_only_base, + linthresh=lin_thresh, + ) def __call__(self, x: float, pos: int | None = None): if x >= self.__lin_thresh: return super().__call__(x, pos) else: - return self._formatter_lin(x, pos) + s = self._formatter_lin(x, pos) + try: + if float(s) == 0.0: + s = '0' + except ValueError: + pass + return s def set_axis(self, axis: Axis): - dummy = _DummyAxis( - axis, self.__base, self.__lin_thresh, self.__lin_scale, False + self._formatter_lin.set_axis( + _DummyAxis( + axis, self.__base, self.__lin_thresh, self.__lin_scale, False + ) + ) + super().set_axis( + _DummyAxis( + axis, self.__base, self.__lin_thresh, self.__lin_scale, True + ) ) - self._formatter_lin.set_axis(dummy) - super().set_axis(axis) def create_dummy_axis(self, **kwargs): self._formatter_lin.create_dummy_axis(**kwargs) @@ -201,7 +222,8 @@ def create_dummy_axis(self, **kwargs): def set_locs(self, locs=None): """Set the locations of the ticks.""" super().set_locs(locs) - self._formatter_lin.set_locs(locs) + lin_locs = [i for i in locs if i <= self.__lin_thresh] + self._formatter_lin.set_locs(lin_locs) class LinLogLocator(Locator): @@ -294,21 +316,11 @@ def tick_values(self, vmin: float, vmax: float): ticks_log = self._locator_log() mask = np.greater(ticks_log, log_lower_lim) if mask.any(): - ticks_log = ticks_log[mask] - - # ignore the first major tick of log range if too close to 0 - if not self._is_minor: - t0 = ticks_log[0] - fx = np.log(t0) / self._log_base - if np.isclose(fx, round(fx)): - t0_ = self.transform_non_affine(np.asarray([t0])) - if t0_ < self._lin_scale_adj: - ticks_log = ticks_log[1:] - - ticks.append(ticks_log) + ticks.append(ticks_log[mask]) if ticks: ticks = np.unique(np.hstack(ticks)) + return self.raise_if_exceeds(ticks) def view_limits(self, vmin, vmax): @@ -397,10 +409,14 @@ def set_default_locators_and_formatters(self, axis): lin_thresh, lin_scale, base, 'auto', is_minor=True ) axis.set_minor_locator(minor_locator) - minor_formatter = LogFormatterSciNotation( - base=self.base, labelOnlyBase=(self.subs is not None) + axis.set_minor_formatter( + _LinLogFormatter( + base, + lin_thresh, + lin_scale, + label_only_base=(self.subs is not None), + ) ) - axis.set_minor_formatter(minor_formatter) def get_transform(self): """Return the `.LinLogTransform` associated with this scale."""