Skip to content

Commit

Permalink
Merge pull request #86 from daxiongshu/util_dot_support_cupy
Browse files Browse the repository at this point in the history
Fix and clean dask-glm.utils.dot for cupy and sparse input
  • Loading branch information
TomAugspurger authored Sep 25, 2020
2 parents 9eab0c2 + 27ee58f commit 69a947e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 24 deletions.
39 changes: 39 additions & 0 deletions dask_glm/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,42 @@ def test_dask_array_is_sparse():
"(https://github.com/pydata/sparse/issues/292)")
def test_dok_dask_array_is_sparse():
assert utils.is_dask_array_sparse(da.from_array(sparse.DOK((10, 10))))


def test_dot_with_cupy():
cupy = pytest.importorskip('cupy')

# dot(cupy.array, cupy.array)
A = cupy.random.rand(100, 100)
B = cupy.random.rand(100)
ans = cupy.dot(A, B)
res = utils.dot(A, B)
assert_eq(ans, res)

# dot(dask.array, cupy.array)
dA = da.from_array(A, chunks=(10, 100))
res = utils.dot(dA, B).compute()
assert_eq(ans, res)

# dot(cupy.array, dask.array)
dB = da.from_array(B, chunks=(10))
res = utils.dot(A, dB).compute()
assert_eq(ans, res)


def test_dot_with_sparse():
A = sparse.random((1024, 64))
B = sparse.random((64))
ans = sparse.dot(A, B)

# dot(sparse.array, sparse.array)
res = utils.dot(A, B)
assert_eq(ans, res)

# dot(sparse.array, dask.array)
res = utils.dot(A, da.from_array(B, chunks=B.shape))
assert_eq(ans, res.compute())

# dot(dask.array, sparse.array)
res = utils.dot(da.from_array(A, chunks=A.shape), B)
assert_eq(ans, res.compute())
27 changes: 3 additions & 24 deletions dask_glm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,33 +106,12 @@ def log1p(A):

@dispatch(object, object)
def dot(A, B):
x = max([A, B], key=lambda x: getattr(x, '__array_priority__', 0))
module = package_of(x)
return module.dot(A, B)


@dispatch(da.Array, np.ndarray)
def dot(A, B):
B = da.from_array(B, chunks=B.shape)
return da.dot(A, B)


@dispatch(np.ndarray, da.Array)
def dot(A, B):
A = da.from_array(A, chunks=A.shape)
return da.dot(A, B)


@dispatch(np.ndarray, np.ndarray)
def dot(A, B):
if isinstance(A, da.Array) or isinstance(B, da.Array):
A = da.asarray(A, chunks=A.shape)
B = da.asarray(B, chunks=B.shape)
return np.dot(A, B)


@dispatch(da.Array, da.Array)
def dot(A, B):
return da.dot(A, B)


@dispatch(object)
def sum(A):
return A.sum()
Expand Down

0 comments on commit 69a947e

Please sign in to comment.