Skip to content

Commit

Permalink
api: make Derivative reconstructable, Fix devitocodes#2330
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout authored and enwask committed Jul 26, 2024
1 parent bfc4db7 commit 916e44d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 30 deletions.
45 changes: 18 additions & 27 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from .differentiable import Differentiable
from .tools import direct, transpose
from .rsfd import d45
from devito.tools import as_mapper, as_tuple, filter_ordered, frozendict, is_integer
from devito.tools import (as_mapper, as_tuple, filter_ordered, frozendict, is_integer,
Reconstructable)
from devito.types.utils import DimensionTuple

__all__ = ['Derivative']


class Derivative(sympy.Derivative, Differentiable):
class Derivative(sympy.Derivative, Differentiable, Reconstructable):

"""
An unevaluated Derivative, which carries metadata (Dimensions,
Expand Down Expand Up @@ -86,7 +87,7 @@ class Derivative(sympy.Derivative, Differentiable):

_fd_priority = 3

__rargs__ = ('expr', 'dims')
__rargs__ = ('expr', '*dims')
__rkwargs__ = ('side', 'deriv_order', 'fd_order', 'transpose', '_ppsubs',
'x0', 'method')

Expand Down Expand Up @@ -125,6 +126,10 @@ def __new__(cls, expr, *dims, **kwargs):

return obj

def _rebuild(self, *args, **kwargs):
kwargs['preprocessed'] = True
return super()._rebuild(*args, **kwargs)

@classmethod
def _process_kwargs(cls, expr, *dims, **kwargs):
"""
Expand Down Expand Up @@ -215,8 +220,7 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None):
fd_order = fd_order or self._fd_order
side = side or self._side
method = method or self._method
return self._new_from_self(fd_order=fd_order, side=side, x0=_x0,
method=method)
return self._rebuild(fd_order=fd_order, side=side, x0=_x0, method=method)

if side is not None:
raise TypeError("Side only supported for first order single"
Expand All @@ -230,18 +234,7 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None):
except AttributeError:
raise TypeError("Multi-dimensional Derivative, input expected as a dict")

return self._new_from_self(fd_order=_fd_order, x0=_x0)

def _new_from_self(self, **kwargs):
expr = kwargs.pop('expr', self.expr)
_kwargs = {'deriv_order': self.deriv_order, 'fd_order': self.fd_order,
'side': self.side, 'transpose': self.transpose, 'subs': self._ppsubs,
'x0': self.x0, 'preprocessed': True, 'method': self.method}
_kwargs.update(**kwargs)
return Derivative(expr, *self.dims, **_kwargs)

def func(self, expr, *args, **kwargs):
return self._new_from_self(expr=expr, **kwargs)
return self._rebuild(fd_order=_fd_order, x0=_x0)

def _subs(self, old, new, **hints):
# Basic case
Expand All @@ -251,7 +244,7 @@ def _subs(self, old, new, **hints):
if self.expr.has(old):
newexpr = self.expr._subs(old, new, **hints)
try:
return self._new_from_self(expr=newexpr)
return self._rebuild(newexpr)
except ValueError:
# Expr replacement leads to non-differentiable expression
# e.g `f.dx.subs(f: 1) = 1.dx = 0`
Expand All @@ -260,7 +253,7 @@ def _subs(self, old, new, **hints):

# In case `x0` was passed as a substitution instead of `(x0=`
if str(old) == 'x0':
return self._new_from_self(x0={self.dims[0]: new})
return self._rebuild(x0={self.dims[0]: new})

# Trying to substitute by another derivative with different metadata
# Only need to check if is a Derivative since one for the cases above would
Expand Down Expand Up @@ -289,13 +282,11 @@ def _xreplace(self, subs):
return new, True

subs = self._ppsubs + (subs,) # Postponed substitutions
return self._new_from_self(subs=subs), True
return self._rebuild(subs=subs), True

@cached_property
def _metadata(self):
state = list(self.__rargs__ + self.__rkwargs__)
state.remove('expr')
ret = [getattr(self, i) for i in state]
ret = [self.dims] + [getattr(self, i) for i in self.__rkwargs__]
ret.append(self.expr.staggered or (None,))
return tuple(ret)

Expand Down Expand Up @@ -348,7 +339,7 @@ def T(self):
else:
adjoint = direct

return self._new_from_self(transpose=adjoint)
return self._rebuild(transpose=adjoint)

def _eval_at(self, func):
"""
Expand All @@ -374,19 +365,19 @@ def _eval_at(self, func):
mapper = as_mapper(self.expr._args_diff, lambda i: i.staggered)
args = [self.expr.func(*v) for v in mapper.values()]
args.extend([a for a in self.expr.args if a not in self.expr._args_diff])
args = [self._new_from_self(expr=a, x0=x0) for a in args]
args = [self._rebuild(a, x0=x0) for a in args]
return self.expr.func(*args)
elif self.expr.is_Mul:
# For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear
# in most equation with div(a * u) for example. The expression is re-centered
# at the highest priority index (see _gather_for_diff) to compute the
# derivative at x0.
return self._new_from_self(x0=x0, expr=self.expr._gather_for_diff)
return self._rebuild(expr=self.expr._gather_for_diff, x0=x0)
else:
# For every other cases, that has more functions or more complexe arithmetic,
# there is not actual way to decide what to do so it’s as safe to use
# the expression as is.
return self._new_from_self(x0=x0)
return self._rebuild(x0=x0)

def _evaluate(self, **kwargs):
# Evaluate finite-difference.
Expand Down
6 changes: 3 additions & 3 deletions devito/passes/equations/linearity.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ def _(expr, mapper, nn_derivs=None):
return expr

if len(derivs) == 1 and with_deriv is derivs[0]:
expr = with_deriv._new_from_self(expr=expr.func(*hope_coeffs, with_deriv.expr))
expr = with_deriv._rebuild(expr.func(*hope_coeffs, with_deriv.expr))
else:
others = [expr.func(*hope_coeffs, a) for a in others]
derivs = [a._new_from_self(expr=expr.func(*hope_coeffs, a.expr)) for a in derivs]
derivs = [a._rebuild(expr.func(*hope_coeffs, a.expr)) for a in derivs]
expr = with_deriv.func(*(derivs + others))

return expr
Expand Down Expand Up @@ -216,7 +216,7 @@ def _(expr):
if len(v) == 1:
args.append(c)
else:
args.append(c._new_from_self(expr=expr.func(*[i.expr for i in v])))
args.append(c._rebuild(expr.func(*[i.expr for i in v])))
expr = expr.func(*args)

return expr

0 comments on commit 916e44d

Please sign in to comment.