Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test raise if compute #224

Merged
merged 18 commits into from
Jan 12, 2021
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Bug Fixes
~~~~~~~~~
- :py:func:`~xskillscore.sign_test` now works for ``xr.Dataset`` inputs.
(:issue:`198`, :pr:`199`) `Aaron Spring`_
- Passing weights no longer triggers eager computation. (:issue:`218`, :pr:`224`). `Andrew Huang`_

Internal Changes
~~~~~~~~~~~~~~~~
Expand Down
5 changes: 5 additions & 0 deletions xskillscore/core/np_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def _check_weights(weights):
elif np.all(np.isnan(weights)):
return None
else:
if weights.min() < 0:
raise ValueError(
"Weights has a minimum below 0. Please submit a weights array "
"of positive numbers."
)
return weights


Expand Down
6 changes: 0 additions & 6 deletions xskillscore/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,6 @@ def _preprocess_weights(a, dim, new_dim, weights):
if weights is None:
return None
else:
# Throw error if there are negative weights.
if weights.min() < 0:
raise ValueError(
"Weights has a minimum below 0. Please submit a weights array "
"of positive numbers."
)
# Scale weights to vary from 0 to 1.
weights = weights / weights.max()
# Check that the weights array has the same size
Expand Down
53 changes: 42 additions & 11 deletions xskillscore/tests/test_skipna_functionality.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dask
import numpy as np
import pytest
from xarray.tests import assert_allclose
Expand Down Expand Up @@ -34,6 +35,31 @@
NON_WEIGHTED_METRICS = [median_absolute_error]


class CountingScheduler:
"""Simple dask scheduler counting the number of computes.
Reference: https://stackoverflow.com/questions/53289286/"""

def __init__(self, max_computes=0):
self.total_computes = 0
self.max_computes = max_computes

def __call__(self, dsk, keys, **kwargs):
self.total_computes += 1
if self.total_computes > self.max_computes:
raise RuntimeError(
"Too many computes. Total: %d > max: %d."
% (self.total_computes, self.max_computes)
)
return dask.get(dsk, keys, **kwargs)


def raise_if_dask_computes(max_computes=0):
# return a dummy context manager so
# that this can be used for non-dask objects
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm being nit-picky, you should change this to a """ docstring for style.

scheduler = CountingScheduler(max_computes)
return dask.config.set(scheduler=scheduler)


def drop_nans(a, b, weights=None, dim="time"):
"""
Masks a and b where they have pairwise nans.
Expand All @@ -51,8 +77,9 @@ def test_skipna_returns_same_value_as_dropped_pairwise_nans(a_1d_nan, b_1d_nan,
"""Tests that DataArrays with pairwise nans return the same result
as the same two with those nans dropped."""
a_dropped, b_dropped, _ = drop_nans(a_1d_nan, b_1d_nan)
res_with_nans = metric(a_1d_nan, b_1d_nan, "time", skipna=True)
res_dropped_nans = metric(a_dropped, b_dropped, "time")
with raise_if_dask_computes():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thats a nice one. could be implemented allover the package

res_with_nans = metric(a_1d_nan, b_1d_nan, "time", skipna=True)
res_dropped_nans = metric(a_dropped, b_dropped, "time")
assert_allclose(res_with_nans, res_dropped_nans)


Expand All @@ -65,20 +92,22 @@ def test_skipna_returns_same_value_as_dropped_pairwise_nans_with_weights(
a_dropped, b_dropped, weights_time_dropped = drop_nans(
a_1d_nan, b_1d_nan, weights_time
)
res_with_nans = metric(
a_1d_nan, b_1d_nan, "time", skipna=True, weights=weights_time
)
res_dropped_nans = metric(
a_dropped, b_dropped, "time", weights=weights_time_dropped
)
with raise_if_dask_computes():
res_with_nans = metric(
a_1d_nan, b_1d_nan, "time", skipna=True, weights=weights_time
)
res_dropped_nans = metric(
a_dropped, b_dropped, "time", weights=weights_time_dropped
)
assert_allclose(res_with_nans, res_dropped_nans)


@pytest.mark.parametrize("metric", WEIGHTED_METRICS + NON_WEIGHTED_METRICS)
def test_skipna_returns_nan_when_false(a_1d_nan, b_1d_nan, metric):
"""Tests that nan is returned if there's any nans in the time series
and skipna is False."""
res = metric(a_1d_nan, b_1d_nan, "time", skipna=False)
with raise_if_dask_computes():
res = metric(a_1d_nan, b_1d_nan, "time", skipna=False)
assert np.isnan(res).all()


Expand All @@ -88,10 +117,12 @@ def test_skipna_broadcast_weights_assignment_destination(
):
"""Tests that 'assignment destination is read-only' is not raised
https://github.com/xarray-contrib/xskillscore/issues/79"""
metric(a_nan, b_nan, ["lat", "lon"], weights=weights_lonlat, skipna=True)
with raise_if_dask_computes():
metric(a_nan, b_nan, ["lat", "lon"], weights=weights_lonlat, skipna=True)


def test_nan_skipna(a, b):
# Randomly add some nans to a
a = a.where(np.random.random(a.shape) < 0.5)
pearson_r(a, b, dim="lat", skipna=True)
with raise_if_dask_computes():
pearson_r(a, b, dim="lat", skipna=True)