diff --git a/dask_image/dispatch/_dispatch_ndmorph.py b/dask_image/dispatch/_dispatch_ndmorph.py new file mode 100644 index 00000000..7c4f5edb --- /dev/null +++ b/dask_image/dispatch/_dispatch_ndmorph.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +import numpy as np +import scipy.ndimage + +from ._dispatcher import Dispatcher + +__all__ = [ + "dispatch_binary_dilation", + "dispatch_binary_erosion", + "dispatch_binary_structure", +] + +dispatch_binary_dilation = Dispatcher(name="dispatch_binary_dilation") +dispatch_binary_erosion = Dispatcher(name="dispatch_binary_erosion") +dispatch_binary_structure = Dispatcher(name='dispatch_binary_structure') + + +# ================== binary_dilation ================== +@dispatch_binary_dilation.register(np.ndarray) +def numpy_binary_dilation(*args, **kwargs): + return scipy.ndimage.binary_dilation + + +@dispatch_binary_dilation.register_lazy("cupy") +def register_cupy_binary_dilation(): + import cupy + import cupyx.scipy.ndimage + + @dispatch_binary_dilation.register(cupy.ndarray) + def cupy_binary_dilation(*args, **kwargs): + return cupyx.scipy.ndimage.binary_dilation + + +# ================== binary_erosion ================== +@dispatch_binary_erosion.register(np.ndarray) +def numpy_binary_erosion(*args, **kwargs): + return scipy.ndimage.binary_erosion + + +@dispatch_binary_erosion.register_lazy("cupy") +def register_cupy_binary_erosion(): + import cupy + import cupyx.scipy.ndimage + + @dispatch_binary_erosion.register(cupy.ndarray) + def cupy_binary_erosion(*args, **kwargs): + return cupyx.scipy.ndimage.binary_erosion + + +# ================== generate_binary_structure ================== +@dispatch_binary_structure.register(np.ndarray) +def numpy_binary_structure(*args, **kwargs): + return scipy.ndimage.generate_binary_structure + + +@dispatch_binary_structure.register_lazy("cupy") +def register_cupy_binary_structure(): + import cupy + import cupyx.scipy.ndimage + + @dispatch_binary_structure.register(cupy.ndarray) + def cupy_binary_structure(*args, **kwargs): + return cupyx.scipy.ndimage.generate_binary_structure diff --git a/dask_image/ndmorph/__init__.py b/dask_image/ndmorph/__init__.py index 6c2119ea..fa2e1526 100644 --- a/dask_image/ndmorph/__init__.py +++ b/dask_image/ndmorph/__init__.py @@ -8,6 +8,16 @@ from . import _utils from . import _ops +from ..dispatch._dispatch_ndmorph import ( + dispatch_binary_dilation, + dispatch_binary_erosion) + +__all__ = [ + "binary_closing", + "binary_dilation", + "binary_erosion", + "binary_opening", +] @_utils._update_wrapper(scipy.ndimage.binary_closing) @@ -43,7 +53,7 @@ def binary_dilation(image, border_value = _utils._get_border_value(border_value) result = _ops._binary_op( - scipy.ndimage.binary_dilation, + dispatch_binary_dilation(image), image, structure=structure, iterations=iterations, @@ -67,7 +77,7 @@ def binary_erosion(image, border_value = _utils._get_border_value(border_value) result = _ops._binary_op( - scipy.ndimage.binary_erosion, + dispatch_binary_erosion(image), image, structure=structure, iterations=iterations, diff --git a/dask_image/ndmorph/_utils.py b/dask_image/ndmorph/_utils.py index db167065..e2839b32 100644 --- a/dask_image/ndmorph/_utils.py +++ b/dask_image/ndmorph/_utils.py @@ -8,6 +8,7 @@ import dask.array +from ..dispatch._dispatch_ndmorph import dispatch_binary_structure from ..ndfilters._utils import ( _update_wrapper, _get_depth_boundary, @@ -24,8 +25,9 @@ def _get_structure(image, structure): # Create square connectivity as default if structure is None: - structure = scipy.ndimage.generate_binary_structure(image.ndim, 1) - elif isinstance(structure, (numpy.ndarray, dask.array.Array)): + generate_binary_structure = dispatch_binary_structure(image) + structure = generate_binary_structure(image.ndim, 1) + elif hasattr(structure, 'ndim'): if structure.ndim != image.ndim: raise RuntimeError( "`structure` must have the same rank as `image`." diff --git a/tests/test_dask_image/test_ndmorph/test_cupy_ndmorph.py b/tests/test_dask_image/test_ndmorph/test_cupy_ndmorph.py new file mode 100644 index 00000000..22fec6d6 --- /dev/null +++ b/tests/test_dask_image/test_ndmorph/test_cupy_ndmorph.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import dask.array as da +import numpy as np +import pytest + +from dask_image import ndmorph + +cupy = pytest.importorskip("cupy", minversion="7.7.0") + + +@pytest.fixture +def array(): + s = (10, 10) + a = da.from_array(cupy.arange(int(np.prod(s)), + dtype=cupy.float32).reshape(s), chunks=5) + return a + + +@pytest.mark.cupy +@pytest.mark.parametrize("func", [ + ndmorph.binary_closing, + ndmorph.binary_dilation, + ndmorph.binary_erosion, + ndmorph.binary_opening, +]) +def test_cupy_ndmorph(array, func): + """Test convolve & correlate filters with cupy input arrays.""" + result = func(array) + result.compute()