diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 33fe96e7ad..fc01123b62 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -15,30 +15,34 @@ from __future__ import annotations import dataclasses -from typing import Any, Generic +from typing import Generic from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators -from gt4py.next.iterator import ir as itir +from gt4py.next.ffront import past_process_args, past_to_itir +from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import processor_interface as ppi +DEFAULT_TRANSFORMS: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = ( + past_process_args.past_process_args.chain(past_to_itir.PastToItirFactory()) +) + + @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): executor: ppi.ProgramExecutor allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] + transformer: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = DEFAULT_TRANSFORMS - def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None: - self.executor.__call__(program, *args, **kwargs) + def __call__(self, program: stages.PastClosure) -> None: + program_call = self.transformer(program) + self.executor(program_call.program, *program_call.args, **program_call.kwargs) @property def __name__(self) -> str: return getattr(self.executor, "__name__", None) or repr(self) - @property - def kind(self) -> type[ppi.ProgramExecutor]: - return self.executor.kind - @property def __gt_allocator__( self, diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 60e96a6cc2..4dd7c6e399 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -18,31 +18,35 @@ from __future__ import annotations -import collections import dataclasses import functools import types import typing import warnings -from collections.abc import Callable, Iterable +from collections.abc import Callable from typing import Generator, Generic, TypeVar -from devtools import debug - from gt4py import eve from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, embedded as next_embedded, errors -from gt4py.next.common import Dimension, DimensionKind, GridType +from gt4py.next import ( + allocators as next_allocators, + backend as next_backend, + embedded as next_embedded, + errors, +) +from gt4py.next.common import Dimension, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( dialect_ast_enums, field_operator_ast as foast, + past_process_args, + past_to_itir, program_ast as past, + transform_utils, type_specifications as ts_ffront, ) -from gt4py.next.ffront.fbuiltins import FieldOffset from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering from gt4py.next.ffront.func_to_foast import FieldOperatorParser @@ -52,7 +56,6 @@ ClosureVarTypeDeduction as ProgramClosureVarTypeDeduction, ) from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction -from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils.ir_makers import ( @@ -61,6 +64,7 @@ ref, sym, ) +from gt4py.next.otf import stages from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -68,69 +72,6 @@ DEFAULT_BACKEND: Callable = None -def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any]: - all_closure_vars = collections.ChainMap(closure_vars) - - for closure_var in closure_vars.values(): - if isinstance(closure_var, GTCallable): - # if the closure ref has closure refs by itself, also add them - if child_closure_vars := closure_var.__gt_closure_vars__(): - all_child_closure_vars = _get_closure_vars_recursively(child_closure_vars) - - collisions: list[str] = [] - for potential_collision in set(closure_vars) & set(all_child_closure_vars): - if ( - closure_vars[potential_collision] - != all_child_closure_vars[potential_collision] - ): - collisions.append(potential_collision) - if collisions: - raise NotImplementedError( - f"Using closure vars with same name but different value " - f"across functions is not implemented yet. \n" - f"Collisions: '{', '.join(collisions)}'." - ) - - all_closure_vars = collections.ChainMap(all_closure_vars, all_child_closure_vars) - return dict(all_closure_vars) - - -def _filter_closure_vars_by_type(closure_vars: dict[str, Any], *types: type) -> dict[str, Any]: - return {name: value for name, value in closure_vars.items() if isinstance(value, types)} - - -def _deduce_grid_type( - requested_grid_type: Optional[GridType], - offsets_and_dimensions: Iterable[FieldOffset | Dimension], -) -> GridType: - """ - Derive grid type from actually occurring dimensions and check against optional user request. - - Unstructured grid type is consistent with any kind of offset, cartesian - is easier to optimize for but only allowed in the absence of unstructured - dimensions and offsets. - """ - - def is_cartesian_offset(o: FieldOffset): - return len(o.target) == 1 and o.source == o.target[0] - - deduced_grid_type = GridType.CARTESIAN - for o in offsets_and_dimensions: - if isinstance(o, FieldOffset) and not is_cartesian_offset(o): - deduced_grid_type = GridType.UNSTRUCTURED - break - if isinstance(o, Dimension) and o.kind == DimensionKind.LOCAL: - deduced_grid_type = GridType.UNSTRUCTURED - break - - if requested_grid_type == GridType.CARTESIAN and deduced_grid_type == GridType.UNSTRUCTURED: - raise ValueError( - "'grid_type == GridType.CARTESIAN' was requested, but unstructured 'FieldOffset' or local 'Dimension' was found." - ) - - return deduced_grid_type if requested_grid_type is None else requested_grid_type - - def _field_constituents_shape_and_dims( arg, arg_type: ts.FieldType | ts.ScalarType | ts.TupleType ) -> Generator[tuple[tuple[int, ...], list[Dimension]]]: @@ -152,10 +93,6 @@ def _field_constituents_shape_and_dims( # TODO(tehrengruber): Decide if and how programs can call other programs. As a # result Program could become a GTCallable. -# TODO(ricoh): factor out the generated ITIR together with arguments rewriting -# so that using fencil processors on `some_program.itir` becomes trivial without -# prior knowledge of the fencil signature rewriting done by `Program`. -# After that, drop the `.format_itir()` method, since it won't be needed. @dataclasses.dataclass(frozen=True) class Program: """ @@ -177,14 +114,14 @@ class Program: past_node: past.Program closure_vars: dict[str, Any] definition: Optional[types.FunctionType] - backend: Optional[ppi.ProgramExecutor] + backend: Optional[next_backend.Backend] grid_type: Optional[GridType] @classmethod def from_function( cls, definition: types.FunctionType, - backend: Optional[ppi.ProgramExecutor], + backend: Optional[next_backend], grid_type: Optional[GridType] = None, ) -> Program: source_def = SourceDefinition.from_function(definition) @@ -200,7 +137,9 @@ def from_function( ) def __post_init__(self): - function_closure_vars = _filter_closure_vars_by_type(self.closure_vars, GTCallable) + function_closure_vars = transform_utils._filter_closure_vars_by_type( + self.closure_vars, GTCallable + ) misnamed_functions = [ f"{name} vs. {func.id}" for name, func in function_closure_vars.items() @@ -276,24 +215,21 @@ def with_bound_args(self, **kwargs) -> ProgramWithBoundArgs: @functools.cached_property def _all_closure_vars(self) -> dict[str, Any]: - return _get_closure_vars_recursively(self.closure_vars) + return transform_utils._get_closure_vars_recursively(self.closure_vars) @functools.cached_property def itir(self) -> itir.FencilDefinition: - offsets_and_dimensions = _filter_closure_vars_by_type( - self._all_closure_vars, FieldOffset, Dimension - ) - grid_type = _deduce_grid_type(self.grid_type, offsets_and_dimensions.values()) - - gt_callables = _filter_closure_vars_by_type(self._all_closure_vars, GTCallable).values() - lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] - return ProgramLowering.apply( - self.past_node, function_definitions=lowered_funcs, grid_type=grid_type - ) + return past_to_itir.PastToItirFactory()( + stages.PastClosure( + past_node=self.past_node, + closure_vars=self.closure_vars, + grid_type=self.grid_type, + args=[], + kwargs={}, + ) + ).program def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> None: - rewritten_args, size_args, kwargs = self._process_args(args, kwargs) - if self.backend is None: warnings.warn( UserWarning( @@ -302,122 +238,31 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No stacklevel=2, ) with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: + # TODO(ricoh): check if rewriting still needed + rewritten_args, size_args, kwargs = past_process_args._process_args( + self.past_node, args, kwargs + ) ctx.run(self.definition, *rewritten_args, **kwargs) return - ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor) - if "debug" in kwargs: - debug(self.itir) + ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor) self.backend( - self.itir, - *rewritten_args, - *size_args, - **kwargs, - offset_provider=offset_provider, - column_axis=self._column_axis, - ) - - def format_itir( - self, - *args, - formatter: ppi.ProgramFormatter, - offset_provider: dict[str, Dimension], - **kwargs, - ) -> str: - ppi.ensure_processor_kind(formatter, ppi.ProgramFormatter) - rewritten_args, size_args, kwargs = self._process_args(args, kwargs) - if "debug" in kwargs: - debug(self.itir) - return formatter( - self.itir, - *rewritten_args, - *size_args, - **kwargs, - offset_provider=offset_provider, - ) - - def _validate_args(self, *args, **kwargs) -> None: - arg_types = [type_translation.from_value(arg) for arg in args] - kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()} - - try: - type_info.accepts_args( - self.past_node.type, - with_args=arg_types, - with_kwargs=kwarg_types, - raise_exception=True, + stages.PastClosure( + closure_vars=self.closure_vars, + past_node=self.past_node, + grid_type=self.grid_type, + args=args, + kwargs=kwargs | {"offset_provider": offset_provider}, ) - except ValueError as err: - raise errors.DSLError( - None, f"Invalid argument types in call to '{self.past_node.id}'.\n{err}" - ) from err - - def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: - self._validate_args(*args, **kwargs) - - args, kwargs = type_info.canonicalize_arguments(self.past_node.type, args, kwargs) - - implicit_domain = any( - isinstance(stmt, past.Call) and "domain" not in stmt.kwargs - for stmt in self.past_node.body ) - # extract size of all field arguments - size_args: list[Optional[tuple[int, ...]]] = [] - rewritten_args = list(args) - for param_idx, param in enumerate(self.past_node.params): - if implicit_domain and isinstance(param.type, (ts.FieldType, ts.TupleType)): - shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)] - shape, dims = shapes_and_dims[0] - if not all( - el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims - ): - raise ValueError( - "Constituents of composite arguments (e.g. the elements of a" - " tuple) need to have the same shape and dimensions." - ) - size_args.extend(shape if shape else [None] * len(dims)) - return tuple(rewritten_args), tuple(size_args), kwargs - - @functools.cached_property - def _column_axis(self): - # construct mapping from column axis to scan operators defined on - # that dimension. only one column axis is allowed, but we can use - # this mapping to provide good error messages. - scanops_per_axis: dict[Dimension, str] = {} - for name, gt_callable in _filter_closure_vars_by_type( - self._all_closure_vars, GTCallable - ).items(): - if isinstance( - (type_ := gt_callable.__gt_type__()), - ts_ffront.ScanOperatorType, - ): - scanops_per_axis.setdefault(type_.axis, []).append(name) - - if len(scanops_per_axis.values()) == 0: - return None - - if len(scanops_per_axis.values()) != 1: - scanops_per_axis_strs = [ - f"- {dim.value}: {', '.join(scanops)}" for dim, scanops in scanops_per_axis.items() - ] - - raise TypeError( - "Only 'ScanOperator's defined on the same axis " - + "can be used in a 'Program', found:\n" - + "\n".join(scanops_per_axis_strs) - + "." - ) - - return iter(scanops_per_axis.keys()).__next__() - @dataclasses.dataclass(frozen=True) class ProgramWithBoundArgs(Program): bound_args: dict[str, typing.Union[float, int, bool]] = None - def _process_args(self, args: tuple, kwargs: dict): + def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): type_ = self.past_node.type new_type = ts_ffront.ProgramType( definition=ts.FunctionType( @@ -468,7 +313,7 @@ def _process_args(self, args: tuple, kwargs: dict): else: full_kwargs[str(param.id)] = self.bound_args[param.id] - return super()._process_args(tuple(full_args), full_kwargs) + return super().__call__(*tuple(full_args), offset_provider=offset_provider, **full_kwargs) @functools.cached_property def itir(self): @@ -507,7 +352,8 @@ def program( def program( definition=None, *, - backend=eve.NOTHING, # `NOTHING` -> default backend, `None` -> no backend (embedded execution) + # `NOTHING` -> default backend, `None` -> no backend (embedded execution) + backend=eve.NOTHING, grid_type=None, ) -> Program | Callable[[types.FunctionType], Program]: """ @@ -653,7 +499,8 @@ def as_program( pass loc = self.foast_node.location - param_sym_uids = eve_utils.UIDGenerator() # use a new UID generator to allow caching + # use a new UID generator to allow caching + param_sym_uids = eve_utils.UIDGenerator() type_ = self.__gt_type__() params_decl: list[past.Symbol] = [ @@ -711,6 +558,7 @@ def as_program( backend=self.backend, grid_type=self.grid_type, ) + return self._program_cache[hash_] def __call__( diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py new file mode 100644 index 0000000000..22661502e9 --- /dev/null +++ b/src/gt4py/next/ffront/past_process_args.py @@ -0,0 +1,109 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any, Iterator, Optional + +from gt4py.next import common, errors +from gt4py.next.ffront import program_ast as past, type_specifications as ts_ffront +from gt4py.next.otf import stages, workflow +from gt4py.next.type_system import type_info, type_specifications as ts, type_translation + + +@workflow.make_step +def past_process_args(inp: stages.PastClosure) -> stages.PastClosure: + extra_kwarg_names = ["offset_provider", "column_axis"] + extra_kwargs = {k: v for k, v in inp.kwargs.items() if k in extra_kwarg_names} + kwargs = {k: v for k, v in inp.kwargs.items() if k not in extra_kwarg_names} + rewritten_args, size_args, kwargs = _process_args( + past_node=inp.past_node, args=list(inp.args), kwargs=kwargs + ) + return stages.PastClosure( + past_node=inp.past_node, + closure_vars=inp.closure_vars, + grid_type=inp.grid_type, + args=tuple([*rewritten_args, *size_args]), + kwargs=kwargs | extra_kwargs, + ) + + +def _validate_args(past_node: past.Program, args: list, kwargs: dict[str, Any]) -> None: + arg_types = [type_translation.from_value(arg) for arg in args] + kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()} + + if not isinstance(past_node.type, ts_ffront.ProgramType): + raise TypeError("Can not validate arguments for PAST programs prior to type inference.") + + try: + type_info.accepts_args( + past_node.type, + with_args=arg_types, + with_kwargs=kwarg_types, + raise_exception=True, + ) + except ValueError as err: + raise errors.DSLError( + None, f"Invalid argument types in call to '{past_node.id}'.\n{err}" + ) from err + + +def _process_args( + past_node: past.Program, args: list, kwargs: dict[str, Any] +) -> tuple[tuple, tuple, dict[str, Any]]: + if not isinstance(past_node.type, ts_ffront.ProgramType): + raise TypeError("Can not process arguments for PAST programs prior to type inference.") + + _validate_args(past_node=past_node, args=args, kwargs=kwargs) + + args, kwargs = type_info.canonicalize_arguments(past_node.type, args, kwargs) + + implicit_domain = any( + isinstance(stmt, past.Call) and "domain" not in stmt.kwargs for stmt in past_node.body + ) + + # extract size of all field arguments + size_args: list[Optional[int]] = [] + rewritten_args = list(args) + for param_idx, param in enumerate(past_node.params): + if implicit_domain and isinstance(param.type, (ts.FieldType, ts.TupleType)): + shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)] + shape, dims = shapes_and_dims[0] + if not all( + el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims + ): + raise ValueError( + "Constituents of composite arguments (e.g. the elements of a" + " tuple) need to have the same shape and dimensions." + ) + size_args.extend(shape if shape else [None] * len(dims)) + return tuple(rewritten_args), tuple(size_args), kwargs + + +def _field_constituents_shape_and_dims( + arg, arg_type: ts.DataType +) -> Iterator[tuple[tuple[int, ...], list[common.Dimension]]]: + match arg_type: + case ts.TupleType(): + for el, el_type in zip(arg, arg_type.types): + yield from _field_constituents_shape_and_dims(el, el_type) + case ts.FieldType(): + dims = type_info.extract_dims(arg_type) + if hasattr(arg, "shape"): + assert len(arg.shape) == len(dims) + yield (arg.shape, dims) + else: + yield (tuple(), dims) + case ts.ScalarType(): + yield (tuple(), []) + case _: + raise ValueError("Expected 'FieldType' or 'TupleType' thereof.") diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 99534b4e61..0fc9a6280d 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -14,16 +14,95 @@ from __future__ import annotations -from typing import Optional, cast +import dataclasses +from typing import Any, Optional, cast + +import devtools +import factory from gt4py.eve import NodeTranslator, concepts, traits -from gt4py.next.common import Dimension, DimensionKind, GridType -from gt4py.next.ffront import lowering_utils, program_ast as past, type_specifications as ts_ffront +from gt4py.next import common, config +from gt4py.next.ffront import ( + fbuiltins, + gtcallable, + lowering_utils, + program_ast as past, + transform_utils, + type_specifications as ts_ffront, +) from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.otf import stages, workflow from gt4py.next.type_system import type_info, type_specifications as ts +@dataclasses.dataclass(frozen=True) +class PastToItir(workflow.ChainableWorkflowMixin): + def __call__(self, inp: stages.PastClosure) -> stages.ProgramCall: + all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars) + offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( + all_closure_vars, fbuiltins.FieldOffset, common.Dimension + ) + grid_type = transform_utils._deduce_grid_type( + inp.grid_type, offsets_and_dimensions.values() + ) + + gt_callables = transform_utils._filter_closure_vars_by_type( + all_closure_vars, gtcallable.GTCallable + ).values() + lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + + itir_program = ProgramLowering.apply( + inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type + ) + + if config.DEBUG or "debug" in inp.kwargs: + devtools.debug(itir_program) + + return stages.ProgramCall( + itir_program, + inp.args, + inp.kwargs | {"column_axis": _column_axis(all_closure_vars)}, + ) + + +class PastToItirFactory(factory.Factory): + class Meta: + model = PastToItir + + +def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]: + # construct mapping from column axis to scan operators defined on + # that dimension. only one column axis is allowed, but we can use + # this mapping to provide good error messages. + scanops_per_axis: dict[common.Dimension, list[str]] = {} + for name, gt_callable in transform_utils._filter_closure_vars_by_type( + all_closure_vars, gtcallable.GTCallable + ).items(): + if isinstance( + (type_ := gt_callable.__gt_type__()), + ts_ffront.ScanOperatorType, + ): + scanops_per_axis.setdefault(type_.axis, []).append(name) + + if len(scanops_per_axis.values()) == 0: + return None + + if len(scanops_per_axis.values()) != 1: + scanops_per_axis_strs = [ + f"- {dim.value}: {', '.join(scanops)}" for dim, scanops in scanops_per_axis.items() + ] + + raise TypeError( + "Only 'ScanOperator's defined on the same axis " + + "can be used in a 'Program', found:\n" + + "\n".join(scanops_per_axis_strs) + + "." + ) + + return iter(scanops_per_axis.keys()).__next__() + + def _size_arg_from_field(field_name: str, dim: int) -> str: return f"__{field_name}_size_{dim}" @@ -67,7 +146,7 @@ class ProgramLowering( ... expr=ir.FunCall(fun=ir.SymRef(id="deref"), pos_only_args=[ir.SymRef(id="inp")]), ... ) # doctest: +SKIP >>> lowered = ProgramLowering.apply( - ... parsed, [fieldop_def], grid_type=GridType.CARTESIAN + ... parsed, [fieldop_def], grid_type=common.GridType.CARTESIAN ... ) # doctest: +SKIP >>> type(lowered) # doctest: +SKIP @@ -85,7 +164,7 @@ def apply( cls, node: past.Program, function_definitions: list[itir.FunctionDefinition], - grid_type: GridType, + grid_type: common.GridType, ) -> itir.FencilDefinition: return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions) @@ -97,7 +176,7 @@ def _gen_size_params_from_program(self, node: past.Program): size_params = [] for param in node.params: if type_info.is_type_or_tuple_of_type(param.type, ts.FieldType): - fields_dims: list[list[Dimension]] = ( + fields_dims: list[list[common.Dimension]] = ( type_info.primitive_constituents(param.type).getattr("dims").to_list() ) assert all(field_dims == fields_dims[0] for field_dims in fields_dims) @@ -263,8 +342,8 @@ def _construct_itir_domain_arg( slices[dim_i].upper if slices else None, dim_size, dim_size ) - if dim.kind == DimensionKind.LOCAL: - raise ValueError(f"Dimension '{dim.value}' must not be local.") + if dim.kind == common.DimensionKind.LOCAL: + raise ValueError(f"common.Dimension '{dim.value}' must not be local.") domain_args.append( itir.FunCall( fun=itir.SymRef(id="named_range"), @@ -273,14 +352,14 @@ def _construct_itir_domain_arg( ) domain_args_kind.append(dim.kind) - if self.grid_type == GridType.CARTESIAN: + if self.grid_type == common.GridType.CARTESIAN: domain_builtin = "cartesian_domain" - elif self.grid_type == GridType.UNSTRUCTURED: + elif self.grid_type == common.GridType.UNSTRUCTURED: domain_builtin = "unstructured_domain" # for no good reason, the domain arguments for unstructured need to be in order (horizontal, vertical) - if domain_args_kind[0] == DimensionKind.VERTICAL: + if domain_args_kind[0] == common.DimensionKind.VERTICAL: assert len(domain_args) == 2 - assert domain_args_kind[1] == DimensionKind.HORIZONTAL + assert domain_args_kind[1] == common.DimensionKind.HORIZONTAL domain_args[0], domain_args[1] = domain_args[1], domain_args[0] else: raise AssertionError() @@ -294,14 +373,14 @@ def _construct_itir_domain_arg( def _construct_itir_initialized_domain_arg( self, dim_i: int, - dim: Dimension, + dim: common.Dimension, node_domain: past.Dict, ) -> list[itir.FunCall]: assert len(node_domain.values_[dim_i].elts) == 2 keys_dims_types = cast(ts.DimensionType, node_domain.keys_[dim_i].type).dim if keys_dims_types != dim: raise ValueError( - "Dimensions in out field and field domain are not equivalent:" + "common.Dimensions in out field and field domain are not equivalent:" f"expected '{dim}', got '{keys_dims_types}'." ) diff --git a/src/gt4py/next/ffront/transform_utils.py b/src/gt4py/next/ffront/transform_utils.py new file mode 100644 index 0000000000..987598b21d --- /dev/null +++ b/src/gt4py/next/ffront/transform_utils.py @@ -0,0 +1,86 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import collections +from typing import Any, Iterable, Optional + +from gt4py.next import common +from gt4py.next.ffront import fbuiltins +from gt4py.next.ffront.gtcallable import GTCallable + + +def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any]: + all_closure_vars = collections.ChainMap(closure_vars) + + for closure_var in closure_vars.values(): + if isinstance(closure_var, GTCallable): + # if the closure ref has closure refs by itself, also add them + if child_closure_vars := closure_var.__gt_closure_vars__(): + all_child_closure_vars = _get_closure_vars_recursively(child_closure_vars) + + collisions: list[str] = [] + for potential_collision in set(closure_vars) & set(all_child_closure_vars): + if ( + closure_vars[potential_collision] + != all_child_closure_vars[potential_collision] + ): + collisions.append(potential_collision) + if collisions: + raise NotImplementedError( + f"Using closure vars with same name but different value " + f"across functions is not implemented yet. \n" + f"Collisions: '{', '.join(collisions)}'." + ) + + all_closure_vars = collections.ChainMap(all_closure_vars, all_child_closure_vars) + return dict(all_closure_vars) + + +def _filter_closure_vars_by_type(closure_vars: dict[str, Any], *types: type) -> dict[str, Any]: + return {name: value for name, value in closure_vars.items() if isinstance(value, types)} + + +def _deduce_grid_type( + requested_grid_type: Optional[common.GridType], + offsets_and_dimensions: Iterable[fbuiltins.FieldOffset | common.Dimension], +) -> common.GridType: + """ + Derive grid type from actually occurring dimensions and check against optional user request. + + Unstructured grid type is consistent with any kind of offset, cartesian + is easier to optimize for but only allowed in the absence of unstructured + dimensions and offsets. + """ + + def is_cartesian_offset(o: fbuiltins.FieldOffset): + return len(o.target) == 1 and o.source == o.target[0] + + deduced_grid_type = common.GridType.CARTESIAN + for o in offsets_and_dimensions: + if isinstance(o, fbuiltins.FieldOffset) and not is_cartesian_offset(o): + deduced_grid_type = common.GridType.UNSTRUCTURED + break + if isinstance(o, common.Dimension) and o.kind == common.DimensionKind.LOCAL: + deduced_grid_type = common.GridType.UNSTRUCTURED + break + + if ( + requested_grid_type == common.GridType.CARTESIAN + and deduced_grid_type == common.GridType.UNSTRUCTURED + ): + raise ValueError( + "'grid_type == GridType.CARTESIAN' was requested, but unstructured 'FieldOffset' or local 'Dimension' was found." + ) + + return deduced_grid_type if requested_grid_type is None else requested_grid_type diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index 5de4839b55..8209c6dd41 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -91,7 +91,11 @@ def __call__(self, *args, backend: Optional[ProgramExecutor] = None, **kwargs): if backend is not None: ensure_processor_kind(backend, ProgramExecutor) - backend(self.itir(*args, **kwargs), *args, **kwargs) + backend( + self.itir(*args, **kwargs), + *args, + **kwargs, + ) else: if fendef_embedded is None: raise RuntimeError("Embedded execution is not registered.") diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index e9c7f49c26..88c1b44792 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -17,6 +17,8 @@ import dataclasses from typing import Any, Generic, Optional, Protocol, TypeVar +from gt4py.next import common +from gt4py.next.ffront import program_ast as past from gt4py.next.iterator import ir as itir from gt4py.next.otf import languages from gt4py.next.otf.binding import interface @@ -30,6 +32,15 @@ SettingT_co = TypeVar("SettingT_co", bound=languages.LanguageSettings, covariant=True) +@dataclasses.dataclass(frozen=True) +class PastClosure: + closure_vars: dict[str, Any] + past_node: past.Program + grid_type: common.GridType + args: tuple[Any, ...] + kwargs: dict[str, Any] + + @dataclasses.dataclass(frozen=True) class ProgramCall: """Iterator IR representaion of a program together with arguments to be passed to it.""" diff --git a/src/gt4py/next/program_processors/modular_executor.py b/src/gt4py/next/program_processors/modular_executor.py index 19708a656b..b8032c17b8 100644 --- a/src/gt4py/next/program_processors/modular_executor.py +++ b/src/gt4py/next/program_processors/modular_executor.py @@ -15,17 +15,11 @@ from __future__ import annotations import dataclasses -from typing import Any, Optional, TypeVar +from typing import Any, Optional -import gt4py.next.iterator.ir as itir import gt4py.next.program_processors.processor_interface as ppi -from gt4py.next.otf import languages, stages, workflow - - -SrcL = TypeVar("SrcL", bound=languages.NanobindSrcL) -TgtL = TypeVar("TgtL", bound=languages.LanguageTag) -LS = TypeVar("LS", bound=languages.LanguageSettings) -HashT = TypeVar("HashT") +from gt4py.next.iterator import ir as itir +from gt4py.next.otf import stages, workflow @dataclasses.dataclass(frozen=True) @@ -34,10 +28,14 @@ class ModularExecutor(ppi.ProgramExecutor): name: Optional[str] = None def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None: - self.otf_workflow(stages.ProgramCall(program, args, kwargs))( + self.otf_workflow(stages.ProgramCall(program=program, args=args, kwargs=kwargs))( *args, offset_provider=kwargs["offset_provider"] ) @property def __name__(self) -> str: return self.name or repr(self) + + @property + def kind(self) -> type[ppi.ProgramExecutor]: + return ppi.ProgramExecutor diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py index 37921cdd35..36870252c5 100644 --- a/src/gt4py/next/program_processors/processor_interface.py +++ b/src/gt4py/next/program_processors/processor_interface.py @@ -237,8 +237,10 @@ class ProgramBackend( def is_program_backend(obj: Callable) -> TypeGuard[ProgramBackend]: + if not hasattr(obj, "executor"): + return False return is_processor_kind( - obj, + obj.executor, ProgramExecutor, # type: ignore[type-abstract] # ProgramExecutor is abstract ) and next_allocators.is_field_allocator_factory(obj) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index ae0cfbfee8..0b6f18f5d3 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -26,7 +26,7 @@ import gt4py.next.allocators as next_allocators import gt4py.next.iterator.ir as itir import gt4py.next.program_processors.processor_interface as ppi -from gt4py.next import backend, common +from gt4py.next import backend as next_backend, common from gt4py.next.iterator import transforms as itir_transforms from gt4py.next.otf.compilation import cache as compilation_cache from gt4py.next.type_system import type_specifications as ts, type_translation @@ -356,7 +356,8 @@ def build_sdfg_from_itir( def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): # build parameters build_cache = kwargs.get("build_cache", None) - compiler_args = kwargs.get("compiler_args", None) # `None` will take default. + # `None` will take default. + compiler_args = kwargs.get("compiler_args", None) build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", _default_on_gpu) auto_optimize = kwargs.get("auto_optimize", True) @@ -436,7 +437,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: ) -run_dace_cpu = backend.Backend( +run_dace_cpu = next_backend.Backend( executor=ppi.program_executor(_run_dace_cpu, name="run_dace_cpu"), allocator=next_allocators.StandardCPUFieldBufferAllocator(), ) @@ -459,7 +460,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: raise RuntimeError("Missing 'cupy' dependency for GPU execution.") -run_dace_gpu = backend.Backend( +run_dace_gpu = next_backend.Backend( executor=ppi.program_executor(_run_dace_gpu, name="run_dace_gpu"), allocator=next_allocators.StandardGPUFieldBufferAllocator(), ) diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index 3662020200..e37fb65891 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -14,23 +14,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any - -import gt4py.next.program_processors.processor_interface as ppi from gt4py.next import backend as next_backend from gt4py.next.program_processors.runners import roundtrip -if TYPE_CHECKING: - import gt4py.next.iterator.ir as itir - - -@ppi.program_executor -def executor(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: - roundtrip.execute_roundtrip(program, *args, dispatch_backend=roundtrip.executor, **kwargs) - - backend = next_backend.Backend( - executor=executor, + executor=roundtrip.RoundtripExecutorFactory( + dispatch_backend=roundtrip.RoundtripExecutorFactory(), + ), allocator=roundtrip.backend.allocator, ) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 0f07b5519a..eb6c4e9d9e 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -14,6 +14,7 @@ from __future__ import annotations +import dataclasses import importlib.util import pathlib import tempfile @@ -21,16 +22,15 @@ from collections.abc import Callable, Iterable from typing import Any, Optional -import gt4py.next.allocators as next_allocators -import gt4py.next.common as common -import gt4py.next.iterator.embedded as embedded -import gt4py.next.iterator.ir as itir -import gt4py.next.iterator.transforms as itir_transforms -import gt4py.next.iterator.transforms.global_tmps as gtmps_transform -import gt4py.next.program_processors.processor_interface as ppi +import factory + from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako -from gt4py.next import backend as next_backend +from gt4py.next import allocators as next_allocators, backend as next_backend, common +from gt4py.next.iterator import embedded, ir as itir, transforms as itir_transforms +from gt4py.next.iterator.transforms import global_tmps as gtmps_transform +from gt4py.next.otf import stages, workflow +from gt4py.next.program_processors import modular_executor, processor_interface as ppi def _create_tmp(axes, origin, shape, dtype): @@ -200,6 +200,7 @@ def fencil_generator( return fencil +@ppi.program_executor # type: ignore[arg-type] def execute_roundtrip( ir: itir.Node, *args, @@ -227,8 +228,55 @@ def execute_roundtrip( return fencil(*args, **new_kwargs) -executor = ppi.program_executor(execute_roundtrip) # type: ignore[arg-type] +@dataclasses.dataclass(frozen=True) +class Roundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): + debug: bool = False + lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE + use_embedded: bool = True + + def __call__(self, inp: stages.ProgramCall) -> stages.CompiledProgram: + return fencil_generator( + inp.program, + offset_provider=inp.kwargs.get("offset_provider", None), + debug=self.debug, + lift_mode=self.lift_mode, + use_embedded=self.use_embedded, + ) + + +class RoundtripFactory(factory.Factory): + class Meta: + model = Roundtrip + + +@dataclasses.dataclass(frozen=True) +class RoundtripExecutor(modular_executor.ModularExecutor): + dispatch_backend: Optional[ppi.ProgramExecutor] = None + + def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> None: + kwargs["backend"] = self.dispatch_backend + self.otf_workflow(stages.ProgramCall(program=program, args=args, kwargs=kwargs))( + *args, **kwargs + ) + + +class RoundtripExecutorFactory(factory.Factory): + class Meta: + model = RoundtripExecutor + + class Params: + use_embedded = factory.LazyAttribute(lambda o: o.dispatch_backend is None) + roundtrip_workflow = factory.SubFactory( + RoundtripFactory, use_embedded=factory.SelfAttribute("..use_embedded") + ) + + dispatch_backend = None + otf_workflow = factory.LazyAttribute(lambda o: o.roundtrip_workflow) + + +executor = RoundtripExecutorFactory(name="roundtrip") backend = next_backend.Backend( - executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator() + executor=executor, + allocator=next_allocators.StandardCPUFieldBufferAllocator(), ) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index fd11a421c0..88c1d31e9d 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -667,7 +667,7 @@ def function_signature_incompatibilities_func( and not is_concretizable(a_arg, to_type=b_arg) ): if i < len(func_type.pos_only_args): - arg_repr = f"{_number_to_ordinal_number(i+1)} argument" + arg_repr = f"{_number_to_ordinal_number(i + 1)} argument" else: arg_repr = f"argument '{list(func_type.pos_or_kw_args.keys())[i - len(func_type.pos_only_args)]}'" yield f"Expected {arg_repr} to be of type '{a_arg}', got '{b_arg}'." diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 3456a3cede..f178a5752f 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later from dataclasses import dataclass -from typing import Optional +from typing import Iterator, Optional from gt4py.eve.type_definitions import IntEnum from gt4py.next import common as func_common @@ -99,6 +99,9 @@ class TupleType(DataType): def __str__(self): return f"tuple[{', '.join(map(str, self.types))}]" + def __iter__(self) -> Iterator[DataType]: + yield from self.types + @dataclass(frozen=True) class FieldType(DataType, CallableType): diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index c11c1ac256..4b50e21260 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -28,9 +28,14 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self -from gt4py.next import allocators as next_allocators, common, constructors, field_utils +from gt4py.next import ( + allocators as next_allocators, + backend as next_backend, + common, + constructors, + field_utils, +) from gt4py.next.ffront import decorator -from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_specifications as ts, type_translation from next_tests import definitions as test_definitions @@ -477,7 +482,7 @@ def cartesian_case( exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor, ): yield Case( - exec_alloc_descriptor.executor, + exec_alloc_descriptor if exec_alloc_descriptor.executor else None, offset_provider={"Ioff": IDim, "Joff": JDim, "Koff": KDim}, default_sizes={IDim: 10, JDim: 10, KDim: 10}, grid_type=common.GridType.CARTESIAN, @@ -491,7 +496,7 @@ def unstructured_case( exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor, ): yield Case( - exec_alloc_descriptor.executor, + exec_alloc_descriptor if exec_alloc_descriptor.executor else None, offset_provider=mesh_descriptor.offset_provider, default_sizes={ Vertex: mesh_descriptor.num_vertices, @@ -601,7 +606,7 @@ def get_default_data( class Case: """Parametrizable components for single feature integration tests.""" - executor: Optional[ppi.ProgramProcessor] + executor: Optional[next_backend.Backend] offset_provider: dict[str, common.Connectivity | gtx.Dimension] default_sizes: dict[gtx.Dimension, int] grid_type: common.GridType diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 22aec8f838..7260ef15af 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -20,7 +20,7 @@ import pytest import gt4py.next as gtx -from gt4py.next import common +from gt4py.next import backend as next_backend, common from gt4py.next.ffront import decorator from gt4py.next.iterator import ir as itir from gt4py.next.program_processors import processor_interface as ppi @@ -39,11 +39,19 @@ @ppi.program_executor -def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: +def no_exec(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: """Temporary default backend to not accidentally test the wrong backend.""" raise ValueError("No backend selected! Backend selection is mandatory in tests.") +class NoBackend(next_backend.Backend): + def __call__(self, program) -> None: + raise ValueError("No backend selected! Backend selection is mandatory in tests.") + + +no_backend = NoBackend(executor=no_exec, transformer=None, allocator=None) + + OPTIONAL_PROCESSORS = [] if dace_iterator: OPTIONAL_PROCESSORS.append(next_tests.definitions.OptionalProgramBackendId.DACE_CPU) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index e5a821de52..b8d9841616 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -286,7 +286,7 @@ def test_call_bound_program_with_wrong_args(cartesian_case, bound_args_testee): out = cases.allocate(cartesian_case, bound_args_testee, "out")() with pytest.raises(TypeError) as exc_info: - program_with_bound_arg(out, offset_provider={}) + program_with_bound_arg.with_backend(cartesian_case.executor)(out, offset_provider={}) assert ( re.search( @@ -302,7 +302,9 @@ def test_call_bound_program_with_already_bound_arg(cartesian_case, bound_args_te out = cases.allocate(cartesian_case, bound_args_testee, "out")() with pytest.raises(TypeError) as exc_info: - program_with_bound_arg(True, out, arg2=True, offset_provider={}) + program_with_bound_arg.with_backend(cartesian_case.executor)( + True, out, arg2=True, offset_provider={} + ) assert ( re.search( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index eb9e4275ff..d571c61590 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -22,7 +22,9 @@ astype, broadcast, common, + constructors, errors, + field_utils, float32, float64, int32, @@ -33,6 +35,7 @@ ) from gt4py.next.ffront.experimental import as_offset from gt4py.next.program_processors.runners import gtfn +from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -47,6 +50,7 @@ Koff, V2EDim, Vertex, + Edge, cartesian_case, unstructured_case, ) @@ -636,7 +640,7 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_lift_expressions @pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): - if cartesian_case.executor == gtfn.run_gtfn_with_temporaries.executor: + if cartesian_case.executor == gtfn.run_gtfn_with_temporaries: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 1bf640f93d..191f8ee739 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -77,7 +77,7 @@ def prog( def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh_descriptor): unstructured_case = Case( - run_gtfn_with_temporaries_and_symbolic_sizes.executor, + run_gtfn_with_temporaries_and_symbolic_sizes, offset_provider=mesh_descriptor.offset_provider, default_sizes={ Vertex: mesh_descriptor.num_vertices, diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py index 5c80d9e415..b82cea4b22 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py @@ -64,5 +64,5 @@ def test_cartesian_offset_provider(): fencil(out, inp, backend=roundtrip.executor) assert out[0][0] == 42 - fencil(out, inp, backend=double_roundtrip.executor) + fencil(out, inp, backend=double_roundtrip.backend.executor) assert out[0][0] == 42 diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index f1a5b41f81..8da95712f4 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -195,7 +195,7 @@ def reference( @pytest.fixture def test_setup(exec_alloc_descriptor): test_case = cases.Case( - exec_alloc_descriptor.executor, + exec_alloc_descriptor if exec_alloc_descriptor.executor else None, offset_provider={"Koff": KDim}, default_sizes={Cell: 14, KDim: 10}, grid_type=common.GridType.UNSTRUCTURED, @@ -249,13 +249,13 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): if ( test_setup.case.executor - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor + == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() ): pytest.xfail( "Needs implementation of scan projector. Breaks in type inference as executed" "again after CollapseTuple." ) - if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load().executor: + if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load(): pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") cases.verify( @@ -276,7 +276,7 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): def test_solve_nonhydro_stencil_52_like(test_setup): if ( test_setup.case.executor - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor + == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() ): pytest.xfail("Temporary extraction does not work correctly in combination with scans.") @@ -298,10 +298,10 @@ def test_solve_nonhydro_stencil_52_like(test_setup): def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): if ( test_setup.case.executor - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor + == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() ): pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load().executor: + if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load(): pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") cases.run( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 9a1bc6deb6..bcea9e0901 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -83,9 +83,9 @@ def test_anton_toy(program_processor, lift_mode): program_processor, validate = program_processor if program_processor in [ - gtfn.run_gtfn, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, + gtfn.run_gtfn.executor, + gtfn.run_gtfn_imperative.executor, + gtfn.run_gtfn_with_temporaries.executor, ]: from gt4py.next.iterator import transforms diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 9bba1ab89c..5d369c3a8f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -75,9 +75,9 @@ def hdiff(inp, coeff, out, x, y): def test_hdiff(hdiff_reference, program_processor, lift_mode): program_processor, validate = program_processor if program_processor in [ - gtfn.run_gtfn, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, + gtfn.run_gtfn.executor, + gtfn.run_gtfn_imperative.executor, + gtfn.run_gtfn_with_temporaries.executor, ]: # TODO(tehrengruber): check if still true from gt4py.next.iterator import transforms diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index b28c98ab38..46038832d1 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -119,15 +119,18 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): if ( program_processor in [ - gtfn.run_gtfn, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, + gtfn.run_gtfn.executor, + gtfn.run_gtfn_imperative.executor, + gtfn.run_gtfn_with_temporaries.executor, gtfn_formatters.format_cpp, ] and lift_mode == LiftMode.FORCE_INLINE ): pytest.skip("gtfn does only support lifted scans when using temporaries") - if program_processor == gtfn.run_gtfn_with_temporaries or lift_mode == LiftMode.USE_TEMPORARIES: + if ( + program_processor == gtfn.run_gtfn_with_temporaries.executor + or lift_mode == LiftMode.USE_TEMPORARIES + ): pytest.xfail("tuple_get on columns not supported.") a, b, c, d, x = tridiag_reference shape = a.shape diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 17418a9ca6..c9406884e6 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -88,6 +88,8 @@ def program_processor(request) -> tuple[ppi.ProgramProcessor, bool]: processor = processor_id.load() assert is_backend == ppi.is_program_backend(processor) + if is_backend: + processor = processor.executor for marker, skip_mark, msg in next_tests.definitions.BACKEND_SKIP_TEST_MATRIX.get( processor_id, [] diff --git a/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py index 75e23545de..a586f09038 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py @@ -15,7 +15,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.ffront.decorator import _deduce_grid_type +from gt4py.next.ffront.transform_utils import _deduce_grid_type Dim = gtx.Dimension("Dim") diff --git a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py index 1ba35da7c6..dc09d9fe76 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py +++ b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import dataclasses + import pytest import gt4py.next.allocators as next_allocators @@ -95,8 +97,13 @@ class DummyAllocatorFactory: assert not is_program_backend(DummyAllocatorFactory()) - class DummyBackend(DummyProgramExecutor, DummyAllocatorFactory): - def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> None: - return None + @dataclasses.dataclass + class DummyBackend: + executor: DummyProgramExecutor = dataclasses.field(default_factory=DummyProgramExecutor) + allocator: DummyAllocatorFactory = dataclasses.field(default_factory=DummyAllocatorFactory) + + @property + def __gt_allocator__(self): + return self.allocator.__gt_allocator__ assert is_program_backend(DummyBackend()) diff --git a/tox.ini b/tox.ini index 4da10f7d8d..ff67764b05 100644 --- a/tox.ini +++ b/tox.ini @@ -43,7 +43,7 @@ extras = cuda12x: cuda12x package = wheel wheel_build_env = .pkg -pass_env = NUM_PROCESSES +pass_env = NUM_PROCESSES, GT4PY_BUILD_CACHE_LIFETIME, GT4PY_BUILD_CACHE_DIR set_env = PYTHONWARNINGS = {env:PYTHONWARNINGS:ignore:Support for `[tool.setuptools]` in `pyproject.toml` is still *beta*:UserWarning}