Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change parsing of the input in einsum #579

Merged
merged 6 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 181 additions & 16 deletions sparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import scipy.sparse
from functools import wraps, reduce
from itertools import chain
from operator import mul
from operator import mul, index
from collections.abc import Iterable
from scipy.sparse import spmatrix
from numba import literal_unroll
Expand Down Expand Up @@ -1212,6 +1212,182 @@ def _dot_ndarray_coo(array1, coords2, data2, out_shape): # pragma: no cover
return _dot_ndarray_coo


# Copied from : https://github.com/numpy/numpy/blob/59fec4619403762a5d785ad83fcbde5a230416fc/numpy/core/einsumfunc.py#L523
# under BSD-3-Clause license : https://github.com/numpy/numpy/blob/v1.24.0/LICENSE.txt
def _parse_einsum_input(operands):
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved
"""
A copy of the numpy parse_einsum_input that
does not cast the operands to numpy array.

Returns
-------
input_strings : str
Parsed input strings
output_string : str
Parsed output string
operands : list of array_like
The operands to use in the numpy contraction
Examples
--------
The operand list is simplified to reduce printing:
>>> np.random.seed(123)
>>> a = np.random.rand(4, 4)
>>> b = np.random.rand(4, 4, 4)
>>> _parse_einsum_input(('...a,...a->...', a, b))
('za,xza', 'xz', [a, b]) # may vary
>>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
('za,xza', 'xz', [a, b]) # may vary
"""

if len(operands) == 0:
raise ValueError("No input operands")

if isinstance(operands[0], str):
subscripts = operands[0].replace(" ", "")
operands = [v for v in operands[1:]]

# Ensure all characters are valid
for s in subscripts:
if s in ".,->":
continue
if s not in np.core.einsumfunc.einsum_symbols:
raise ValueError("Character %s is not a valid symbol." % s)

else:
tmp_operands = list(operands)
operand_list = []
subscript_list = []
for p in range(len(operands) // 2):
operand_list.append(tmp_operands.pop(0))
subscript_list.append(tmp_operands.pop(0))

output_list = tmp_operands[-1] if len(tmp_operands) else None
operands = [v for v in operand_list]
subscripts = ""
last = len(subscript_list) - 1
for num, sub in enumerate(subscript_list):
for s in sub:
if s is Ellipsis:
subscripts += "..."
else:
try:
s = index(s)
except TypeError as e:
raise TypeError(
"For this input type lists must contain "
"either int or Ellipsis"
) from e
subscripts += np.core.einsumfunc.einsum_symbols[s]
if num != last:
subscripts += ","

if output_list is not None:
subscripts += "->"
for s in output_list:
if s is Ellipsis:
subscripts += "..."
else:
try:
s = index(s)
except TypeError as e:
raise TypeError(
"For this input type lists must contain "
"either int or Ellipsis"
) from e
subscripts += np.core.einsumfunc.einsum_symbols[s]
# Check for proper "->"
if ("-" in subscripts) or (">" in subscripts):
invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
if invalid or (subscripts.count("->") != 1):
raise ValueError("Subscripts can only contain one '->'.")

# Parse ellipses
if "." in subscripts:
used = subscripts.replace(".", "").replace(",", "").replace("->", "")
unused = list(np.core.einsumfunc.einsum_symbols_set - set(used))
ellipse_inds = "".join(unused)
longest = 0

if "->" in subscripts:
input_tmp, output_sub = subscripts.split("->")
split_subscripts = input_tmp.split(",")
out_sub = True
else:
split_subscripts = subscripts.split(",")
out_sub = False

for num, sub in enumerate(split_subscripts):
if "." in sub:
if (sub.count(".") != 3) or (sub.count("...") != 1):
raise ValueError("Invalid Ellipses.")

# Take into account numerical values
if operands[num].shape == ():
ellipse_count = 0
else:
ellipse_count = max(operands[num].ndim, 1)
ellipse_count -= len(sub) - 3

if ellipse_count > longest:
longest = ellipse_count

if ellipse_count < 0:
raise ValueError("Ellipses lengths do not match.")
elif ellipse_count == 0:
split_subscripts[num] = sub.replace("...", "")
else:
rep_inds = ellipse_inds[-ellipse_count:]
split_subscripts[num] = sub.replace("...", rep_inds)

subscripts = ",".join(split_subscripts)
if longest == 0:
out_ellipse = ""
else:
out_ellipse = ellipse_inds[-longest:]

if out_sub:
subscripts += "->" + output_sub.replace("...", out_ellipse)
else:
# Special care for outputless ellipses
output_subscript = ""
tmp_subscripts = subscripts.replace(",", "")
for s in sorted(set(tmp_subscripts)):
if s not in (np.core.einsumfunc.einsum_symbols):
raise ValueError("Character %s is not a valid symbol." % s)
if tmp_subscripts.count(s) == 1:
output_subscript += s
normal_inds = "".join(sorted(set(output_subscript) - set(out_ellipse)))

subscripts += "->" + out_ellipse + normal_inds

# Build output string if does not exist
if "->" in subscripts:
input_subscripts, output_subscript = subscripts.split("->")
else:
input_subscripts = subscripts
# Build output subscripts
tmp_subscripts = subscripts.replace(",", "")
output_subscript = ""
for s in sorted(set(tmp_subscripts)):
if s not in np.core.einsumfunc.einsum_symbols:
raise ValueError("Character %s is not a valid symbol." % s)
if tmp_subscripts.count(s) == 1:
output_subscript += s

# Make sure output subscripts are in the input
for char in output_subscript:
if char not in input_subscripts:
raise ValueError("Output character %s did not appear in the input" % char)

# Make sure number operands is equivalent to the number of terms
if len(input_subscripts.split(",")) != len(operands):
raise ValueError(
"Number of einsum subscripts must be equal to the " "number of operands."
)

return (input_subscripts, output_subscript, operands)


def _einsum_single(lhs, rhs, operand):
"""Perform a single term einsum, i.e. any combination of transposes, sums
and traces of dimensions.
Expand Down Expand Up @@ -1287,7 +1463,7 @@ def _einsum_single(lhs, rhs, operand):
)


def einsum(subscripts, *operands):
def einsum(*operands):
"""
Perform the equivalent of :obj:`numpy.einsum`.

Expand All @@ -1306,21 +1482,10 @@ def einsum(subscripts, *operands):
output : SparseArray
The calculation based on the Einstein summation convention.
"""
check_zero_fill_value(*operands)

if "->" not in subscripts:
# from opt_einsum: calc the output automatically
lhs = subscripts
tmp_subscripts = lhs.replace(",", "")
rhs = "".join(
# sorted sequence of indices
s
for s in sorted(set(tmp_subscripts))
# that appear exactly once
if tmp_subscripts.count(s) == 1
)
else:
lhs, rhs = subscripts.split("->")
lhs, rhs, operands = _parse_einsum_input(operands) # Parse input

check_zero_fill_value(*operands)

if len(operands) == 1:
return _einsum_single(lhs, rhs, operands[0])
Expand Down
45 changes: 45 additions & 0 deletions sparse/tests/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
"dba,ead,cad->bce",
"aef,fbc,dca->bde",
"abab->ba",
"...ab,...ab",
"...ab,...b->...a",
"a...,a...",
"a...,a...",
]


Expand All @@ -94,12 +98,53 @@ def test_einsum(subscripts, density):
assert np.allclose(numpy_out, sparse_out.todense())


@pytest.mark.parametrize(
"input", [[[0, 0]], [[0, Ellipsis]], [[Ellipsis, 1], [Ellipsis]], [[0, 1], [0]]]
)
@pytest.mark.parametrize("density", [0.1, 1.0])
def test_einsum_nosubscript(input, density):
d = 4
arrays = [sparse.random((d, d), density=density)]
sparse_out = sparse.einsum(*arrays, *input)
numpy_out = np.einsum(*(s.todense() for s in arrays), *input)

if not numpy_out.shape:
# scalar output
assert np.allclose(numpy_out, sparse_out)
else:
# array output
assert np.allclose(numpy_out, sparse_out.todense())


def test_einsum_input_fill_value():
x = sparse.random(shape=(2,), density=0.5, format="coo", fill_value=2)
with pytest.raises(ValueError):
sparse.einsum("cba", x)


def test_einsum_no_input():
with pytest.raises(ValueError):
sparse.einsum()


@pytest.mark.parametrize(
"subscript", ["a+b->c", "i->&", "i->ij", "ij->jij", "a..,a...", ".i...", "a,a->->"]
)
def test_einsum_invalid_input(subscript):
x = sparse.random(shape=(2,), density=0.5, format="coo")
y = sparse.random(shape=(2,), density=0.5, format="coo")
with pytest.raises(ValueError):
sparse.einsum(subscript, x, y)


@pytest.mark.parametrize("subscript", [0, [0, 0]])
def test_einsum_type_error(subscript):
x = sparse.random(shape=(2,), density=0.5, format="coo")
y = sparse.random(shape=(2,), density=0.5, format="coo")
with pytest.raises(TypeError):
sparse.einsum(subscript, x, y)


format_test_cases = [
(("coo",), "coo"),
(("dok",), "dok"),
Expand Down