Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Add optional pass for runtime stability check #2327

Merged
merged 4 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from devito.passes.clusters import (Lift, blocking, buffering, cire, cse,
factorize, fission, fuse, optimize_pows,
optimize_hyperplanes)
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize, mpiize,
hoist_prodders, relax_incr_dimensions)
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize,
mpiize, hoist_prodders, relax_incr_dimensions,
check_stability)
from devito.tools import timed_pass

__all__ = ['Cpu64NoopCOperator', 'Cpu64NoopOmpOperator', 'Cpu64AdvCOperator',
Expand Down Expand Up @@ -76,6 +77,7 @@ def _normalize_kwargs(cls, **kwargs):
o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE)
o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE)
o['place-transfers'] = oo.pop('place-transfers', True)
o['errctl'] = oo.pop('errctl', cls.ERRCTL)

# Recognised but unused by the CPU backend
oo.pop('par-disabled', None)
Expand Down Expand Up @@ -189,6 +191,9 @@ def _specialize_iet(cls, graph, **kwargs):
# Misc optimizations
hoist_prodders(graph)

# Perform error checking
check_stability(graph, **kwargs)

# Symbol definitions
cls._Target.DataManager(**kwargs).process(graph)

Expand Down
9 changes: 7 additions & 2 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from devito.passes.clusters import (Lift, Streaming, Tasker, blocking, buffering,
cire, cse, factorize, fission, fuse,
optimize_pows)
from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, mpiize, hoist_prodders,
linearize, pthreadify, relax_incr_dimensions)
from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, mpiize,
hoist_prodders, linearize, pthreadify,
relax_incr_dimensions, check_stability)
from devito.tools import as_tuple, timed_pass

__all__ = ['DeviceNoopOperator', 'DeviceAdvOperator', 'DeviceCustomOperator',
Expand Down Expand Up @@ -91,6 +92,7 @@ def _normalize_kwargs(cls, **kwargs):
o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE)
o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE)
o['place-transfers'] = oo.pop('place-transfers', True)
o['errctl'] = oo.pop('errctl', cls.ERRCTL)

if oo:
raise InvalidOperator("Unsupported optimization options: [%s]"
Expand Down Expand Up @@ -226,6 +228,9 @@ def _specialize_iet(cls, graph, **kwargs):
# Misc optimizations
hoist_prodders(graph)

# Perform error checking
check_stability(graph, **kwargs)

# Symbol definitions
cls._Target.DataManager(**kwargs).process(graph)

Expand Down
10 changes: 10 additions & 0 deletions devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ class BasicOperator(Operator):
(default) or `int32`.
"""

ERRCTL = None
"""
Runtime error checking. If this option is enabled, the generated code will
include runtime checks for various things that might go south, such as
instability (e.g., NaNs), failed library calls (e.g., kernel launches).
"""

_Target = None
"""
The target language constructor, to be specified by subclasses.
Expand Down Expand Up @@ -155,6 +162,9 @@ def _check_kwargs(cls, **kwargs):
if oo['deriv-unroll'] not in (False, 'inner', 'full'):
raise InvalidArgument("Illegal `deriv-unroll` value")

if oo['errctl'] not in (None, False, 'basic', 'max'):
raise InvalidArgument("Illegal `errctl` value")

def _autotune(self, args, setup):
if setup in [False, 'off']:
return args
Expand Down
4 changes: 4 additions & 0 deletions devito/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ class InvalidOperator(DevitoError):
pass


class ExecutionError(DevitoError):
pass


class VisitorException(DevitoError):
pass
7 changes: 5 additions & 2 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,17 +794,19 @@ class CallableBody(MultiTraversable):
Data unbundling for `body`.
frees : list of Calls, optional
Data deallocations for `body`.
errors : list of Nodes, optional
Error handling for `body`.
"""

is_CallableBody = True

_traversable = ['unpacks', 'init', 'standalones', 'allocs', 'stacks',
'casts', 'bundles', 'maps', 'strides', 'objs', 'body',
'unmaps', 'unbundles', 'frees']
'unmaps', 'unbundles', 'frees', 'errors']

def __init__(self, body, init=(), standalones=(), unpacks=(), strides=(),
allocs=(), stacks=(), casts=(), bundles=(), objs=(), maps=(),
unmaps=(), unbundles=(), frees=()):
unmaps=(), unbundles=(), frees=(), errors=()):
# Sanity check
assert not isinstance(body, CallableBody), "CallableBody's cannot be nested"

Expand All @@ -823,6 +825,7 @@ def __init__(self, body, init=(), standalones=(), unpacks=(), strides=(),
self.unmaps = as_tuple(unmaps)
self.unbundles = as_tuple(unbundles)
self.frees = as_tuple(frees)
self.errors = as_tuple(errors)

def __repr__(self):
return ("<CallableBody <unpacks=%d, allocs=%d, casts=%d, maps=%d, "
Expand Down
20 changes: 17 additions & 3 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from devito.arch import compiler_registry, platform_registry
from devito.data import default_allocator
from devito.exceptions import InvalidOperator
from devito.exceptions import InvalidOperator, ExecutionError
from devito.logger import debug, info, perf, warning, is_log_enabled_for
from devito.ir.equations import LoweredEq, lower_exprs
from devito.ir.clusters import ClusterGroup, clusterize
Expand All @@ -21,7 +21,8 @@
from devito.mpi import MPI
from devito.parameters import configuration
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, minimize_symbols, unevaluate)
generate_macros, minimize_symbols, unevaluate,
error_mapper)
from devito.symbolics import estimate_cost
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
filter_sorted, frozendict, is_integer, split, timed_pass,
Expand Down Expand Up @@ -646,6 +647,16 @@ def _prepare_arguments(self, autotune=None, **kwargs):

return args

def _postprocess_errors(self, retval):
if retval == 0:
return
elif retval == error_mapper['Stability']:
raise ExecutionError("Detected nan/inf in some output Functions")
elif retval == error_mapper['KernelLaunch']:
raise ExecutionError("Kernel launch failed")
else:
raise ExecutionError("An error occurred during execution")

def _postprocess_arguments(self, args, **kwargs):
"""Process runtime arguments upon returning from ``.apply()``."""
for p in self.parameters:
Expand Down Expand Up @@ -842,7 +853,7 @@ def apply(self, **kwargs):
try:
cfunction = self.cfunction
with self._profiler.timer_on('apply', comm=args.comm):
cfunction(*arg_values)
retval = cfunction(*arg_values)
except ctypes.ArgumentError as e:
if e.args[0].startswith("argument "):
argnum = int(e.args[0][9:].split(':')[0]) - 1
Expand All @@ -854,6 +865,9 @@ def apply(self, **kwargs):
else:
raise

# Perform error checking
self._postprocess_errors(retval)

# Post-process runtime arguments
self._postprocess_arguments(args, **kwargs)

Expand Down
1 change: 1 addition & 0 deletions devito/passes/iet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .asynchrony import * # noqa
from .instrument import * # noqa
from .languages import * # noqa
from .errors import * # noqa
87 changes: 87 additions & 0 deletions devito/passes/iet/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import cgen as c
import numpy as np
from sympy import Not

from devito.ir.iet import (Call, Conditional, EntryFunction, Iteration, List,
Return, FindNodes, FindSymbols, Transformer,
make_callable)
from devito.passes.iet.engine import iet_pass
from devito.symbolics import CondEq, DefFunction
from devito.types import Eq, Inc, Symbol

__all__ = ['check_stability', 'error_mapper']


def check_stability(graph, options=None, rcompile=None, sregistry=None, **kwargs):
"""
Check if the simulation is stable. If not, return to Python as quickly as
possible with an error code.
"""
if options['errctl'] != 'max':
return

_, wmovs = graph.data_movs

_check_stability(graph, wmovs=wmovs, rcompile=rcompile, sregistry=sregistry)


@iet_pass
def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None):
if not isinstance(iet, EntryFunction):
return iet, {}

# NOTE: Stability is a domain-specific concept, hence looking for time
# Iterations and TimeFunctions is acceptable
efuncs = []
includes = []
mapper = {}
for n in FindNodes(Iteration).visit(iet):
if not n.dim.is_Time:
continue

functions = [f for f in FindSymbols().visit(n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpicking: Might be cases where there is no stepping dim and only time dime, but can be for later if someone runs into it (unlikely).

if f.is_TimeFunction and f.time_dim.is_Stepping]

# We compute the norm of just one TimeFunction, hence we sort for
# determinism and reproducibility
candidates = sorted(set(functions) & set(wmovs), key=lambda f: f.name)
for f in candidates:
if f in wmovs:
break
else:
continue

accumulator = Symbol(name='accumulator', dtype=f.dtype)
eqns = [Eq(accumulator, 0.0),
Inc(accumulator, f.subs(f.time_dim, 0))]
irs, byproduct = rcompile(eqns)

name = sregistry.make_name(prefix='is_finite')
retval = Return(DefFunction('isfinite', accumulator))
body = irs.iet.body.body + (retval,)
efunc = make_callable(name, body, retval='int')

efuncs.extend([i.root for i in byproduct.funcs])
efuncs.append(efunc)

includes.extend(byproduct.includes)

name = sregistry.make_name(prefix='check')
check = Symbol(name=name, dtype=np.int32)

errctl = Conditional(CondEq(n.dim % 100, 0), List(body=[
Call(efunc.name, efunc.parameters, retobj=check),
Conditional(Not(check), Return(error_mapper['Stability']))
]))
errctl = List(header=c.Comment("Stability check"), body=[errctl])
mapper[n] = n._rebuild(nodes=n.nodes + (errctl,))

iet = Transformer(mapper).visit(iet)

return iet, {'efuncs': efuncs, 'includes': includes}


error_mapper = {
'Stability': 100,
'KernelLaunch': 200,
}
27 changes: 27 additions & 0 deletions tests/test_error_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

from devito import Grid, Function, TimeFunction, Eq, Operator, switchconfig
from devito.exceptions import ExecutionError


@switchconfig(safe_math=True)
@pytest.mark.parametrize("expr", [
'u/f',
'(u + v)/f',
])
def test_stability(expr):
grid = Grid(shape=(10, 10))

f = Function(name='f', grid=grid, space_order=2) # noqa
u = TimeFunction(name='u', grid=grid, space_order=2)
v = TimeFunction(name='v', grid=grid, space_order=2)

eq = Eq(u.forward, eval(expr))

op = Operator(eq, opt=('advanced', {'errctl': 'max'}))

u.data[:] = 1.
v.data[:] = 2.

with pytest.raises(ExecutionError):
op.apply(time_M=200, dt=.1)
Loading