Skip to content

Commit

Permalink
api: fix subfunction handling (subs/rebuild/...)
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 8, 2023
1 parent 3209435 commit 2a3654b
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 84 deletions.
4 changes: 4 additions & 0 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def _print_Mod(self, expr):
args = ['(%s)' % self._print(a) for a in expr.args]
return '%'.join(args)

def _print_Mul(self, expr):
term = super()._print_Mul(expr)
return term.replace("(-1)*", "-")

def _print_Min(self, expr):
if has_integer_args(*expr.args) and len(expr.args) == 2:
return "MIN(%s)" % self._print(expr.args)[1:-1]
Expand Down
3 changes: 2 additions & 1 deletion devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from devito.finite_differences import Differentiable, generate_fd_shortcuts
from devito.tools import (ReducerMap, as_tuple, c_restrict_void_p, flatten, is_integer,
memoized_meth, dtype_to_ctype, humanbytes)
from devito.types.dimension import Dimension, DynamicDimension
from devito.types.dimension import Dimension
from devito.types.args import ArgProvider
from devito.types.caching import CacheManager
from devito.types.basic import AbstractFunction, Size
Expand Down Expand Up @@ -1040,6 +1040,7 @@ def __indices_setup__(cls, *args, **kwargs):
dimensions = grid.dimensions

if args:
assert len(args) == len(dimensions)
return tuple(dimensions), tuple(args)

# Staggered indices
Expand Down
113 changes: 47 additions & 66 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,9 @@ def __indices_setup__(cls, *args, **kwargs):
dimensions = (Dimension(name='p_%s' % kwargs["name"]),)

if args:
indices = args
return tuple(dimensions), tuple(args)
else:
indices = dimensions

return dimensions, indices
return dimensions, dimensions

@classmethod
def __shape_setup__(cls, **kwargs):
Expand All @@ -80,16 +78,6 @@ def __shape_setup__(cls, **kwargs):
shape = (glb_npoint[grid.distributor.myrank],)
return shape

def func(self, *args, **kwargs):
# Rebuild subfunctions first to avoid new data creation as we have to use `_data`
# as a reconstruction kwargs to avoid the circular dependency
# with the parent in SubFunction
# This is also necessary to avoid shape issue in the SubFunction with mpi
for s in self._sub_functions:
if getattr(self, s) is not None:
kwargs.update({s: getattr(self, s).func(*args, **kwargs)})
return super().func(*args, **kwargs)

def __fd_setup__(self):
"""
Dynamically add derivative short-cuts.
Expand All @@ -108,24 +96,39 @@ def __distributor_setup__(self, **kwargs):
)

def __subfunc_setup__(self, key, suffix, dtype=None):
# Shape and dimensions from args
name = '%s_%s' % (self.name, suffix)

if key is not None and not isinstance(key, SubFunction):
key = np.array(key)

if key is not None:
dimensions = (self._sparse_dim, Dimension(name='d'))
if key.ndim > 2:
dimensions = (self._sparse_dim, Dimension(name='d'),
*mkdims("i", n=key.ndim-2))
else:
dimensions = (self._sparse_dim, Dimension(name='d'))
shape = (self.npoint, self.grid.dim, *key.shape[2:])
else:
dimensions = (self._sparse_dim, Dimension(name='d'))
shape = (self.npoint, self.grid.dim)

# Check if already a SubFunction
if isinstance(key, SubFunction):
return key
# Need to rebuild so the dimensions match the parent SparseFunction
indices = (self.indices[self._sparse_position], *key.indices[1:])
return key._rebuild(*indices, name=name, shape=shape,
alias=self.alias, halo=None)
elif key is not None and not isinstance(key, Iterable):
raise ValueError("`%s` must be either SubFunction "
"or iterable (e.g., list, np.ndarray)" % key)

name = '%s_%s' % (self.name, suffix)
dimensions = (self._sparse_dim, Dimension(name='d'))
shape = (self.npoint, self.grid.dim)

if key is None:
# Fallback to default behaviour
dtype = dtype or self.dtype
else:
if key is not None:
key = np.array(key)

if (shape != key.shape[:2] and key.shape != (shape[1],)) and \
if (shape != key.shape and key.shape != (shape[1],)) and \
self._distributor.nprocs == 1:
raise ValueError("Incompatible shape for %s, `%s`; expected `%s`" %
(suffix, key.shape[:2], shape))
Expand All @@ -136,12 +139,8 @@ def __subfunc_setup__(self, key, suffix, dtype=None):
else:
dtype = dtype or self.dtype

if key is not None and key.ndim > 2:
shape = (*shape, *key.shape[2:])
dimensions = (*dimensions, *mkdims("i", n=key.ndim-2))

sf = SubFunction(
name=name, parent=self, dtype=dtype, dimensions=dimensions,
name=name, dtype=dtype, dimensions=dimensions,
shape=shape, space_order=0, initializer=key, alias=self.alias,
distributor=self._distributor
)
Expand Down Expand Up @@ -657,20 +656,6 @@ def time_dim(self):
"""The time Dimension."""
return self._time_dim

@classmethod
def __indices_setup__(cls, *args, **kwargs):
dimensions = as_tuple(kwargs.get('dimensions'))
if not dimensions:
dimensions = (kwargs['grid'].time_dim,
Dimension(name='p_%s' % kwargs["name"]))

if args:
indices = args
else:
indices = dimensions

return dimensions, indices

@classmethod
def __shape_setup__(cls, **kwargs):
shape = kwargs.get('shape')
Expand All @@ -686,6 +671,18 @@ def __shape_setup__(cls, **kwargs):

return tuple(shape)

@classmethod
def __indices_setup__(cls, *args, **kwargs):
dimensions = as_tuple(kwargs.get('dimensions'))
if not dimensions:
dimensions = (kwargs['grid'].time_dim,
Dimension(name='p_%s' % kwargs["name"]),)

if args:
return tuple(dimensions), tuple(args)
else:
return dimensions, dimensions

@property
def nt(self):
return self.shape[self._time_position]
Expand Down Expand Up @@ -1032,13 +1029,14 @@ def __init_finalize__(self, *args, **kwargs):
if r <= 0:
raise ValueError('`r` must be > 0')
# Make sure radius matches the coefficients size
nr = interpolation_coeffs.shape[-1]
if nr // 2 != r:
if nr == r:
r = r // 2
else:
raise ValueError("Interpolation coefficients shape %d do "
"not match specified radius %d" % (r, nr))
if interpolation_coeffs is not None:
nr = interpolation_coeffs.shape[-1]
if nr // 2 != r:
if nr == r:
r = r // 2
else:
raise ValueError("Interpolation coefficients shape %d do "
"not match specified radius %d" % (r, nr))
self._radius = r

if coordinates is not None and gridpoints is not None:
Expand Down Expand Up @@ -1680,23 +1678,6 @@ def inject(self, field, expr, u_t=None, p_t=None):

return out

@classmethod
def __indices_setup__(cls, *args, **kwargs):
"""
Return the default Dimension indices for a given data shape.
"""
dimensions = kwargs.get('dimensions')
if dimensions is None:
dimensions = (kwargs['grid'].time_dim, Dimension(
name='p_%s' % kwargs["name"]))

if args:
indices = args
else:
indices = dimensions

return dimensions, indices

@classmethod
def __shape_setup__(cls, **kwargs):
# This happens before __init__, so we have to get 'npoint'
Expand Down
22 changes: 9 additions & 13 deletions examples/seismic/inversion/inversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,15 @@ def compute_residual(res, dobs, dsyn):
"""
Computes the data residual dsyn - dobs into residual
"""
if res.grid.distributor.is_parallel:
# If we run with MPI, we have to compute the residual via an operator
# First make sure we can take the difference and that receivers are at the
# same position
assert np.allclose(dobs.coordinates.data[:], dsyn.coordinates.data)
assert np.allclose(res.coordinates.data[:], dsyn.coordinates.data)
# Create a difference operator
diff_eq = Eq(res, dsyn.subs({dsyn.dimensions[-1]: res.dimensions[-1]}) -
dobs.subs({dobs.dimensions[-1]: res.dimensions[-1]}))
Operator(diff_eq)()
else:
# A simple data difference is enough in serial
res.data[:] = dsyn.data[:] - dobs.data[:]
# If we run with MPI, we have to compute the residual via an operator
# First make sure we can take the difference and that receivers are at the
# same position
assert np.allclose(dobs.coordinates.data[:], dsyn.coordinates.data)
assert np.allclose(res.coordinates.data[:], dsyn.coordinates.data)
# Create a difference operator
diff_eq = Eq(res, dsyn.subs({dsyn.dimensions[-1]: res.dimensions[-1]}) -
dobs.subs({dobs.dimensions[-1]: res.dimensions[-1]}))
Operator(diff_eq)()

return res

Expand Down
4 changes: 2 additions & 2 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def test_precomputed_injection_time(r):
sf = PrecomputedSparseTimeFunction(name='s', grid=m.grid, r=r, npoint=len(coords),
gridpoints=gridpoints, nt=nt,
interpolation_coeffs=interpolation_coeffs)

expr = sf.inject(m, Float(1.))
sf.data.fill(1.)
expr = sf.inject(m, sf)

op = Operator(expr)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_precomputed_sparse_function(self, mode, pickle):

sf = PrecomputedSparseTimeFunction(
name='sf', grid=grid, r=2, npoint=3, nt=5,
interpolation_coeffs=np.ndarray(shape=(3, 2, 2)), **kw
interpolation_coeffs=np.random.randn(3, 2, 2), **kw
)
sf.data[2, 1] = 5.

Expand Down
64 changes: 63 additions & 1 deletion tests/test_msparse.py → tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import numpy as np
import scipy.sparse

from devito import Grid, TimeFunction, Eq, Operator, MatrixSparseTimeFunction
from devito import Grid, TimeFunction, Eq, Operator, Dimension
from devito import (SparseFunction, SparseTimeFunction, PrecomputedSparseFunction,
PrecomputedSparseTimeFunction, MatrixSparseTimeFunction)


_sptypes = [SparseFunction, SparseTimeFunction,
PrecomputedSparseFunction, PrecomputedSparseTimeFunction]


class TestMatrixSparseTimeFunction(object):
Expand Down Expand Up @@ -394,5 +400,61 @@ def test_mpi(self):
assert sf.data[0, 0] == -3.0 # 1 * (1 * 1) * 1 + (-1) * (2 * 2) * 1


class TestSparseFunction(object):

@pytest.mark.parametrize('sptype', _sptypes)
def test_rebuild(self, sptype):
grid = Grid((3, 3, 3))
# Base object
sp = sptype(name="s", grid=grid, npoint=1, nt=11, r=2,
interpolation_coeffs=np.random.randn(1, 3, 2),
coordinates=np.random.randn(1, 3))

# Check subfunction setup
for subf in sp._sub_functions:
if getattr(sp, subf) is not None:
assert getattr(sp, subf).name.startswith("s_")

# Rebuild with different name, this should drop the function
# and create new data
sp2 = sp._rebuild(name="sr")

# Check new subfunction
for subf in sp2._sub_functions:
if getattr(sp2, subf) is not None:
assert getattr(sp2, subf).name.startswith("sr_")
assert np.all(getattr(sp2, subf).data == 0)

# Rebuild with different name as an alias
sp2 = sp._rebuild(name="sr2", alias=True)
for subf in sp2._sub_functions:
if getattr(sp2, subf) is not None:
assert getattr(sp2, subf).name.startswith("sr2_")
assert getattr(sp2, subf).data is None

@pytest.mark.parametrize('sptype', _sptypes)
def test_subs(self, sptype):
grid = Grid((3, 3, 3))
# Base object
sp = sptype(name="s", grid=grid, npoint=1, nt=11, r=2,
interpolation_coeffs=np.random.randn(1, 3, 2),
coordinates=np.random.randn(1, 3))

# Check subfunction setup
for subf in sp._sub_functions:
if getattr(sp, subf) is not None:
assert getattr(sp, subf).dimensions[0] == sp._sparse_dim

# Do substitution on sparse dimension
new_spdim = Dimension(name="newsp")

sps = sp._subs(sp._sparse_dim, new_spdim)
assert sps.indices[sp._sparse_position] == new_spdim
for subf in sps._sub_functions:
if getattr(sps, subf) is not None:
assert getattr(sps, subf).indices[0] == new_spdim
assert np.all(getattr(sps, subf).data == getattr(sp, subf).data)


if __name__ == "__main__":
TestMatrixSparseTimeFunction().test_mpi()

0 comments on commit 2a3654b

Please sign in to comment.