diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py
index 33fe96e7ad..fc01123b62 100644
--- a/src/gt4py/next/backend.py
+++ b/src/gt4py/next/backend.py
@@ -15,30 +15,34 @@
from __future__ import annotations
import dataclasses
-from typing import Any, Generic
+from typing import Generic
from gt4py._core import definitions as core_defs
from gt4py.next import allocators as next_allocators
-from gt4py.next.iterator import ir as itir
+from gt4py.next.ffront import past_process_args, past_to_itir
+from gt4py.next.otf import stages, workflow
from gt4py.next.program_processors import processor_interface as ppi
+DEFAULT_TRANSFORMS: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = (
+ past_process_args.past_process_args.chain(past_to_itir.PastToItirFactory())
+)
+
+
@dataclasses.dataclass(frozen=True)
class Backend(Generic[core_defs.DeviceTypeT]):
executor: ppi.ProgramExecutor
allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]
+ transformer: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = DEFAULT_TRANSFORMS
- def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None:
- self.executor.__call__(program, *args, **kwargs)
+ def __call__(self, program: stages.PastClosure) -> None:
+ program_call = self.transformer(program)
+ self.executor(program_call.program, *program_call.args, **program_call.kwargs)
@property
def __name__(self) -> str:
return getattr(self.executor, "__name__", None) or repr(self)
- @property
- def kind(self) -> type[ppi.ProgramExecutor]:
- return self.executor.kind
-
@property
def __gt_allocator__(
self,
diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py
index 60e96a6cc2..4dd7c6e399 100644
--- a/src/gt4py/next/ffront/decorator.py
+++ b/src/gt4py/next/ffront/decorator.py
@@ -18,31 +18,35 @@
from __future__ import annotations
-import collections
import dataclasses
import functools
import types
import typing
import warnings
-from collections.abc import Callable, Iterable
+from collections.abc import Callable
from typing import Generator, Generic, TypeVar
-from devtools import debug
-
from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Any, Optional
-from gt4py.next import allocators as next_allocators, embedded as next_embedded, errors
-from gt4py.next.common import Dimension, DimensionKind, GridType
+from gt4py.next import (
+ allocators as next_allocators,
+ backend as next_backend,
+ embedded as next_embedded,
+ errors,
+)
+from gt4py.next.common import Dimension, GridType
from gt4py.next.embedded import operators as embedded_operators
from gt4py.next.ffront import (
dialect_ast_enums,
field_operator_ast as foast,
+ past_process_args,
+ past_to_itir,
program_ast as past,
+ transform_utils,
type_specifications as ts_ffront,
)
-from gt4py.next.ffront.fbuiltins import FieldOffset
from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction
from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering
from gt4py.next.ffront.func_to_foast import FieldOperatorParser
@@ -52,7 +56,6 @@
ClosureVarTypeDeduction as ProgramClosureVarTypeDeduction,
)
from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction
-from gt4py.next.ffront.past_to_itir import ProgramLowering
from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils.ir_makers import (
@@ -61,6 +64,7 @@
ref,
sym,
)
+from gt4py.next.otf import stages
from gt4py.next.program_processors import processor_interface as ppi
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
@@ -68,69 +72,6 @@
DEFAULT_BACKEND: Callable = None
-def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any]:
- all_closure_vars = collections.ChainMap(closure_vars)
-
- for closure_var in closure_vars.values():
- if isinstance(closure_var, GTCallable):
- # if the closure ref has closure refs by itself, also add them
- if child_closure_vars := closure_var.__gt_closure_vars__():
- all_child_closure_vars = _get_closure_vars_recursively(child_closure_vars)
-
- collisions: list[str] = []
- for potential_collision in set(closure_vars) & set(all_child_closure_vars):
- if (
- closure_vars[potential_collision]
- != all_child_closure_vars[potential_collision]
- ):
- collisions.append(potential_collision)
- if collisions:
- raise NotImplementedError(
- f"Using closure vars with same name but different value "
- f"across functions is not implemented yet. \n"
- f"Collisions: '{', '.join(collisions)}'."
- )
-
- all_closure_vars = collections.ChainMap(all_closure_vars, all_child_closure_vars)
- return dict(all_closure_vars)
-
-
-def _filter_closure_vars_by_type(closure_vars: dict[str, Any], *types: type) -> dict[str, Any]:
- return {name: value for name, value in closure_vars.items() if isinstance(value, types)}
-
-
-def _deduce_grid_type(
- requested_grid_type: Optional[GridType],
- offsets_and_dimensions: Iterable[FieldOffset | Dimension],
-) -> GridType:
- """
- Derive grid type from actually occurring dimensions and check against optional user request.
-
- Unstructured grid type is consistent with any kind of offset, cartesian
- is easier to optimize for but only allowed in the absence of unstructured
- dimensions and offsets.
- """
-
- def is_cartesian_offset(o: FieldOffset):
- return len(o.target) == 1 and o.source == o.target[0]
-
- deduced_grid_type = GridType.CARTESIAN
- for o in offsets_and_dimensions:
- if isinstance(o, FieldOffset) and not is_cartesian_offset(o):
- deduced_grid_type = GridType.UNSTRUCTURED
- break
- if isinstance(o, Dimension) and o.kind == DimensionKind.LOCAL:
- deduced_grid_type = GridType.UNSTRUCTURED
- break
-
- if requested_grid_type == GridType.CARTESIAN and deduced_grid_type == GridType.UNSTRUCTURED:
- raise ValueError(
- "'grid_type == GridType.CARTESIAN' was requested, but unstructured 'FieldOffset' or local 'Dimension' was found."
- )
-
- return deduced_grid_type if requested_grid_type is None else requested_grid_type
-
-
def _field_constituents_shape_and_dims(
arg, arg_type: ts.FieldType | ts.ScalarType | ts.TupleType
) -> Generator[tuple[tuple[int, ...], list[Dimension]]]:
@@ -152,10 +93,6 @@ def _field_constituents_shape_and_dims(
# TODO(tehrengruber): Decide if and how programs can call other programs. As a
# result Program could become a GTCallable.
-# TODO(ricoh): factor out the generated ITIR together with arguments rewriting
-# so that using fencil processors on `some_program.itir` becomes trivial without
-# prior knowledge of the fencil signature rewriting done by `Program`.
-# After that, drop the `.format_itir()` method, since it won't be needed.
@dataclasses.dataclass(frozen=True)
class Program:
"""
@@ -177,14 +114,14 @@ class Program:
past_node: past.Program
closure_vars: dict[str, Any]
definition: Optional[types.FunctionType]
- backend: Optional[ppi.ProgramExecutor]
+ backend: Optional[next_backend.Backend]
grid_type: Optional[GridType]
@classmethod
def from_function(
cls,
definition: types.FunctionType,
- backend: Optional[ppi.ProgramExecutor],
+ backend: Optional[next_backend],
grid_type: Optional[GridType] = None,
) -> Program:
source_def = SourceDefinition.from_function(definition)
@@ -200,7 +137,9 @@ def from_function(
)
def __post_init__(self):
- function_closure_vars = _filter_closure_vars_by_type(self.closure_vars, GTCallable)
+ function_closure_vars = transform_utils._filter_closure_vars_by_type(
+ self.closure_vars, GTCallable
+ )
misnamed_functions = [
f"{name} vs. {func.id}"
for name, func in function_closure_vars.items()
@@ -276,24 +215,21 @@ def with_bound_args(self, **kwargs) -> ProgramWithBoundArgs:
@functools.cached_property
def _all_closure_vars(self) -> dict[str, Any]:
- return _get_closure_vars_recursively(self.closure_vars)
+ return transform_utils._get_closure_vars_recursively(self.closure_vars)
@functools.cached_property
def itir(self) -> itir.FencilDefinition:
- offsets_and_dimensions = _filter_closure_vars_by_type(
- self._all_closure_vars, FieldOffset, Dimension
- )
- grid_type = _deduce_grid_type(self.grid_type, offsets_and_dimensions.values())
-
- gt_callables = _filter_closure_vars_by_type(self._all_closure_vars, GTCallable).values()
- lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables]
- return ProgramLowering.apply(
- self.past_node, function_definitions=lowered_funcs, grid_type=grid_type
- )
+ return past_to_itir.PastToItirFactory()(
+ stages.PastClosure(
+ past_node=self.past_node,
+ closure_vars=self.closure_vars,
+ grid_type=self.grid_type,
+ args=[],
+ kwargs={},
+ )
+ ).program
def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> None:
- rewritten_args, size_args, kwargs = self._process_args(args, kwargs)
-
if self.backend is None:
warnings.warn(
UserWarning(
@@ -302,122 +238,31 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No
stacklevel=2,
)
with next_embedded.context.new_context(offset_provider=offset_provider) as ctx:
+ # TODO(ricoh): check if rewriting still needed
+ rewritten_args, size_args, kwargs = past_process_args._process_args(
+ self.past_node, args, kwargs
+ )
ctx.run(self.definition, *rewritten_args, **kwargs)
return
- ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor)
- if "debug" in kwargs:
- debug(self.itir)
+ ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor)
self.backend(
- self.itir,
- *rewritten_args,
- *size_args,
- **kwargs,
- offset_provider=offset_provider,
- column_axis=self._column_axis,
- )
-
- def format_itir(
- self,
- *args,
- formatter: ppi.ProgramFormatter,
- offset_provider: dict[str, Dimension],
- **kwargs,
- ) -> str:
- ppi.ensure_processor_kind(formatter, ppi.ProgramFormatter)
- rewritten_args, size_args, kwargs = self._process_args(args, kwargs)
- if "debug" in kwargs:
- debug(self.itir)
- return formatter(
- self.itir,
- *rewritten_args,
- *size_args,
- **kwargs,
- offset_provider=offset_provider,
- )
-
- def _validate_args(self, *args, **kwargs) -> None:
- arg_types = [type_translation.from_value(arg) for arg in args]
- kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()}
-
- try:
- type_info.accepts_args(
- self.past_node.type,
- with_args=arg_types,
- with_kwargs=kwarg_types,
- raise_exception=True,
+ stages.PastClosure(
+ closure_vars=self.closure_vars,
+ past_node=self.past_node,
+ grid_type=self.grid_type,
+ args=args,
+ kwargs=kwargs | {"offset_provider": offset_provider},
)
- except ValueError as err:
- raise errors.DSLError(
- None, f"Invalid argument types in call to '{self.past_node.id}'.\n{err}"
- ) from err
-
- def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]:
- self._validate_args(*args, **kwargs)
-
- args, kwargs = type_info.canonicalize_arguments(self.past_node.type, args, kwargs)
-
- implicit_domain = any(
- isinstance(stmt, past.Call) and "domain" not in stmt.kwargs
- for stmt in self.past_node.body
)
- # extract size of all field arguments
- size_args: list[Optional[tuple[int, ...]]] = []
- rewritten_args = list(args)
- for param_idx, param in enumerate(self.past_node.params):
- if implicit_domain and isinstance(param.type, (ts.FieldType, ts.TupleType)):
- shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)]
- shape, dims = shapes_and_dims[0]
- if not all(
- el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims
- ):
- raise ValueError(
- "Constituents of composite arguments (e.g. the elements of a"
- " tuple) need to have the same shape and dimensions."
- )
- size_args.extend(shape if shape else [None] * len(dims))
- return tuple(rewritten_args), tuple(size_args), kwargs
-
- @functools.cached_property
- def _column_axis(self):
- # construct mapping from column axis to scan operators defined on
- # that dimension. only one column axis is allowed, but we can use
- # this mapping to provide good error messages.
- scanops_per_axis: dict[Dimension, str] = {}
- for name, gt_callable in _filter_closure_vars_by_type(
- self._all_closure_vars, GTCallable
- ).items():
- if isinstance(
- (type_ := gt_callable.__gt_type__()),
- ts_ffront.ScanOperatorType,
- ):
- scanops_per_axis.setdefault(type_.axis, []).append(name)
-
- if len(scanops_per_axis.values()) == 0:
- return None
-
- if len(scanops_per_axis.values()) != 1:
- scanops_per_axis_strs = [
- f"- {dim.value}: {', '.join(scanops)}" for dim, scanops in scanops_per_axis.items()
- ]
-
- raise TypeError(
- "Only 'ScanOperator's defined on the same axis "
- + "can be used in a 'Program', found:\n"
- + "\n".join(scanops_per_axis_strs)
- + "."
- )
-
- return iter(scanops_per_axis.keys()).__next__()
-
@dataclasses.dataclass(frozen=True)
class ProgramWithBoundArgs(Program):
bound_args: dict[str, typing.Union[float, int, bool]] = None
- def _process_args(self, args: tuple, kwargs: dict):
+ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs):
type_ = self.past_node.type
new_type = ts_ffront.ProgramType(
definition=ts.FunctionType(
@@ -468,7 +313,7 @@ def _process_args(self, args: tuple, kwargs: dict):
else:
full_kwargs[str(param.id)] = self.bound_args[param.id]
- return super()._process_args(tuple(full_args), full_kwargs)
+ return super().__call__(*tuple(full_args), offset_provider=offset_provider, **full_kwargs)
@functools.cached_property
def itir(self):
@@ -507,7 +352,8 @@ def program(
def program(
definition=None,
*,
- backend=eve.NOTHING, # `NOTHING` -> default backend, `None` -> no backend (embedded execution)
+ # `NOTHING` -> default backend, `None` -> no backend (embedded execution)
+ backend=eve.NOTHING,
grid_type=None,
) -> Program | Callable[[types.FunctionType], Program]:
"""
@@ -653,7 +499,8 @@ def as_program(
pass
loc = self.foast_node.location
- param_sym_uids = eve_utils.UIDGenerator() # use a new UID generator to allow caching
+ # use a new UID generator to allow caching
+ param_sym_uids = eve_utils.UIDGenerator()
type_ = self.__gt_type__()
params_decl: list[past.Symbol] = [
@@ -711,6 +558,7 @@ def as_program(
backend=self.backend,
grid_type=self.grid_type,
)
+
return self._program_cache[hash_]
def __call__(
diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py
new file mode 100644
index 0000000000..22661502e9
--- /dev/null
+++ b/src/gt4py/next/ffront/past_process_args.py
@@ -0,0 +1,109 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from typing import Any, Iterator, Optional
+
+from gt4py.next import common, errors
+from gt4py.next.ffront import program_ast as past, type_specifications as ts_ffront
+from gt4py.next.otf import stages, workflow
+from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
+
+
+@workflow.make_step
+def past_process_args(inp: stages.PastClosure) -> stages.PastClosure:
+ extra_kwarg_names = ["offset_provider", "column_axis"]
+ extra_kwargs = {k: v for k, v in inp.kwargs.items() if k in extra_kwarg_names}
+ kwargs = {k: v for k, v in inp.kwargs.items() if k not in extra_kwarg_names}
+ rewritten_args, size_args, kwargs = _process_args(
+ past_node=inp.past_node, args=list(inp.args), kwargs=kwargs
+ )
+ return stages.PastClosure(
+ past_node=inp.past_node,
+ closure_vars=inp.closure_vars,
+ grid_type=inp.grid_type,
+ args=tuple([*rewritten_args, *size_args]),
+ kwargs=kwargs | extra_kwargs,
+ )
+
+
+def _validate_args(past_node: past.Program, args: list, kwargs: dict[str, Any]) -> None:
+ arg_types = [type_translation.from_value(arg) for arg in args]
+ kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()}
+
+ if not isinstance(past_node.type, ts_ffront.ProgramType):
+ raise TypeError("Can not validate arguments for PAST programs prior to type inference.")
+
+ try:
+ type_info.accepts_args(
+ past_node.type,
+ with_args=arg_types,
+ with_kwargs=kwarg_types,
+ raise_exception=True,
+ )
+ except ValueError as err:
+ raise errors.DSLError(
+ None, f"Invalid argument types in call to '{past_node.id}'.\n{err}"
+ ) from err
+
+
+def _process_args(
+ past_node: past.Program, args: list, kwargs: dict[str, Any]
+) -> tuple[tuple, tuple, dict[str, Any]]:
+ if not isinstance(past_node.type, ts_ffront.ProgramType):
+ raise TypeError("Can not process arguments for PAST programs prior to type inference.")
+
+ _validate_args(past_node=past_node, args=args, kwargs=kwargs)
+
+ args, kwargs = type_info.canonicalize_arguments(past_node.type, args, kwargs)
+
+ implicit_domain = any(
+ isinstance(stmt, past.Call) and "domain" not in stmt.kwargs for stmt in past_node.body
+ )
+
+ # extract size of all field arguments
+ size_args: list[Optional[int]] = []
+ rewritten_args = list(args)
+ for param_idx, param in enumerate(past_node.params):
+ if implicit_domain and isinstance(param.type, (ts.FieldType, ts.TupleType)):
+ shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)]
+ shape, dims = shapes_and_dims[0]
+ if not all(
+ el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims
+ ):
+ raise ValueError(
+ "Constituents of composite arguments (e.g. the elements of a"
+ " tuple) need to have the same shape and dimensions."
+ )
+ size_args.extend(shape if shape else [None] * len(dims))
+ return tuple(rewritten_args), tuple(size_args), kwargs
+
+
+def _field_constituents_shape_and_dims(
+ arg, arg_type: ts.DataType
+) -> Iterator[tuple[tuple[int, ...], list[common.Dimension]]]:
+ match arg_type:
+ case ts.TupleType():
+ for el, el_type in zip(arg, arg_type.types):
+ yield from _field_constituents_shape_and_dims(el, el_type)
+ case ts.FieldType():
+ dims = type_info.extract_dims(arg_type)
+ if hasattr(arg, "shape"):
+ assert len(arg.shape) == len(dims)
+ yield (arg.shape, dims)
+ else:
+ yield (tuple(), dims)
+ case ts.ScalarType():
+ yield (tuple(), [])
+ case _:
+ raise ValueError("Expected 'FieldType' or 'TupleType' thereof.")
diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py
index 99534b4e61..0fc9a6280d 100644
--- a/src/gt4py/next/ffront/past_to_itir.py
+++ b/src/gt4py/next/ffront/past_to_itir.py
@@ -14,16 +14,95 @@
from __future__ import annotations
-from typing import Optional, cast
+import dataclasses
+from typing import Any, Optional, cast
+
+import devtools
+import factory
from gt4py.eve import NodeTranslator, concepts, traits
-from gt4py.next.common import Dimension, DimensionKind, GridType
-from gt4py.next.ffront import lowering_utils, program_ast as past, type_specifications as ts_ffront
+from gt4py.next import common, config
+from gt4py.next.ffront import (
+ fbuiltins,
+ gtcallable,
+ lowering_utils,
+ program_ast as past,
+ transform_utils,
+ type_specifications as ts_ffront,
+)
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
+from gt4py.next.otf import stages, workflow
from gt4py.next.type_system import type_info, type_specifications as ts
+@dataclasses.dataclass(frozen=True)
+class PastToItir(workflow.ChainableWorkflowMixin):
+ def __call__(self, inp: stages.PastClosure) -> stages.ProgramCall:
+ all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars)
+ offsets_and_dimensions = transform_utils._filter_closure_vars_by_type(
+ all_closure_vars, fbuiltins.FieldOffset, common.Dimension
+ )
+ grid_type = transform_utils._deduce_grid_type(
+ inp.grid_type, offsets_and_dimensions.values()
+ )
+
+ gt_callables = transform_utils._filter_closure_vars_by_type(
+ all_closure_vars, gtcallable.GTCallable
+ ).values()
+ lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables]
+
+ itir_program = ProgramLowering.apply(
+ inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type
+ )
+
+ if config.DEBUG or "debug" in inp.kwargs:
+ devtools.debug(itir_program)
+
+ return stages.ProgramCall(
+ itir_program,
+ inp.args,
+ inp.kwargs | {"column_axis": _column_axis(all_closure_vars)},
+ )
+
+
+class PastToItirFactory(factory.Factory):
+ class Meta:
+ model = PastToItir
+
+
+def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]:
+ # construct mapping from column axis to scan operators defined on
+ # that dimension. only one column axis is allowed, but we can use
+ # this mapping to provide good error messages.
+ scanops_per_axis: dict[common.Dimension, list[str]] = {}
+ for name, gt_callable in transform_utils._filter_closure_vars_by_type(
+ all_closure_vars, gtcallable.GTCallable
+ ).items():
+ if isinstance(
+ (type_ := gt_callable.__gt_type__()),
+ ts_ffront.ScanOperatorType,
+ ):
+ scanops_per_axis.setdefault(type_.axis, []).append(name)
+
+ if len(scanops_per_axis.values()) == 0:
+ return None
+
+ if len(scanops_per_axis.values()) != 1:
+ scanops_per_axis_strs = [
+ f"- {dim.value}: {', '.join(scanops)}" for dim, scanops in scanops_per_axis.items()
+ ]
+
+ raise TypeError(
+ "Only 'ScanOperator's defined on the same axis "
+ + "can be used in a 'Program', found:\n"
+ + "\n".join(scanops_per_axis_strs)
+ + "."
+ )
+
+ return iter(scanops_per_axis.keys()).__next__()
+
+
def _size_arg_from_field(field_name: str, dim: int) -> str:
return f"__{field_name}_size_{dim}"
@@ -67,7 +146,7 @@ class ProgramLowering(
... expr=ir.FunCall(fun=ir.SymRef(id="deref"), pos_only_args=[ir.SymRef(id="inp")]),
... ) # doctest: +SKIP
>>> lowered = ProgramLowering.apply(
- ... parsed, [fieldop_def], grid_type=GridType.CARTESIAN
+ ... parsed, [fieldop_def], grid_type=common.GridType.CARTESIAN
... ) # doctest: +SKIP
>>> type(lowered) # doctest: +SKIP
@@ -85,7 +164,7 @@ def apply(
cls,
node: past.Program,
function_definitions: list[itir.FunctionDefinition],
- grid_type: GridType,
+ grid_type: common.GridType,
) -> itir.FencilDefinition:
return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions)
@@ -97,7 +176,7 @@ def _gen_size_params_from_program(self, node: past.Program):
size_params = []
for param in node.params:
if type_info.is_type_or_tuple_of_type(param.type, ts.FieldType):
- fields_dims: list[list[Dimension]] = (
+ fields_dims: list[list[common.Dimension]] = (
type_info.primitive_constituents(param.type).getattr("dims").to_list()
)
assert all(field_dims == fields_dims[0] for field_dims in fields_dims)
@@ -263,8 +342,8 @@ def _construct_itir_domain_arg(
slices[dim_i].upper if slices else None, dim_size, dim_size
)
- if dim.kind == DimensionKind.LOCAL:
- raise ValueError(f"Dimension '{dim.value}' must not be local.")
+ if dim.kind == common.DimensionKind.LOCAL:
+ raise ValueError(f"common.Dimension '{dim.value}' must not be local.")
domain_args.append(
itir.FunCall(
fun=itir.SymRef(id="named_range"),
@@ -273,14 +352,14 @@ def _construct_itir_domain_arg(
)
domain_args_kind.append(dim.kind)
- if self.grid_type == GridType.CARTESIAN:
+ if self.grid_type == common.GridType.CARTESIAN:
domain_builtin = "cartesian_domain"
- elif self.grid_type == GridType.UNSTRUCTURED:
+ elif self.grid_type == common.GridType.UNSTRUCTURED:
domain_builtin = "unstructured_domain"
# for no good reason, the domain arguments for unstructured need to be in order (horizontal, vertical)
- if domain_args_kind[0] == DimensionKind.VERTICAL:
+ if domain_args_kind[0] == common.DimensionKind.VERTICAL:
assert len(domain_args) == 2
- assert domain_args_kind[1] == DimensionKind.HORIZONTAL
+ assert domain_args_kind[1] == common.DimensionKind.HORIZONTAL
domain_args[0], domain_args[1] = domain_args[1], domain_args[0]
else:
raise AssertionError()
@@ -294,14 +373,14 @@ def _construct_itir_domain_arg(
def _construct_itir_initialized_domain_arg(
self,
dim_i: int,
- dim: Dimension,
+ dim: common.Dimension,
node_domain: past.Dict,
) -> list[itir.FunCall]:
assert len(node_domain.values_[dim_i].elts) == 2
keys_dims_types = cast(ts.DimensionType, node_domain.keys_[dim_i].type).dim
if keys_dims_types != dim:
raise ValueError(
- "Dimensions in out field and field domain are not equivalent:"
+ "common.Dimensions in out field and field domain are not equivalent:"
f"expected '{dim}', got '{keys_dims_types}'."
)
diff --git a/src/gt4py/next/ffront/transform_utils.py b/src/gt4py/next/ffront/transform_utils.py
new file mode 100644
index 0000000000..987598b21d
--- /dev/null
+++ b/src/gt4py/next/ffront/transform_utils.py
@@ -0,0 +1,86 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import collections
+from typing import Any, Iterable, Optional
+
+from gt4py.next import common
+from gt4py.next.ffront import fbuiltins
+from gt4py.next.ffront.gtcallable import GTCallable
+
+
+def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any]:
+ all_closure_vars = collections.ChainMap(closure_vars)
+
+ for closure_var in closure_vars.values():
+ if isinstance(closure_var, GTCallable):
+ # if the closure ref has closure refs by itself, also add them
+ if child_closure_vars := closure_var.__gt_closure_vars__():
+ all_child_closure_vars = _get_closure_vars_recursively(child_closure_vars)
+
+ collisions: list[str] = []
+ for potential_collision in set(closure_vars) & set(all_child_closure_vars):
+ if (
+ closure_vars[potential_collision]
+ != all_child_closure_vars[potential_collision]
+ ):
+ collisions.append(potential_collision)
+ if collisions:
+ raise NotImplementedError(
+ f"Using closure vars with same name but different value "
+ f"across functions is not implemented yet. \n"
+ f"Collisions: '{', '.join(collisions)}'."
+ )
+
+ all_closure_vars = collections.ChainMap(all_closure_vars, all_child_closure_vars)
+ return dict(all_closure_vars)
+
+
+def _filter_closure_vars_by_type(closure_vars: dict[str, Any], *types: type) -> dict[str, Any]:
+ return {name: value for name, value in closure_vars.items() if isinstance(value, types)}
+
+
+def _deduce_grid_type(
+ requested_grid_type: Optional[common.GridType],
+ offsets_and_dimensions: Iterable[fbuiltins.FieldOffset | common.Dimension],
+) -> common.GridType:
+ """
+ Derive grid type from actually occurring dimensions and check against optional user request.
+
+ Unstructured grid type is consistent with any kind of offset, cartesian
+ is easier to optimize for but only allowed in the absence of unstructured
+ dimensions and offsets.
+ """
+
+ def is_cartesian_offset(o: fbuiltins.FieldOffset):
+ return len(o.target) == 1 and o.source == o.target[0]
+
+ deduced_grid_type = common.GridType.CARTESIAN
+ for o in offsets_and_dimensions:
+ if isinstance(o, fbuiltins.FieldOffset) and not is_cartesian_offset(o):
+ deduced_grid_type = common.GridType.UNSTRUCTURED
+ break
+ if isinstance(o, common.Dimension) and o.kind == common.DimensionKind.LOCAL:
+ deduced_grid_type = common.GridType.UNSTRUCTURED
+ break
+
+ if (
+ requested_grid_type == common.GridType.CARTESIAN
+ and deduced_grid_type == common.GridType.UNSTRUCTURED
+ ):
+ raise ValueError(
+ "'grid_type == GridType.CARTESIAN' was requested, but unstructured 'FieldOffset' or local 'Dimension' was found."
+ )
+
+ return deduced_grid_type if requested_grid_type is None else requested_grid_type
diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py
index 5de4839b55..8209c6dd41 100644
--- a/src/gt4py/next/iterator/runtime.py
+++ b/src/gt4py/next/iterator/runtime.py
@@ -91,7 +91,11 @@ def __call__(self, *args, backend: Optional[ProgramExecutor] = None, **kwargs):
if backend is not None:
ensure_processor_kind(backend, ProgramExecutor)
- backend(self.itir(*args, **kwargs), *args, **kwargs)
+ backend(
+ self.itir(*args, **kwargs),
+ *args,
+ **kwargs,
+ )
else:
if fendef_embedded is None:
raise RuntimeError("Embedded execution is not registered.")
diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py
index e9c7f49c26..88c1b44792 100644
--- a/src/gt4py/next/otf/stages.py
+++ b/src/gt4py/next/otf/stages.py
@@ -17,6 +17,8 @@
import dataclasses
from typing import Any, Generic, Optional, Protocol, TypeVar
+from gt4py.next import common
+from gt4py.next.ffront import program_ast as past
from gt4py.next.iterator import ir as itir
from gt4py.next.otf import languages
from gt4py.next.otf.binding import interface
@@ -30,6 +32,15 @@
SettingT_co = TypeVar("SettingT_co", bound=languages.LanguageSettings, covariant=True)
+@dataclasses.dataclass(frozen=True)
+class PastClosure:
+ closure_vars: dict[str, Any]
+ past_node: past.Program
+ grid_type: common.GridType
+ args: tuple[Any, ...]
+ kwargs: dict[str, Any]
+
+
@dataclasses.dataclass(frozen=True)
class ProgramCall:
"""Iterator IR representaion of a program together with arguments to be passed to it."""
diff --git a/src/gt4py/next/program_processors/modular_executor.py b/src/gt4py/next/program_processors/modular_executor.py
index 19708a656b..b8032c17b8 100644
--- a/src/gt4py/next/program_processors/modular_executor.py
+++ b/src/gt4py/next/program_processors/modular_executor.py
@@ -15,17 +15,11 @@
from __future__ import annotations
import dataclasses
-from typing import Any, Optional, TypeVar
+from typing import Any, Optional
-import gt4py.next.iterator.ir as itir
import gt4py.next.program_processors.processor_interface as ppi
-from gt4py.next.otf import languages, stages, workflow
-
-
-SrcL = TypeVar("SrcL", bound=languages.NanobindSrcL)
-TgtL = TypeVar("TgtL", bound=languages.LanguageTag)
-LS = TypeVar("LS", bound=languages.LanguageSettings)
-HashT = TypeVar("HashT")
+from gt4py.next.iterator import ir as itir
+from gt4py.next.otf import stages, workflow
@dataclasses.dataclass(frozen=True)
@@ -34,10 +28,14 @@ class ModularExecutor(ppi.ProgramExecutor):
name: Optional[str] = None
def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None:
- self.otf_workflow(stages.ProgramCall(program, args, kwargs))(
+ self.otf_workflow(stages.ProgramCall(program=program, args=args, kwargs=kwargs))(
*args, offset_provider=kwargs["offset_provider"]
)
@property
def __name__(self) -> str:
return self.name or repr(self)
+
+ @property
+ def kind(self) -> type[ppi.ProgramExecutor]:
+ return ppi.ProgramExecutor
diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py
index 37921cdd35..36870252c5 100644
--- a/src/gt4py/next/program_processors/processor_interface.py
+++ b/src/gt4py/next/program_processors/processor_interface.py
@@ -237,8 +237,10 @@ class ProgramBackend(
def is_program_backend(obj: Callable) -> TypeGuard[ProgramBackend]:
+ if not hasattr(obj, "executor"):
+ return False
return is_processor_kind(
- obj,
+ obj.executor,
ProgramExecutor, # type: ignore[type-abstract] # ProgramExecutor is abstract
) and next_allocators.is_field_allocator_factory(obj)
diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
index ae0cfbfee8..0b6f18f5d3 100644
--- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
+++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
@@ -26,7 +26,7 @@
import gt4py.next.allocators as next_allocators
import gt4py.next.iterator.ir as itir
import gt4py.next.program_processors.processor_interface as ppi
-from gt4py.next import backend, common
+from gt4py.next import backend as next_backend, common
from gt4py.next.iterator import transforms as itir_transforms
from gt4py.next.otf.compilation import cache as compilation_cache
from gt4py.next.type_system import type_specifications as ts, type_translation
@@ -356,7 +356,8 @@ def build_sdfg_from_itir(
def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
# build parameters
build_cache = kwargs.get("build_cache", None)
- compiler_args = kwargs.get("compiler_args", None) # `None` will take default.
+ # `None` will take default.
+ compiler_args = kwargs.get("compiler_args", None)
build_type = kwargs.get("build_type", "RelWithDebInfo")
on_gpu = kwargs.get("on_gpu", _default_on_gpu)
auto_optimize = kwargs.get("auto_optimize", True)
@@ -436,7 +437,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
)
-run_dace_cpu = backend.Backend(
+run_dace_cpu = next_backend.Backend(
executor=ppi.program_executor(_run_dace_cpu, name="run_dace_cpu"),
allocator=next_allocators.StandardCPUFieldBufferAllocator(),
)
@@ -459,7 +460,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
raise RuntimeError("Missing 'cupy' dependency for GPU execution.")
-run_dace_gpu = backend.Backend(
+run_dace_gpu = next_backend.Backend(
executor=ppi.program_executor(_run_dace_gpu, name="run_dace_gpu"),
allocator=next_allocators.StandardGPUFieldBufferAllocator(),
)
diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py
index 3662020200..e37fb65891 100644
--- a/src/gt4py/next/program_processors/runners/double_roundtrip.py
+++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py
@@ -14,23 +14,13 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any
-
-import gt4py.next.program_processors.processor_interface as ppi
from gt4py.next import backend as next_backend
from gt4py.next.program_processors.runners import roundtrip
-if TYPE_CHECKING:
- import gt4py.next.iterator.ir as itir
-
-
-@ppi.program_executor
-def executor(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None:
- roundtrip.execute_roundtrip(program, *args, dispatch_backend=roundtrip.executor, **kwargs)
-
-
backend = next_backend.Backend(
- executor=executor,
+ executor=roundtrip.RoundtripExecutorFactory(
+ dispatch_backend=roundtrip.RoundtripExecutorFactory(),
+ ),
allocator=roundtrip.backend.allocator,
)
diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py
index 0f07b5519a..eb6c4e9d9e 100644
--- a/src/gt4py/next/program_processors/runners/roundtrip.py
+++ b/src/gt4py/next/program_processors/runners/roundtrip.py
@@ -14,6 +14,7 @@
from __future__ import annotations
+import dataclasses
import importlib.util
import pathlib
import tempfile
@@ -21,16 +22,15 @@
from collections.abc import Callable, Iterable
from typing import Any, Optional
-import gt4py.next.allocators as next_allocators
-import gt4py.next.common as common
-import gt4py.next.iterator.embedded as embedded
-import gt4py.next.iterator.ir as itir
-import gt4py.next.iterator.transforms as itir_transforms
-import gt4py.next.iterator.transforms.global_tmps as gtmps_transform
-import gt4py.next.program_processors.processor_interface as ppi
+import factory
+
from gt4py.eve import codegen
from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako
-from gt4py.next import backend as next_backend
+from gt4py.next import allocators as next_allocators, backend as next_backend, common
+from gt4py.next.iterator import embedded, ir as itir, transforms as itir_transforms
+from gt4py.next.iterator.transforms import global_tmps as gtmps_transform
+from gt4py.next.otf import stages, workflow
+from gt4py.next.program_processors import modular_executor, processor_interface as ppi
def _create_tmp(axes, origin, shape, dtype):
@@ -200,6 +200,7 @@ def fencil_generator(
return fencil
+@ppi.program_executor # type: ignore[arg-type]
def execute_roundtrip(
ir: itir.Node,
*args,
@@ -227,8 +228,55 @@ def execute_roundtrip(
return fencil(*args, **new_kwargs)
-executor = ppi.program_executor(execute_roundtrip) # type: ignore[arg-type]
+@dataclasses.dataclass(frozen=True)
+class Roundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]):
+ debug: bool = False
+ lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE
+ use_embedded: bool = True
+
+ def __call__(self, inp: stages.ProgramCall) -> stages.CompiledProgram:
+ return fencil_generator(
+ inp.program,
+ offset_provider=inp.kwargs.get("offset_provider", None),
+ debug=self.debug,
+ lift_mode=self.lift_mode,
+ use_embedded=self.use_embedded,
+ )
+
+
+class RoundtripFactory(factory.Factory):
+ class Meta:
+ model = Roundtrip
+
+
+@dataclasses.dataclass(frozen=True)
+class RoundtripExecutor(modular_executor.ModularExecutor):
+ dispatch_backend: Optional[ppi.ProgramExecutor] = None
+
+ def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> None:
+ kwargs["backend"] = self.dispatch_backend
+ self.otf_workflow(stages.ProgramCall(program=program, args=args, kwargs=kwargs))(
+ *args, **kwargs
+ )
+
+
+class RoundtripExecutorFactory(factory.Factory):
+ class Meta:
+ model = RoundtripExecutor
+
+ class Params:
+ use_embedded = factory.LazyAttribute(lambda o: o.dispatch_backend is None)
+ roundtrip_workflow = factory.SubFactory(
+ RoundtripFactory, use_embedded=factory.SelfAttribute("..use_embedded")
+ )
+
+ dispatch_backend = None
+ otf_workflow = factory.LazyAttribute(lambda o: o.roundtrip_workflow)
+
+
+executor = RoundtripExecutorFactory(name="roundtrip")
backend = next_backend.Backend(
- executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator()
+ executor=executor,
+ allocator=next_allocators.StandardCPUFieldBufferAllocator(),
)
diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py
index fd11a421c0..88c1d31e9d 100644
--- a/src/gt4py/next/type_system/type_info.py
+++ b/src/gt4py/next/type_system/type_info.py
@@ -667,7 +667,7 @@ def function_signature_incompatibilities_func(
and not is_concretizable(a_arg, to_type=b_arg)
):
if i < len(func_type.pos_only_args):
- arg_repr = f"{_number_to_ordinal_number(i+1)} argument"
+ arg_repr = f"{_number_to_ordinal_number(i + 1)} argument"
else:
arg_repr = f"argument '{list(func_type.pos_or_kw_args.keys())[i - len(func_type.pos_only_args)]}'"
yield f"Expected {arg_repr} to be of type '{a_arg}', got '{b_arg}'."
diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py
index 3456a3cede..f178a5752f 100644
--- a/src/gt4py/next/type_system/type_specifications.py
+++ b/src/gt4py/next/type_system/type_specifications.py
@@ -13,7 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
from dataclasses import dataclass
-from typing import Optional
+from typing import Iterator, Optional
from gt4py.eve.type_definitions import IntEnum
from gt4py.next import common as func_common
@@ -99,6 +99,9 @@ class TupleType(DataType):
def __str__(self):
return f"tuple[{', '.join(map(str, self.types))}]"
+ def __iter__(self) -> Iterator[DataType]:
+ yield from self.types
+
@dataclass(frozen=True)
class FieldType(DataType, CallableType):
diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py
index c11c1ac256..4b50e21260 100644
--- a/tests/next_tests/integration_tests/cases.py
+++ b/tests/next_tests/integration_tests/cases.py
@@ -28,9 +28,14 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
from gt4py.eve.extended_typing import Self
-from gt4py.next import allocators as next_allocators, common, constructors, field_utils
+from gt4py.next import (
+ allocators as next_allocators,
+ backend as next_backend,
+ common,
+ constructors,
+ field_utils,
+)
from gt4py.next.ffront import decorator
-from gt4py.next.program_processors import processor_interface as ppi
from gt4py.next.type_system import type_specifications as ts, type_translation
from next_tests import definitions as test_definitions
@@ -477,7 +482,7 @@ def cartesian_case(
exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor,
):
yield Case(
- exec_alloc_descriptor.executor,
+ exec_alloc_descriptor if exec_alloc_descriptor.executor else None,
offset_provider={"Ioff": IDim, "Joff": JDim, "Koff": KDim},
default_sizes={IDim: 10, JDim: 10, KDim: 10},
grid_type=common.GridType.CARTESIAN,
@@ -491,7 +496,7 @@ def unstructured_case(
exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor,
):
yield Case(
- exec_alloc_descriptor.executor,
+ exec_alloc_descriptor if exec_alloc_descriptor.executor else None,
offset_provider=mesh_descriptor.offset_provider,
default_sizes={
Vertex: mesh_descriptor.num_vertices,
@@ -601,7 +606,7 @@ def get_default_data(
class Case:
"""Parametrizable components for single feature integration tests."""
- executor: Optional[ppi.ProgramProcessor]
+ executor: Optional[next_backend.Backend]
offset_provider: dict[str, common.Connectivity | gtx.Dimension]
default_sizes: dict[gtx.Dimension, int]
grid_type: common.GridType
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
index 22aec8f838..7260ef15af 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
@@ -20,7 +20,7 @@
import pytest
import gt4py.next as gtx
-from gt4py.next import common
+from gt4py.next import backend as next_backend, common
from gt4py.next.ffront import decorator
from gt4py.next.iterator import ir as itir
from gt4py.next.program_processors import processor_interface as ppi
@@ -39,11 +39,19 @@
@ppi.program_executor
-def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None:
+def no_exec(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None:
"""Temporary default backend to not accidentally test the wrong backend."""
raise ValueError("No backend selected! Backend selection is mandatory in tests.")
+class NoBackend(next_backend.Backend):
+ def __call__(self, program) -> None:
+ raise ValueError("No backend selected! Backend selection is mandatory in tests.")
+
+
+no_backend = NoBackend(executor=no_exec, transformer=None, allocator=None)
+
+
OPTIONAL_PROCESSORS = []
if dace_iterator:
OPTIONAL_PROCESSORS.append(next_tests.definitions.OptionalProgramBackendId.DACE_CPU)
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py
index e5a821de52..b8d9841616 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py
@@ -286,7 +286,7 @@ def test_call_bound_program_with_wrong_args(cartesian_case, bound_args_testee):
out = cases.allocate(cartesian_case, bound_args_testee, "out")()
with pytest.raises(TypeError) as exc_info:
- program_with_bound_arg(out, offset_provider={})
+ program_with_bound_arg.with_backend(cartesian_case.executor)(out, offset_provider={})
assert (
re.search(
@@ -302,7 +302,9 @@ def test_call_bound_program_with_already_bound_arg(cartesian_case, bound_args_te
out = cases.allocate(cartesian_case, bound_args_testee, "out")()
with pytest.raises(TypeError) as exc_info:
- program_with_bound_arg(True, out, arg2=True, offset_provider={})
+ program_with_bound_arg.with_backend(cartesian_case.executor)(
+ True, out, arg2=True, offset_provider={}
+ )
assert (
re.search(
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
index eb9e4275ff..d571c61590 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
@@ -22,7 +22,9 @@
astype,
broadcast,
common,
+ constructors,
errors,
+ field_utils,
float32,
float64,
int32,
@@ -33,6 +35,7 @@
)
from gt4py.next.ffront.experimental import as_offset
from gt4py.next.program_processors.runners import gtfn
+from gt4py.next.type_system import type_specifications as ts
from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
@@ -47,6 +50,7 @@
Koff,
V2EDim,
Vertex,
+ Edge,
cartesian_case,
unstructured_case,
)
@@ -636,7 +640,7 @@ def simple_scan_operator(carry: float) -> float:
@pytest.mark.uses_lift_expressions
@pytest.mark.uses_scan_nested
def test_solve_triag(cartesian_case):
- if cartesian_case.executor == gtfn.run_gtfn_with_temporaries.executor:
+ if cartesian_case.executor == gtfn.run_gtfn_with_temporaries:
pytest.xfail("Temporary extraction does not work correctly in combination with scans.")
@gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0))
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py
index 1bf640f93d..191f8ee739 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py
@@ -77,7 +77,7 @@ def prog(
def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh_descriptor):
unstructured_case = Case(
- run_gtfn_with_temporaries_and_symbolic_sizes.executor,
+ run_gtfn_with_temporaries_and_symbolic_sizes,
offset_provider=mesh_descriptor.offset_provider,
default_sizes={
Vertex: mesh_descriptor.num_vertices,
diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py
index 5c80d9e415..b82cea4b22 100644
--- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py
+++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py
@@ -64,5 +64,5 @@ def test_cartesian_offset_provider():
fencil(out, inp, backend=roundtrip.executor)
assert out[0][0] == 42
- fencil(out, inp, backend=double_roundtrip.executor)
+ fencil(out, inp, backend=double_roundtrip.backend.executor)
assert out[0][0] == 42
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py
index f1a5b41f81..8da95712f4 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py
@@ -195,7 +195,7 @@ def reference(
@pytest.fixture
def test_setup(exec_alloc_descriptor):
test_case = cases.Case(
- exec_alloc_descriptor.executor,
+ exec_alloc_descriptor if exec_alloc_descriptor.executor else None,
offset_provider={"Koff": KDim},
default_sizes={Cell: 14, KDim: 10},
grid_type=common.GridType.UNSTRUCTURED,
@@ -249,13 +249,13 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup):
def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup):
if (
test_setup.case.executor
- == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor
+ == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load()
):
pytest.xfail(
"Needs implementation of scan projector. Breaks in type inference as executed"
"again after CollapseTuple."
)
- if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load().executor:
+ if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load():
pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].")
cases.verify(
@@ -276,7 +276,7 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup):
def test_solve_nonhydro_stencil_52_like(test_setup):
if (
test_setup.case.executor
- == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor
+ == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load()
):
pytest.xfail("Temporary extraction does not work correctly in combination with scans.")
@@ -298,10 +298,10 @@ def test_solve_nonhydro_stencil_52_like(test_setup):
def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup):
if (
test_setup.case.executor
- == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor
+ == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load()
):
pytest.xfail("Temporary extraction does not work correctly in combination with scans.")
- if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load().executor:
+ if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load():
pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].")
cases.run(
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py
index 9a1bc6deb6..bcea9e0901 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py
@@ -83,9 +83,9 @@ def test_anton_toy(program_processor, lift_mode):
program_processor, validate = program_processor
if program_processor in [
- gtfn.run_gtfn,
- gtfn.run_gtfn_imperative,
- gtfn.run_gtfn_with_temporaries,
+ gtfn.run_gtfn.executor,
+ gtfn.run_gtfn_imperative.executor,
+ gtfn.run_gtfn_with_temporaries.executor,
]:
from gt4py.next.iterator import transforms
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py
index 9bba1ab89c..5d369c3a8f 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py
@@ -75,9 +75,9 @@ def hdiff(inp, coeff, out, x, y):
def test_hdiff(hdiff_reference, program_processor, lift_mode):
program_processor, validate = program_processor
if program_processor in [
- gtfn.run_gtfn,
- gtfn.run_gtfn_imperative,
- gtfn.run_gtfn_with_temporaries,
+ gtfn.run_gtfn.executor,
+ gtfn.run_gtfn_imperative.executor,
+ gtfn.run_gtfn_with_temporaries.executor,
]:
# TODO(tehrengruber): check if still true
from gt4py.next.iterator import transforms
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py
index b28c98ab38..46038832d1 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py
@@ -119,15 +119,18 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode):
if (
program_processor
in [
- gtfn.run_gtfn,
- gtfn.run_gtfn_imperative,
- gtfn.run_gtfn_with_temporaries,
+ gtfn.run_gtfn.executor,
+ gtfn.run_gtfn_imperative.executor,
+ gtfn.run_gtfn_with_temporaries.executor,
gtfn_formatters.format_cpp,
]
and lift_mode == LiftMode.FORCE_INLINE
):
pytest.skip("gtfn does only support lifted scans when using temporaries")
- if program_processor == gtfn.run_gtfn_with_temporaries or lift_mode == LiftMode.USE_TEMPORARIES:
+ if (
+ program_processor == gtfn.run_gtfn_with_temporaries.executor
+ or lift_mode == LiftMode.USE_TEMPORARIES
+ ):
pytest.xfail("tuple_get on columns not supported.")
a, b, c, d, x = tridiag_reference
shape = a.shape
diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py
index 17418a9ca6..c9406884e6 100644
--- a/tests/next_tests/unit_tests/conftest.py
+++ b/tests/next_tests/unit_tests/conftest.py
@@ -88,6 +88,8 @@ def program_processor(request) -> tuple[ppi.ProgramProcessor, bool]:
processor = processor_id.load()
assert is_backend == ppi.is_program_backend(processor)
+ if is_backend:
+ processor = processor.executor
for marker, skip_mark, msg in next_tests.definitions.BACKEND_SKIP_TEST_MATRIX.get(
processor_id, []
diff --git a/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py
index 75e23545de..a586f09038 100644
--- a/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py
+++ b/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py
@@ -15,7 +15,7 @@
import pytest
import gt4py.next as gtx
-from gt4py.next.ffront.decorator import _deduce_grid_type
+from gt4py.next.ffront.transform_utils import _deduce_grid_type
Dim = gtx.Dimension("Dim")
diff --git a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py
index 1ba35da7c6..dc09d9fe76 100644
--- a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py
+++ b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py
@@ -12,6 +12,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
+import dataclasses
+
import pytest
import gt4py.next.allocators as next_allocators
@@ -95,8 +97,13 @@ class DummyAllocatorFactory:
assert not is_program_backend(DummyAllocatorFactory())
- class DummyBackend(DummyProgramExecutor, DummyAllocatorFactory):
- def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> None:
- return None
+ @dataclasses.dataclass
+ class DummyBackend:
+ executor: DummyProgramExecutor = dataclasses.field(default_factory=DummyProgramExecutor)
+ allocator: DummyAllocatorFactory = dataclasses.field(default_factory=DummyAllocatorFactory)
+
+ @property
+ def __gt_allocator__(self):
+ return self.allocator.__gt_allocator__
assert is_program_backend(DummyBackend())
diff --git a/tox.ini b/tox.ini
index 4da10f7d8d..ff67764b05 100644
--- a/tox.ini
+++ b/tox.ini
@@ -43,7 +43,7 @@ extras =
cuda12x: cuda12x
package = wheel
wheel_build_env = .pkg
-pass_env = NUM_PROCESSES
+pass_env = NUM_PROCESSES, GT4PY_BUILD_CACHE_LIFETIME, GT4PY_BUILD_CACHE_DIR
set_env =
PYTHONWARNINGS = {env:PYTHONWARNINGS:ignore:Support for `[tool.setuptools]` in `pyproject.toml` is still *beta*:UserWarning}