Skip to content

Commit

Permalink
Merge pull request #2176 from devitocodes/builtin-batching
Browse files Browse the repository at this point in the history
builtins: Support batched initialize_function
  • Loading branch information
FabioLuporini authored Sep 8, 2023
2 parents 40975d2 + 4f62b4d commit a7c3446
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 66 deletions.
158 changes: 95 additions & 63 deletions devito/builtins/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,66 @@ def fset(f, g):
return f


def _initialize_function(function, data, nbl, mapper=None, mode='constant'):
"""
Construct the symbolic objects for `initialize_function`.
"""
nbl, slices = nbl_to_padsize(nbl, function.ndim)
if isinstance(data, dv.Function):
function.data[slices] = data.data[:]
else:
function.data[slices] = data
lhs = []
rhs = []
options = []

if mode == 'reflect' and function.grid.distributor.is_parallel:
# Check that HALO size is appropriate
halo = function.halo
local_size = function.shape

def buff(i, j):
return [(i + k - 2*max(max(nbl))) for k in j]

b = [min(l) for l in (w for w in (buff(i, j) for i, j in zip(local_size, halo)))]
if any(np.array(b) < 0):
raise ValueError("Function `%s` halo is not sufficiently thick." % function)

for d, (nl, nr) in zip(function.space_dimensions, as_tuple(nbl)):
dim_l = dv.SubDimension.left(name='abc_%s_l' % d.name, parent=d, thickness=nl)
dim_r = dv.SubDimension.right(name='abc_%s_r' % d.name, parent=d, thickness=nr)
if mode == 'constant':
subsl = nl
subsr = d.symbolic_max - nr
elif mode == 'reflect':
subsl = 2*nl - 1 - dim_l
subsr = 2*(d.symbolic_max - nr) + 1 - dim_r
else:
raise ValueError("Mode not available")
lhs.append(function.subs({d: dim_l}))
lhs.append(function.subs({d: dim_r}))
rhs.append(function.subs({d: subsl}))
rhs.append(function.subs({d: subsr}))
options.extend([None, None])

if mapper and d in mapper.keys():
exprs = mapper[d]
lhs_extra = exprs['lhs']
rhs_extra = exprs['rhs']
lhs.extend(as_list(lhs_extra))
rhs.extend(as_list(rhs_extra))
options_extra = exprs.get('options', len(as_list(lhs_extra))*[None, ])
if isinstance(options_extra, list):
options.extend(options_extra)
else:
options.extend([options_extra])

if all(options is None for i in options):
options = None

return lhs, rhs, options


def initialize_function(function, data, nbl, mapper=None, mode='constant',
name=None, pad_halo=True, **kwargs):
"""
Expand All @@ -225,9 +285,9 @@ def initialize_function(function, data, nbl, mapper=None, mode='constant',
Parameters
----------
function : Function
function : Function or list of Functions
The initialised object.
data : ndarray or Function
data : ndarray or Function or list of ndarray/Function
The data used for initialisation.
nbl : int or tuple of int or tuple of tuple of int
Number of outer layers (such as absorbing layers for boundary damping).
Expand Down Expand Up @@ -286,73 +346,45 @@ def initialize_function(function, data, nbl, mapper=None, mode='constant',
[2, 3, 3, 3, 3, 2],
[2, 2, 2, 2, 2, 2]], dtype=int32)
"""
name = name or 'pad_%s' % function.name
if isinstance(function, dv.TimeFunction):
if isinstance(function, (list, tuple)):
if not isinstance(data, (list, tuple)):
raise TypeError("Expected a list of `data`")
elif len(function) != len(data):
raise ValueError("Expected %d `data` items, got %d" %
(len(function), len(data)))

if mapper is not None:
raise NotImplementedError("Unsupported `mapper` with batching")

functions = function
datas = data
else:
functions = (function,)
datas = (data,)

if any(isinstance(f, dv.TimeFunction) for f in functions):
raise NotImplementedError("TimeFunctions are not currently supported.")

if nbl == 0:
if isinstance(data, dv.Function):
function.data[:] = data.data[:]
else:
function.data[:] = data[:]
if pad_halo:
pad_outhalo(function)
return

nbl, slices = nbl_to_padsize(nbl, function.ndim)
if isinstance(data, dv.Function):
function.data[slices] = data.data[:]
for f, data in zip(functions, datas):
if isinstance(data, dv.Function):
f.data[:] = data.data[:]
else:
f.data[:] = data[:]
else:
function.data[slices] = data
lhs = []
rhs = []
options = []

if mode == 'reflect' and function.grid.distributor.is_parallel:
# Check that HALO size is appropriate
halo = function.halo
local_size = function.shape

def buff(i, j):
return [(i + k - 2*max(max(nbl))) for k in j]
lhss, rhss, optionss = [], [], []
for f, data in zip(functions, datas):
lhs, rhs, options = _initialize_function(f, data, nbl, mapper, mode)

b = [min(l) for l in (w for w in (buff(i, j) for i, j in zip(local_size, halo)))]
if any(np.array(b) < 0):
raise ValueError("Function `%s` halo is not sufficiently thick." % function)
lhss.extend(lhs)
rhss.extend(rhs)
optionss.extend(options)

for d, (nl, nr) in zip(function.space_dimensions, as_tuple(nbl)):
dim_l = dv.SubDimension.left(name='abc_%s_l' % d.name, parent=d, thickness=nl)
dim_r = dv.SubDimension.right(name='abc_%s_r' % d.name, parent=d, thickness=nr)
if mode == 'constant':
subsl = nl
subsr = d.symbolic_max - nr
elif mode == 'reflect':
subsl = 2*nl - 1 - dim_l
subsr = 2*(d.symbolic_max - nr) + 1 - dim_r
else:
raise ValueError("Mode not available")
lhs.append(function.subs({d: dim_l}))
lhs.append(function.subs({d: dim_r}))
rhs.append(function.subs({d: subsl}))
rhs.append(function.subs({d: subsr}))
options.extend([None, None])

if mapper and d in mapper.keys():
exprs = mapper[d]
lhs_extra = exprs['lhs']
rhs_extra = exprs['rhs']
lhs.extend(as_list(lhs_extra))
rhs.extend(as_list(rhs_extra))
options_extra = exprs.get('options', len(as_list(lhs_extra))*[None, ])
if isinstance(options_extra, list):
options.extend(options_extra)
else:
options.extend([options_extra])

if all(options is None for i in options):
options = None
assert len(lhss) == len(rhss) == len(optionss)

assign(lhs, rhs, options=options, name=name, **kwargs)
name = name or 'initialize_%s' % '_'.join(f.name for f in functions)
assign(lhss, rhss, options=optionss, name=name, **kwargs)

if pad_halo:
pad_outhalo(function)
for f in functions:
pad_outhalo(f)
2 changes: 1 addition & 1 deletion devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _interp_idx(self, variables, implicit_dims=None):
temps.extend(self._coeff_temps(implicit_dims))

# Substitution mapper for variables
idx_subs = {v: v.subs({k: c - v.origin.get(k, 0) + p
idx_subs = {v: v.subs({k: c + p
for ((k, c), p) in zip(mapper.items(), pos)})
for v in variables}
idx_subs.update(dict(zip(self._rdim, mapper.values())))
Expand Down
11 changes: 9 additions & 2 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,8 +1245,15 @@ def indexify(self, indices=None, subs=None):
subs = [{**{d.spacing: 1, -d.spacing: -1}, **subs} for d in self.dimensions]

# Indices after substitutions
indices = [sympy.sympify(a.subs(d, d - o).xreplace(s)) for a, d, o, s in
zip(self.args, self.dimensions, self.origin, subs)]
indices = []
for a, d, o, s in zip(self.args, self.dimensions, self.origin, subs):
if d in a.free_symbols:
# Shift by origin d -> d - o.
indices.append(sympy.sympify(a.subs(d, d - o).xreplace(s)))
else:
# Dimension has been removed, e.g. u[10], plain shift by origin
indices.append(sympy.sympify(a - o).xreplace(s))

indices = [i.xreplace({k: sympy.Integer(k) for k in i.atoms(sympy.Float)})
for i in indices]

Expand Down
17 changes: 17 additions & 0 deletions tests/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,23 @@ def test_if_halo_mpi(self, nbl):
expected = np.pad(a[na//2:, na//2:], [(0, 1+nbl), (0, 1+nbl)], 'edge')
assert np.all(f._data_with_outhalo._local == expected)

def test_batching(self):
grid = Grid(shape=(12, 12))

a = np.arange(16).reshape((4, 4))

f = Function(name='f', grid=grid, dtype=np.int32)
g = Function(name='g', grid=grid, dtype=np.float32)
h = Function(name='h', grid=grid, dtype=np.float64)

initialize_function([f, g, h], [a, a, a], 4, mode='reflect')

for i in [f, g, h]:
assert np.all(a[:, ::-1] - np.array(i.data[4:8, 0:4]) == 0)
assert np.all(a[:, ::-1] - np.array(i.data[4:8, 8:12]) == 0)
assert np.all(a[::-1, :] - np.array(i.data[0:4, 4:8]) == 0)
assert np.all(a[::-1, :] - np.array(i.data[8:12, 4:8]) == 0)


class TestBuiltinsResult(object):

Expand Down
11 changes: 11 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ def test_indexed():
assert ub.indexed.free_symbols == {ub.indexed}


def test_indexed_staggered():
grid = Grid(shape=(10, 10))
x, y = grid.dimensions
hx, hy = x.spacing, y.spacing

u = Function(name='u', grid=grid, staggered=(x, y))
u0 = u.subs({x: 1, y: 2})
assert u0.indices == (1 + hx / 2, 2 + hy / 2)
assert u0.indexify().indices == (1, 2)


def test_bundle():
grid = Grid(shape=(4, 4))

Expand Down

0 comments on commit a7c3446

Please sign in to comment.