From 2ec2de62dd20ed9404ac864edc6a54af0b2097a2 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 12 Oct 2023 09:07:31 +0000 Subject: [PATCH] compiler: Patch cluster.is_sparse --- devito/ir/clusters/cluster.py | 13 +++---------- tests/test_lower_clusters.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 10 deletions(-) create mode 100644 tests/test_lower_clusters.py diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 2143331aba..e7354b8693 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -231,17 +231,10 @@ def is_dense(self): @cached_property def is_sparse(self): """ - A cluster is sparse if it represent a sparse operation i.e if both - - * The cluster contains sparse functions - * The cluster uses dense functions - - If only the first case is true, the cluster only contains operation on the sparse - function itself without indirection and therefore only contains dense operations. + A Cluster is sparse if it represents a sparse operation, i.e iff + There's at least one irregular access. """ - return (any(f.is_SparseFunction for f in self.functions) and - len([f for f in self.functions - if (f.is_Function and not f.is_SparseFunction)]) > 0) + return any(a.is_irregular for a in self.scope.accesses) @property def is_halo_touch(self): diff --git a/tests/test_lower_clusters.py b/tests/test_lower_clusters.py new file mode 100644 index 0000000000..f5ecf46bc6 --- /dev/null +++ b/tests/test_lower_clusters.py @@ -0,0 +1,18 @@ +from devito import Grid, SparseTimeFunction, TimeFunction, Operator +from devito.ir.iet import FindSymbols + + +class TestLowerReductions(object): + + def test_no_temp_upon_reduce_expansion(self): + grid = Grid(shape=(10, 10, 10)) + + u = TimeFunction(name='u', grid=grid) + sf = SparseTimeFunction(name='sf', grid=grid, npoint=1, nt=5) + + rec_term = sf.interpolate(expr=u) + + op = Operator(rec_term, opt=('advanced', {'mapify-reduce': True})) + + arrays = [i for i in FindSymbols().visit(op) if i.is_Array] + assert len([i for i in arrays if i.ndim == grid.dim]) == 0