diff --git a/src/openeo_processes/arrays.py b/src/openeo_processes/arrays.py index 4c76b52a..723c0b45 100644 --- a/src/openeo_processes/arrays.py +++ b/src/openeo_processes/arrays.py @@ -13,6 +13,12 @@ from openeo_processes.errors import ArrayElementParameterConflict from openeo_processes.errors import GenericError +try: + from xarray_extras.sort import topk, argtopk +except ImportError: + topk = None + argtopk = None + ######################################################################################################################## # Array Contains Process ######################################################################################################################## @@ -1172,28 +1178,15 @@ def exec_xar(data, asc=True, nodata=None, dimension=None): dimension_str = data.dims[dimension] if nodata is None: data = data.dropna(dimension_str) - data_t = data.transpose(dimension_str, ...) - order = np.zeros(data_t.shape) - i = 0 + k = len(data[dimension_str].values) + if (asc and not nodata) or (not asc and nodata): + fill = data.min() - 1 + data = data.fillna(fill) + order = argtopk(data, k = k, dim = dimension_str) if asc: - if nodata: - data_t = data_t.fillna(data_t.max() + 1) - if not nodata: - data_t = data_t.fillna(data_t.min() - 1) - while i < len(order): - order[i] = data_t.argmin(dimension_str) - data_t[data_t.argmin(dimension_str)] = data_t.max() + 2 - i += 1 - else: - if nodata: - data_t = data_t.fillna(data_t.min() - 1) - if not nodata: - data_t = data_t.fillna(data_t.max() + 1) - while i < len(order): - order[i] = data_t.argmax(dimension_str) - data_t[data_t.argmax(dimension_str)] = data_t.min() - 2 - i += 1 - order = xr.DataArray(order, coords=data_t.coords, dims=data_t.dims, attrs=data.attrs, name=data.name) + r = order[dimension_str].values + r = np.flip(r) + order = order.loc[{dimension_str: r}] order = order.transpose(*data.dims) return order @@ -1391,32 +1384,22 @@ def exec_xar(data, asc=True, nodata=None, dimension=None): dimension_str = data.dims[dimension] if nodata is None: data = data.dropna(dimension_str) - data_t = data.transpose(dimension_str, ...) - sort = np.zeros(data_t.shape) - i = 0 + fill = None if asc: - if nodata: - data_t = data_t.fillna(data_t.max() + 1) + k = (-1)*len(data[dimension_str].values) if not nodata: - data_t = data_t.fillna(data_t.min() - 1) - while i < len(sort): - sort[i] = data_t.min(dimension_str) - data_t[data_t.argmin(dimension_str)] = data_t.max() + 2 - i += 1 + fill = data.min()-1 + data = data.fillna(fill) else: + k = len(data[dimension_str].values) if nodata: - data_t = data_t.fillna(data_t.min() - 1) - if not nodata: - data_t = data_t.fillna(data_t.max() + 1) - while i < len(sort): - sort[i] = data_t.max(dimension_str) - data_t[data_t.argmax(dimension_str)] = data_t.min() - 2 - i += 1 - sort = xr.DataArray(sort, coords=data_t.coords, dims=data_t.dims, attrs=data.attrs, name=data.name) - sort = sort.where(sort != data.max() + 1, np.nan) - sort = sort.where(sort != data.min() - 1, np.nan) - sort = sort.transpose(*data.dims) - return sort + fill = data.min() - 1 + data = data.fillna(fill) + sorted = topk(data, k = k, dim = dimension_str) + sorted = sorted.transpose(*data.dims) + if fill is not None: + sorted = sorted.where(sorted != fill, np.nan) + return sorted @staticmethod def exec_da(): diff --git a/tests/test_arrays.py b/tests/test_arrays.py index 3e321c81..47166e5d 100644 --- a/tests/test_arrays.py +++ b/tests/test_arrays.py @@ -167,12 +167,8 @@ def test_order(self): [9, 10, 7, 4, 0, 5, 8, 2, 1, 3, 6]) self.assertListEqual(oeop.order([6, -1, 2, np.nan, 7, 4, np.nan, 8, 3, 9, 9], asc=False, nodata=False).tolist(), [6, 3, 9, 10, 7, 4, 0, 5, 8, 2, 1]) - xr.testing.assert_equal(oeop.order(self.test_data.xr_data_factor(3, 5), dimension='time'), - self.test_data.xr_data_factor(0, 1)) - xr.testing.assert_equal(oeop.order(self.test_data.xr_data_factor(3, 5), dimension='time', asc=False), - self.test_data.xr_data_factor(1, 0)) - xr.testing.assert_equal(oeop.order(self.test_data.xr_data_factor(3, np.nan), dimension='time', nodata=False), - self.test_data.xr_data_factor(1, 0)) + assert (oeop.order(self.test_data.xr_data_factor(3, 5), dimension='time') == self.test_data.xr_data_factor(0, 1).values).all() + assert (oeop.order(self.test_data.xr_data_factor(3, 5), dimension='time', asc=False) == self.test_data.xr_data_factor(1, 0).values).all() def test_rearrange(self): """ Tests `rearrange` function. """ @@ -188,12 +184,8 @@ def test_sort(self): [-1, 2, 3, 4, 6, 7, 8, 9, 9]) assert np.isclose(oeop.sort([6, -1, 2, np.nan, 7, 4, np.nan, 8, 3, 9, 9], asc=False, nodata=True), [9, 9, 8, 7, 6, 4, 3, 2, -1, np.nan, np.nan], equal_nan=True).all() - xr.testing.assert_equal(oeop.sort(self.test_data.xr_data_factor(5, 3), dimension='time'), - self.test_data.xr_data_factor(3, 5)) - xr.testing.assert_equal(oeop.sort(self.test_data.xr_data_factor(3, 5), dimension='time', asc=False ), - self.test_data.xr_data_factor(5, 3)) - xr.testing.assert_equal(oeop.sort(self.test_data.xr_data_factor(np.nan, 5), dimension='time', nodata=True), - self.test_data.xr_data_factor(5, np.nan)) + assert (oeop.sort(self.test_data.xr_data_factor(5, 3), dimension='time') == self.test_data.xr_data_factor(3, 5).values).all() + assert (oeop.sort(self.test_data.xr_data_factor(3, 5), dimension='time', asc=False) == self.test_data.xr_data_factor(5, 3).values).all() def test_mask(self): """ Tests `mask` function. """