Skip to content

Commit

Permalink
API: Add kwargs to sparse.einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Jan 8, 2024
1 parent b8f2717 commit 2d89129
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
10 changes: 9 additions & 1 deletion sparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,7 +1393,7 @@ def _einsum_single(lhs, rhs, operand):
return to_output_format(COO(new_coords, new_data, shape=new_shape, has_duplicates=True))


def einsum(*operands):
def einsum(*operands, **kwargs):
"""
Perform the equivalent of :obj:`numpy.einsum`.
Expand All @@ -1406,6 +1406,11 @@ def einsum(*operands):
included as well as subscript labels of the precise output form.
operands : sequence of SparseArray
These are the arrays for the operation.
dtype : data-type, optional
If provided, forces the calculation to use the data type specified.
Default is ``None``.
**kwargs : dict, optional
Any additional arguments to pass to the function.
Returns
-------
Expand All @@ -1417,6 +1422,9 @@ def einsum(*operands):

check_zero_fill_value(*operands)

if "dtype" in kwargs and kwargs["dtype"] is not None:
operands = [o.astype(kwargs["dtype"]) for o in operands]

if len(operands) == 1:
return _einsum_single(lhs, rhs, operands[0])

Expand Down
12 changes: 12 additions & 0 deletions sparse/tests/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,15 @@ def test_einsum_shape_check():
y = sparse.random((2, 3, 4), density=0.5)
with pytest.raises(ValueError):
sparse.einsum("abc,acb", x, y)


@pytest.mark.parametrize("dtype", [np.int64, np.complex128])
def test_einsum_dtype(dtype):
x = sparse.random((3, 3), density=0.5) * 10.
x = x.astype(np.float64)

y = sparse.COO.from_numpy(np.ones((3,1), dtype=np.float64))

result = sparse.einsum("ij,i->j", x, y, dtype=dtype)

assert result.dtype == dtype

0 comments on commit 2d89129

Please sign in to comment.