Skip to content

Commit

Permalink
Raise error in parse_einsum_input when output subscript is specified …
Browse files Browse the repository at this point in the history
…multiple times (#222)
  • Loading branch information
lgeiger authored May 5, 2024
1 parent 1a984b7 commit 2824c9e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 36 deletions.
4 changes: 3 additions & 1 deletion opt_einsum/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,10 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L
else:
input_subscripts, output_subscript = subscripts, find_output_str(subscripts)

# Make sure output subscripts are in the input
# Make sure output subscripts are unique and in the input
for char in output_subscript:
if output_subscript.count(char) != 1:
raise ValueError("Output character '{}' appeared more than once in the output.".format(char))
if char not in input_subscripts:
raise ValueError("Output character '{}' did not appear in the input".format(char))

Expand Down
70 changes: 35 additions & 35 deletions opt_einsum/tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,92 +81,92 @@ def test_type_errors():
contract(views[0], [Ellipsis, dict()], [Ellipsis, "a"])


def test_value_errors():
@pytest.mark.parametrize("contract_fn", [contract, contract_path])
def test_value_errors(contract_fn):
with pytest.raises(ValueError):
contract("")
contract_fn("")

# subscripts must be a string
with pytest.raises(TypeError):
contract(0, 0)
contract_fn(0, 0)

# invalid subscript character
with pytest.raises(ValueError):
contract("i%...", [0, 0])
contract_fn("i%...", [0, 0])
with pytest.raises(ValueError):
contract("...j$", [0, 0])
contract_fn("...j$", [0, 0])
with pytest.raises(ValueError):
contract("i->&", [0, 0])
contract_fn("i->&", [0, 0])

with pytest.raises(ValueError):
contract("")
contract_fn("")
# number of operands must match count in subscripts string
with pytest.raises(ValueError):
contract("", 0, 0)
contract_fn("", 0, 0)
with pytest.raises(ValueError):
contract(",", 0, [0], [0])
contract_fn(",", 0, [0], [0])
with pytest.raises(ValueError):
contract(",", [0])
contract_fn(",", [0])

# can't have more subscripts than dimensions in the operand
with pytest.raises(ValueError):
contract("i", 0)
contract_fn("i", 0)
with pytest.raises(ValueError):
contract("ij", [0, 0])
contract_fn("ij", [0, 0])
with pytest.raises(ValueError):
contract("...i", 0)
contract_fn("...i", 0)
with pytest.raises(ValueError):
contract("i...j", [0, 0])
contract_fn("i...j", [0, 0])
with pytest.raises(ValueError):
contract("i...", 0)
contract_fn("i...", 0)
with pytest.raises(ValueError):
contract("ij...", [0, 0])
contract_fn("ij...", [0, 0])

# invalid ellipsis
with pytest.raises(ValueError):
contract("i..", [0, 0])
contract_fn("i..", [0, 0])
with pytest.raises(ValueError):
contract(".i...", [0, 0])
contract_fn(".i...", [0, 0])
with pytest.raises(ValueError):
contract("j->..j", [0, 0])
contract_fn("j->..j", [0, 0])
with pytest.raises(ValueError):
contract("j->.j...", [0, 0])
contract_fn("j->.j...", [0, 0])

# invalid subscript character
with pytest.raises(ValueError):
contract("i%...", [0, 0])
contract_fn("i%...", [0, 0])
with pytest.raises(ValueError):
contract("...j$", [0, 0])
contract_fn("...j$", [0, 0])
with pytest.raises(ValueError):
contract("i->&", [0, 0])
contract_fn("i->&", [0, 0])

# output subscripts must appear in input
with pytest.raises(ValueError):
contract("i->ij", [0, 0])
contract_fn("i->ij", [0, 0])

# output subscripts may only be specified once
with pytest.raises(ValueError):
contract("ij->jij", [[0, 0], [0, 0]])
contract_fn("ij->jij", [[0, 0], [0, 0]])

# dimensions much match when being collapsed
with pytest.raises(ValueError):
contract("ii", np.arange(6).reshape(2, 3))
contract_fn("ii", np.arange(6).reshape(2, 3))
with pytest.raises(ValueError):
contract("ii->i", np.arange(6).reshape(2, 3))
contract_fn("ii->i", np.arange(6).reshape(2, 3))

# broadcasting to new dimensions must be enabled explicitly
with pytest.raises(ValueError):
contract("i", np.arange(6).reshape(2, 3))
with pytest.raises(ValueError):
contract("i->i", [[0, 1], [0, 1]], out=np.arange(4).reshape(2, 2))


def test_contract_inputs():
contract_fn("i", np.arange(6).reshape(2, 3))
if contract_fn is contract:
# contract_path does not have an `out` parameter
with pytest.raises(ValueError):
contract_fn("i->i", [[0, 1], [0, 1]], out=np.arange(4).reshape(2, 2))

with pytest.raises(TypeError):
contract_path("i->i", [[0, 1], [0, 1]], bad_kwarg=True)
contract_fn("i->i", [[0, 1], [0, 1]], bad_kwarg=True)

with pytest.raises(ValueError):
contract_path("i->i", [[0, 1], [0, 1]], memory_limit=-1)
contract_fn("i->i", [[0, 1], [0, 1]], memory_limit=-1)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 2824c9e

Please sign in to comment.