Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Introduce cluster-level Temp #2281

Merged
merged 3 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,33 @@
from functools import singledispatch

from sympy import Add, Function, Indexed, Mul, Pow
from sympy.core.core import ordering_of_classes

from devito.finite_differences.differentiable import IndexDerivative
from devito.ir import Cluster, Scope, cluster_pass
from devito.passes.clusters.utils import makeit_ssa
from devito.symbolics import estimate_cost, q_leaf
from devito.symbolics.manipulation import _uxreplace
from devito.tools import as_list
from devito.types import Eq, Temp as Temp0
from devito.types import Eq, Temp

__all__ = ['cse']


class Temp(Temp0):
pass
class CTemp(Temp):

"""
A cluster-level Temp, similar to Temp, ensured to have different priority
"""
ordering_of_classes.insert(ordering_of_classes.index('Temp') + 1, 'CTemp')


@cluster_pass
def cse(cluster, sregistry, options, *args):
"""
Common sub-expressions elimination (CSE).
"""
make = lambda: Temp(name=sregistry.make_name(), dtype=cluster.dtype)
make = lambda: CTemp(name=sregistry.make_name(), dtype=cluster.dtype)
exprs = _cse(cluster, make, min_cost=options['cse-min-cost'])

return cluster.rebuild(exprs=exprs)
Expand Down Expand Up @@ -130,7 +135,7 @@ def _compact_temporaries(exprs, exclude):
# safely be compacted; a generic Symbol could instead be accessed in a subsequent
# Cluster, for example: `for (i = ...) { a = b; for (j = a ...) ...`
mapper = {e.lhs: e.rhs for e in exprs
if isinstance(e.lhs, Temp) and q_leaf(e.rhs) and e.lhs not in exclude}
if isinstance(e.lhs, CTemp) and q_leaf(e.rhs) and e.lhs not in exclude}

processed = []
for e in exprs:
Expand Down
33 changes: 23 additions & 10 deletions examples/performance/01_gpu.ipynb

Large diffs are not rendered by default.

28 changes: 22 additions & 6 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from cached_property import cached_property

from sympy import Mul # noqa
from sympy.core.mul import _mulsort

from conftest import (skipif, EVAL, _R, assert_structure, assert_blocking, # noqa
get_params, get_arrays, check_array)
Expand All @@ -18,13 +19,13 @@
FindSymbols, ParallelIteration, retrieve_iteration_tree)
from devito.passes.clusters.aliases import collect
from devito.passes.clusters.factorization import collect_nested
from devito.passes.clusters.cse import Temp, _cse
from devito.passes.clusters.cse import CTemp, _cse
from devito.passes.iet.parpragma import VExpanded
from devito.symbolics import (INT, FLOAT, DefFunction, FieldFromPointer, # noqa
IndexedPointer, Keyword, SizeOf, estimate_cost,
pow_to_mul, indexify)
from devito.tools import as_tuple, generator
from devito.types import Array, Scalar, Symbol, PrecomputedSparseTimeFunction
from devito.types import Array, Scalar, Symbol, PrecomputedSparseTimeFunction, Temp

from examples.seismic.acoustic import AcousticWaveSolver
from examples.seismic import demo_model, AcquisitionGeometry
Expand Down Expand Up @@ -132,9 +133,9 @@ def test_cse(exprs, expected, min_cost):
fx = Function(name="fx", grid=grid, dimensions=(x,), shape=(3,)) # noqa
ti0 = Array(name='ti0', shape=(3, 5, 7), dimensions=(x, y, z)).indexify() # noqa
ti1 = Array(name='ti1', shape=(3, 5, 7), dimensions=(x, y, z)).indexify() # noqa
t0 = Temp(name='t0') # noqa
t1 = Temp(name='t1') # noqa
t2 = Temp(name='t2') # noqa
t0 = CTemp(name='t0') # noqa
t1 = CTemp(name='t1') # noqa
t2 = CTemp(name='t2') # noqa
# Needs to not be a Temp to mimic nested index extraction and prevent
# cse to compact the temporary back.
e0 = Symbol(name='e0') # noqa
Expand All @@ -144,13 +145,28 @@ def test_cse(exprs, expected, min_cost):
exprs[i] = DummyEq(indexify(diffify(eval(e).evaluate)))

counter = generator()
make = lambda: Temp(name='r%d' % counter()).indexify()
make = lambda: CTemp(name='r%d' % counter()).indexify()
processed = _cse(exprs, make, min_cost)

assert len(processed) == len(expected)
assert all(str(i.rhs) == j for i, j in zip(processed, expected))


def test_cse_temp_order():
# Test order of classes inserted to Sympy's core ordering
a = Temp(name='r6')
b = CTemp(name='r6')
c = Symbol(name='r6')

args = [b, a, c]

_mulsort(args)

assert type(args[0]) is Symbol
assert type(args[1]) is Temp
assert type(args[2]) is CTemp


@pytest.mark.parametrize('expr,expected', [
('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'),
('fa[x]**2', 'fa[x]*fa[x]'),
Expand Down
1 change: 1 addition & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
from sympy.abc import a, b, c, d, e

import time

from devito.tools import (UnboundedMultiTuple, ctypes_to_cstr, toposort,
Expand Down
Loading