Skip to content

Commit

Permalink
Merge pull request #224 from ahuang11/no_compute
Browse files Browse the repository at this point in the history
  • Loading branch information
raybellwaves authored Jan 12, 2021
2 parents ef46406 + 1348a72 commit 7f04738
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Bug Fixes
~~~~~~~~~
- :py:func:`~xskillscore.sign_test` now works for ``xr.Dataset`` inputs.
(:issue:`198`, :pr:`199`) `Aaron Spring`_
- Passing weights no longer triggers eager computation.
(:issue:`218`, :pr:`224`). `Andrew Huang`_

Internal Changes
~~~~~~~~~~~~~~~~
Expand Down
5 changes: 5 additions & 0 deletions xskillscore/core/np_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def _check_weights(weights):
elif np.all(np.isnan(weights)):
return None
else:
if weights.min() < 0:
raise ValueError(
"Weights has a minimum below 0. Please submit a weights array "
"of positive numbers."
)
return weights


Expand Down
6 changes: 0 additions & 6 deletions xskillscore/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,6 @@ def _preprocess_weights(a, dim, new_dim, weights):
if weights is None:
return None
else:
# Throw error if there are negative weights.
if weights.min() < 0:
raise ValueError(
"Weights has a minimum below 0. Please submit a weights array "
"of positive numbers."
)
# Scale weights to vary from 0 to 1.
weights = weights / weights.max()
# Check that the weights array has the same size
Expand Down
55 changes: 44 additions & 11 deletions xskillscore/tests/test_skipna_functionality.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dask
import numpy as np
import pytest
from xarray.tests import assert_allclose
Expand Down Expand Up @@ -34,6 +35,33 @@
NON_WEIGHTED_METRICS = [median_absolute_error]


class CountingScheduler:
"""Simple dask scheduler counting the number of computes.
Reference: https://stackoverflow.com/questions/53289286/"""

def __init__(self, max_computes=0):
self.total_computes = 0
self.max_computes = max_computes

def __call__(self, dsk, keys, **kwargs):
self.total_computes += 1
if self.total_computes > self.max_computes:
raise RuntimeError(
"Too many computes. Total: %d > max: %d."
% (self.total_computes, self.max_computes)
)
return dask.get(dsk, keys, **kwargs)


def raise_if_dask_computes(max_computes=0):
"""
Return a dummy context manager so
that this can be used for non-dask objects
"""
scheduler = CountingScheduler(max_computes)
return dask.config.set(scheduler=scheduler)


def drop_nans(a, b, weights=None, dim="time"):
"""
Masks a and b where they have pairwise nans.
Expand All @@ -51,8 +79,9 @@ def test_skipna_returns_same_value_as_dropped_pairwise_nans(a_1d_nan, b_1d_nan,
"""Tests that DataArrays with pairwise nans return the same result
as the same two with those nans dropped."""
a_dropped, b_dropped, _ = drop_nans(a_1d_nan, b_1d_nan)
res_with_nans = metric(a_1d_nan, b_1d_nan, "time", skipna=True)
res_dropped_nans = metric(a_dropped, b_dropped, "time")
with raise_if_dask_computes():
res_with_nans = metric(a_1d_nan, b_1d_nan, "time", skipna=True)
res_dropped_nans = metric(a_dropped, b_dropped, "time")
assert_allclose(res_with_nans, res_dropped_nans)


Expand All @@ -65,20 +94,22 @@ def test_skipna_returns_same_value_as_dropped_pairwise_nans_with_weights(
a_dropped, b_dropped, weights_time_dropped = drop_nans(
a_1d_nan, b_1d_nan, weights_time
)
res_with_nans = metric(
a_1d_nan, b_1d_nan, "time", skipna=True, weights=weights_time
)
res_dropped_nans = metric(
a_dropped, b_dropped, "time", weights=weights_time_dropped
)
with raise_if_dask_computes():
res_with_nans = metric(
a_1d_nan, b_1d_nan, "time", skipna=True, weights=weights_time
)
res_dropped_nans = metric(
a_dropped, b_dropped, "time", weights=weights_time_dropped
)
assert_allclose(res_with_nans, res_dropped_nans)


@pytest.mark.parametrize("metric", WEIGHTED_METRICS + NON_WEIGHTED_METRICS)
def test_skipna_returns_nan_when_false(a_1d_nan, b_1d_nan, metric):
"""Tests that nan is returned if there's any nans in the time series
and skipna is False."""
res = metric(a_1d_nan, b_1d_nan, "time", skipna=False)
with raise_if_dask_computes():
res = metric(a_1d_nan, b_1d_nan, "time", skipna=False)
assert np.isnan(res).all()


Expand All @@ -88,10 +119,12 @@ def test_skipna_broadcast_weights_assignment_destination(
):
"""Tests that 'assignment destination is read-only' is not raised
https://github.com/xarray-contrib/xskillscore/issues/79"""
metric(a_nan, b_nan, ["lat", "lon"], weights=weights_lonlat, skipna=True)
with raise_if_dask_computes():
metric(a_nan, b_nan, ["lat", "lon"], weights=weights_lonlat, skipna=True)


def test_nan_skipna(a, b):
# Randomly add some nans to a
a = a.where(np.random.random(a.shape) < 0.5)
pearson_r(a, b, dim="lat", skipna=True)
with raise_if_dask_computes():
pearson_r(a, b, dim="lat", skipna=True)

0 comments on commit 7f04738

Please sign in to comment.