diff --git a/swcgeom/transforms/images.py b/swcgeom/transforms/images.py index f6b0794..dc96707 100644 --- a/swcgeom/transforms/images.py +++ b/swcgeom/transforms/images.py @@ -6,7 +6,7 @@ import numpy as np import numpy.typing as npt -from swcgeom.transforms.base import Transform +from swcgeom.transforms.base import Identity, Transform __all__ = [ "ImagesCenterCrop", @@ -14,6 +14,8 @@ "ImagesClip", "ImagesNormalizer", "ImagesMeanVarianceAdjustment", + "ImagesScaleToUnitRange", + "ImagesHistogramEqualization", "Center", # legacy ] @@ -66,6 +68,9 @@ def __init__(self, scaler: float) -> None: def __call__(self, x: NDArrayf32) -> NDArrayf32: return self.scaler * x + def extra_repr(self) -> str: + return f"scaler={self.scaler}" + class ImagesClip(Transform[NDArrayf32, NDArrayf32]): def __init__(self, vmin: float = 0, vmax: float = 1, /) -> None: @@ -75,6 +80,9 @@ def __init__(self, vmin: float = 0, vmax: float = 1, /) -> None: def __call__(self, x: NDArrayf32) -> NDArrayf32: return np.clip(x, self.vmin, self.vmax) + def extra_repr(self) -> str: + return f"vmin={self.vmin}, vmax={self.vmax}" + class ImagesNormalizer(Transform[NDArrayf32, NDArrayf32]): """Normalize image stack.""" @@ -101,5 +109,61 @@ def __init__(self, mean: float, variance: float) -> None: def __call__(self, x: NDArrayf32) -> NDArrayf32: return (x - self.mean) / self.variance - def extra_repr(self): + def extra_repr(self) -> str: return f"mean={self.mean}, variance={self.variance}" + + +class ImagesScaleToUnitRange(Transform[NDArrayf32, NDArrayf32]): + """Scale image stack to unit range.""" + + def __init__(self, vmin: float, vmax: float, *, clip: bool = True) -> None: + """Scale image stack to unit range. + + Parameters + ---------- + vmin : float + Minimum value. + vmax : float + Maximum value. + clip : bool, default True + Clip values to [0, 1] to avoid numerical issues. + """ + + super().__init__() + self.vmin = vmin + self.vmax = vmax + self.diff = vmax - vmin + self.clip = clip + self.post = ImagesClip(0, 1) if self.clip else Identity() + + def __call__(self, x: NDArrayf32) -> NDArrayf32: + return self.post((x - self.vmin) / self.diff) + + def extra_repr(self) -> str: + return f"vmin={self.vmin}, vmax={self.vmax}, clip={self.clip}" + + +class ImagesHistogramEqualization(Transform[NDArrayf32, NDArrayf32]): + """Image histogram equalization. + + References + ---------- + http://www.janeriksolem.net/histogram-equalization-with-python-and.html + """ + + def __init__(self, bins: int = 256) -> None: + super().__init__() + self.bins = bins + + def __call__(self, x: NDArrayf32) -> NDArrayf32: + # get image histogram + hist, bin_edges = np.histogram(x.flatten(), self.bins, density=True) + cdf = hist.cumsum() # cumulative distribution function + cdf = cdf / cdf[-1] # normalize + + # use linear interpolation of cdf to find new pixel values + equalized = np.interp(x.flatten(), bin_edges[:-1], cdf) + return equalized.reshape(x.shape).astype(np.float32) + + def extra_repr(self) -> str: + return f"bins={self.bins}"