From 221a09a1100d6d25a04dc67a57f2e539b0263990 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 6 Mar 2024 16:23:20 +0000 Subject: [PATCH 1/4] compiler: Add hook for error checking --- devito/operator/operator.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index df03ca7653..1c042ce29e 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -646,6 +646,9 @@ def _prepare_arguments(self, autotune=None, **kwargs): return args + def _postprocess_errors(self, retval): + return + def _postprocess_arguments(self, args, **kwargs): """Process runtime arguments upon returning from ``.apply()``.""" for p in self.parameters: @@ -842,7 +845,7 @@ def apply(self, **kwargs): try: cfunction = self.cfunction with self._profiler.timer_on('apply', comm=args.comm): - cfunction(*arg_values) + retval = cfunction(*arg_values) except ctypes.ArgumentError as e: if e.args[0].startswith("argument "): argnum = int(e.args[0][9:].split(':')[0]) - 1 @@ -854,6 +857,9 @@ def apply(self, **kwargs): else: raise + # Perform error checking + self._postprocess_errors(retval) + # Post-process runtime arguments self._postprocess_arguments(args, **kwargs) From a5c9062c84a55d2b4010eb73663e621783f6193c Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 7 Mar 2024 14:41:01 +0000 Subject: [PATCH 2/4] compiler: Draft operator error handling --- devito/exceptions.py | 4 ++++ devito/ir/iet/nodes.py | 7 +++++-- devito/operator/operator.py | 14 +++++++++++--- devito/passes/iet/__init__.py | 1 + devito/passes/iet/errors.py | 19 +++++++++++++++++++ 5 files changed, 40 insertions(+), 5 deletions(-) create mode 100644 devito/passes/iet/errors.py diff --git a/devito/exceptions.py b/devito/exceptions.py index f03b0720ec..fa5619d4ca 100644 --- a/devito/exceptions.py +++ b/devito/exceptions.py @@ -14,5 +14,9 @@ class InvalidOperator(DevitoError): pass +class ExecutionError(DevitoError): + pass + + class VisitorException(DevitoError): pass diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 72307c5ac5..36a7a2f836 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -794,17 +794,19 @@ class CallableBody(MultiTraversable): Data unbundling for `body`. frees : list of Calls, optional Data deallocations for `body`. + errors : list of Nodes, optional + Error handling for `body`. """ is_CallableBody = True _traversable = ['unpacks', 'init', 'standalones', 'allocs', 'stacks', 'casts', 'bundles', 'maps', 'strides', 'objs', 'body', - 'unmaps', 'unbundles', 'frees'] + 'unmaps', 'unbundles', 'frees', 'errors'] def __init__(self, body, init=(), standalones=(), unpacks=(), strides=(), allocs=(), stacks=(), casts=(), bundles=(), objs=(), maps=(), - unmaps=(), unbundles=(), frees=()): + unmaps=(), unbundles=(), frees=(), errors=()): # Sanity check assert not isinstance(body, CallableBody), "CallableBody's cannot be nested" @@ -823,6 +825,7 @@ def __init__(self, body, init=(), standalones=(), unpacks=(), strides=(), self.unmaps = as_tuple(unmaps) self.unbundles = as_tuple(unbundles) self.frees = as_tuple(frees) + self.errors = as_tuple(errors) def __repr__(self): return (" Date: Thu, 7 Mar 2024 21:50:23 +0000 Subject: [PATCH 3/4] compiler: Add optional pass to check stability --- devito/core/cpu.py | 9 ++++- devito/core/gpu.py | 9 ++++- devito/core/operator.py | 10 +++++ devito/passes/iet/errors.py | 78 ++++++++++++++++++++++++++++++++++-- tests/test_error_checking.py | 21 ++++++++++ 5 files changed, 119 insertions(+), 8 deletions(-) create mode 100644 tests/test_error_checking.py diff --git a/devito/core/cpu.py b/devito/core/cpu.py index 81611da9b4..1d34314697 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -6,8 +6,9 @@ from devito.passes.clusters import (Lift, blocking, buffering, cire, cse, factorize, fission, fuse, optimize_pows, optimize_hyperplanes) -from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize, mpiize, - hoist_prodders, relax_incr_dimensions) +from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize, + mpiize, hoist_prodders, relax_incr_dimensions, + check_stability) from devito.tools import timed_pass __all__ = ['Cpu64NoopCOperator', 'Cpu64NoopOmpOperator', 'Cpu64AdvCOperator', @@ -76,6 +77,7 @@ def _normalize_kwargs(cls, **kwargs): o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE) o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE) o['place-transfers'] = oo.pop('place-transfers', True) + o['errctl'] = oo.pop('errctl', cls.ERRCTL) # Recognised but unused by the CPU backend oo.pop('par-disabled', None) @@ -189,6 +191,9 @@ def _specialize_iet(cls, graph, **kwargs): # Misc optimizations hoist_prodders(graph) + # Perform error checking + check_stability(graph, **kwargs) + # Symbol definitions cls._Target.DataManager(**kwargs).process(graph) diff --git a/devito/core/gpu.py b/devito/core/gpu.py index de070c6890..a0c2da774a 100644 --- a/devito/core/gpu.py +++ b/devito/core/gpu.py @@ -10,8 +10,9 @@ from devito.passes.clusters import (Lift, Streaming, Tasker, blocking, buffering, cire, cse, factorize, fission, fuse, optimize_pows) -from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, mpiize, hoist_prodders, - linearize, pthreadify, relax_incr_dimensions) +from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, mpiize, + hoist_prodders, linearize, pthreadify, + relax_incr_dimensions, check_stability) from devito.tools import as_tuple, timed_pass __all__ = ['DeviceNoopOperator', 'DeviceAdvOperator', 'DeviceCustomOperator', @@ -91,6 +92,7 @@ def _normalize_kwargs(cls, **kwargs): o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE) o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE) o['place-transfers'] = oo.pop('place-transfers', True) + o['errctl'] = oo.pop('errctl', cls.ERRCTL) if oo: raise InvalidOperator("Unsupported optimization options: [%s]" @@ -226,6 +228,9 @@ def _specialize_iet(cls, graph, **kwargs): # Misc optimizations hoist_prodders(graph) + # Perform error checking + check_stability(graph, **kwargs) + # Symbol definitions cls._Target.DataManager(**kwargs).process(graph) diff --git a/devito/core/operator.py b/devito/core/operator.py index 1ba976d3b6..e61ff2c806 100644 --- a/devito/core/operator.py +++ b/devito/core/operator.py @@ -119,6 +119,13 @@ class BasicOperator(Operator): (default) or `int32`. """ + ERRCTL = None + """ + Runtime error checking. If this option is enabled, the generated code will + include runtime checks for various things that might go south, such as + instability (e.g., NaNs), failed library calls (e.g., kernel launches). + """ + _Target = None """ The target language constructor, to be specified by subclasses. @@ -155,6 +162,9 @@ def _check_kwargs(cls, **kwargs): if oo['deriv-unroll'] not in (False, 'inner', 'full'): raise InvalidArgument("Illegal `deriv-unroll` value") + if oo['errctl'] not in (None, False, 'basic', 'max'): + raise InvalidArgument("Illegal `errctl` value") + def _autotune(self, args, setup): if setup in [False, 'off']: return args diff --git a/devito/passes/iet/errors.py b/devito/passes/iet/errors.py index 440b92762c..a3810c72eb 100644 --- a/devito/passes/iet/errors.py +++ b/devito/passes/iet/errors.py @@ -1,16 +1,86 @@ +import cgen as c +from sympy import Not + +from devito.finite_differences import Abs +from devito.finite_differences.differentiable import Pow +from devito.ir.iet import (Call, Conditional, EntryFunction, Iteration, List, + Return, FindNodes, FindSymbols, Transformer, + make_callable) from devito.passes.iet.engine import iet_pass +from devito.symbolics import CondEq, DefFunction +from devito.tools import dtype_to_cstr +from devito.types import Eq, Inc, Symbol __all__ = ['check_stability', 'error_mapper'] -@iet_pass -def check_stability(iet, **kwargs): +def check_stability(graph, options=None, rcompile=None, sregistry=None, **kwargs): """ Check if the simulation is stable. If not, return to Python as quickly as possible with an error code. """ - # TODO - return iet, {} + if options['errctl'] != 'max': + return + + _, wmovs = graph.data_movs + + _check_stability(graph, wmovs=wmovs, rcompile=rcompile, sregistry=sregistry) + + +@iet_pass +def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None): + if not isinstance(iet, EntryFunction): + return iet, {} + + # NOTE: Stability is a domain-specific concept, hence looking for time + # Iterations and TimeFunctions is acceptable + efuncs = [] + includes = [] + mapper = {} + for n in FindNodes(Iteration).visit(iet): + if not n.dim.is_Time: + continue + + functions = [f for f in FindSymbols().visit(n) + if f.is_TimeFunction and f.time_dim.is_Stepping] + + # We compute the norm of just one TimeFunction, hence we sort for + # determinism and reproducibility + candidates = sorted(set(functions) & set(wmovs), key=lambda f: f.name) + for f in candidates: + if f in wmovs: + break + else: + continue + + name = sregistry.make_name(prefix='energy') + energy = Symbol(name=name, dtype=f.dtype) + + eqns = [Eq(energy, 0.0), + Inc(energy, Abs(Pow(f.subs(f.time_dim, 0), 2)))] + irs, byproduct = rcompile(eqns) + body = irs.iet.body.body + (Return(energy),) + + name = sregistry.make_name(prefix='compute_energy') + retval = dtype_to_cstr(energy.dtype) + efunc = make_callable(name, body, retval=retval) + + efuncs.extend([i.root for i in byproduct.funcs]) + efuncs.append(efunc) + + includes.extend(byproduct.includes) + + errctl = Conditional(CondEq(n.dim % 100, 0), List(body=[ + Call(efunc.name, efunc.parameters, retobj=energy), + Conditional(Not(DefFunction('isfinite', energy)), + Return(error_mapper['Stability'])) + ])) + errctl = List(header=c.Comment("Stability check"), body=[errctl]) + mapper[n] = n._rebuild(nodes=n.nodes + (errctl,)) + + iet = Transformer(mapper).visit(iet) + + return iet, {'efuncs': efuncs, 'includes': includes} error_mapper = { diff --git a/tests/test_error_checking.py b/tests/test_error_checking.py new file mode 100644 index 0000000000..5a3929f10c --- /dev/null +++ b/tests/test_error_checking.py @@ -0,0 +1,21 @@ +import pytest + +from devito import Grid, Function, TimeFunction, Eq, Operator, switchconfig +from devito.exceptions import ExecutionError + + +@switchconfig(safe_math=True) +def test_stability(): + grid = Grid(shape=(10, 10)) + + f = Function(name='f', grid=grid, space_order=2) + u = TimeFunction(name='u', grid=grid, space_order=2) + + eq = Eq(u.forward, u/f) + + op = Operator(eq, opt=('advanced', {'errctl': 'max'})) + + u.data[:] = 1. + + with pytest.raises(ExecutionError): + op.apply(time_M=200, dt=.1) From aaea94a23cde5d8c78308a57e0e2b25e03d19583 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 12 Mar 2024 09:24:55 +0000 Subject: [PATCH 4/4] compiler: Tweak stability check --- devito/passes/iet/errors.py | 28 +++++++++++++--------------- tests/test_error_checking.py | 12 +++++++++--- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/devito/passes/iet/errors.py b/devito/passes/iet/errors.py index a3810c72eb..a6785f170b 100644 --- a/devito/passes/iet/errors.py +++ b/devito/passes/iet/errors.py @@ -1,14 +1,12 @@ import cgen as c +import numpy as np from sympy import Not -from devito.finite_differences import Abs -from devito.finite_differences.differentiable import Pow from devito.ir.iet import (Call, Conditional, EntryFunction, Iteration, List, Return, FindNodes, FindSymbols, Transformer, make_callable) from devito.passes.iet.engine import iet_pass from devito.symbolics import CondEq, DefFunction -from devito.tools import dtype_to_cstr from devito.types import Eq, Inc, Symbol __all__ = ['check_stability', 'error_mapper'] @@ -53,27 +51,27 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None): else: continue - name = sregistry.make_name(prefix='energy') - energy = Symbol(name=name, dtype=f.dtype) - - eqns = [Eq(energy, 0.0), - Inc(energy, Abs(Pow(f.subs(f.time_dim, 0), 2)))] + accumulator = Symbol(name='accumulator', dtype=f.dtype) + eqns = [Eq(accumulator, 0.0), + Inc(accumulator, f.subs(f.time_dim, 0))] irs, byproduct = rcompile(eqns) - body = irs.iet.body.body + (Return(energy),) - name = sregistry.make_name(prefix='compute_energy') - retval = dtype_to_cstr(energy.dtype) - efunc = make_callable(name, body, retval=retval) + name = sregistry.make_name(prefix='is_finite') + retval = Return(DefFunction('isfinite', accumulator)) + body = irs.iet.body.body + (retval,) + efunc = make_callable(name, body, retval='int') efuncs.extend([i.root for i in byproduct.funcs]) efuncs.append(efunc) includes.extend(byproduct.includes) + name = sregistry.make_name(prefix='check') + check = Symbol(name=name, dtype=np.int32) + errctl = Conditional(CondEq(n.dim % 100, 0), List(body=[ - Call(efunc.name, efunc.parameters, retobj=energy), - Conditional(Not(DefFunction('isfinite', energy)), - Return(error_mapper['Stability'])) + Call(efunc.name, efunc.parameters, retobj=check), + Conditional(Not(check), Return(error_mapper['Stability'])) ])) errctl = List(header=c.Comment("Stability check"), body=[errctl]) mapper[n] = n._rebuild(nodes=n.nodes + (errctl,)) diff --git a/tests/test_error_checking.py b/tests/test_error_checking.py index 5a3929f10c..061cb2b575 100644 --- a/tests/test_error_checking.py +++ b/tests/test_error_checking.py @@ -5,17 +5,23 @@ @switchconfig(safe_math=True) -def test_stability(): +@pytest.mark.parametrize("expr", [ + 'u/f', + '(u + v)/f', +]) +def test_stability(expr): grid = Grid(shape=(10, 10)) - f = Function(name='f', grid=grid, space_order=2) + f = Function(name='f', grid=grid, space_order=2) # noqa u = TimeFunction(name='u', grid=grid, space_order=2) + v = TimeFunction(name='v', grid=grid, space_order=2) - eq = Eq(u.forward, u/f) + eq = Eq(u.forward, eval(expr)) op = Operator(eq, opt=('advanced', {'errctl': 'max'})) u.data[:] = 1. + v.data[:] = 2. with pytest.raises(ExecutionError): op.apply(time_M=200, dt=.1)