Skip to content

Commit

Permalink
Merge pull request #2327 from devitocodes/error-checking
Browse files Browse the repository at this point in the history
compiler: Add optional pass for runtime stability check
  • Loading branch information
mloubout authored Mar 12, 2024
2 parents 32308f4 + aaea94a commit f29cb35
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 9 deletions.
9 changes: 7 additions & 2 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from devito.passes.clusters import (Lift, blocking, buffering, cire, cse,
factorize, fission, fuse, optimize_pows,
optimize_hyperplanes)
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize, mpiize,
hoist_prodders, relax_incr_dimensions)
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize,
mpiize, hoist_prodders, relax_incr_dimensions,
check_stability)
from devito.tools import timed_pass

__all__ = ['Cpu64NoopCOperator', 'Cpu64NoopOmpOperator', 'Cpu64AdvCOperator',
Expand Down Expand Up @@ -76,6 +77,7 @@ def _normalize_kwargs(cls, **kwargs):
o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE)
o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE)
o['place-transfers'] = oo.pop('place-transfers', True)
o['errctl'] = oo.pop('errctl', cls.ERRCTL)

# Recognised but unused by the CPU backend
oo.pop('par-disabled', None)
Expand Down Expand Up @@ -189,6 +191,9 @@ def _specialize_iet(cls, graph, **kwargs):
# Misc optimizations
hoist_prodders(graph)

# Perform error checking
check_stability(graph, **kwargs)

# Symbol definitions
cls._Target.DataManager(**kwargs).process(graph)

Expand Down
9 changes: 7 additions & 2 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from devito.passes.clusters import (Lift, Streaming, Tasker, blocking, buffering,
cire, cse, factorize, fission, fuse,
optimize_pows)
from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, mpiize, hoist_prodders,
linearize, pthreadify, relax_incr_dimensions)
from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, mpiize,
hoist_prodders, linearize, pthreadify,
relax_incr_dimensions, check_stability)
from devito.tools import as_tuple, timed_pass

__all__ = ['DeviceNoopOperator', 'DeviceAdvOperator', 'DeviceCustomOperator',
Expand Down Expand Up @@ -91,6 +92,7 @@ def _normalize_kwargs(cls, **kwargs):
o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE)
o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE)
o['place-transfers'] = oo.pop('place-transfers', True)
o['errctl'] = oo.pop('errctl', cls.ERRCTL)

if oo:
raise InvalidOperator("Unsupported optimization options: [%s]"
Expand Down Expand Up @@ -226,6 +228,9 @@ def _specialize_iet(cls, graph, **kwargs):
# Misc optimizations
hoist_prodders(graph)

# Perform error checking
check_stability(graph, **kwargs)

# Symbol definitions
cls._Target.DataManager(**kwargs).process(graph)

Expand Down
10 changes: 10 additions & 0 deletions devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ class BasicOperator(Operator):
(default) or `int32`.
"""

ERRCTL = None
"""
Runtime error checking. If this option is enabled, the generated code will
include runtime checks for various things that might go south, such as
instability (e.g., NaNs), failed library calls (e.g., kernel launches).
"""

_Target = None
"""
The target language constructor, to be specified by subclasses.
Expand Down Expand Up @@ -155,6 +162,9 @@ def _check_kwargs(cls, **kwargs):
if oo['deriv-unroll'] not in (False, 'inner', 'full'):
raise InvalidArgument("Illegal `deriv-unroll` value")

if oo['errctl'] not in (None, False, 'basic', 'max'):
raise InvalidArgument("Illegal `errctl` value")

def _autotune(self, args, setup):
if setup in [False, 'off']:
return args
Expand Down
4 changes: 4 additions & 0 deletions devito/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ class InvalidOperator(DevitoError):
pass


class ExecutionError(DevitoError):
pass


class VisitorException(DevitoError):
pass
7 changes: 5 additions & 2 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,17 +794,19 @@ class CallableBody(MultiTraversable):
Data unbundling for `body`.
frees : list of Calls, optional
Data deallocations for `body`.
errors : list of Nodes, optional
Error handling for `body`.
"""

is_CallableBody = True

_traversable = ['unpacks', 'init', 'standalones', 'allocs', 'stacks',
'casts', 'bundles', 'maps', 'strides', 'objs', 'body',
'unmaps', 'unbundles', 'frees']
'unmaps', 'unbundles', 'frees', 'errors']

def __init__(self, body, init=(), standalones=(), unpacks=(), strides=(),
allocs=(), stacks=(), casts=(), bundles=(), objs=(), maps=(),
unmaps=(), unbundles=(), frees=()):
unmaps=(), unbundles=(), frees=(), errors=()):
# Sanity check
assert not isinstance(body, CallableBody), "CallableBody's cannot be nested"

Expand All @@ -823,6 +825,7 @@ def __init__(self, body, init=(), standalones=(), unpacks=(), strides=(),
self.unmaps = as_tuple(unmaps)
self.unbundles = as_tuple(unbundles)
self.frees = as_tuple(frees)
self.errors = as_tuple(errors)

def __repr__(self):
return ("<CallableBody <unpacks=%d, allocs=%d, casts=%d, maps=%d, "
Expand Down
20 changes: 17 additions & 3 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from devito.arch import compiler_registry, platform_registry
from devito.data import default_allocator
from devito.exceptions import InvalidOperator
from devito.exceptions import InvalidOperator, ExecutionError
from devito.logger import debug, info, perf, warning, is_log_enabled_for
from devito.ir.equations import LoweredEq, lower_exprs
from devito.ir.clusters import ClusterGroup, clusterize
Expand All @@ -21,7 +21,8 @@
from devito.mpi import MPI
from devito.parameters import configuration
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, minimize_symbols, unevaluate)
generate_macros, minimize_symbols, unevaluate,
error_mapper)
from devito.symbolics import estimate_cost
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
filter_sorted, frozendict, is_integer, split, timed_pass,
Expand Down Expand Up @@ -646,6 +647,16 @@ def _prepare_arguments(self, autotune=None, **kwargs):

return args

def _postprocess_errors(self, retval):
if retval == 0:
return
elif retval == error_mapper['Stability']:
raise ExecutionError("Detected nan/inf in some output Functions")
elif retval == error_mapper['KernelLaunch']:
raise ExecutionError("Kernel launch failed")
else:
raise ExecutionError("An error occurred during execution")

def _postprocess_arguments(self, args, **kwargs):
"""Process runtime arguments upon returning from ``.apply()``."""
for p in self.parameters:
Expand Down Expand Up @@ -842,7 +853,7 @@ def apply(self, **kwargs):
try:
cfunction = self.cfunction
with self._profiler.timer_on('apply', comm=args.comm):
cfunction(*arg_values)
retval = cfunction(*arg_values)
except ctypes.ArgumentError as e:
if e.args[0].startswith("argument "):
argnum = int(e.args[0][9:].split(':')[0]) - 1
Expand All @@ -854,6 +865,9 @@ def apply(self, **kwargs):
else:
raise

# Perform error checking
self._postprocess_errors(retval)

# Post-process runtime arguments
self._postprocess_arguments(args, **kwargs)

Expand Down
1 change: 1 addition & 0 deletions devito/passes/iet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .asynchrony import * # noqa
from .instrument import * # noqa
from .languages import * # noqa
from .errors import * # noqa
87 changes: 87 additions & 0 deletions devito/passes/iet/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import cgen as c
import numpy as np
from sympy import Not

from devito.ir.iet import (Call, Conditional, EntryFunction, Iteration, List,
Return, FindNodes, FindSymbols, Transformer,
make_callable)
from devito.passes.iet.engine import iet_pass
from devito.symbolics import CondEq, DefFunction
from devito.types import Eq, Inc, Symbol

__all__ = ['check_stability', 'error_mapper']


def check_stability(graph, options=None, rcompile=None, sregistry=None, **kwargs):
"""
Check if the simulation is stable. If not, return to Python as quickly as
possible with an error code.
"""
if options['errctl'] != 'max':
return

_, wmovs = graph.data_movs

_check_stability(graph, wmovs=wmovs, rcompile=rcompile, sregistry=sregistry)


@iet_pass
def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None):
if not isinstance(iet, EntryFunction):
return iet, {}

# NOTE: Stability is a domain-specific concept, hence looking for time
# Iterations and TimeFunctions is acceptable
efuncs = []
includes = []
mapper = {}
for n in FindNodes(Iteration).visit(iet):
if not n.dim.is_Time:
continue

functions = [f for f in FindSymbols().visit(n)
if f.is_TimeFunction and f.time_dim.is_Stepping]

# We compute the norm of just one TimeFunction, hence we sort for
# determinism and reproducibility
candidates = sorted(set(functions) & set(wmovs), key=lambda f: f.name)
for f in candidates:
if f in wmovs:
break
else:
continue

accumulator = Symbol(name='accumulator', dtype=f.dtype)
eqns = [Eq(accumulator, 0.0),
Inc(accumulator, f.subs(f.time_dim, 0))]
irs, byproduct = rcompile(eqns)

name = sregistry.make_name(prefix='is_finite')
retval = Return(DefFunction('isfinite', accumulator))
body = irs.iet.body.body + (retval,)
efunc = make_callable(name, body, retval='int')

efuncs.extend([i.root for i in byproduct.funcs])
efuncs.append(efunc)

includes.extend(byproduct.includes)

name = sregistry.make_name(prefix='check')
check = Symbol(name=name, dtype=np.int32)

errctl = Conditional(CondEq(n.dim % 100, 0), List(body=[
Call(efunc.name, efunc.parameters, retobj=check),
Conditional(Not(check), Return(error_mapper['Stability']))
]))
errctl = List(header=c.Comment("Stability check"), body=[errctl])
mapper[n] = n._rebuild(nodes=n.nodes + (errctl,))

iet = Transformer(mapper).visit(iet)

return iet, {'efuncs': efuncs, 'includes': includes}


error_mapper = {
'Stability': 100,
'KernelLaunch': 200,
}
27 changes: 27 additions & 0 deletions tests/test_error_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

from devito import Grid, Function, TimeFunction, Eq, Operator, switchconfig
from devito.exceptions import ExecutionError


@switchconfig(safe_math=True)
@pytest.mark.parametrize("expr", [
'u/f',
'(u + v)/f',
])
def test_stability(expr):
grid = Grid(shape=(10, 10))

f = Function(name='f', grid=grid, space_order=2) # noqa
u = TimeFunction(name='u', grid=grid, space_order=2)
v = TimeFunction(name='v', grid=grid, space_order=2)

eq = Eq(u.forward, eval(expr))

op = Operator(eq, opt=('advanced', {'errctl': 'max'}))

u.data[:] = 1.
v.data[:] = 2.

with pytest.raises(ExecutionError):
op.apply(time_M=200, dt=.1)

0 comments on commit f29cb35

Please sign in to comment.