Skip to content

Commit

Permalink
Allow an ndarray of integers as a shape.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Mar 27, 2022
1 parent 3923946 commit 9d50c26
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
21 changes: 16 additions & 5 deletions sparse/_coo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,11 @@ def __init__(
if self.data.ndim != 1:
raise ValueError("data must be a scalar or 1-dimensional.")

if shape and not self.coords.size:
self.coords = np.zeros(
(len(shape) if isinstance(shape, Iterable) else 1, 0), dtype=np.intp
)

if shape is None:
warnings.warn(
"shape should be provided. This will raise a ValueError in the future.",
DeprecationWarning,
)
if self.coords.nbytes:
shape = tuple((self.coords.max(axis=1) + 1))
else:
Expand All @@ -255,6 +254,18 @@ def __init__(
if not isinstance(shape, Iterable):
shape = (shape,)

if isinstance(shape, np.ndarray):
shape = tuple(shape)

if shape and not self.coords.size:
warnings.warn(
"coords should be an ndarray. This will raise a ValueError in the future.",
DeprecationWarning,
)
self.coords = np.zeros(
(len(shape) if isinstance(shape, Iterable) else 1, 0), dtype=np.intp
)

super().__init__(shape, fill_value=fill_value)
if idx_dtype:
if not can_store(idx_dtype, max(shape)):
Expand Down
7 changes: 7 additions & 0 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,3 +1672,10 @@ def test_scalar_elemwise():
x1 = s1.todense()

assert_eq(s1 * x2, x1 * x2)


def test_array_as_shape():
coords = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]
data = [10, 20, 30, 40, 50]

s = sparse.COO(coords, data, shape=np.array((5, 5)))

0 comments on commit 9d50c26

Please sign in to comment.