From 32f20ee8caa716a74a4b184bcaa520b44298792a Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 16 Jan 2025 21:48:23 -0500 Subject: [PATCH] symbolics: move printers rogether through registry --- devito/ir/iet/visitors.py | 72 +++++++++------------ devito/operator/operator.py | 19 +++--- devito/passes/iet/definitions.py | 4 +- devito/passes/iet/dtypes.py | 24 ++++--- devito/passes/iet/languages/C.py | 19 ------ devito/passes/iet/languages/CXX.py | 22 ------- devito/passes/iet/languages/openacc.py | 7 +-- devito/passes/iet/languages/targets.py | 9 +-- devito/symbolics/printer.py | 87 +++++++++++++++++++++----- tests/test_dtypes.py | 15 ++--- tests/test_symbolics.py | 5 +- 11 files changed, 129 insertions(+), 154 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 622b1ed95b..7ab38b5191 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -13,14 +13,13 @@ from sympy import IndexedBase from sympy.core.function import Application -from devito.parameters import configuration, switchconfig from devito.exceptions import CompilationError from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle, Call, Lambda, BlankLine, Section, ListMajor) from devito.ir.support.space import Backward from devito.symbolics import (FieldFromComposite, FieldFromPointer, ListInitializer, uxreplace) -from devito.symbolics.printer import _DevitoPrinterBase +from devito.symbolics.printer import ccode from devito.symbolics.extended_dtypes import NoDeclStruct from devito.tools import (GenericVisitor, as_tuple, filter_ordered, filter_sorted, flatten, is_external_ctype, @@ -177,10 +176,8 @@ class CGen(Visitor): Return a representation of the Iteration/Expression tree as a :module:`cgen` tree. """ - def __init__(self, *args, compiler=None, printer=None, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._compiler = compiler or configuration['compiler'] - self._printer = printer or _DevitoPrinterBase # The following mappers may be customized by subclasses (that is, # backend-specific CGen-erators) @@ -192,19 +189,6 @@ def __init__(self, *args, compiler=None, printer=None, **kwargs): } _restrict_keyword = 'restrict' - @property - def compiler(self): - return self._compiler - - def ccode(self, expr, **settings): - return self._printer(settings=settings).doprint(expr, None) - - def visit(self, o, *args, **kwargs): - # Make sure the visitor always is within the generating compiler - # in case the configuration is accessed - with switchconfig(compiler=self.compiler.name): - return super().visit(o, *args, **kwargs) - def _gen_struct_decl(self, obj, masked=()): """ Convert ctypes.Struct -> cgen.Structure. @@ -238,7 +222,7 @@ def _gen_struct_decl(self, obj, masked=()): try: entries.append(self._gen_value(i, 0, masked=('const',))) except AttributeError: - cstr = self.ccode(ct) + cstr = ccode(ct) if ct is c_restrict_void_p: cstr = '%srestrict' % cstr entries.append(c.Value(cstr, n)) @@ -260,10 +244,10 @@ def _gen_value(self, obj, mode=1, masked=()): if getattr(obj.function, k, False) and v not in masked] if (obj._mem_stack or obj._mem_constant) and mode == 1: - strtype = self.ccode(obj._C_typedata) - strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape) + strtype = ccode(obj._C_typedata) + strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape) else: - strtype = self.ccode(obj._C_ctype) + strtype = ccode(obj._C_ctype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: @@ -277,7 +261,7 @@ def _gen_value(self, obj, mode=1, masked=()): strobj = '%s%s' % (strname, strshape) if obj.is_LocalObject and obj.cargs and mode == 1: - arguments = [self.ccode(i) for i in obj.cargs] + arguments = [ccode(i) for i in obj.cargs] strobj = MultilineCall(strobj, arguments, True) value = c.Value(strtype, strobj) @@ -291,9 +275,9 @@ def _gen_value(self, obj, mode=1, masked=()): if obj.is_Array and obj.initvalue is not None and mode == 1: init = ListInitializer(obj.initvalue) if not obj._mem_constant or init.is_numeric: - value = c.Initializer(value, self.ccode(init)) + value = c.Initializer(value, ccode(init)) elif obj.is_LocalObject and obj.initvalue is not None and mode == 1: - value = c.Initializer(value, self.ccode(obj.initvalue)) + value = c.Initializer(value, ccode(obj.initvalue)) return value @@ -327,7 +311,7 @@ def _args_call(self, args): else: ret.append(i._C_name) except AttributeError: - ret.append(self.ccode(i)) + ret.append(ccode(i)) return ret def _gen_signature(self, o, is_declaration=False): @@ -393,7 +377,7 @@ def visit_tuple(self, o): def visit_PointerCast(self, o): f = o.function i = f.indexed - cstr = self.ccode(i._C_typedata) + cstr = ccode(i._C_typedata) if f.is_PointerArray: # lvalue @@ -415,7 +399,7 @@ def visit_PointerCast(self, o): else: v = f.name if o.flat is None: - shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape) + shape = ''.join("[%s]" % ccode(i) for i in o.castshape) rshape = '(*)%s' % shape lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape)) else: @@ -448,9 +432,9 @@ def visit_Dereference(self, o): a0, a1 = o.functions if a1.is_PointerArray or a1.is_TempFunction: i = a1.indexed - cstr = self.ccode(i._C_typedata) + cstr = ccode(i._C_typedata) if o.flat is None: - shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:]) + shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:]) rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, a1.dim.name) lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape)) @@ -489,8 +473,8 @@ def visit_Definition(self, o): return self._gen_value(o.function) def visit_Expression(self, o): - lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler) - rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler) + lhs = ccode(o.expr.lhs, dtype=o.dtype) + rhs = ccode(o.expr.rhs, dtype=o.dtype) if o.init: code = c.Initializer(self._gen_value(o.expr.lhs, 0), rhs) @@ -503,8 +487,8 @@ def visit_Expression(self, o): return code def visit_AugmentedExpression(self, o): - c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler) - c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler) + c_lhs = ccode(o.expr.lhs, dtype=o.dtype) + c_rhs = ccode(o.expr.rhs, dtype=o.dtype) code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs)) if o.pragmas: code = c.Module(self._visit(o.pragmas) + (code,)) @@ -523,7 +507,7 @@ def visit_Call(self, o, nested_call=False): o.templates) if retobj.is_Indexed or \ isinstance(retobj, (FieldFromComposite, FieldFromPointer)): - return c.Assign(self.ccode(retobj), call) + return c.Assign(ccode(retobj), call) else: return c.Initializer(c.Value(rettype, retobj._C_name), call) @@ -537,9 +521,9 @@ def visit_Conditional(self, o): then_body = c.Block(self._visit(then_body)) if else_body: else_body = c.Block(self._visit(else_body)) - return c.If(self.ccode(o.condition), then_body, else_body) + return c.If(ccode(o.condition), then_body, else_body) else: - return c.If(self.ccode(o.condition), then_body) + return c.If(ccode(o.condition), then_body) def visit_Iteration(self, o): body = flatten(self._visit(i) for i in self._blankline_logic(o.children)) @@ -549,23 +533,23 @@ def visit_Iteration(self, o): # For backward direction flip loop bounds if o.direction == Backward: - loop_init = 'int %s = %s' % (o.index, self.ccode(_max)) - loop_cond = '%s >= %s' % (o.index, self.ccode(_min)) + loop_init = 'int %s = %s' % (o.index, ccode(_max)) + loop_cond = '%s >= %s' % (o.index, ccode(_min)) loop_inc = '%s -= %s' % (o.index, o.limits[2]) else: - loop_init = 'int %s = %s' % (o.index, self.ccode(_min)) - loop_cond = '%s <= %s' % (o.index, self.ccode(_max)) + loop_init = 'int %s = %s' % (o.index, ccode(_min)) + loop_cond = '%s <= %s' % (o.index, ccode(_max)) loop_inc = '%s += %s' % (o.index, o.limits[2]) # Append unbounded indices, if any if o.uindices: - uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices] + uinit = ['%s = %s' % (i.name, ccode(i.symbolic_min)) for i in o.uindices] loop_init = c.Line(', '.join([loop_init] + uinit)) ustep = [] for i in o.uindices: op = '=' if i.is_Modulo else '+=' - ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr))) + ustep.append('%s %s %s' % (i.name, op, ccode(i.symbolic_incr))) loop_inc = c.Line(', '.join([loop_inc] + ustep)) # Create For header+body @@ -582,7 +566,7 @@ def visit_Pragma(self, o): return c.Pragma(o._generate) def visit_While(self, o): - condition = self.ccode(o.condition) + condition = ccode(o.condition) if o.body: body = flatten(self._visit(i) for i in o.children) return c.While(condition, c.Block(body)) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index a5a8ad57a0..c4ce4dc9aa 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -23,7 +23,7 @@ from devito.operator.profiling import create_profile from devito.operator.registry import operator_selector from devito.mpi import MPI -from devito.parameters import configuration +from devito.parameters import configuration, switchconfig from devito.passes import (Graph, lower_index_derivatives, generate_implicit, generate_macros, minimize_symbols, unevaluate, error_mapper, is_on_device) @@ -758,19 +758,14 @@ def _soname(self): """A unique name for the shared object resulting from JIT compilation.""" return Signer._digest(self, configuration) - @property - def printer(self): - return self._Target.Printer - @cached_property def ccode(self): - try: - return self._ccode_handler(compiler=self._compiler, - printer=self.printer).visit(self) - except (AttributeError, TypeError): - from devito.ir.iet.visitors import CGen - return CGen(compiler=self._compiler, - printer=self.printer).visit(self) + with switchconfig(compiler=self._compiler, language=self._language): + try: + return self._ccode_handler().visit(self) + except (AttributeError, TypeError): + from devito.ir.iet.visitors import CGen + return CGen().visit(self) def _jit_compile(self): """ diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 608c9f3662..67f0441ec4 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -75,12 +75,10 @@ class DataManager: The language used to express data allocations, deletions, and host-device transfers. """ - def __init__(self, rcompile=None, sregistry=None, platform=None, - compiler=None, **kwargs): + def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs): self.rcompile = rcompile self.sregistry = sregistry self.platform = platform - self.compiler = compiler def _alloc_object_on_low_lat_mem(self, site, obj, storage): """ diff --git a/devito/passes/iet/dtypes.py b/devito/passes/iet/dtypes.py index f775e20ea4..7f617b1c41 100644 --- a/devito/passes/iet/dtypes.py +++ b/devito/passes/iet/dtypes.py @@ -9,16 +9,6 @@ __all__ = ['lower_dtypes'] -def lower_dtypes(graph: Callable, lang: type[LangBB] = None, compiler: Compiler = None, - sregistry: SymbolRegistry = None, **kwargs) -> tuple[Callable, dict]: - """ - Lowers float16 scalar types to pointers since we can't directly pass their - value. Also includes headers for complex arithmetic if needed. - """ - # Complex numbers - _complex_includes(graph, lang=lang, compiler=compiler) - - @iet_pass def _complex_includes(iet: Callable, lang: type[LangBB] = None, compiler: Compiler = None) -> tuple[Callable, dict]: @@ -50,3 +40,17 @@ def _complex_includes(iet: Callable, lang: type[LangBB] = None, metadata['includes'] = lib return iet, metadata + + +dtype_passes = [_complex_includes] + + +def lower_dtypes(graph: Callable, lang: type[LangBB] = None, compiler: Compiler = None, + sregistry: SymbolRegistry = None, **kwargs) -> tuple[Callable, dict]: + """ + Lowers float16 scalar types to pointers since we can't directly pass their + value. Also includes headers for complex arithmetic if needed. + """ + + for dtype_pass in dtype_passes: + dtype_pass(graph, lang=lang, compiler=compiler) diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 4285a673e1..ff50e54205 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -1,11 +1,7 @@ -import numpy as np - from devito.ir import Call from devito.passes.iet.definitions import DataManager from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.langbase import LangBB -from devito.symbolics.extended_dtypes import c_complex, c_double_complex -from devito.symbolics.printer import _DevitoPrinterBase __all__ = ['CBB', 'CDataManager', 'COrchestrator'] @@ -35,18 +31,3 @@ class CDataManager(DataManager): class COrchestrator(Orchestrator): lang = CBB - - -class CDevitoPrinter(_DevitoPrinterBase): - - # These cannot go through _print_xxx because they are classes not - # instances - type_mappings = {**_DevitoPrinterBase.type_mappings, - c_complex: 'float _Complex', - c_double_complex: 'double _Complex'} - - _func_prefix = {**_DevitoPrinterBase._func_prefix, np.complex64: 'c', - np.complex128: 'c'} - - def _print_ImaginaryUnit(self, expr): - return '_Complex_I' diff --git a/devito/passes/iet/languages/CXX.py b/devito/passes/iet/languages/CXX.py index a6e1715e33..17003c0d8f 100644 --- a/devito/passes/iet/languages/CXX.py +++ b/devito/passes/iet/languages/CXX.py @@ -1,9 +1,5 @@ -from sympy.printing.cxx import CXX11CodePrinter - from devito.ir import Call, UsingNamespace from devito.passes.iet.langbase import LangBB -from devito.symbolics.printer import _DevitoPrinterBase -from devito.symbolics.extended_dtypes import c_complex, c_double_complex __all__ = ['CXXBB'] @@ -64,21 +60,3 @@ class CXXBB(LangBB): 'complex-namespace': [UsingNamespace('std::complex_literals')], 'def-complex': std_arith, } - - -class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter): - - _default_settings = {**_DevitoPrinterBase._default_settings, - **CXX11CodePrinter._default_settings} - _ns = "std::" - _func_litterals = {} - - # These cannot go through _print_xxx because they are classes not - # instances - type_mappings = {**_DevitoPrinterBase.type_mappings, - c_complex: 'std::complex', - c_double_complex: 'std::complex', - **CXX11CodePrinter.type_mappings} - - def _print_ImaginaryUnit(self, expr): - return f'1i{self.prec_literal(expr).lower()}' diff --git a/devito/passes/iet/languages/openacc.py b/devito/passes/iet/languages/openacc.py index 1718a5269a..bcf5660ac7 100644 --- a/devito/passes/iet/languages/openacc.py +++ b/devito/passes/iet/languages/openacc.py @@ -9,7 +9,7 @@ from devito.passes.iet.orchestration import Orchestrator from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB, PragmaIteration, PragmaTransfer) -from devito.passes.iet.languages.CXX import CXXBB, CXXDevitoPrinter +from devito.passes.iet.languages.CXX import CXXBB from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration from devito.symbolics import FieldFromPointer, Macro, cast_mapper from devito.tools import filter_ordered, UnboundTuple @@ -263,8 +263,3 @@ def place_devptr(self, iet, **kwargs): class AccOrchestrator(Orchestrator): lang = AccBB - - -class AccDevitoPrinter(CXXDevitoPrinter): - - pass diff --git a/devito/passes/iet/languages/targets.py b/devito/passes/iet/languages/targets.py index 66137a53e7..4ac8d94398 100644 --- a/devito/passes/iet/languages/targets.py +++ b/devito/passes/iet/languages/targets.py @@ -1,9 +1,9 @@ -from devito.passes.iet.languages.C import CDataManager, COrchestrator, CDevitoPrinter +from devito.passes.iet.languages.C import CDataManager, COrchestrator from devito.passes.iet.languages.openmp import (SimdOmpizer, Ompizer, DeviceOmpizer, OmpDataManager, DeviceOmpDataManager, OmpOrchestrator, DeviceOmpOrchestrator) from devito.passes.iet.languages.openacc import (DeviceAccizer, DeviceAccDataManager, - AccOrchestrator, AccDevitoPrinter) + AccOrchestrator) from devito.passes.iet.instrument import instrument __all__ = ['CTarget', 'OmpTarget', 'DeviceOmpTarget', 'DeviceAccTarget'] @@ -13,7 +13,6 @@ class Target: Parizer = None DataManager = None Orchestrator = None - Printer = None @classmethod def lang(cls): @@ -28,25 +27,21 @@ class CTarget(Target): Parizer = SimdOmpizer DataManager = CDataManager Orchestrator = COrchestrator - Printer = CDevitoPrinter class OmpTarget(Target): Parizer = Ompizer DataManager = OmpDataManager Orchestrator = OmpOrchestrator - Printer = CDevitoPrinter class DeviceOmpTarget(Target): Parizer = DeviceOmpizer DataManager = DeviceOmpDataManager Orchestrator = DeviceOmpOrchestrator - Printer = CDevitoPrinter class DeviceAccTarget(Target): Parizer = DeviceAccizer DataManager = DeviceAccDataManager Orchestrator = AccOrchestrator - Printer = AccDevitoPrinter diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 75566d0145..5c19ec02b8 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -7,17 +7,19 @@ from mpmath.libmp import prec_to_dps, to_str from packaging.version import Version -from numbers import Real from sympy.core import S from sympy.core.numbers import equal_valued, Float +from sympy.printing.codeprinter import CodePrinter from sympy.printing.c import C99CodePrinter +from sympy.printing.cxx import CXX11CodePrinter from sympy.logic.boolalg import BooleanFunction from sympy.printing.precedence import PRECEDENCE_VALUES, precedence from devito import configuration from devito.arch.compiler import AOMPCompiler from devito.symbolics.inspection import has_integer_args, sympy_dtype +from devito.symbolics.extended_dtypes import c_complex, c_double_complex from devito.types.basic import AbstractFunction from devito.tools import ctypes_to_cstr @@ -27,7 +29,7 @@ _prec_litterals = {np.float16: 'F16', np.float32: 'F', np.complex64: 'F'} -class _DevitoPrinterBase(C99CodePrinter): +class _DevitoPrinterBase(CodePrinter): """ Decorator for sympy.printing.ccode.CCodePrinter. @@ -38,10 +40,10 @@ class _DevitoPrinterBase(C99CodePrinter): Options for code printing. """ _default_settings = {'compiler': None, 'dtype': np.float32, - **C99CodePrinter._default_settings} + **CodePrinter._default_settings} - _func_prefix = {np.float32: 'f', np.float64: 'f'} - _func_litterals = {np.float32: 'f', np.complex64: 'f', Real: 'f'} + _func_prefix = {} + _func_litterals = {} @property def dtype(self): @@ -65,8 +67,10 @@ def _prec(self, expr): dtype = sympy_dtype(expr, default=self.dtype) if dtype is None or np.issubdtype(dtype, np.integer): real = any(isinstance(i, Float) for i in expr.atoms()) - stype = self.dtype if real else np.int32 - return np.result_type(dtype or stype, stype).type + if real: + return self.dtype + else: + return dtype or self.dtype else: return dtype or self.dtype @@ -206,11 +210,13 @@ def _print_Max(self, expr): def _print_Abs(self, expr): """Print an absolute value. Use `abs` if can infer it is an Integer""" + # Unary function, single argument + arg = expr.args[0] # AOMPCC errors with abs, always use fabs if isinstance(self.compiler, AOMPCompiler): - return "fabs(%s)" % self._print(expr.args[0]) - func = f'{self.func_prefix(expr, abs=True)}abs{self.func_literal(expr)}' - return f"{self._ns}{func}({self._print(expr.args[0])})" + return "fabs(%s)" % self._print(arg) + func = f'{self.func_prefix(arg, abs=True)}abs{self.func_literal(arg)}' + return f"{self._ns}{func}({self._print(arg)})" def _print_Add(self, expr, order=None): """" @@ -346,6 +352,58 @@ def _print_Fallback(self, expr): PRECEDENCE_VALUES['InlineIf'] = 1 +# Sympy 1.11 has introduced a bug in `_print_Add`, so we enforce here +# to always use the correct one from our printer +if Version(sympy.__version__) >= Version("1.11"): + setattr(sympy.printing.str.StrPrinter, '_print_Add', _DevitoPrinterBase._print_Add) + + +class CDevitoPrinter(_DevitoPrinterBase, C99CodePrinter): + + _default_settings = {**_DevitoPrinterBase._default_settings, + **C99CodePrinter._default_settings} + _func_litterals = {np.float32: 'f', np.complex64: 'f'} + _func_prefix = {np.float32: 'f', np.float64: 'f', + np.complex64: 'c', np.complex128: 'c'} + + # These cannot go through _print_xxx because they are classes not + # instances + type_mappings = {**C99CodePrinter.type_mappings, + c_complex: 'float _Complex', + c_double_complex: 'double _Complex'} + + def _print_ImaginaryUnit(self, expr): + return '_Complex_I' + + +class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter): + + _default_settings = {**_DevitoPrinterBase._default_settings, + **CXX11CodePrinter._default_settings} + _ns = "std::" + _func_litterals = {} + _func_prefix = {np.float32: 'f', np.float64: 'f'} + + # These cannot go through _print_xxx because they are classes not + # instances + type_mappings = {**CXX11CodePrinter.type_mappings, + c_complex: 'std::complex', + c_double_complex: 'std::complex'} + + def _print_ImaginaryUnit(self, expr): + return f'1i{self.prec_literal(expr).lower()}' + + +class AccDevitoPrinter(CXXDevitoPrinter): + + pass + + +printer_registry: dict[str, type[_DevitoPrinterBase]] = { + 'C': CDevitoPrinter, 'openmp': CDevitoPrinter, + 'openacc': AccDevitoPrinter} + + def ccode(expr, **settings): """Generate C++ code from an expression. @@ -362,10 +420,5 @@ def ccode(expr, **settings): The resulting code as a C++ string. If something went south, returns the input ``expr`` itself. """ - return _DevitoPrinterBase(settings=settings).doprint(expr, None) - - -# Sympy 1.11 has introduced a bug in `_print_Add`, so we enforce here -# to always use the correct one from our printer -if Version(sympy.__version__) >= Version("1.11"): - setattr(sympy.printing.str.StrPrinter, '_print_Add', _DevitoPrinterBase._print_Add) + printer = printer_registry.get(configuration['language'], CDevitoPrinter) + return printer(settings=settings).doprint(expr, None) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index ffc67a4d2b..88d9299ee5 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -4,11 +4,11 @@ from devito import Constant, Eq, Function, Grid, Operator from devito.passes.iet.langbase import LangBB -from devito.passes.iet.languages.C import CBB, CDevitoPrinter -from devito.passes.iet.languages.openacc import AccBB, AccDevitoPrinter +from devito.passes.iet.languages.C import CBB +from devito.passes.iet.languages.openacc import AccBB from devito.passes.iet.languages.openmp import OmpBB from devito.symbolics.extended_dtypes import ctypes_vector_mapper -from devito.symbolics.printer import _DevitoPrinterBase +from devito.symbolics.printer import printer_registry, _DevitoPrinterBase from devito.types.basic import Basic, Scalar, Symbol from devito.types.dense import TimeFunction @@ -20,13 +20,6 @@ } -_printers: dict[str, type[_DevitoPrinterBase]] = { - 'C': CDevitoPrinter, - 'openmp': CDevitoPrinter, - 'openacc': AccDevitoPrinter -} - - def _get_language(language: str, **_) -> type[LangBB]: """ Gets the language building block type from parametrized kwargs. @@ -40,7 +33,7 @@ def _get_printer(language: str, **_) -> type[_DevitoPrinterBase]: Gets the printer building block type from parametrized kwargs. """ - return _printers[language] + return printer_registry[language] def _config_kwargs(platform: str, language: str) -> dict[str, str]: diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 66a7b3b28c..31fb79eb48 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -12,7 +12,7 @@ from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa CallFromPointer, Cast, DefFunction, FieldFromPointer, INT, FieldFromComposite, IntDiv, Namespace, Rvalue, - ReservedWord, ListInitializer, ccode, uxreplace, + ReservedWord, ListInitializer, uxreplace, ccode, retrieve_derivatives) from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, @@ -298,8 +298,7 @@ def test_extended_sympy_arithmetic(): def test_integer_abs(): i1 = Dimension(name="i1") assert ccode(Abs(i1 - 1)) == "abs(i1 - 1)" - # .5 is a standard python Float, i.e an np.float64 - assert ccode(Abs(i1 - .5)) == "fabs(i1 - 5.0e-1)" + assert ccode(Abs(i1 - .5)) == "fabsf(i1 - 5.0e-1F)" assert ccode( Abs(i1 - Constant('half', dtype=np.float64, default_value=0.5)) ) == "fabs(i1 - half)"