Skip to content

Commit

Permalink
handle empty contraction list in PathInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed May 6, 2024
1 parent 2824c9e commit f2bb3b7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
6 changes: 3 additions & 3 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def __init__(
self.scale_list = scale_list
self.naive_cost = Decimal(naive_cost)
self.opt_cost = Decimal(opt_cost)
self.speedup = self.naive_cost / self.opt_cost
self.speedup = self.naive_cost / max(self.opt_cost, 1)
self.size_list = size_list
self.size_dict = size_dict

self.shapes = [tuple(size_dict[k] for k in ks) for ks in input_subscripts.split(",")]
self.eq = "{}->{}".format(input_subscripts, output_subscript)
self.largest_intermediate = Decimal(max(size_list))
self.largest_intermediate = Decimal(max(size_list, default=1))

def __repr__(self) -> str:
# Return the path along with a nice string representation
Expand All @@ -65,7 +65,7 @@ def __repr__(self) -> str:
path_print = [
" Complete contraction: {}\n".format(self.eq),
" Naive scaling: {}\n".format(len(self.indices)),
" Optimized scaling: {}\n".format(max(self.scale_list)),
" Optimized scaling: {}\n".format(max(self.scale_list, default=0)),
" Naive FLOP count: {:.3e}\n".format(self.naive_cost),
" Optimized FLOP count: {:.3e}\n".format(self.opt_cost),
" Theoretical speedup: {:.3e}\n".format(self.speedup),
Expand Down
11 changes: 10 additions & 1 deletion opt_einsum/tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest

from opt_einsum import contract, contract_expression
from opt_einsum import contract, contract_path, contract_expression


def test_contract_expression_checks():
Expand Down Expand Up @@ -123,3 +123,12 @@ def test_can_blas_on_healed_broadcast_dimensions():
# but then is healed GEMM is usable
assert expr.contraction_list[1][2] == "bca,bd->acd"
assert expr.contraction_list[1][-1] == "GEMM"


def test_pathinfo_for_empty_contraction():
eq = "->"
arrays = (1.0,)
path = []
_, info = contract_path(eq, *arrays, optimize=path)
print(info)
assert info.largest_intermediate == 1

0 comments on commit f2bb3b7

Please sign in to comment.