From cdcb3fb7d87e7420210ed4a3219f4ee90f56fe21 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Mon, 12 Sep 2022 17:56:23 -0400 Subject: [PATCH] simplify __array_ufunc__ check --- xarray/core/arithmetic.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index ff7af02abfc..8d6a1d3ed8c 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -16,7 +16,7 @@ from .common import ImplementsArrayReduce, ImplementsDatasetReduce from .ops import IncludeCumMethods, IncludeNumpySameMethods, IncludeReduceMethods from .options import OPTIONS, _get_keep_attrs -from .pycompat import dask_array_type +from .pycompat import is_duck_array class SupportsArithmetic: @@ -33,12 +33,11 @@ class SupportsArithmetic: # TODO: allow extending this with some sort of registration system _HANDLED_TYPES = ( - np.ndarray, np.generic, numbers.Number, bytes, str, - ) + dask_array_type + ) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): from .computation import apply_ufunc @@ -46,7 +45,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin. out = kwargs.get("out", ()) for x in inputs + out: - if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)): + if not is_duck_array(x) and not isinstance( + x, self._HANDLED_TYPES + (SupportsArithmetic,) + ): return NotImplemented if ufunc.signature is not None: