-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rely on NEP-18 to dispatch to dask in duck_array_ops #5571
Changes from 10 commits
895c262
ec5fb50
78a3cee
bf8600e
b64732a
7405c02
1c394ee
895d44a
6b711a2
5b93fbd
6658439
59278ef
7cd65cd
7db47cb
bbc95cf
9654e22
d9afa59
2c85e39
a349538
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,38 +29,16 @@ | |
except ImportError: | ||
dask_array = None | ||
|
||
|
||
def _dask_or_eager_func( | ||
name, | ||
eager_module=np, | ||
dask_module=dask_array, | ||
list_of_args=False, | ||
array_args=slice(1), | ||
requires_dask=None, | ||
): | ||
"""Create a function that dispatches to dask for dask array inputs.""" | ||
if dask_module is not None: | ||
|
||
def f(*args, **kwargs): | ||
if list_of_args: | ||
dispatch_args = args[0] | ||
else: | ||
dispatch_args = args[array_args] | ||
if any(is_duck_dask_array(a) for a in dispatch_args): | ||
try: | ||
wrapped = getattr(dask_module, name) | ||
except AttributeError as e: | ||
raise AttributeError(f"{e}: requires dask >={requires_dask}") | ||
else: | ||
wrapped = getattr(eager_module, name) | ||
return wrapped(*args, **kwargs) | ||
|
||
else: | ||
|
||
def f(*args, **kwargs): | ||
return getattr(eager_module, name)(*args, **kwargs) | ||
|
||
return f | ||
from numpy import all as array_all # noqa | ||
from numpy import any as array_any # noqa | ||
from numpy import zeros_like # noqa | ||
from numpy import around, broadcast_to # noqa | ||
from numpy import concatenate as _concatenate | ||
from numpy import einsum, isclose, isin, isnan, isnat, pad # noqa | ||
from numpy import stack as _stack | ||
from numpy import take, tensordot, transpose, unravel_index # noqa | ||
from numpy import where as _where | ||
from numpy.ma import masked_invalid # noqa | ||
|
||
|
||
def fail_on_dask_array_input(values, msg=None, func_name=None): | ||
|
@@ -72,16 +50,15 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): | |
raise NotImplementedError(msg % func_name) | ||
|
||
|
||
around = _dask_or_eager_func("around") | ||
isclose = _dask_or_eager_func("isclose") | ||
|
||
|
||
isnat = np.isnat | ||
isnan = _dask_or_eager_func("isnan") | ||
zeros_like = _dask_or_eager_func("zeros_like") | ||
# Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18 | ||
def _dask_or_eager_isnull(obj): | ||
if is_duck_dask_array(obj): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does this: d = dask_array.array([1, 2, 3])
q = pint.Quantity(d, units='m')
dask.array.isnull(q) raises warnings
and returns the result
I think that's a sensible result of calling |
||
return dask_array.isnull(obj) | ||
else: | ||
return pd.isnull(obj) | ||
|
||
|
||
pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd) | ||
pandas_isnull = _dask_or_eager_isnull | ||
|
||
|
||
def isnull(data): | ||
|
@@ -114,23 +91,6 @@ def notnull(data): | |
return ~isnull(data) | ||
|
||
|
||
transpose = _dask_or_eager_func("transpose") | ||
_where = _dask_or_eager_func("where", array_args=slice(3)) | ||
isin = _dask_or_eager_func("isin", array_args=slice(2)) | ||
take = _dask_or_eager_func("take") | ||
broadcast_to = _dask_or_eager_func("broadcast_to") | ||
pad = _dask_or_eager_func("pad", dask_module=dask_array_compat) | ||
|
||
_concatenate = _dask_or_eager_func("concatenate", list_of_args=True) | ||
_stack = _dask_or_eager_func("stack", list_of_args=True) | ||
|
||
array_all = _dask_or_eager_func("all") | ||
array_any = _dask_or_eager_func("any") | ||
|
||
tensordot = _dask_or_eager_func("tensordot", array_args=slice(2)) | ||
einsum = _dask_or_eager_func("einsum", array_args=slice(1, None)) | ||
|
||
|
||
def gradient(x, coord, axis, edge_order): | ||
if is_duck_dask_array(x): | ||
return dask_array.gradient(x, coord, axis=axis, edge_order=edge_order) | ||
|
@@ -166,11 +126,6 @@ def cumulative_trapezoid(y, x, axis): | |
return cumsum(integrand, axis=axis, skipna=False) | ||
|
||
|
||
masked_invalid = _dask_or_eager_func( | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"masked_invalid", eager_module=np.ma, dask_module=getattr(dask_array, "ma", None) | ||
) | ||
|
||
|
||
def astype(data, dtype, **kwargs): | ||
if ( | ||
isinstance(data, sparse_array_type) | ||
|
@@ -317,9 +272,7 @@ def _ignore_warnings_if(condition): | |
yield | ||
|
||
|
||
def _create_nan_agg_method( | ||
name, dask_module=dask_array, coerce_strings=False, invariant_0d=False | ||
): | ||
def _create_nan_agg_method(name, coerce_strings=False, invariant_0d=False): | ||
from . import nanops | ||
|
||
def f(values, axis=None, skipna=None, **kwargs): | ||
|
@@ -344,7 +297,8 @@ def f(values, axis=None, skipna=None, **kwargs): | |
else: | ||
if name in ["sum", "prod"]: | ||
kwargs.pop("min_count", None) | ||
func = _dask_or_eager_func(name, dask_module=dask_module) | ||
|
||
func = getattr(np, name) | ||
|
||
try: | ||
with warnings.catch_warnings(): | ||
|
@@ -378,9 +332,7 @@ def f(values, axis=None, skipna=None, **kwargs): | |
std.numeric_only = True | ||
var = _create_nan_agg_method("var") | ||
var.numeric_only = True | ||
median = _create_nan_agg_method( | ||
"median", dask_module=dask_array_compat, invariant_0d=True | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
median = _create_nan_agg_method("median", invariant_0d=True) | ||
median.numeric_only = True | ||
prod = _create_nan_agg_method("prod", invariant_0d=True) | ||
prod.numeric_only = True | ||
|
@@ -389,7 +341,6 @@ def f(values, axis=None, skipna=None, **kwargs): | |
cumprod_1d.numeric_only = True | ||
cumsum_1d = _create_nan_agg_method("cumsum", invariant_0d=True) | ||
cumsum_1d.numeric_only = True | ||
unravel_index = _dask_or_eager_func("unravel_index") | ||
|
||
|
||
_mean = _create_nan_agg_method("mean", invariant_0d=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do masked arrays support NEP-18?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I don't see an
__array_function__
method inMaskedArray
, but it does inherit fromndarray
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does have an
__array_function__
attributeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, it does support NEP-18.
However, for some functions (e.g.
np.median
) it returns a strangeMaskedConstant
object if the result would be a constant. Not sure if it needs special care, but something to be aware of.