From 727206984e1b485262fe7d96e118258ca936c7fe Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 15 Dec 2023 17:10:19 +0000 Subject: [PATCH] compiler: Improve estimate_cost --- devito/symbolics/inspection.py | 63 ++++++++++++++++++++++------------ tests/test_dse.py | 2 ++ 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 45b5dce754..99e752abce 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -91,6 +91,7 @@ def estimate_cost(exprs, estimate=False): # We don't use SymPy's count_ops because we do not count integer arithmetic # (e.g., array index functions such as i+1 in A[i+1]) # Also, the routine below is *much* faster than count_ops + seen = {} flops = 0 for expr in as_tuple(exprs): # TODO: this if-then should be part of singledispatch too, but because @@ -103,7 +104,7 @@ def estimate_cost(exprs, estimate=False): else: e = expr - flops += _estimate_cost(e, estimate)[0] + flops += _estimate_cost(e, estimate, seen)[0] return flops except: @@ -121,11 +122,27 @@ def estimate_cost(exprs, estimate=False): } +def dont_count_if_seen(func): + """ + This decorator is used to avoid counting the same expression multiple + times. This is necessary because the same expression may appear multiple + times in the same expression tree or even across different expressions. + """ + def wrapper(expr, estimate, seen): + try: + _, flags = seen[expr] + flops = 0 + except KeyError: + flops, flags = seen[expr] = func(expr, estimate, seen) + return flops, flags + return wrapper + + @singledispatch -def _estimate_cost(expr, estimate): +def _estimate_cost(expr, estimate, seen): # Retval: flops (int), flag (bool) # The flag tells wether it's an integer expression (implying flops==0) or not - flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args]) + flops, flags = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args]) flops = sum(flops) if all(flags): # `expr` is an operation involving integer operands only @@ -138,28 +155,28 @@ def _estimate_cost(expr, estimate): @_estimate_cost.register(Tuple) @_estimate_cost.register(CallFromPointer) -def _(expr, estimate): +def _(expr, estimate, seen): try: - flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args]) + flops, flags = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args]) except ValueError: flops, flags = [], [] return sum(flops), all(flags) @_estimate_cost.register(Integer) -def _(expr, estimate): +def _(expr, estimate, seen): return 0, True @_estimate_cost.register(Number) @_estimate_cost.register(ReservedWord) -def _(expr, estimate): +def _(expr, estimate, seen): return 0, False @_estimate_cost.register(Symbol) @_estimate_cost.register(Indexed) -def _(expr, estimate): +def _(expr, estimate, seen): try: if issubclass(expr.dtype, np.integer): return 0, True @@ -169,27 +186,27 @@ def _(expr, estimate): @_estimate_cost.register(Mul) -def _(expr, estimate): - flops, flags = _estimate_cost.registry[object](expr, estimate) +def _(expr, estimate, seen): + flops, flags = _estimate_cost.registry[object](expr, estimate, seen) if {S.One, S.NegativeOne}.intersection(expr.args): flops -= 1 return flops, flags @_estimate_cost.register(INT) -def _(expr, estimate): - return _estimate_cost(expr.base, estimate)[0], True +def _(expr, estimate, seen): + return _estimate_cost(expr.base, estimate, seen)[0], True @_estimate_cost.register(Cast) -def _(expr, estimate): - return _estimate_cost(expr.base, estimate)[0], False +def _(expr, estimate, seen): + return _estimate_cost(expr.base, estimate, seen)[0], False @_estimate_cost.register(Function) -def _(expr, estimate): +def _(expr, estimate, seen): if q_routine(expr): - flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args]) + flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args]) flops = sum(flops) if isinstance(expr, DefFunction): # Bypass user-defined or language-specific functions @@ -207,8 +224,8 @@ def _(expr, estimate): @_estimate_cost.register(Pow) -def _(expr, estimate): - flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args]) +def _(expr, estimate, seen): + flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args]) flops = sum(flops) if estimate: if expr.exp.is_Number: @@ -229,13 +246,15 @@ def _(expr, estimate): @_estimate_cost.register(Derivative) -def _(expr, estimate): - return _estimate_cost(expr._evaluate(expand=False), estimate) +@dont_count_if_seen +def _(expr, estimate, seen): + return _estimate_cost(expr._evaluate(expand=False), estimate, seen) @_estimate_cost.register(IndexDerivative) -def _(expr, estimate): - flops, _ = _estimate_cost(expr.expr, estimate) +@dont_count_if_seen +def _(expr, estimate, seen): + flops, _ = _estimate_cost(expr.expr, estimate, seen) # It's an increment flops += 1 diff --git a/tests/test_dse.py b/tests/test_dse.py index 5fc88b5d94..8f4d756a8b 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -270,6 +270,8 @@ def test_factorize(expr, expected): ('Eq(fb, fd.dx)', 10, True), ('Eq(fb, fd.dx._evaluate(expand=False))', 10, False), ('Eq(fb, fd.dx.dy + fa.dx)', 66, False), + # Ensure redundancies aren't counted + ('Eq(t0, fd.dx.dy + fa*fd.dx.dy)', 62, True), ]) def test_estimate_cost(expr, expected, estimate): # Note: integer arithmetic isn't counted