From 0744570685c8893bdf1ea82eef832783774bca1d Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Sat, 28 Oct 2023 11:05:31 +0000 Subject: [PATCH] compiler: Patch compare_ops for IndexDerivatives --- devito/symbolics/inspection.py | 5 +++++ tests/test_unexpansion.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 25ed8f84eb..f6a3a863eb 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -48,6 +48,11 @@ def compare_ops(e1, e2): if type(e1) is type(e2) and len(e1.args) == len(e2.args): if e1.is_Atom: return True if e1 == e2 else False + elif isinstance(e1, IndexDerivative) and isinstance(e2, IndexDerivative): + if e1.mapper == e2.mapper: + return compare_ops(e1.base, e2.base) + else: + return False elif e1.is_Indexed and e2.is_Indexed: return True if e1.base == e2.base else False else: diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index e32f2ad4a3..cf962966de 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from conftest import assert_structure, get_params, get_arrays, check_array from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator, @@ -61,6 +62,34 @@ def test_numeric_coeffs(self): # Compound expression Operator(Eq(u, (v*u.dx).dy, coefficients=coeffs), opt=opt).cfunction + @pytest.mark.parametrize('coeffs,expected', [ + ((7, 7, 7), 1), # We've had a bug triggered by identical coeffs + ((5, 7, 9), 3), + ]) + def test_multiple_cross_derivs(self, coeffs, expected): + grid = Grid(shape=(11, 11, 11), extent=(10., 10., 10.)) + x, y, z = grid.dimensions + + p = TimeFunction(name='p', grid=grid, space_order=4, + coefficients='symbolic') + + c0, c1, c2 = coeffs + coeffs0 = np.full(5, c0) + coeffs1 = np.full(5, c1) + coeffs2 = np.full(5, c2) + + subs = Substitutions(Coefficient(1, p, x, coeffs0), + Coefficient(1, p, y, coeffs1), + Coefficient(1, p, z, coeffs2)) + + eq = Eq(p.forward, p.dy.dz + p.dx.dy, coefficients=subs) + + op = Operator(eq, opt=('advanced', {'expand': False})) + op.cfunction + + # w0, w1, ... + assert len(op._globals) == expected + class Test1Pass(object):