diff --git a/opt_einsum/parser.py b/opt_einsum/parser.py index 1f3daf8..e390ec5 100644 --- a/opt_einsum/parser.py +++ b/opt_einsum/parser.py @@ -379,8 +379,10 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L else: input_subscripts, output_subscript = subscripts, find_output_str(subscripts) - # Make sure output subscripts are in the input + # Make sure output subscripts are unique and in the input for char in output_subscript: + if output_subscript.count(char) != 1: + raise ValueError("Output character '{}' appeared more than once in the output.".format(char)) if char not in input_subscripts: raise ValueError("Output character '{}' did not appear in the input".format(char)) diff --git a/opt_einsum/tests/test_input.py b/opt_einsum/tests/test_input.py index 37a6db9..17dc0ac 100644 --- a/opt_einsum/tests/test_input.py +++ b/opt_einsum/tests/test_input.py @@ -81,92 +81,92 @@ def test_type_errors(): contract(views[0], [Ellipsis, dict()], [Ellipsis, "a"]) -def test_value_errors(): +@pytest.mark.parametrize("contract_fn", [contract, contract_path]) +def test_value_errors(contract_fn): with pytest.raises(ValueError): - contract("") + contract_fn("") # subscripts must be a string with pytest.raises(TypeError): - contract(0, 0) + contract_fn(0, 0) # invalid subscript character with pytest.raises(ValueError): - contract("i%...", [0, 0]) + contract_fn("i%...", [0, 0]) with pytest.raises(ValueError): - contract("...j$", [0, 0]) + contract_fn("...j$", [0, 0]) with pytest.raises(ValueError): - contract("i->&", [0, 0]) + contract_fn("i->&", [0, 0]) with pytest.raises(ValueError): - contract("") + contract_fn("") # number of operands must match count in subscripts string with pytest.raises(ValueError): - contract("", 0, 0) + contract_fn("", 0, 0) with pytest.raises(ValueError): - contract(",", 0, [0], [0]) + contract_fn(",", 0, [0], [0]) with pytest.raises(ValueError): - contract(",", [0]) + contract_fn(",", [0]) # can't have more subscripts than dimensions in the operand with pytest.raises(ValueError): - contract("i", 0) + contract_fn("i", 0) with pytest.raises(ValueError): - contract("ij", [0, 0]) + contract_fn("ij", [0, 0]) with pytest.raises(ValueError): - contract("...i", 0) + contract_fn("...i", 0) with pytest.raises(ValueError): - contract("i...j", [0, 0]) + contract_fn("i...j", [0, 0]) with pytest.raises(ValueError): - contract("i...", 0) + contract_fn("i...", 0) with pytest.raises(ValueError): - contract("ij...", [0, 0]) + contract_fn("ij...", [0, 0]) # invalid ellipsis with pytest.raises(ValueError): - contract("i..", [0, 0]) + contract_fn("i..", [0, 0]) with pytest.raises(ValueError): - contract(".i...", [0, 0]) + contract_fn(".i...", [0, 0]) with pytest.raises(ValueError): - contract("j->..j", [0, 0]) + contract_fn("j->..j", [0, 0]) with pytest.raises(ValueError): - contract("j->.j...", [0, 0]) + contract_fn("j->.j...", [0, 0]) # invalid subscript character with pytest.raises(ValueError): - contract("i%...", [0, 0]) + contract_fn("i%...", [0, 0]) with pytest.raises(ValueError): - contract("...j$", [0, 0]) + contract_fn("...j$", [0, 0]) with pytest.raises(ValueError): - contract("i->&", [0, 0]) + contract_fn("i->&", [0, 0]) # output subscripts must appear in input with pytest.raises(ValueError): - contract("i->ij", [0, 0]) + contract_fn("i->ij", [0, 0]) # output subscripts may only be specified once with pytest.raises(ValueError): - contract("ij->jij", [[0, 0], [0, 0]]) + contract_fn("ij->jij", [[0, 0], [0, 0]]) # dimensions much match when being collapsed with pytest.raises(ValueError): - contract("ii", np.arange(6).reshape(2, 3)) + contract_fn("ii", np.arange(6).reshape(2, 3)) with pytest.raises(ValueError): - contract("ii->i", np.arange(6).reshape(2, 3)) + contract_fn("ii->i", np.arange(6).reshape(2, 3)) # broadcasting to new dimensions must be enabled explicitly with pytest.raises(ValueError): - contract("i", np.arange(6).reshape(2, 3)) - with pytest.raises(ValueError): - contract("i->i", [[0, 1], [0, 1]], out=np.arange(4).reshape(2, 2)) - - -def test_contract_inputs(): + contract_fn("i", np.arange(6).reshape(2, 3)) + if contract_fn is contract: + # contract_path does not have an `out` parameter + with pytest.raises(ValueError): + contract_fn("i->i", [[0, 1], [0, 1]], out=np.arange(4).reshape(2, 2)) with pytest.raises(TypeError): - contract_path("i->i", [[0, 1], [0, 1]], bad_kwarg=True) + contract_fn("i->i", [[0, 1], [0, 1]], bad_kwarg=True) with pytest.raises(ValueError): - contract_path("i->i", [[0, 1], [0, 1]], memory_limit=-1) + contract_fn("i->i", [[0, 1], [0, 1]], memory_limit=-1) @pytest.mark.parametrize(