From e01780026998faa1c7b4ed986ce6d73649704655 Mon Sep 17 00:00:00 2001 From: Mathias Louboutin Date: Tue, 26 Sep 2023 09:21:54 -0400 Subject: [PATCH] compiler: prevent radius dependent temps for sparse operations --- devito/ir/clusters/algorithms.py | 2 +- tests/test_dse.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 2c0bdacc0c..dcea8e2f1c 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -458,7 +458,7 @@ def normalize_reductions(cluster, sregistry, options): processed = [] for e in cluster.exprs: - if e.is_Reduction and e.lhs.is_Indexed and cluster.is_sparse: + if e.is_Reduction and (e.lhs.is_Indexed or cluster.is_sparse): # Transform `e` such that we reduce into a scalar (ultimately via # atomic ops, though this part is carried out by a much later pass) # For example, given `i = m[p_src]` (i.e., indirection array), turn: diff --git a/tests/test_dse.py b/tests/test_dse.py index 728f8f9357..e1ae16eb69 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -24,7 +24,7 @@ IndexedPointer, Keyword, SizeOf, estimate_cost, pow_to_mul, indexify) from devito.tools import as_tuple, generator -from devito.types import Array, Scalar, Symbol +from devito.types import Array, Scalar, Symbol, PrecomputedSparseTimeFunction from examples.seismic.acoustic import AcousticWaveSolver from examples.seismic import demo_model, AcquisitionGeometry @@ -2664,6 +2664,18 @@ def test_dtype_aliases(self): assert FindNodes(Expression).visit(op)[0].dtype == np.float32 assert np.all(fo.data[:-1, :-1] == 8) + def test_sparse_const(self): + grid = Grid((11, 11, 11)) + + u = TimeFunction(name="u", grid=grid) + src = PrecomputedSparseTimeFunction(name="src", grid=grid, npoint=1, nt=11, + r=2, interpolation_coeffs=np.ones((1, 3, 2))) + op = Operator(src.interpolate(u)) + + cond = FindNodes(Conditional).visit(op) + assert len(cond) == 1 + assert all(e.is_scalar for e in cond[0].args['then_body'][0].exprs) + class TestIsoAcoustic(object):