Skip to content

Commit

Permalink
symbolics: move printers rogether through registry
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 17, 2025
1 parent 62f2deb commit 32f20ee
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 154 deletions.
72 changes: 28 additions & 44 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
from sympy import IndexedBase
from sympy.core.function import Application

from devito.parameters import configuration, switchconfig
from devito.exceptions import CompilationError
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
Call, Lambda, BlankLine, Section, ListMajor)
from devito.ir.support.space import Backward
from devito.symbolics import (FieldFromComposite, FieldFromPointer,
ListInitializer, uxreplace)
from devito.symbolics.printer import _DevitoPrinterBase
from devito.symbolics.printer import ccode
from devito.symbolics.extended_dtypes import NoDeclStruct
from devito.tools import (GenericVisitor, as_tuple, filter_ordered,
filter_sorted, flatten, is_external_ctype,
Expand Down Expand Up @@ -177,10 +176,8 @@ class CGen(Visitor):
Return a representation of the Iteration/Expression tree as a :module:`cgen` tree.
"""

def __init__(self, *args, compiler=None, printer=None, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._compiler = compiler or configuration['compiler']
self._printer = printer or _DevitoPrinterBase

# The following mappers may be customized by subclasses (that is,
# backend-specific CGen-erators)
Expand All @@ -192,19 +189,6 @@ def __init__(self, *args, compiler=None, printer=None, **kwargs):
}
_restrict_keyword = 'restrict'

@property
def compiler(self):
return self._compiler

def ccode(self, expr, **settings):
return self._printer(settings=settings).doprint(expr, None)

def visit(self, o, *args, **kwargs):
# Make sure the visitor always is within the generating compiler
# in case the configuration is accessed
with switchconfig(compiler=self.compiler.name):
return super().visit(o, *args, **kwargs)

def _gen_struct_decl(self, obj, masked=()):
"""
Convert ctypes.Struct -> cgen.Structure.
Expand Down Expand Up @@ -238,7 +222,7 @@ def _gen_struct_decl(self, obj, masked=()):
try:
entries.append(self._gen_value(i, 0, masked=('const',)))
except AttributeError:
cstr = self.ccode(ct)
cstr = ccode(ct)
if ct is c_restrict_void_p:
cstr = '%srestrict' % cstr
entries.append(c.Value(cstr, n))
Expand All @@ -260,10 +244,10 @@ def _gen_value(self, obj, mode=1, masked=()):
if getattr(obj.function, k, False) and v not in masked]

if (obj._mem_stack or obj._mem_constant) and mode == 1:
strtype = self.ccode(obj._C_typedata)
strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape)
strtype = ccode(obj._C_typedata)
strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape)
else:
strtype = self.ccode(obj._C_ctype)
strtype = ccode(obj._C_ctype)
strshape = ''
if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1:
if not obj._mem_stack:
Expand All @@ -277,7 +261,7 @@ def _gen_value(self, obj, mode=1, masked=()):
strobj = '%s%s' % (strname, strshape)

if obj.is_LocalObject and obj.cargs and mode == 1:
arguments = [self.ccode(i) for i in obj.cargs]
arguments = [ccode(i) for i in obj.cargs]
strobj = MultilineCall(strobj, arguments, True)

value = c.Value(strtype, strobj)
Expand All @@ -291,9 +275,9 @@ def _gen_value(self, obj, mode=1, masked=()):
if obj.is_Array and obj.initvalue is not None and mode == 1:
init = ListInitializer(obj.initvalue)
if not obj._mem_constant or init.is_numeric:
value = c.Initializer(value, self.ccode(init))
value = c.Initializer(value, ccode(init))
elif obj.is_LocalObject and obj.initvalue is not None and mode == 1:
value = c.Initializer(value, self.ccode(obj.initvalue))
value = c.Initializer(value, ccode(obj.initvalue))

return value

Expand Down Expand Up @@ -327,7 +311,7 @@ def _args_call(self, args):
else:
ret.append(i._C_name)
except AttributeError:
ret.append(self.ccode(i))
ret.append(ccode(i))
return ret

def _gen_signature(self, o, is_declaration=False):
Expand Down Expand Up @@ -393,7 +377,7 @@ def visit_tuple(self, o):
def visit_PointerCast(self, o):
f = o.function
i = f.indexed
cstr = self.ccode(i._C_typedata)
cstr = ccode(i._C_typedata)

if f.is_PointerArray:
# lvalue
Expand All @@ -415,7 +399,7 @@ def visit_PointerCast(self, o):
else:
v = f.name
if o.flat is None:
shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape)
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
rshape = '(*)%s' % shape
lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape))
else:
Expand Down Expand Up @@ -448,9 +432,9 @@ def visit_Dereference(self, o):
a0, a1 = o.functions
if a1.is_PointerArray or a1.is_TempFunction:
i = a1.indexed
cstr = self.ccode(i._C_typedata)
cstr = ccode(i._C_typedata)
if o.flat is None:
shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:])
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
a1.dim.name)
lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape))
Expand Down Expand Up @@ -489,8 +473,8 @@ def visit_Definition(self, o):
return self._gen_value(o.function)

def visit_Expression(self, o):
lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler)
rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler)
lhs = ccode(o.expr.lhs, dtype=o.dtype)
rhs = ccode(o.expr.rhs, dtype=o.dtype)

if o.init:
code = c.Initializer(self._gen_value(o.expr.lhs, 0), rhs)
Expand All @@ -503,8 +487,8 @@ def visit_Expression(self, o):
return code

def visit_AugmentedExpression(self, o):
c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype, compiler=self._compiler)
c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype, compiler=self._compiler)
c_lhs = ccode(o.expr.lhs, dtype=o.dtype)
c_rhs = ccode(o.expr.rhs, dtype=o.dtype)
code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs))
if o.pragmas:
code = c.Module(self._visit(o.pragmas) + (code,))
Expand All @@ -523,7 +507,7 @@ def visit_Call(self, o, nested_call=False):
o.templates)
if retobj.is_Indexed or \
isinstance(retobj, (FieldFromComposite, FieldFromPointer)):
return c.Assign(self.ccode(retobj), call)
return c.Assign(ccode(retobj), call)
else:
return c.Initializer(c.Value(rettype, retobj._C_name), call)

Expand All @@ -537,9 +521,9 @@ def visit_Conditional(self, o):
then_body = c.Block(self._visit(then_body))
if else_body:
else_body = c.Block(self._visit(else_body))
return c.If(self.ccode(o.condition), then_body, else_body)
return c.If(ccode(o.condition), then_body, else_body)
else:
return c.If(self.ccode(o.condition), then_body)
return c.If(ccode(o.condition), then_body)

def visit_Iteration(self, o):
body = flatten(self._visit(i) for i in self._blankline_logic(o.children))
Expand All @@ -549,23 +533,23 @@ def visit_Iteration(self, o):

# For backward direction flip loop bounds
if o.direction == Backward:
loop_init = 'int %s = %s' % (o.index, self.ccode(_max))
loop_cond = '%s >= %s' % (o.index, self.ccode(_min))
loop_init = 'int %s = %s' % (o.index, ccode(_max))
loop_cond = '%s >= %s' % (o.index, ccode(_min))
loop_inc = '%s -= %s' % (o.index, o.limits[2])
else:
loop_init = 'int %s = %s' % (o.index, self.ccode(_min))
loop_cond = '%s <= %s' % (o.index, self.ccode(_max))
loop_init = 'int %s = %s' % (o.index, ccode(_min))
loop_cond = '%s <= %s' % (o.index, ccode(_max))
loop_inc = '%s += %s' % (o.index, o.limits[2])

# Append unbounded indices, if any
if o.uindices:
uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices]
uinit = ['%s = %s' % (i.name, ccode(i.symbolic_min)) for i in o.uindices]
loop_init = c.Line(', '.join([loop_init] + uinit))

ustep = []
for i in o.uindices:
op = '=' if i.is_Modulo else '+='
ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr)))
ustep.append('%s %s %s' % (i.name, op, ccode(i.symbolic_incr)))
loop_inc = c.Line(', '.join([loop_inc] + ustep))

# Create For header+body
Expand All @@ -582,7 +566,7 @@ def visit_Pragma(self, o):
return c.Pragma(o._generate)

def visit_While(self, o):
condition = self.ccode(o.condition)
condition = ccode(o.condition)
if o.body:
body = flatten(self._visit(i) for i in o.children)
return c.While(condition, c.Block(body))
Expand Down
19 changes: 7 additions & 12 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from devito.operator.profiling import create_profile
from devito.operator.registry import operator_selector
from devito.mpi import MPI
from devito.parameters import configuration
from devito.parameters import configuration, switchconfig
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, minimize_symbols, unevaluate,
error_mapper, is_on_device)
Expand Down Expand Up @@ -758,19 +758,14 @@ def _soname(self):
"""A unique name for the shared object resulting from JIT compilation."""
return Signer._digest(self, configuration)

@property
def printer(self):
return self._Target.Printer

@cached_property
def ccode(self):
try:
return self._ccode_handler(compiler=self._compiler,
printer=self.printer).visit(self)
except (AttributeError, TypeError):
from devito.ir.iet.visitors import CGen
return CGen(compiler=self._compiler,
printer=self.printer).visit(self)
with switchconfig(compiler=self._compiler, language=self._language):
try:
return self._ccode_handler().visit(self)
except (AttributeError, TypeError):
from devito.ir.iet.visitors import CGen
return CGen().visit(self)

def _jit_compile(self):
"""
Expand Down
4 changes: 1 addition & 3 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,10 @@ class DataManager:
The language used to express data allocations, deletions, and host-device transfers.
"""

def __init__(self, rcompile=None, sregistry=None, platform=None,
compiler=None, **kwargs):
def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs):
self.rcompile = rcompile
self.sregistry = sregistry
self.platform = platform
self.compiler = compiler

def _alloc_object_on_low_lat_mem(self, site, obj, storage):
"""
Expand Down
24 changes: 14 additions & 10 deletions devito/passes/iet/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,6 @@
__all__ = ['lower_dtypes']


def lower_dtypes(graph: Callable, lang: type[LangBB] = None, compiler: Compiler = None,
sregistry: SymbolRegistry = None, **kwargs) -> tuple[Callable, dict]:
"""
Lowers float16 scalar types to pointers since we can't directly pass their
value. Also includes headers for complex arithmetic if needed.
"""
# Complex numbers
_complex_includes(graph, lang=lang, compiler=compiler)


@iet_pass
def _complex_includes(iet: Callable, lang: type[LangBB] = None,
compiler: Compiler = None) -> tuple[Callable, dict]:
Expand Down Expand Up @@ -50,3 +40,17 @@ def _complex_includes(iet: Callable, lang: type[LangBB] = None,
metadata['includes'] = lib

return iet, metadata


dtype_passes = [_complex_includes]


def lower_dtypes(graph: Callable, lang: type[LangBB] = None, compiler: Compiler = None,
sregistry: SymbolRegistry = None, **kwargs) -> tuple[Callable, dict]:
"""
Lowers float16 scalar types to pointers since we can't directly pass their
value. Also includes headers for complex arithmetic if needed.
"""

for dtype_pass in dtype_passes:
dtype_pass(graph, lang=lang, compiler=compiler)
19 changes: 0 additions & 19 deletions devito/passes/iet/languages/C.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import numpy as np

from devito.ir import Call
from devito.passes.iet.definitions import DataManager
from devito.passes.iet.orchestration import Orchestrator
from devito.passes.iet.langbase import LangBB
from devito.symbolics.extended_dtypes import c_complex, c_double_complex
from devito.symbolics.printer import _DevitoPrinterBase

__all__ = ['CBB', 'CDataManager', 'COrchestrator']

Expand Down Expand Up @@ -35,18 +31,3 @@ class CDataManager(DataManager):

class COrchestrator(Orchestrator):
lang = CBB


class CDevitoPrinter(_DevitoPrinterBase):

# These cannot go through _print_xxx because they are classes not
# instances
type_mappings = {**_DevitoPrinterBase.type_mappings,
c_complex: 'float _Complex',
c_double_complex: 'double _Complex'}

_func_prefix = {**_DevitoPrinterBase._func_prefix, np.complex64: 'c',
np.complex128: 'c'}

def _print_ImaginaryUnit(self, expr):
return '_Complex_I'
22 changes: 0 additions & 22 deletions devito/passes/iet/languages/CXX.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from sympy.printing.cxx import CXX11CodePrinter

from devito.ir import Call, UsingNamespace
from devito.passes.iet.langbase import LangBB
from devito.symbolics.printer import _DevitoPrinterBase
from devito.symbolics.extended_dtypes import c_complex, c_double_complex

__all__ = ['CXXBB']

Expand Down Expand Up @@ -64,21 +60,3 @@ class CXXBB(LangBB):
'complex-namespace': [UsingNamespace('std::complex_literals')],
'def-complex': std_arith,
}


class CXXDevitoPrinter(_DevitoPrinterBase, CXX11CodePrinter):

_default_settings = {**_DevitoPrinterBase._default_settings,
**CXX11CodePrinter._default_settings}
_ns = "std::"
_func_litterals = {}

# These cannot go through _print_xxx because they are classes not
# instances
type_mappings = {**_DevitoPrinterBase.type_mappings,
c_complex: 'std::complex<float>',
c_double_complex: 'std::complex<double>',
**CXX11CodePrinter.type_mappings}

def _print_ImaginaryUnit(self, expr):
return f'1i{self.prec_literal(expr).lower()}'
7 changes: 1 addition & 6 deletions devito/passes/iet/languages/openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from devito.passes.iet.orchestration import Orchestrator
from devito.passes.iet.parpragma import (PragmaDeviceAwareTransformer, PragmaLangBB,
PragmaIteration, PragmaTransfer)
from devito.passes.iet.languages.CXX import CXXBB, CXXDevitoPrinter
from devito.passes.iet.languages.CXX import CXXBB
from devito.passes.iet.languages.openmp import OmpRegion, OmpIteration
from devito.symbolics import FieldFromPointer, Macro, cast_mapper
from devito.tools import filter_ordered, UnboundTuple
Expand Down Expand Up @@ -263,8 +263,3 @@ def place_devptr(self, iet, **kwargs):

class AccOrchestrator(Orchestrator):
lang = AccBB


class AccDevitoPrinter(CXXDevitoPrinter):

pass
Loading

0 comments on commit 32f20ee

Please sign in to comment.