Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

np.isnan triggers eager computation #218

Closed
ahuang11 opened this issue Oct 27, 2020 · 17 comments
Closed

np.isnan triggers eager computation #218

ahuang11 opened this issue Oct 27, 2020 · 17 comments

Comments

@ahuang11
Copy link
Member

I was reading about this:
pydata/xarray#4541

And I think xskillscore's also trigger an eager computation:
https://github.com/xarray-contrib/xskillscore/blob/master/xskillscore/core/np_deterministic.py#L23

@ahuang11
Copy link
Member Author

But maybe not; xskilscore's is nested under apply_ufunc

@raybellwaves
Copy link
Member

I don't think it does but hard to dissect the steps under the hood. Probably easiest is to explore in the binder: https://mybinder.org/v2/gh/raybellwaves/xskillscore-tutorial/master?urlpath=lab
using an example and see if the dashboard computes something earlier than expected.

@bradyrx
Copy link
Collaborator

bradyrx commented Oct 27, 2020

You could also just make a dask array that's much larger than your system's memory and see if it breaks when chunking and sending it through np.isnan. My guess is that apply_ufunc avoids eagerly evaluating it.

@aaronspring
Copy link
Collaborator

aaronspring commented Oct 28, 2020

dont we have tests that check this?

@pytest.mark.parametrize("metrics", distance_metrics)
@pytest.mark.parametrize("dim", AXES)
@pytest.mark.parametrize("weight_bool", [True, False])
@pytest.mark.parametrize("skipna", [True, False])
@pytest.mark.parametrize("has_nan", [True, False])
def test_distance_metrics_xr_dask(
a_dask, b_dask, dim, weight_bool, weights_dask, metrics, skipna, has_nan
):
"""Test whether distance metrics for xarray functions can be lazy when
chunked by using dask and give same results."""
a = a_dask.copy()
if has_nan:
a = a.load()
a[0] = np.nan
a = a.chunk()
b = b_dask.copy()
weights = weights_dask.copy()
# unpack metrics
metric, _metric = metrics
# Generates subsetted weights to pass in as arg to main function and for
# the numpy testing.
_weights = adjust_weights(dim, weight_bool, weights)
if _weights is not None:
_weights = _weights.load()
if metric is median_absolute_error:
actual = metric(a, b, dim, skipna=skipna)
else:
actual = metric(a, b, dim, weights=_weights, skipna=skipna)
# check that chunks for chunk inputs
assert actual.chunks is not None
if metric is median_absolute_error:
expected = metric(a.load(), b.load(), dim, skipna=skipna)
else:
expected = metric(a.load(), b.load(), dim, weights=_weights, skipna=skipna)
assert expected.chunks is None
assert_allclose(actual.compute(), expected)

seems not.

test_deterministic.py needs to be refactored to make it more readable to see whats missing. also by trying to have as many fixtures as possible, i.e. adding ('chunk', [True,False]) in

@pytest.mark.parametrize("metrics", distance_metrics)
@pytest.mark.parametrize("dim", AXES)
@pytest.mark.parametrize("weight_bool", [True, False])
@pytest.mark.parametrize("skipna", [True, False])
@pytest.mark.parametrize("has_nan", [True, False])
def test_distance_metrics_xr(a, b, dim, weight_bool, weights, metrics, skipna, has_nan):

@ahuang11
Copy link
Member Author

ahuang11 commented Nov 2, 2020

That doesn't actually test whether it triggers eager computation or not; only whether it works (even though it triggered an eager computation somewhere in between). Could add an overloading test that reflects Riley's comment: "You could also just make a dask array that's much larger than your system's memory and see if it breaks when chunking and sending it through np.isnan. My guess is that apply_ufunc avoids eagerly evaluating it."

@aaronspring
Copy link
Collaborator

xarray has some tests that check whether they trigger any computation.

Also in the PR you mention they seemed to have found a nice solution which could also be our solution.

@ahuang11
Copy link
Member Author

ahuang11 commented Nov 3, 2020

What Aaron is referencing:


def raise_if_dask_computes(max_computes=0):
    # return a dummy context manager so that this can be used for non-dask objects
    if not has_dask:
        return dummy_context()
    scheduler = CountingScheduler(max_computes)
    return dask.config.set(scheduler=scheduler)

with raise_if_dask_computes():
     weighted = data.weighted(weights)

https://github.com/pydata/xarray/pull/4559/files

@ahuang11
Copy link
Member Author

ahuang11 commented Nov 4, 2020

Ha it does trigger a computation!

import dask.array as da
import numpy as np
import xarray as xr
import xskillscore as xs


class CountingScheduler:
    """Simple dask scheduler counting the number of computes.
    Reference: https://stackoverflow.com/questions/53289286/"""

    def __init__(self, max_computes=0):
        self.total_computes = 0
        self.max_computes = max_computes

    def __call__(self, dsk, keys, **kwargs):
        self.total_computes += 1
        if self.total_computes > self.max_computes:
            raise RuntimeError(
                "Too many computes. Total: %d > max: %d."
                % (self.total_computes, self.max_computes)
            )
        return dask.get(dsk, keys, **kwargs)

def raise_if_dask_computes(max_computes=0):
    # return a dummy context manager so that this can be used for non-dask objects
    scheduler = CountingScheduler(max_computes)
    return dask.config.set(scheduler=scheduler)

value = dask.delayed(np.ones)(10000, 10000)
array = da.from_delayed(value, (100,100), dtype=float)

obs3 = xr.DataArray(
       array,
       dims=["lat", "lon"],
       name='var'
)
fct3 = xr.DataArray(
       array,
       dims=["lat", "lon"],
       name='var'
)
with raise_if_dask_computes():
    xs.rmse(obs3, fct3, skipna=True)

/mnt/c/Users/Solactus/GOOGLE~1/Bash/xskillscore/xskillscore/core/np_deterministic.py in _rmse(a, b, weights, axis, skipna)
    527     sumfunc, meanfunc = _get_numpy_funcs(skipna)
    528     if skipna:
--> 529         a, b, weights = _match_nans(a, b, weights)
    530     weights = _check_weights(weights)
    531 

/mnt/c/Users/Solactus/GOOGLE~1/Bash/xskillscore/xskillscore/core/np_deterministic.py in _match_nans(a, b, weights)
     34 
     35     """
---> 36     if np.isnan(a).any() or np.isnan(b).any():
     37         # Avoids mutating original arrays and bypasses read-only issue.
     38         a, b = a.copy(), b.copy()
...
<ipython-input-27-37dcfba2ae67> in __call__(self, dsk, keys, **kwargs)
     12             raise RuntimeError(
     13                 "Too many computes. Total: %d > max: %d."
---> 14                 % (self.total_computes, self.max_computes)
     15             )
     16         return dask.get(dsk, keys, **kwargs)

RuntimeError: Too many computes. Total: 1 > max: 0.

@ahuang11
Copy link
Member Author

ahuang11 commented Nov 4, 2020

After looking at the code and trying a few things, I think we may need to move call _match_nans in the xarray level (use map blocks like the PR is doing) instead of calling it from np_deterministic.

I tried the following, but the if statement still triggers it.


def _get_nan_funcs(*args):
    """
    Returns nansum and nanmean if skipna is True;
    Returns sum and mean if skipna is False.
    """
    import dask
    import dask.array as da

    if any(dask.is_dask_collection(arg) for arg in args):
        return da.isnan
    else:
        return np.isnan


def _match_nans(a, b, weights):
    """
    Considers missing values pairwise. If a value is missing
    in a, the corresponding value in b is turned to nan, and
    vice versa.

    Returns
    -------
    a, b, weights : ndarray
        a, b, and weights (if not None) with nans placed at
        pairwise locations.
    """
    isnan = _get_nan_funcs(a, b, weights)

    if isnan(a).any() or isnan(b).any():
        # Avoids mutating original arrays and bypasses read-only issue.
        a, b = a.copy(), b.copy()
        # Find pairwise indices in a and b that have nans.
        idx = np.logical_or(isnan(a), isnan(b))
        a[idx], b[idx] = np.nan, np.nan
        # https://github.com/xarray-contrib/xskillscore/issues/168
        if isinstance(weights, np.ndarray):
            if weights.shape:  # not None
                weights = weights.copy()
                weights[idx] = np.nan
    return a, b, weights

@aaronspring
Copy link
Collaborator

how about their .map_blocks solution? https://github.com/pydata/xarray/pull/4559/files#

@ahuang11
Copy link
Member Author

ahuang11 commented Nov 4, 2020 via email

@ahuang11
Copy link
Member Author

ahuang11 commented Nov 14, 2020

Actually, the above doesn't trigger computation (I think I might have accidentally called compute() above), it was the weights that trigger computation:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-944fb188883e> in <module>
     42 )
     43 with raise_if_dask_computes():
---> 44     xs.rmse(obs3, fct3, skipna=True, weights=obs3)

/mnt/c/Users/Solactus/GOOGLE~1/Bash/xskillscore/xskillscore/core/deterministic.py in rmse(a, b, dim, weights, skipna, keep_attrs)
    831     dim, axis = _preprocess_dims(dim, a)
    832     a, b = xr.broadcast(a, b, exclude=dim)
--> 833     weights = _preprocess_weights(a, dim, dim, weights)
    834     input_core_dims = _determine_input_core_dims(dim, weights)
    835 

/mnt/c/Users/Solactus/GOOGLE~1/Bash/xskillscore/xskillscore/core/utils.py in _preprocess_weights(a, dim, new_dim, weights)
     83     else:
     84         # Throw error if there are negative weights.
---> 85         if weights.min() < 0:
     86             raise ValueError(
     87                 "Weights has a minimum below 0. Please submit a weights array "

@aaronspring
Copy link
Collaborator

so setting the np.nan in weights[idx] is the bad guy?

@ahuang11
Copy link
Member Author

not quite. the if is evaluated eagerly.
if weights.min() < 0:

@bradyrx
Copy link
Collaborator

bradyrx commented Dec 6, 2020

Please ping me if you need any thoughts specifically from me on this. T minus one week on dissertation writing and then will be taking a short break after that. Just trying to pop in and address anything that folks are waiting on me for.

@ahuang11
Copy link
Member Author

Nah nothing from me, I still haven't tested because I'm working on a personal project at the moment which is absorbing all my time outside of work~

@raybellwaves
Copy link
Member

Closed via #224

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants