Skip to content

Commit

Permalink
api: fix EvalDerivative and expand arithmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Mar 6, 2024
1 parent a0426f2 commit ae85ddf
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 114 deletions.
5 changes: 2 additions & 3 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .finite_difference import generic_derivative, first_derivative, cross_derivative
from .differentiable import Differentiable
from .tools import direct, transpose
from .rsfd import difrot
from .rsfd import d45
from devito.tools import as_mapper, as_tuple, filter_ordered, frozendict
from devito.types.utils import DimensionTuple

Expand Down Expand Up @@ -396,8 +396,7 @@ def _eval_fd(self, expr, **kwargs):
if self.method == 'RSFD':
assert len(self.dims) == 1
assert self.deriv_order == 1
fdfunc = difrot[expr.grid.dim]['d%s' % self.dims[0].name]
res = fdfunc(expr, self.x0, expand=expand)
res = d45(expr, self.dims[0], x0=self.x0, expand=expand)
elif self.side is not None and self.deriv_order == 1:
assert self.method == 'FD'
res = first_derivative(expr, self.dims[0], self.fd_order,
Expand Down
4 changes: 2 additions & 2 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def _evaluate(self, **kwargs):
expr = self.expr._evaluate(**kwargs)

if not kwargs.get('expand', True):
return self.func(expr, self.dimensions)
return self._rebuild(expr)

values = product(*[list(d.range) for d in self.dimensions])
terms = []
Expand Down Expand Up @@ -834,7 +834,7 @@ def _evaluate(self, **kwargs):
mapper = {w.subs(d, i): f.weights[n] for n, i in enumerate(d.range)}
expr = expr.xreplace(mapper)

return EvalDerivative(expr, base=self.base)
return EvalDerivative(*expr.args, base=self.base)


class DiffDerivative(IndexDerivative, DifferentiableOp):
Expand Down
93 changes: 44 additions & 49 deletions devito/finite_differences/rsfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from devito.types.dimension import StencilDimension
from .differentiable import Weights, DiffDerivative
from .tools import generate_indices_staggered, fd_weights_registry
from .elementary import sqrt

__all__ = ['drot', 'dxrot', 'dyrot', 'dzrot']
__all__ = ['drot', 'd45']

smapper = {1: (1, 1, 1), 2: (1, 1, -1), 3: (1, -1, 1), 4: (1, -1, -1)}

Expand All @@ -20,12 +19,12 @@ def shift(sign, x0):

def drot(expr, dim, dir=1, x0=None):
"""
Finite difference approximation of the derivative along d1
Finite difference approximation of the derivative along dir
of a Function `f` at point `x0`.
Rotated finite differences based on:
https://www.sciencedirect.com/science/article/pii/S0165212599000232
The rotated axis (the four diagonal of a cube) are:
The rotated axis (the four diagonals of a cube) are:
d1 = dx/dr x + dz/dl y + dy/dl z
d2 = dx/dl x + dz/dl y - dy/dr z
d3 = dx/dr x - dz/dl y + dy/dl z
Expand All @@ -39,8 +38,8 @@ def drot(expr, dim, dir=1, x0=None):
if dir > 2 and ndim == 2:
return 0

# Spacing along diagonal
r = sqrt(sum(d.spacing**2 for d in expr.grid.dimensions))
# RSFD scaling
s = 2**(expr.grid.dim - 1)

# Center point and indices
start, indices = generate_indices_staggered(expr, dim, expr.space_order, x0=x0)
Expand All @@ -63,15 +62,17 @@ def drot(expr, dim, dir=1, x0=None):
signs = smapper[dir][::(1 if ndim == 3 else 2)]

# Direction substitutions
dim_mapper = {d: d + signs[di]*i*d.spacing - shift(signs[di], mid)*d.spacing
for (di, d) in enumerate(expr.grid.dimensions)}
dim_mapper = {}
for (di, d) in enumerate(expr.grid.dimensions):
s0 = 0 if mid == adim_start else shift(signs[di], mid)*d.spacing
dim_mapper[d] = d + signs[di]*i*d.spacing - s0

# Create IndexDerivative
ui = expr.subs(dim_mapper)

deriv = DiffDerivative(w0*ui, {d: i for d in expr.grid.dimensions})
deriv = DiffDerivative(w0*ui/(s*dim.spacing), {d: i for d in expr.grid.dimensions})

return deriv/r
return deriv


grid_node = lambda grid: {d: d for d in grid.dimensions}
Expand All @@ -97,7 +98,7 @@ def check_staggering(func):
- grid.dimension center point and NODE staggering
"""
@wraps(func)
def wrapper(expr, x0=None, expand=True):
def wrapper(expr, dim, x0=None, expand=True):
grid = expr.grid
x0 = {k: v for k, v in x0.items() if k.is_Space}
if expr.staggered is NODE or expr.staggered is None:
Expand All @@ -107,52 +108,46 @@ def wrapper(expr, x0=None, expand=True):
else:
cond = False
if cond:
return func(expr, x0=x0, expand=expand)
return func(expr, dim, x0=x0, expand=expand)
else:
raise ValueError('Invalid staggering or x0 for rotated finite differences')
return wrapper


@check_staggering
def dxrot(expr, x0=None, expand=True):
x = expr.grid.dimensions[0]
r = sqrt(sum(d.spacing**2 for d in expr.grid.dimensions))
s = 2**(expr.grid.dim - 1)
dxrsfd = (drot(expr, x, x0=x0, dir=1) + drot(expr, x, x0=x0, dir=2) +
drot(expr, x, x0=x0, dir=3) + drot(expr, x, x0=x0, dir=4))
dx45 = r / (s * x.spacing) * dxrsfd
if expand:
return dx45.evaluate
else:
return dx45

def d45(expr, dim, x0=None, expand=True):
"""
RSFD approximation of the derivative of `expr` along `dim` at point `x0`.
Parameters
----------
expr : expr-like
Expression for which the derivative is produced.
dim : Dimension
The Dimension w.r.t. which to differentiate.
x0 : dict, optional
Origin of the finite-difference. Defaults to 0 for all dimensions.
expand : bool, optional
Expand the expression. Defaults to True.
"""
# Make sure the grid supports RSFD
if expr.grid.dim == 1:
raise ValueError('RSFD only supported in 2D and 3D')

@check_staggering
def dyrot(expr, x0=None, expand=True):
y = expr.grid.dimensions[1]
r = sqrt(sum(d.spacing**2 for d in expr.grid.dimensions))
s = 2**(expr.grid.dim - 1)
dyrsfd = (drot(expr, y, x0=x0, dir=1) + drot(expr, y, x0=x0, dir=2) -
drot(expr, y, x0=x0, dir=3) - drot(expr, y, x0=x0, dir=4))
dy45 = r / (s * y.spacing) * dyrsfd
if expand:
return dy45.evaluate
else:
return dy45
# Diagonals weights
w = dir_weights[(dim.name, expr.grid.dim)]

# RSFD
rsfd = (w[0] * drot(expr, dim, x0=x0, dir=1) +
w[1] * drot(expr, dim, x0=x0, dir=2) +
w[2] * drot(expr, dim, x0=x0, dir=3) +
w[3] * drot(expr, dim, x0=x0, dir=4))

@check_staggering
def dzrot(expr, x0=None, expand=True):
z = expr.grid.dimensions[-1]
r = sqrt(sum(d.spacing**2 for d in expr.grid.dimensions))
s = 2**(expr.grid.dim - 1)
dzrsfd = (drot(expr, z, x0=x0, dir=1) - drot(expr, z, x0=x0, dir=2) +
drot(expr, z, x0=x0, dir=3) - drot(expr, z, x0=x0, dir=4))
dz45 = r / (s * z.spacing) * dzrsfd
if expand:
return dz45.evaluate
else:
return dz45
# Evaluate
return rsfd._evaluate(expand=expand)


difrot = {2: {'dx': dxrot, 'dy': dzrot}, 3: {'dx': dxrot, 'dy': dyrot, 'dz': dzrot}}
# How to sum d1, d2, d3, d4 depending on the dimension
dir_weights = {('x', 2): (1, 1, 1, 1), ('x', 3): (1, 1, 1, 1),
('y', 2): (1, -1, 1, -1), ('y', 3): (1, 1, -1, -1),
('z', 2): (1, -1, 1, -1), ('z', 3): (1, -1, 1, -1)}
2 changes: 2 additions & 0 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def wrapper(expr, *args, **kwargs):
"with symbolic coefficients is not currently "
"supported")
kwargs['coefficients'] = 'symbolic'
else:
kwargs['coefficients'] = expr.coefficients
return func(expr, *args, **kwargs)
return wrapper

Expand Down
6 changes: 4 additions & 2 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from ctypes import POINTER, Structure, c_int, c_ulong, c_void_p, cast, byref
from functools import wraps, reduce
from operator import mul
import warnings

import numpy as np
import sympy
import warnings
from psutil import virtual_memory
from cached_property import cached_property

Expand Down Expand Up @@ -44,6 +44,8 @@ class DiscreteFunction(AbstractFunction, ArgProvider, Differentiable):
Users should not instantiate this class directly. Use Function or
SparseFunction (or their subclasses) instead.
"""

# Default method for the finite difference approximation weights computation.
_default_fd = 'taylor'

# Required by SymPy, otherwise the presence of __getitem__ will make SymPy
Expand Down Expand Up @@ -71,7 +73,7 @@ def __init_finalize__(self, *args, function=None, **kwargs):

# Symbolic (finite difference) coefficients
self._coefficients = kwargs.get('coefficients', self._default_fd)
if self._coefficients not in fd_weights_registry.keys():
if self._coefficients not in fd_weights_registry:
if self._coefficients == 'standard':
self._coefficients = 'taylor'
warnings.warn("The `standard` mode is deprecated and will be removed in "
Expand Down
Loading

0 comments on commit ae85ddf

Please sign in to comment.