diff --git a/devito/core/gpu.py b/devito/core/gpu.py index eb0d0d0dce..4f99eccabf 100644 --- a/devito/core/gpu.py +++ b/devito/core/gpu.py @@ -72,7 +72,7 @@ def _normalize_kwargs(cls, **kwargs): o['par-dynamic-work'] = np.inf # Always use static scheduling o['par-nested'] = np.inf # Never use nested parallelism o['par-disabled'] = oo.pop('par-disabled', True) # No host parallelism by default - o['gpu-fit'] = as_tuple(oo.pop('gpu-fit', cls._normalize_gpu_fit(**kwargs))) + o['gpu-fit'] = cls._normalize_gpu_fit(oo, **kwargs) o['gpu-create'] = as_tuple(oo.pop('gpu-create', ())) # Distributed parallelism @@ -95,11 +95,16 @@ def _normalize_kwargs(cls, **kwargs): return kwargs @classmethod - def _normalize_gpu_fit(cls, **kwargs): - if any(i in kwargs['mode'] for i in ['tasking', 'streaming']): - return None - else: - return cls.GPU_FIT + def _normalize_gpu_fit(cls, oo, **kwargs): + try: + gfit = as_tuple(oo.pop('gpu-fit')) + gfit = set().union([f.values() if f.is_AbstractTensor else f for f in gfit]) + return tuple(gfit) + except KeyError: + if any(i in kwargs['mode'] for i in ['tasking', 'streaming']): + return (None,) + else: + return as_tuple(cls.GPU_FIT) @classmethod def _rcompile_wrapper(cls, **kwargs0): diff --git a/devito/types/basic.py b/devito/types/basic.py index e72b531d75..53e8a87189 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -238,6 +238,7 @@ class Basic(CodeSymbol): # Top hierarchy is_AbstractFunction = False + is_AbstractTensor = False is_AbstractObject = False # Symbolic objects created internally by Devito diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 0a794effca..aa66645801 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -8,7 +8,7 @@ from devito import (Constant, Eq, Inc, Grid, Function, ConditionalDimension, Dimension, MatrixSparseTimeFunction, SparseTimeFunction, SubDimension, SubDomain, SubDomainSet, TimeFunction, - Operator, configuration, switchconfig) + Operator, configuration, switchconfig, TensorTimeFunction) from devito.arch import get_gpu_info from devito.exceptions import InvalidArgument from devito.ir import (Conditional, Expression, Section, FindNodes, FindSymbols, @@ -1423,6 +1423,18 @@ def test_npthreads(self): with pytest.raises(InvalidArgument): assert op.arguments(time_M=2, npthreads0=5) + def test_gpu_fit_w_tensor_functions(self): + grid = Grid(shape=(10, 10)) + + u = TensorTimeFunction(name='u', grid=grid) + usave = TensorTimeFunction(name="usave", grid=grid, save=10) + + eqns = [Eq(u.forward, u + 1), + Eq(usave, u.forward)] + + op = Operator(eqns, opt=('noop', {'gpu-fit': usave})) + assert set(op._options['gpu-fit']) - set(usave.values()) == set() + class TestMisc(object):