diff --git a/opt_einsum_fx/__init__.py b/opt_einsum_fx/__init__.py index 4057049..c50a6cc 100644 --- a/opt_einsum_fx/__init__.py +++ b/opt_einsum_fx/__init__.py @@ -3,6 +3,7 @@ from ._script import jitable from ._opt_ein import optimize_einsums, optimize_einsums_full from ._fuse import fuse_einsums, fuse_scalars +from ._efficient_shape_prop import EfficientShapeProp __all__ = [ "jitable", @@ -10,4 +11,5 @@ "optimize_einsums_full", "fuse_einsums", "fuse_scalars", + "EfficientShapeProp", ] diff --git a/opt_einsum_fx/_efficient_shape_prop.py b/opt_einsum_fx/_efficient_shape_prop.py new file mode 100644 index 0000000..c379af5 --- /dev/null +++ b/opt_einsum_fx/_efficient_shape_prop.py @@ -0,0 +1,107 @@ +from typing import Any, NamedTuple + +import opt_einsum +import torch +from torch.fx.node import Node + +from ._fuse import _EINSUM_FUNCS + + +class SimpleMeta(NamedTuple): + """ + The full ShapeProp defines and uses a NamedTuple to + store a whole bunch of metadata about the tensors + going into and out of the Node op. But we don't + have most of that info, and anyway, I don't think + most of it's used in opt_einsum or opt_einsum_fx. + (These are only concerned with computing a summation + order.) + + Rather than give dummy or default values, which I + only *assume* would be fine, I'm defining a NamedTuple + with only the values we actually know. So if I'm wrong + we will get a very clear error message, rather than + some invisible error. + """ + + shape: torch.Size + dtype: torch.dtype + + +class EfficientShapeProp(torch.fx.Interpreter): + """ + Like ShapeProp, traverses a graph Node-by-Node + and records the shape and type of the result + into each Node. + + Except we treat 'einsum' as a special case. + We don't actually execute 'einsum' on tensors, + since the einsums will typically not be optimized + yet (ShapeProp is called before optimization), + and inefficient summation order can create + enormous intermediate tensors, which often creates + needless out-of-memory errors. + + So we override 'run_node' only for 'einsums'. + It's straightforward to determine the shape of the + result just from the output indices. + + (The call to opt_einsum that will typically follow + this, also doesn't actually build the tensors + during its exploration.) + """ + + def run_node(self, n: Node) -> Any: + if n.op == "call_function" and n.target in _EINSUM_FUNCS: + args, kwargs = self.fetch_args_kwargs_from_env(n) + equation, *operands = args + shapes = [op.shape for op in operands] + + assert len({op.dtype for op in operands}) == 1 + meta = SimpleMeta(einsum_shape(equation, *shapes), operands[0].dtype) + result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=operands[0].device).expand(meta.shape) + elif n.op == "call_function" and n.target == torch.tensordot: + args, kwargs = self.fetch_args_kwargs_from_env(n) + shape_a = [dim for i, dim in enumerate(args[0].shape) if i not in kwargs['dims'][0]] + shape_b = [dim for i, dim in enumerate(args[1].shape) if i not in kwargs['dims'][1]] + + assert len({op.dtype for op in args}) == 1 + meta = SimpleMeta(shape_a + shape_b, args[0].dtype) + result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=args[0].device).expand(meta.shape) + else: + result = super().run_node(n) + + if isinstance(result, torch.Tensor): + meta = SimpleMeta(result.shape, result.dtype) + else: + meta = None + + n.meta = dict() + n.meta['tensor_meta'] = meta + n.meta['type'] = type(result) + + return result + + def propagate(self, *args): + return super().run(*args) + + +def einsum_shape(subscripts, *shapes): + """ + Given an einsum equation and input shapes, returns the output + shape of the einsum. + + Args: + subscripts: the einsum formula + shapes: the input shapes + """ + Shaped = NamedTuple('Shaped', [('shape', tuple)]) + input_subscripts, output_subscript, _ = opt_einsum.parser.parse_einsum_input( + (subscripts,) + tuple(Shaped(shape) for shape in shapes) + ) + dims = { + i: dim + for ii, shape in zip(input_subscripts.split(','), shapes) + for i, dim in zip(ii, shape) + } + return tuple(dims[i] for i in output_subscript) diff --git a/opt_einsum_fx/_opt_ein.py b/opt_einsum_fx/_opt_ein.py index 9c8a7a9..dc6c124 100644 --- a/opt_einsum_fx/_opt_ein.py +++ b/opt_einsum_fx/_opt_ein.py @@ -3,7 +3,7 @@ import torch from torch import fx -from torch.fx.passes.shape_prop import ShapeProp +from ._efficient_shape_prop import EfficientShapeProp as ShapeProp import opt_einsum from opt_einsum.contract import _core_contract diff --git a/opt_einsum_fx/fx_utils.py b/opt_einsum_fx/fx_utils.py index 021ea93..c1d4023 100644 --- a/opt_einsum_fx/fx_utils.py +++ b/opt_einsum_fx/fx_utils.py @@ -1,28 +1,14 @@ from typing import Optional -from packaging import version import torch from torch import fx -_TORCH_IS_GE_19: bool = version.parse(torch.__version__) >= version.parse("1.9.0") -# The torch FX APIs are not stable, so we need helper wrappers - -if _TORCH_IS_GE_19: - - def get_shape(n: fx.Node) -> Optional[torch.Size]: - """Get the shape of a node after ``ShapeProp``""" - try: - return n.meta["tensor_meta"].shape - except KeyError: - return None - - -else: - - def get_shape(n: fx.Node) -> Optional[torch.Size]: - """Get the shape of a node after ``ShapeProp``""" - try: - return n.shape - except AttributeError: - return None +def get_shape(n: fx.Node) -> Optional[torch.Size]: + """Get the shape of a node after ``ShapeProp``""" + try: + return n.meta["tensor_meta"].shape + except KeyError: + return None + except AttributeError: + return None diff --git a/tests/test_einsum_optimizer.py b/tests/test_einsum_optimizer.py index 1b5083d..593c9fa 100644 --- a/tests/test_einsum_optimizer.py +++ b/tests/test_einsum_optimizer.py @@ -2,9 +2,8 @@ import torch import torch.fx -from torch.fx.passes.shape_prop import ShapeProp -from opt_einsum_fx import optimize_einsums, optimize_einsums_full, jitable +from opt_einsum_fx import optimize_einsums, optimize_einsums_full, jitable, EfficientShapeProp def einmatmul(x, y): @@ -74,7 +73,7 @@ def test_optimize_einsums(einfunc, allclose): func_res = einfunc(x, y) func_fx = torch.fx.symbolic_trace(einfunc) - sp = ShapeProp(func_fx) + sp = EfficientShapeProp(func_fx) sp.run(x, y) func_fx_res = func_fx(x, y)