From 65b8cc6b65e4e8e3bad2b81bd68a5bbb7944bb81 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 30 Oct 2023 09:27:47 +0000 Subject: [PATCH] compiler: Hotfix compare-ops --- devito/symbolics/inspection.py | 2 +- tests/test_unexpansion.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index f6a3a863eb..45b5dce754 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -50,7 +50,7 @@ def compare_ops(e1, e2): return True if e1 == e2 else False elif isinstance(e1, IndexDerivative) and isinstance(e2, IndexDerivative): if e1.mapper == e2.mapper: - return compare_ops(e1.base, e2.base) + return compare_ops(e1.expr, e2.expr) else: return False elif e1.is_Indexed and e2.is_Indexed: diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index cf962966de..9b6d30c9b1 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -349,6 +349,21 @@ def test_v1(self): op.cfunction + def test_diff_first_deriv(self): + grid = Grid(shape=(16, 16, 16)) + + u = TimeFunction(name='u', grid=grid, space_order=16) + + eq = Eq(u.forward, u.dy2.dz + u.dy.dx + 1) + + op = Operator(eq, opt=('advanced', {'expand': False})) + + xs, ys, zs = get_params(op, 'x0_blk0_size', 'y0_blk0_size', 'z_size') + arrays = get_arrays(op) + assert len(arrays) == 2 + check_array(arrays[0], ((8, 8), (0, 0), (8, 8)), (xs+16, ys, zs+16)) + check_array(arrays[1], ((8, 8), (0, 0), (8, 8)), (xs+16, ys, zs+16)) + def tti_sa_eqns(grid): t = grid.stepping_dim