Skip to content

Commit

Permalink
API: Add set functions [Array API] (#619)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol authored Jan 8, 2024
1 parent b8f2717 commit 925112f
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 3 deletions.
4 changes: 4 additions & 0 deletions sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
roll,
tril,
triu,
unique_counts,
unique_values,
where,
)
from ._dok import DOK
Expand Down Expand Up @@ -114,6 +116,8 @@
"min",
"max",
"nanreduce",
"unique_counts",
"unique_values",
]

__array_api_version__ = "2022.12"
4 changes: 4 additions & 0 deletions sparse/_coo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
stack,
tril,
triu,
unique_counts,
unique_values,
where,
)
from .core import COO, as_coo
Expand Down Expand Up @@ -49,4 +51,6 @@
"result_type",
"diagonal",
"diagonalize",
"unique_counts",
"unique_values",
]
110 changes: 107 additions & 3 deletions sparse/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from collections.abc import Iterable
from functools import reduce
from typing import Optional, Tuple
from typing import NamedTuple, Optional, Tuple

import numba

Expand Down Expand Up @@ -1059,6 +1059,106 @@ def clip(a, a_min=None, a_max=None, out=None):
return a.clip(a_min, a_max)


# Array API set functions


class UniqueCountsResult(NamedTuple):
values: np.ndarray
counts: np.ndarray


def unique_counts(x, /):
"""
Returns the unique elements of an input array `x`, and the corresponding
counts for each unique element in `x`.
Parameters
----------
x : COO
Input COO array. It will be flattened if it is not already 1-D.
Returns
-------
out : namedtuple
The result containing:
* values - The unique elements of an input array.
* counts - The corresponding counts for each unique element.
Raises
------
ValueError
If the input array is in a different format than COO.
Examples
--------
>>> import sparse
>>> x = sparse.COO.from_numpy([1, 0, 2, 1, 2, -3])
>>> sparse.unique_counts(x)
UniqueCountsResult(values=array([-3, 0, 1, 2]), counts=array([1, 1, 2, 2]))
"""
from .core import COO

if isinstance(x, scipy.sparse.spmatrix):
x = COO.from_scipy_sparse(x)
elif not isinstance(x, SparseArray):
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
elif not isinstance(x, COO):
x = x.asformat(COO)

x = x.flatten()
values, counts = np.unique(x.data, return_counts=True)
if x.nnz < x.size:
values = np.concatenate([[x.fill_value], values])
counts = np.concatenate([[x.size - x.nnz], counts])
sorted_indices = np.argsort(values)
values[sorted_indices] = values.copy()
counts[sorted_indices] = counts.copy()

return UniqueCountsResult(values, counts)


def unique_values(x, /):
"""
Returns the unique elements of an input array `x`.
Parameters
----------
x : COO
Input COO array. It will be flattened if it is not already 1-D.
Returns
-------
out : ndarray
The unique elements of an input array.
Raises
------
ValueError
If the input array is in a different format than COO.
Examples
--------
>>> import sparse
>>> x = sparse.COO.from_numpy([1, 0, 2, 1, 2, -3])
>>> sparse.unique_values(x)
array([-3, 0, 1, 2])
"""
from .core import COO

if isinstance(x, scipy.sparse.spmatrix):
x = COO.from_scipy_sparse(x)
elif not isinstance(x, SparseArray):
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
elif not isinstance(x, COO):
x = x.asformat(COO)

x = x.flatten()
values = np.unique(x.data)
if x.nnz < x.size:
values = np.sort(np.concatenate([[x.fill_value], values]))
return values


@numba.jit(nopython=True, nogil=True)
def _compute_minmax_args(
coords: np.ndarray,
Expand Down Expand Up @@ -1121,8 +1221,12 @@ def _arg_minmax_common(

from .core import COO

if not isinstance(x, COO):
raise ValueError(f"Only COO arrays are supported but {type(x)} was passed.")
if isinstance(x, scipy.sparse.spmatrix):
x = COO.from_scipy_sparse(x)
elif not isinstance(x, SparseArray):
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
elif not isinstance(x, COO):
x = x.asformat(COO)

if not isinstance(axis, (int, type(None))):
raise ValueError(f"`axis` must be `int` or `None`, but it's: {type(axis)}.")
Expand Down
32 changes: 32 additions & 0 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,3 +1745,35 @@ def test_squeeze_validation(self):

with pytest.raises(ValueError, match="Specified axis `0` has a size greater than one: 3"):
s_arr.squeeze(0)


class TestUnique:
arr = np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64)
arr_empty = np.zeros((5, 5))
arr_full = np.arange(1, 10)

@pytest.mark.parametrize("arr", [arr, arr_empty, arr_full])
@pytest.mark.parametrize("fill_value", [-1, 0, 1])
def test_unique_counts(self, arr, fill_value):
s_arr = sparse.COO.from_numpy(arr, fill_value)

result_values, result_counts = sparse.unique_counts(s_arr)
expected_values, expected_counts = np.unique(arr, return_counts=True)

np.testing.assert_equal(result_values, expected_values)
np.testing.assert_equal(result_counts, expected_counts)

@pytest.mark.parametrize("arr", [arr, arr_empty, arr_full])
@pytest.mark.parametrize("fill_value", [-1, 0, 1])
def test_unique_values(self, arr, fill_value):
s_arr = sparse.COO.from_numpy(arr, fill_value)

result = sparse.unique_values(s_arr)
expected = np.unique(arr)

np.testing.assert_equal(result, expected)

@pytest.mark.parametrize("func", [sparse.unique_counts, sparse.unique_values])
def test_input_validation(self, func):
with pytest.raises(ValueError, match=r"Input must be an instance of SparseArray"):
func(self.arr)

0 comments on commit 925112f

Please sign in to comment.