Skip to content

Commit

Permalink
mpi: rework subclassing of Compute calls
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Jun 27, 2023
1 parent 026d354 commit 8bb5b24
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 66 deletions.
86 changes: 44 additions & 42 deletions devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,42 @@

# ElementalFunction machinery

class ElementalCall(Call):

def __init__(self, name, arguments=None, mapper=None, dynamic_args_mapper=None,
incr=False, retobj=None, is_indirect=False):
self._mapper = mapper or {}

arguments = list(as_tuple(arguments))
dynamic_args_mapper = dynamic_args_mapper or {}
for k, v in dynamic_args_mapper.items():
tv = as_tuple(v)

# Sanity check
if k not in self._mapper:
raise ValueError("`k` is not a dynamic parameter" % k)
if len(self._mapper[k]) != len(tv):
raise ValueError("Expected %d values for dynamic parameter `%s`, given %d"
% (len(self._mapper[k]), k, len(tv)))
# Create the argument list
for i, j in zip(self._mapper[k], tv):
arguments[i] = j if incr is False else (arguments[i] + j)

super(ElementalCall, self).__init__(name, arguments, retobj, is_indirect)

def _rebuild(self, *args, dynamic_args_mapper=None, incr=False,
retobj=None, **kwargs):
# This guarantees that `ec._rebuild(arguments=ec.arguments) == ec`
return super(ElementalCall, self)._rebuild(
*args, dynamic_args_mapper=dynamic_args_mapper, incr=incr,
retobj=retobj, **kwargs
)

@cached_property
def dynamic_defaults(self):
return {k: tuple(self.arguments[i] for i in v) for k, v in self._mapper.items()}


class ElementalFunction(Callable):

"""
Expand All @@ -21,6 +57,7 @@ class ElementalFunction(Callable):
supplying bounds and step increment for each Dimension listed in
``dynamic_parameters``.
"""
_Call_cls = ElementalCall

is_ElementalFunction = True

Expand All @@ -47,53 +84,18 @@ def dynamic_defaults(self):

def make_call(self, dynamic_args_mapper=None, incr=False, retobj=None,
is_indirect=False):
return ElementalCall(self.name, list(self.parameters), dict(self._mapper),
dynamic_args_mapper, incr, retobj, is_indirect)


class ElementalCall(Call):

def __init__(self, name, arguments=None, mapper=None, dynamic_args_mapper=None,
incr=False, retobj=None, is_indirect=False):
self._mapper = mapper or {}

arguments = list(as_tuple(arguments))
dynamic_args_mapper = dynamic_args_mapper or {}
for k, v in dynamic_args_mapper.items():
tv = as_tuple(v)

# Sanity check
if k not in self._mapper:
raise ValueError("`k` is not a dynamic parameter" % k)
if len(self._mapper[k]) != len(tv):
raise ValueError("Expected %d values for dynamic parameter `%s`, given %d"
% (len(self._mapper[k]), k, len(tv)))
# Create the argument list
for i, j in zip(self._mapper[k], tv):
arguments[i] = j if incr is False else (arguments[i] + j)

super(ElementalCall, self).__init__(name, arguments, retobj, is_indirect)

def _rebuild(self, *args, dynamic_args_mapper=None, incr=False,
retobj=None, **kwargs):
# This guarantees that `ec._rebuild(arguments=ec.arguments) == ec`
return super(ElementalCall, self)._rebuild(
*args, dynamic_args_mapper=dynamic_args_mapper, incr=incr,
retobj=retobj, **kwargs
)

@cached_property
def dynamic_defaults(self):
return {k: tuple(self.arguments[i] for i in v) for k, v in self._mapper.items()}
return self._Call_cls(self.name, list(self.parameters), dict(self._mapper),
dynamic_args_mapper, incr, retobj, is_indirect)


def make_efunc(name, iet, dynamic_parameters=None, retval='void', prefix='static'):
def make_efunc(name, iet, dynamic_parameters=None, retval='void', prefix='static',
efunc_type=ElementalFunction):
"""
Shortcut to create an ElementalFunction.
"""
return ElementalFunction(name, iet, retval=retval,
parameters=derive_parameters(iet), prefix=prefix,
dynamic_parameters=dynamic_parameters)
return efunc_type(name, iet, retval=retval,
parameters=derive_parameters(iet), prefix=prefix,
dynamic_parameters=dynamic_parameters)


# Callable machinery
Expand Down
37 changes: 13 additions & 24 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.ir.iet import (Call, Callable, Conditional, ElementalFunction,
Expression, ExpressionBundle, AugmentedExpression,
Iteration, List, Prodder, Return, make_efunc, FindNodes,
Transformer, derive_parameters, ElementalCall)
Transformer, ElementalCall)
from devito.mpi import MPI
from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite,
IndexedPointer, Macro, cast_mapper, subs_op_args)
Expand Down Expand Up @@ -572,6 +572,14 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
return HaloUpdate('haloupdate%s' % key, iet, parameters)


class ComputeCall(ElementalCall):
pass


class ComputeFunction(ElementalFunction):
_Call_cls = ComputeCall


class OverlapHaloExchangeBuilder(DiagHaloExchangeBuilder):

"""
Expand Down Expand Up @@ -647,7 +655,8 @@ def _make_compute(self, hs, key, *args):
if hs.body.is_Call:
return None
else:
return make_compute_func('compute%d' % key, hs.body, hs.arguments)
return make_efunc('compute%d' % key, hs.body, hs.arguments,
efunc_type=ComputeFunction)

def _call_compute(self, hs, compute, *args):
if compute is None:
Expand Down Expand Up @@ -952,7 +961,8 @@ def _make_compute(self, hs, key, msgs, callpoke):
mapper = {i: List(body=[callpoke, i]) for i in
FindNodes(ExpressionBundle).visit(hs.body)}
iet = Transformer(mapper).visit(hs.body)
return make_compute_func('compute%d' % key, iet, hs.arguments)
return make_efunc('compute%d' % key, iet, hs.arguments,
efunc_type=ComputeFunction)

def _make_poke(self, hs, key, msgs):
lflag = Symbol(name='lflag')
Expand Down Expand Up @@ -1025,23 +1035,6 @@ def __init__(self, name, body, parameters):
super(HaloUpdate, self).__init__(name, body, parameters)


class ComputeFunction(ElementalFunction):

def make_call(self, dynamic_args_mapper=None, incr=False, retobj=None,
is_indirect=False):
return ComputeCall(self.name, list(self.parameters), dict(self._mapper),
dynamic_args_mapper, incr, retobj, is_indirect)


def make_compute_func(name, iet, dynamic_parameters=None, retval='void', prefix='static'):
"""
Shortcut to create a ComputeFunction.
"""
return ComputeFunction(name, iet, retval=retval,
parameters=derive_parameters(iet), prefix=prefix,
dynamic_parameters=dynamic_parameters)


class Remainder(ElementalFunction):
pass

Expand Down Expand Up @@ -1086,10 +1079,6 @@ class HaloWaitCall(MPICall):
pass


class ComputeCall(ElementalCall):
pass


class RemainderCall(MPICall):
pass

Expand Down

0 comments on commit 8bb5b24

Please sign in to comment.