From 0dd900bf208d19eaafcc50c9e2393e84e2560de7 Mon Sep 17 00:00:00 2001 From: Hadrien Date: Fri, 24 Feb 2023 14:48:24 +0100 Subject: [PATCH 1/5] change parsing of the input in einsum --- sparse/_common.py | 174 +++++++++++++++++++++++++++++++++++- sparse/tests/test_einsum.py | 4 + 2 files changed, 176 insertions(+), 2 deletions(-) diff --git a/sparse/_common.py b/sparse/_common.py index 64343560..6b7675d8 100644 --- a/sparse/_common.py +++ b/sparse/_common.py @@ -3,7 +3,7 @@ import scipy.sparse from functools import wraps, reduce from itertools import chain -from operator import mul +from operator import mul, index from collections.abc import Iterable from scipy.sparse import spmatrix from numba import literal_unroll @@ -1212,6 +1212,173 @@ def _dot_ndarray_coo(array1, coords2, data2, out_shape): # pragma: no cover return _dot_ndarray_coo +def _parse_einsum_input(operands): + """ + A copy of the numpy parse_einsum_input that + does not cast the operands to numpy array. + Returns + ------- + input_strings : str + Parsed input strings + output_string : str + Parsed output string + operands : list of array_like + The operands to use in the numpy contraction + Examples + -------- + The operand list is simplified to reduce printing: + >>> np.random.seed(123) + >>> a = np.random.rand(4, 4) + >>> b = np.random.rand(4, 4, 4) + >>> _parse_einsum_input(('...a,...a->...', a, b)) + ('za,xza', 'xz', [a, b]) # may vary + >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) + ('za,xza', 'xz', [a, b]) # may vary + """ + + if len(operands) == 0: + raise ValueError("No input operands") + + if isinstance(operands[0], str): + subscripts = operands[0].replace(" ", "") + operands = [v for v in operands[1:]] + + # Ensure all characters are valid + for s in subscripts: + if s in ".,->": + continue + if s not in np.core.einsumfunc.einsum_symbols: + raise ValueError("Character %s is not a valid symbol." % s) + + else: + tmp_operands = list(operands) + operand_list = [] + subscript_list = [] + for p in range(len(operands) // 2): + operand_list.append(tmp_operands.pop(0)) + subscript_list.append(tmp_operands.pop(0)) + + output_list = tmp_operands[-1] if len(tmp_operands) else None + operands = [v for v in operand_list] + subscripts = "" + last = len(subscript_list) - 1 + for num, sub in enumerate(subscript_list): + for s in sub: + if s is Ellipsis: + subscripts += "..." + else: + try: + s = index(s) + except TypeError as e: + raise TypeError("For this input type lists must contain " "either int or Ellipsis") from e + subscripts += np.core.einsumfunc.einsum_symbols[s] + if num != last: + subscripts += "," + + if output_list is not None: + subscripts += "->" + for s in output_list: + if s is Ellipsis: + subscripts += "..." + else: + try: + s = index(s) + except TypeError as e: + raise TypeError("For this input type lists must contain " "either int or Ellipsis") from e + subscripts += np.core.einsumfunc.einsum_symbols[s] + # Check for proper "->" + if ("-" in subscripts) or (">" in subscripts): + invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) + if invalid or (subscripts.count("->") != 1): + raise ValueError("Subscripts can only contain one '->'.") + + # Parse ellipses + if "." in subscripts: + used = subscripts.replace(".", "").replace(",", "").replace("->", "") + unused = list(np.core.einsumfunc.einsum_symbols_set - set(used)) + ellipse_inds = "".join(unused) + longest = 0 + + if "->" in subscripts: + input_tmp, output_sub = subscripts.split("->") + split_subscripts = input_tmp.split(",") + out_sub = True + else: + split_subscripts = subscripts.split(",") + out_sub = False + + for num, sub in enumerate(split_subscripts): + if "." in sub: + if (sub.count(".") != 3) or (sub.count("...") != 1): + raise ValueError("Invalid Ellipses.") + + # Take into account numerical values + if operands[num].shape == (): + ellipse_count = 0 + else: + ellipse_count = max(operands[num].ndim, 1) + ellipse_count -= len(sub) - 3 + + if ellipse_count > longest: + longest = ellipse_count + + if ellipse_count < 0: + raise ValueError("Ellipses lengths do not match.") + elif ellipse_count == 0: + split_subscripts[num] = sub.replace("...", "") + else: + rep_inds = ellipse_inds[-ellipse_count:] + split_subscripts[num] = sub.replace("...", rep_inds) + + subscripts = ",".join(split_subscripts) + if longest == 0: + out_ellipse = "" + else: + out_ellipse = ellipse_inds[-longest:] + + if out_sub: + subscripts += "->" + output_sub.replace("...", out_ellipse) + else: + # Special care for outputless ellipses + output_subscript = "" + tmp_subscripts = subscripts.replace(",", "") + for s in sorted(set(tmp_subscripts)): + if s not in (np.core.einsumfunc.einsum_symbols): + raise ValueError("Character %s is not a valid symbol." % s) + if tmp_subscripts.count(s) == 1: + output_subscript += s + normal_inds = "".join(sorted(set(output_subscript) - set(out_ellipse))) + + subscripts += "->" + out_ellipse + normal_inds + + # Build output string if does not exist + if "->" in subscripts: + input_subscripts, output_subscript = subscripts.split("->") + else: + input_subscripts = subscripts + # Build output subscripts + tmp_subscripts = subscripts.replace(",", "") + output_subscript = "" + for s in sorted(set(tmp_subscripts)): + if s not in np.core.einsumfunc.einsum_symbols: + raise ValueError("Character %s is not a valid symbol." % s) + if tmp_subscripts.count(s) == 1: + output_subscript += s + + # Make sure output subscripts are in the input + for char in output_subscript: + if char not in input_subscripts: + raise ValueError("Output character %s did not appear in the input" % char) + + # Make sure number operands is equivalent to the number of terms + if len(input_subscripts.split(",")) != len(operands): + raise ValueError("Number of einsum subscripts must be equal to the " "number of operands.") + + return (input_subscripts, output_subscript, operands) + + + + def _einsum_single(lhs, rhs, operand): """Perform a single term einsum, i.e. any combination of transposes, sums and traces of dimensions. @@ -1287,7 +1454,7 @@ def _einsum_single(lhs, rhs, operand): ) -def einsum(subscripts, *operands): +def einsum(*operands): """ Perform the equivalent of :obj:`numpy.einsum`. @@ -1306,6 +1473,9 @@ def einsum(subscripts, *operands): output : SparseArray The calculation based on the Einstein summation convention. """ + + lhs, rhs, operands = _parse_einsum_input(operands) # Parse input + check_zero_fill_value(*operands) if "->" not in subscripts: diff --git a/sparse/tests/test_einsum.py b/sparse/tests/test_einsum.py index f593ebcf..de07dc1a 100644 --- a/sparse/tests/test_einsum.py +++ b/sparse/tests/test_einsum.py @@ -74,6 +74,10 @@ "dba,ead,cad->bce", "aef,fbc,dca->bde", "abab->ba", + "...ab,...ab", + "...ab,...b->...a", + "a...,a...", + "a...,a...", ] From 00ec6c0e990195bb358769ba405bd5d8725543dc Mon Sep 17 00:00:00 2001 From: Hadrien Date: Fri, 24 Feb 2023 18:26:18 +0100 Subject: [PATCH 2/5] Link to the airport licence with source of the function --- sparse/_common.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/sparse/_common.py b/sparse/_common.py index 6b7675d8..fe587f9e 100644 --- a/sparse/_common.py +++ b/sparse/_common.py @@ -1216,6 +1216,10 @@ def _parse_einsum_input(operands): """ A copy of the numpy parse_einsum_input that does not cast the operands to numpy array. + + Copied from : https://github.com/numpy/numpy/blob/main/numpy/core/einsumfunc.py + under BSD-3-Clause license : https://github.com/numpy/numpy/blob/main/LICENSE.txt + Returns ------- input_strings : str @@ -1478,20 +1482,6 @@ def einsum(*operands): check_zero_fill_value(*operands) - if "->" not in subscripts: - # from opt_einsum: calc the output automatically - lhs = subscripts - tmp_subscripts = lhs.replace(",", "") - rhs = "".join( - # sorted sequence of indices - s - for s in sorted(set(tmp_subscripts)) - # that appear exactly once - if tmp_subscripts.count(s) == 1 - ) - else: - lhs, rhs = subscripts.split("->") - if len(operands) == 1: return _einsum_single(lhs, rhs, operands[0]) From b097cf1d80cdcd5202ba9b4d0f759c52b5db4b69 Mon Sep 17 00:00:00 2001 From: Hadrien Date: Fri, 24 Feb 2023 18:49:57 +0100 Subject: [PATCH 3/5] Move comment on reference to numpy and link the exact commit --- sparse/_common.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sparse/_common.py b/sparse/_common.py index fe587f9e..6de5457e 100644 --- a/sparse/_common.py +++ b/sparse/_common.py @@ -1212,14 +1212,13 @@ def _dot_ndarray_coo(array1, coords2, data2, out_shape): # pragma: no cover return _dot_ndarray_coo +# Copied from : https://github.com/numpy/numpy/blob/59fec4619403762a5d785ad83fcbde5a230416fc/numpy/core/einsumfunc.py#L523 +# under BSD-3-Clause license : https://github.com/numpy/numpy/blob/v1.24.0/LICENSE.txt def _parse_einsum_input(operands): """ A copy of the numpy parse_einsum_input that does not cast the operands to numpy array. - Copied from : https://github.com/numpy/numpy/blob/main/numpy/core/einsumfunc.py - under BSD-3-Clause license : https://github.com/numpy/numpy/blob/main/LICENSE.txt - Returns ------- input_strings : str From 4e1adb8e0655f7ceeecadd2ba1a37cbd0c5c7f9a Mon Sep 17 00:00:00 2001 From: Hadrien Date: Mon, 27 Feb 2023 13:58:39 +0100 Subject: [PATCH 4/5] more test and black formatting --- sparse/_common.py | 26 ++++++++++++-------- sparse/tests/test_einsum.py | 48 +++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 10 deletions(-) diff --git a/sparse/_common.py b/sparse/_common.py index 6de5457e..96f43c72 100644 --- a/sparse/_common.py +++ b/sparse/_common.py @@ -1212,13 +1212,13 @@ def _dot_ndarray_coo(array1, coords2, data2, out_shape): # pragma: no cover return _dot_ndarray_coo -# Copied from : https://github.com/numpy/numpy/blob/59fec4619403762a5d785ad83fcbde5a230416fc/numpy/core/einsumfunc.py#L523 +# Copied from : https://github.com/numpy/numpy/blob/59fec4619403762a5d785ad83fcbde5a230416fc/numpy/core/einsumfunc.py#L523 # under BSD-3-Clause license : https://github.com/numpy/numpy/blob/v1.24.0/LICENSE.txt def _parse_einsum_input(operands): """ - A copy of the numpy parse_einsum_input that + A copy of the numpy parse_einsum_input that does not cast the operands to numpy array. - + Returns ------- input_strings : str @@ -1273,7 +1273,10 @@ def _parse_einsum_input(operands): try: s = index(s) except TypeError as e: - raise TypeError("For this input type lists must contain " "either int or Ellipsis") from e + raise TypeError( + "For this input type lists must contain " + "either int or Ellipsis" + ) from e subscripts += np.core.einsumfunc.einsum_symbols[s] if num != last: subscripts += "," @@ -1287,7 +1290,10 @@ def _parse_einsum_input(operands): try: s = index(s) except TypeError as e: - raise TypeError("For this input type lists must contain " "either int or Ellipsis") from e + raise TypeError( + "For this input type lists must contain " + "either int or Ellipsis" + ) from e subscripts += np.core.einsumfunc.einsum_symbols[s] # Check for proper "->" if ("-" in subscripts) or (">" in subscripts): @@ -1375,13 +1381,13 @@ def _parse_einsum_input(operands): # Make sure number operands is equivalent to the number of terms if len(input_subscripts.split(",")) != len(operands): - raise ValueError("Number of einsum subscripts must be equal to the " "number of operands.") + raise ValueError( + "Number of einsum subscripts must be equal to the " "number of operands." + ) return (input_subscripts, output_subscript, operands) - - def _einsum_single(lhs, rhs, operand): """Perform a single term einsum, i.e. any combination of transposes, sums and traces of dimensions. @@ -1476,9 +1482,9 @@ def einsum(*operands): output : SparseArray The calculation based on the Einstein summation convention. """ - + lhs, rhs, operands = _parse_einsum_input(operands) # Parse input - + check_zero_fill_value(*operands) if len(operands) == 1: diff --git a/sparse/tests/test_einsum.py b/sparse/tests/test_einsum.py index de07dc1a..22517d47 100644 --- a/sparse/tests/test_einsum.py +++ b/sparse/tests/test_einsum.py @@ -98,12 +98,60 @@ def test_einsum(subscripts, density): assert np.allclose(numpy_out, sparse_out.todense()) +@pytest.mark.parametrize( + "input", [[[0, 0]], [[0, Ellipsis]], [[Ellipsis, 1], [Ellipsis]], [[0, 1], [0]]] +) +@pytest.mark.parametrize("density", [0.1, 1.0]) +def test_einsum_nosubscript(input, density): + d = 4 + arrays = [sparse.random((d, d), density=density)] + sparse_out = sparse.einsum(*arrays, *input) + numpy_out = np.einsum(*(s.todense() for s in arrays), *input) + + if not numpy_out.shape: + # scalar output + assert np.allclose(numpy_out, sparse_out) + else: + # array output + assert np.allclose(numpy_out, sparse_out.todense()) + + def test_einsum_input_fill_value(): x = sparse.random(shape=(2,), density=0.5, format="coo", fill_value=2) with pytest.raises(ValueError): sparse.einsum("cba", x) +def test_einsum_no_input(): + with pytest.raises(ValueError): + sparse.einsum() + + +@pytest.mark.parametrize( + "subscript", ["a+b->c", "i->&", "i->ij", "ij->jij", "a..,a...", ".i...", "a,a->->"] +) +def test_einsum_invalid_input(subscript): + x = sparse.random(shape=(2,), density=0.5, format="coo") + y = sparse.random(shape=(2,), density=0.5, format="coo") + with pytest.raises(ValueError): + sparse.einsum(subscript, x, y) + + +@pytest.mark.parametrize("subscript", [0, [0, 0]]) +def test_einsum_type_error(subscript): + x = sparse.random(shape=(2,), density=0.5, format="coo") + y = sparse.random(shape=(2,), density=0.5, format="coo") + with pytest.raises(TypeError): + sparse.einsum(subscript, x, y) + + +# +# def test_einsum_type_error_nosubscript(): +# x = sparse.random(shape=(2, 2), density=0.5, format="coo") +# with pytest.raises(ValueError): +# sparse.einsum(x, [[0, 1.0], [0]]) + + format_test_cases = [ (("coo",), "coo"), (("dok",), "dok"), From 07d4932f053cd3e42b4afa4a86595d021a41a9ae Mon Sep 17 00:00:00 2001 From: Hadrien Date: Mon, 27 Feb 2023 19:05:14 +0100 Subject: [PATCH 5/5] remove tentative code for a test --- sparse/tests/test_einsum.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sparse/tests/test_einsum.py b/sparse/tests/test_einsum.py index 22517d47..b1f16ad2 100644 --- a/sparse/tests/test_einsum.py +++ b/sparse/tests/test_einsum.py @@ -145,13 +145,6 @@ def test_einsum_type_error(subscript): sparse.einsum(subscript, x, y) -# -# def test_einsum_type_error_nosubscript(): -# x = sparse.random(shape=(2, 2), density=0.5, format="coo") -# with pytest.raises(ValueError): -# sparse.einsum(x, [[0, 1.0], [0]]) - - format_test_cases = [ (("coo",), "coo"), (("dok",), "dok"),