Skip to content

Commit

Permalink
compiler: Improve estimate_cost
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Dec 19, 2023
1 parent b7016f2 commit 7272069
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 22 deletions.
63 changes: 41 additions & 22 deletions devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def estimate_cost(exprs, estimate=False):
# We don't use SymPy's count_ops because we do not count integer arithmetic
# (e.g., array index functions such as i+1 in A[i+1])
# Also, the routine below is *much* faster than count_ops
seen = {}
flops = 0
for expr in as_tuple(exprs):
# TODO: this if-then should be part of singledispatch too, but because
Expand All @@ -103,7 +104,7 @@ def estimate_cost(exprs, estimate=False):
else:
e = expr

flops += _estimate_cost(e, estimate)[0]
flops += _estimate_cost(e, estimate, seen)[0]

return flops
except:
Expand All @@ -121,11 +122,27 @@ def estimate_cost(exprs, estimate=False):
}


def dont_count_if_seen(func):
"""
This decorator is used to avoid counting the same expression multiple
times. This is necessary because the same expression may appear multiple
times in the same expression tree or even across different expressions.
"""
def wrapper(expr, estimate, seen):
try:
_, flags = seen[expr]
flops = 0
except KeyError:
flops, flags = seen[expr] = func(expr, estimate, seen)
return flops, flags
return wrapper


@singledispatch
def _estimate_cost(expr, estimate):
def _estimate_cost(expr, estimate, seen):
# Retval: flops (int), flag (bool)
# The flag tells wether it's an integer expression (implying flops==0) or not
flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args])
flops, flags = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
flops = sum(flops)
if all(flags):
# `expr` is an operation involving integer operands only
Expand All @@ -138,28 +155,28 @@ def _estimate_cost(expr, estimate):

@_estimate_cost.register(Tuple)
@_estimate_cost.register(CallFromPointer)
def _(expr, estimate):
def _(expr, estimate, seen):
try:
flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args])
flops, flags = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
except ValueError:
flops, flags = [], []
return sum(flops), all(flags)


@_estimate_cost.register(Integer)
def _(expr, estimate):
def _(expr, estimate, seen):
return 0, True


@_estimate_cost.register(Number)
@_estimate_cost.register(ReservedWord)
def _(expr, estimate):
def _(expr, estimate, seen):
return 0, False


@_estimate_cost.register(Symbol)
@_estimate_cost.register(Indexed)
def _(expr, estimate):
def _(expr, estimate, seen):
try:
if issubclass(expr.dtype, np.integer):
return 0, True
Expand All @@ -169,27 +186,27 @@ def _(expr, estimate):


@_estimate_cost.register(Mul)
def _(expr, estimate):
flops, flags = _estimate_cost.registry[object](expr, estimate)
def _(expr, estimate, seen):
flops, flags = _estimate_cost.registry[object](expr, estimate, seen)
if {S.One, S.NegativeOne}.intersection(expr.args):
flops -= 1
return flops, flags


@_estimate_cost.register(INT)
def _(expr, estimate):
return _estimate_cost(expr.base, estimate)[0], True
def _(expr, estimate, seen):
return _estimate_cost(expr.base, estimate, seen)[0], True


@_estimate_cost.register(Cast)
def _(expr, estimate):
return _estimate_cost(expr.base, estimate)[0], False
def _(expr, estimate, seen):
return _estimate_cost(expr.base, estimate, seen)[0], False


@_estimate_cost.register(Function)
def _(expr, estimate):
def _(expr, estimate, seen):
if q_routine(expr):
flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args])
flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
flops = sum(flops)
if isinstance(expr, DefFunction):
# Bypass user-defined or language-specific functions
Expand All @@ -207,8 +224,8 @@ def _(expr, estimate):


@_estimate_cost.register(Pow)
def _(expr, estimate):
flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args])
def _(expr, estimate, seen):
flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
flops = sum(flops)
if estimate:
if expr.exp.is_Number:
Expand All @@ -229,13 +246,15 @@ def _(expr, estimate):


@_estimate_cost.register(Derivative)
def _(expr, estimate):
return _estimate_cost(expr._evaluate(expand=False), estimate)
@dont_count_if_seen
def _(expr, estimate, seen):
return _estimate_cost(expr._evaluate(expand=False), estimate, seen)


@_estimate_cost.register(IndexDerivative)
def _(expr, estimate):
flops, _ = _estimate_cost(expr.expr, estimate)
@dont_count_if_seen
def _(expr, estimate, seen):
flops, _ = _estimate_cost(expr.expr, estimate, seen)

# It's an increment
flops += 1
Expand Down
2 changes: 2 additions & 0 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ def test_factorize(expr, expected):
('Eq(fb, fd.dx)', 10, True),
('Eq(fb, fd.dx._evaluate(expand=False))', 10, False),
('Eq(fb, fd.dx.dy + fa.dx)', 66, False),
# Ensure redundancies aren't counted
('Eq(t0, fd.dx.dy + fa*fd.dx.dy)', 62, True),
])
def test_estimate_cost(expr, expected, estimate):
# Note: integer arithmetic isn't counted
Expand Down

0 comments on commit 7272069

Please sign in to comment.