Skip to content

Commit

Permalink
BUG: Fix CSR/CSC matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Apr 8, 2024
1 parent 18c1596 commit 833e3b1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
19 changes: 13 additions & 6 deletions sparse/pydata_backend/_compressed/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import operator
from collections.abc import Iterable
from functools import reduce
from typing import Union

import numpy as np
from numpy.lib.mixins import NDArrayOperatorsMixin
Expand Down Expand Up @@ -876,11 +877,14 @@ def from_scipy_sparse(cls, x):
x = x.asformat("csr", copy=False)
return cls((x.data, x.indices, x.indptr), shape=x.shape)

def transpose(self, axes: None = None, copy: bool = False) -> "CSC":
if axes is not None:
raise ValueError
def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"]:
axes = normalize_axis(axes, self.ndim)
if axes not in [(0, 1), (1, 0), None]:
raise ValueError(f"Invalid transpose axes: {axes}")
if copy:
self = self.copy()
if axes == (0, 1):
return self
return CSC((self.data, self.indices, self.indptr), self.shape[::-1])


Expand All @@ -905,9 +909,12 @@ def from_scipy_sparse(cls, x):
x = x.asformat("csc", copy=False)
return cls((x.data, x.indices, x.indptr), shape=x.shape)

def transpose(self, axes: None = None, copy: bool = False) -> CSR:
if axes is not None:
raise ValueError
def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"]:
axes = normalize_axis(axes, self.ndim)
if axes not in [(0, 1), (1, 0), None]:
raise ValueError(f"Invalid transpose axes: {axes}")
if copy:
self = self.copy()
if axes == (0, 1):
return self
return CSR((self.data, self.indices, self.indptr), self.shape[::-1])
16 changes: 15 additions & 1 deletion sparse/pydata_backend/tests/test_compressed_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def data_rvs(n):

else:
data_rvs = None
return cls(sparse.random((20, 30, 40), density=0.25, data_rvs=data_rvs).astype(dtype))
return cls(sparse.random((20, 20), density=0.25, data_rvs=data_rvs).astype(dtype))


def test_repr(random_sparse):
Expand Down Expand Up @@ -111,7 +111,21 @@ def test_transpose(random_sparse, copy):
assert_eq(random_sparse, tt)
assert type(random_sparse) == type(tt)

assert_eq(random_sparse.transpose(axes=(0, 1)), random_sparse)
assert_eq(random_sparse.transpose(axes=(1, 0)), t)
with pytest.raises(ValueError, match="Invalid transpose axes"):
random_sparse.transpose(axes=0)


def test_transpose_error(random_sparse):
with pytest.raises(ValueError):
random_sparse.transpose(axes=1)


def test_matmul(random_sparse_small):
arr = random_sparse_small.todense()

actual = random_sparse_small @ random_sparse_small
expected = arr @ arr

assert_eq(actual, expected)

0 comments on commit 833e3b1

Please sign in to comment.