Skip to content
2 changes: 2 additions & 0 deletions opt_einsum_fx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from ._script import jitable
from ._opt_ein import optimize_einsums, optimize_einsums_full
from ._fuse import fuse_einsums, fuse_scalars
from ._efficient_shape_prop import EfficientShapeProp

__all__ = [
"jitable",
"optimize_einsums",
"optimize_einsums_full",
"fuse_einsums",
"fuse_scalars",
"EfficientShapeProp",
]
107 changes: 107 additions & 0 deletions opt_einsum_fx/_efficient_shape_prop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Any, NamedTuple

import opt_einsum
import torch
from torch.fx.node import Node

from ._fuse import _EINSUM_FUNCS


class SimpleMeta(NamedTuple):
"""
The full ShapeProp defines and uses a NamedTuple to
store a whole bunch of metadata about the tensors
going into and out of the Node op. But we don't
have most of that info, and anyway, I don't think
most of it's used in opt_einsum or opt_einsum_fx.
(These are only concerned with computing a summation
order.)

Rather than give dummy or default values, which I
only *assume* would be fine, I'm defining a NamedTuple
with only the values we actually know. So if I'm wrong
we will get a very clear error message, rather than
some invisible error.
"""

shape: torch.Size
dtype: torch.dtype


class EfficientShapeProp(torch.fx.Interpreter):
"""
Like ShapeProp, traverses a graph Node-by-Node
and records the shape and type of the result
into each Node.

Except we treat 'einsum' as a special case.
We don't actually execute 'einsum' on tensors,
since the einsums will typically not be optimized
yet (ShapeProp is called before optimization),
and inefficient summation order can create
enormous intermediate tensors, which often creates
needless out-of-memory errors.

So we override 'run_node' only for 'einsums'.
It's straightforward to determine the shape of the
result just from the output indices.

(The call to opt_einsum that will typically follow
this, also doesn't actually build the tensors
during its exploration.)
"""

def run_node(self, n: Node) -> Any:
if n.op == "call_function" and n.target in _EINSUM_FUNCS:
args, kwargs = self.fetch_args_kwargs_from_env(n)
equation, *operands = args
shapes = [op.shape for op in operands]

assert len({op.dtype for op in operands}) == 1
meta = SimpleMeta(einsum_shape(equation, *shapes), operands[0].dtype)
result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=operands[0].device).expand(meta.shape)
elif n.op == "call_function" and n.target == torch.tensordot:
args, kwargs = self.fetch_args_kwargs_from_env(n)
shape_a = [dim for i, dim in enumerate(args[0].shape) if i not in kwargs['dims'][0]]
shape_b = [dim for i, dim in enumerate(args[1].shape) if i not in kwargs['dims'][1]]

assert len({op.dtype for op in args}) == 1
meta = SimpleMeta(shape_a + shape_b, args[0].dtype)
result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=args[0].device).expand(meta.shape)
else:
result = super().run_node(n)

if isinstance(result, torch.Tensor):
meta = SimpleMeta(result.shape, result.dtype)
else:
meta = None

n.meta = dict()
n.meta['tensor_meta'] = meta
n.meta['type'] = type(result)

return result

def propagate(self, *args):
return super().run(*args)


def einsum_shape(subscripts, *shapes):
"""
Given an einsum equation and input shapes, returns the output
shape of the einsum.

Args:
subscripts: the einsum formula
shapes: the input shapes
"""
Shaped = NamedTuple('Shaped', [('shape', tuple)])
input_subscripts, output_subscript, _ = opt_einsum.parser.parse_einsum_input(
(subscripts,) + tuple(Shaped(shape) for shape in shapes)
)
dims = {
i: dim
for ii, shape in zip(input_subscripts.split(','), shapes)
for i, dim in zip(ii, shape)
}
return tuple(dims[i] for i in output_subscript)
2 changes: 1 addition & 1 deletion opt_einsum_fx/_opt_ein.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from torch import fx
from torch.fx.passes.shape_prop import ShapeProp
from ._efficient_shape_prop import EfficientShapeProp as ShapeProp

import opt_einsum
from opt_einsum.contract import _core_contract
Expand Down
30 changes: 8 additions & 22 deletions opt_einsum_fx/fx_utils.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,14 @@
from typing import Optional
from packaging import version

import torch
from torch import fx

_TORCH_IS_GE_19: bool = version.parse(torch.__version__) >= version.parse("1.9.0")

# The torch FX APIs are not stable, so we need helper wrappers

if _TORCH_IS_GE_19:

def get_shape(n: fx.Node) -> Optional[torch.Size]:
"""Get the shape of a node after ``ShapeProp``"""
try:
return n.meta["tensor_meta"].shape
except KeyError:
return None


else:

def get_shape(n: fx.Node) -> Optional[torch.Size]:
"""Get the shape of a node after ``ShapeProp``"""
try:
return n.shape
except AttributeError:
return None
def get_shape(n: fx.Node) -> Optional[torch.Size]:
"""Get the shape of a node after ``ShapeProp``"""
try:
return n.meta["tensor_meta"].shape
except KeyError:
return None
except AttributeError:
return None
5 changes: 2 additions & 3 deletions tests/test_einsum_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import torch
import torch.fx
from torch.fx.passes.shape_prop import ShapeProp

from opt_einsum_fx import optimize_einsums, optimize_einsums_full, jitable
from opt_einsum_fx import optimize_einsums, optimize_einsums_full, jitable, EfficientShapeProp


def einmatmul(x, y):
Expand Down Expand Up @@ -74,7 +73,7 @@ def test_optimize_einsums(einfunc, allclose):
func_res = einfunc(x, y)

func_fx = torch.fx.symbolic_trace(einfunc)
sp = ShapeProp(func_fx)
sp = EfficientShapeProp(func_fx)
sp.run(x, y)

func_fx_res = func_fx(x, y)
Expand Down