Skip to content

Commit

Permalink
Merge pull request #658 from pydata/asarray-update
Browse files Browse the repository at this point in the history
API: Update `asarray` function
  • Loading branch information
mtsokol authored Mar 29, 2024
2 parents f95eeb3 + 63a8593 commit 9a8b31a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 8 deletions.
2 changes: 2 additions & 0 deletions sparse/finch_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from finch import (
add,
asarray,
astype,
bool,
compiled,
Expand Down Expand Up @@ -38,6 +39,7 @@

__all__ = [
"add",
"asarray",
"astype",
"bool",
"compiled",
Expand Down
8 changes: 4 additions & 4 deletions sparse/pydata_backend/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2033,16 +2033,16 @@ def asarray(obj, /, *, dtype=None, format="coo", device=None, copy=False):
<COO: shape=(8, 8), dtype=int64, nnz=8, fill_value=0>
"""

if format not in {"coo", "dok", "gcxs"}:
if format not in {"coo", "dok", "gcxs", "csc", "csr"}:
raise ValueError(f"{format} format not supported.")

from ._compressed import GCXS
from ._compressed import CSC, CSR, GCXS
from ._coo import COO
from ._dok import DOK

format_dict = {"coo": COO, "dok": DOK, "gcxs": GCXS}
format_dict = {"coo": COO, "dok": DOK, "gcxs": GCXS, "csc": CSC, "csr": CSR}

if isinstance(obj, COO | DOK | GCXS):
if isinstance(obj, COO | DOK | GCXS | CSC | CSR):
return obj.asformat(format)

if _is_scipy_sparse_obj(obj):
Expand Down
23 changes: 19 additions & 4 deletions sparse/pydata_backend/_compressed/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,8 @@ def isnan(self):


class _Compressed2d(GCXS):
class_compressed_axes: tuple[int]

def __init__(self, arg, shape=None, compressed_axes=None, prune=False, fill_value=0):
if not hasattr(arg, "shape") and shape is None:
raise ValueError("missing `shape` argument")
Expand Down Expand Up @@ -847,6 +849,11 @@ def __str__(self):
def ndim(self) -> int:
return 2

@classmethod
def from_numpy(cls, x, fill_value=0, idx_dtype=None):
coo = COO.from_numpy(x, fill_value=fill_value, idx_dtype=idx_dtype)
return cls.from_coo(coo, cls.class_compressed_axes, idx_dtype)


class CSR(_Compressed2d):
"""
Expand All @@ -857,8 +864,12 @@ class CSR(_Compressed2d):
Sparse supports 2-D CSR.
"""

def __init__(self, arg, shape=None, prune=False, fill_value=0):
super().__init__(arg, shape=shape, compressed_axes=(0,), fill_value=fill_value)
class_compressed_axes: tuple[int] = (0,)

def __init__(self, arg, shape=None, compressed_axes=class_compressed_axes, prune=False, fill_value=0):
if compressed_axes != self.class_compressed_axes:
raise ValueError(f"CSR only accepts rows as compressed axis but got: {compressed_axes}")
super().__init__(arg, shape=shape, compressed_axes=compressed_axes, fill_value=fill_value)

@classmethod
def from_scipy_sparse(cls, x):
Expand All @@ -882,8 +893,12 @@ class CSC(_Compressed2d):
Sparse supports 2-D CSC.
"""

def __init__(self, arg, shape=None, prune=False, fill_value=0):
super().__init__(arg, shape=shape, compressed_axes=(1,), fill_value=fill_value)
class_compressed_axes: tuple[int] = (1,)

def __init__(self, arg, shape=None, compressed_axes=class_compressed_axes, prune=False, fill_value=0):
if compressed_axes != self.class_compressed_axes:
raise ValueError(f"CSC only accepts columns as compressed axis but got: {compressed_axes}")
super().__init__(arg, shape=shape, compressed_axes=compressed_axes, fill_value=fill_value)

@classmethod
def from_scipy_sparse(cls, x):
Expand Down
11 changes: 11 additions & 0 deletions sparse/tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sparse

import pytest

import numpy as np
import scipy.sparse as sp
from numpy.testing import assert_equal
Expand Down Expand Up @@ -54,3 +56,12 @@ def my_fun(tns1, tns2):
result = my_fun(finch_dense, finch_arr)

assert_equal(result.todense(), np.sum(2 * np_eye, axis=0))


@pytest.mark.parametrize("format", ["csc", "csr", "coo"])
def test_asarray(backend, format):
arr = np.eye(5)

result = sparse.asarray(arr, format=format)

assert_equal(result.todense(), arr)

0 comments on commit 9a8b31a

Please sign in to comment.