From e351cf7de89e46fa5902d32000b516403e1f9a43 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 28 Feb 2024 11:34:35 +0100 Subject: [PATCH 01/47] initial steps --- src/gt4py/next/ffront/decorator.py | 90 ++++------------- src/gt4py/next/otf/stages.py | 11 +++ src/gt4py/next/otf/transforms.py | 96 +++++++++++++++++++ .../next/program_processors/runners/gtfn.py | 5 +- 4 files changed, 131 insertions(+), 71 deletions(-) create mode 100644 src/gt4py/next/otf/transforms.py diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 0cf1611bb1..1468f7b1ab 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -61,76 +61,14 @@ ref, sym, ) -from gt4py.next.program_processors import processor_interface as ppi +from gt4py.next.otf import transforms as otf_transforms, stages +from gt4py.next.program_processors import processor_interface as ppi, otf_compile_executor from gt4py.next.type_system import type_info, type_specifications as ts, type_translation DEFAULT_BACKEND: Callable = None -def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any]: - all_closure_vars = collections.ChainMap(closure_vars) - - for closure_var in closure_vars.values(): - if isinstance(closure_var, GTCallable): - # if the closure ref has closure refs by itself, also add them - if child_closure_vars := closure_var.__gt_closure_vars__(): - all_child_closure_vars = _get_closure_vars_recursively(child_closure_vars) - - collisions: list[str] = [] - for potential_collision in set(closure_vars) & set(all_child_closure_vars): - if ( - closure_vars[potential_collision] - != all_child_closure_vars[potential_collision] - ): - collisions.append(potential_collision) - if collisions: - raise NotImplementedError( - f"Using closure vars with same name but different value " - f"across functions is not implemented yet. \n" - f"Collisions: '{', '.join(collisions)}'." - ) - - all_closure_vars = collections.ChainMap(all_closure_vars, all_child_closure_vars) - return dict(all_closure_vars) - - -def _filter_closure_vars_by_type(closure_vars: dict[str, Any], *types: type) -> dict[str, Any]: - return {name: value for name, value in closure_vars.items() if isinstance(value, types)} - - -def _deduce_grid_type( - requested_grid_type: Optional[GridType], - offsets_and_dimensions: Iterable[FieldOffset | Dimension], -) -> GridType: - """ - Derive grid type from actually occurring dimensions and check against optional user request. - - Unstructured grid type is consistent with any kind of offset, cartesian - is easier to optimize for but only allowed in the absence of unstructured - dimensions and offsets. - """ - - def is_cartesian_offset(o: FieldOffset): - return len(o.target) == 1 and o.source == o.target[0] - - deduced_grid_type = GridType.CARTESIAN - for o in offsets_and_dimensions: - if isinstance(o, FieldOffset) and not is_cartesian_offset(o): - deduced_grid_type = GridType.UNSTRUCTURED - break - if isinstance(o, Dimension) and o.kind == DimensionKind.LOCAL: - deduced_grid_type = GridType.UNSTRUCTURED - break - - if requested_grid_type == GridType.CARTESIAN and deduced_grid_type == GridType.UNSTRUCTURED: - raise ValueError( - "'grid_type == GridType.CARTESIAN' was requested, but unstructured 'FieldOffset' or local 'Dimension' was found." - ) - - return deduced_grid_type if requested_grid_type is None else requested_grid_type - - def _field_constituents_shape_and_dims( arg, arg_type: ts.FieldType | ts.ScalarType | ts.TupleType ) -> Generator[tuple[tuple[int, ...], list[Dimension]]]: @@ -200,7 +138,7 @@ def from_function( ) def __post_init__(self): - function_closure_vars = _filter_closure_vars_by_type(self.closure_vars, GTCallable) + function_closure_vars = otf_transforms._filter_closure_vars_by_type(self.closure_vars, GTCallable) misnamed_functions = [ f"{name} vs. {func.id}" for name, func in function_closure_vars.items() @@ -276,16 +214,16 @@ def with_bound_args(self, **kwargs) -> ProgramWithBoundArgs: @functools.cached_property def _all_closure_vars(self) -> dict[str, Any]: - return _get_closure_vars_recursively(self.closure_vars) + return otf_transforms._get_closure_vars_recursively(self.closure_vars) @functools.cached_property def itir(self) -> itir.FencilDefinition: - offsets_and_dimensions = _filter_closure_vars_by_type( + offsets_and_dimensions = otf_transforms._filter_closure_vars_by_type( self._all_closure_vars, FieldOffset, Dimension ) - grid_type = _deduce_grid_type(self.grid_type, offsets_and_dimensions.values()) + grid_type = otf_transforms._deduce_grid_type(self.grid_type, offsets_and_dimensions.values()) - gt_callables = _filter_closure_vars_by_type(self._all_closure_vars, GTCallable).values() + gt_callables = otf_transforms._filter_closure_vars_by_type(self._all_closure_vars, GTCallable).values() lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] return ProgramLowering.apply( self.past_node, function_definitions=lowered_funcs, grid_type=grid_type @@ -303,6 +241,20 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: ctx.run(self.definition, *rewritten_args, **kwargs) return + elif isinstance(self.backend, otf_compile_executor.OTFBackend): + self.backend( + stages.ProgramIRStage( + definition=self.definition, + past_node=self.past_node, + grid_type=self.grid_type + ), + *rewritten_args, + *size_args, + **kwargs, + offset_provider=offset_provider, + column_axis=self._column_axis, + ) + return ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor) if "debug" in kwargs: diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index bd7f59e7aa..01f9447901 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -15,9 +15,12 @@ from __future__ import annotations import dataclasses +import types from typing import Any, Generic, Optional, Protocol, TypeVar +from gt4py.next import common from gt4py.next.iterator import ir as itir +from gt4py.next.ffront import program_ast as past from gt4py.next.otf import languages from gt4py.next.otf.binding import interface @@ -29,6 +32,14 @@ TgtL_co = TypeVar("TgtL_co", bound=languages.LanguageTag, covariant=True) SettingT_co = TypeVar("SettingT_co", bound=languages.LanguageSettings, covariant=True) +@dataclasses.dataclass(frozen=True) +class ProgramIRStage: + """Program IR representation of a program together with the DSL function definition for it.""" + + definition: types.FunctionType + past_node: past.Program + grid_type: common.GridType + @dataclasses.dataclass(frozen=True) class ProgramCall: diff --git a/src/gt4py/next/otf/transforms.py b/src/gt4py/next/otf/transforms.py new file mode 100644 index 0000000000..945361065f --- /dev/null +++ b/src/gt4py/next/otf/transforms.py @@ -0,0 +1,96 @@ +import dataclasses +import collections + +import factory + +from gt4py.eve.extended_typing import Any, Optional +from gt4py.next import common +from collections.abc import Iterable +from gt4py.next.ffront.gtcallable import GTCallable +from gt4py.next.ffront.fbuiltins import FieldOffset +from gt4py.next.ffront.past_to_itir import ProgramLowering +from gt4py.next.iterator import ir as itir +from gt4py.next.otf import stages +from gt4py.next.ffront.source_utils import get_closure_vars_from_function + +def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any]: + all_closure_vars = collections.ChainMap(closure_vars) + + for closure_var in closure_vars.values(): + if isinstance(closure_var, GTCallable): + # if the closure ref has closure refs by itself, also add them + if child_closure_vars := closure_var.__gt_closure_vars__(): + all_child_closure_vars = _get_closure_vars_recursively(child_closure_vars) + + collisions: list[str] = [] + for potential_collision in set(closure_vars) & set(all_child_closure_vars): + if ( + closure_vars[potential_collision] + != all_child_closure_vars[potential_collision] + ): + collisions.append(potential_collision) + if collisions: + raise NotImplementedError( + f"Using closure vars with same name but different value " + f"across functions is not implemented yet. \n" + f"Collisions: '{', '.join(collisions)}'." + ) + + all_closure_vars = collections.ChainMap(all_closure_vars, all_child_closure_vars) + return dict(all_closure_vars) + +def _filter_closure_vars_by_type(closure_vars: dict[str, Any], *types: type) -> dict[str, Any]: + return {name: value for name, value in closure_vars.items() if isinstance(value, types)} + + +def _deduce_grid_type( + requested_grid_type: Optional[common.GridType], + offsets_and_dimensions: Iterable[FieldOffset | common.Dimension], +) -> common.GridType: + """ + Derive grid type from actually occurring dimensions and check against optional user request. + + Unstructured grid type is consistent with any kind of offset, cartesian + is easier to optimize for but only allowed in the absence of unstructured + dimensions and offsets. + """ + + def is_cartesian_offset(o: FieldOffset): + return len(o.target) == 1 and o.source == o.target[0] + + deduced_grid_type = common.GridType.CARTESIAN + for o in offsets_and_dimensions: + if isinstance(o, FieldOffset) and not is_cartesian_offset(o): + deduced_grid_type = common.GridType.UNSTRUCTURED + break + if isinstance(o, common.Dimension) and o.kind == common.DimensionKind.LOCAL: + deduced_grid_type = common.GridType.UNSTRUCTURED + break + + if requested_grid_type == common.GridType.CARTESIAN and deduced_grid_type == common.GridType.UNSTRUCTURED: + raise ValueError( + "'grid_type == GridType.CARTESIAN' was requested, but unstructured 'FieldOffset' or local 'Dimension' was found." + ) + + return deduced_grid_type if requested_grid_type is None else requested_grid_type + + +@dataclasses.dataclass(frozen=True) +class PastToItir: + def __call__(self, inp: stages.ProgramIRStage) -> itir.FencilDefinition: + closure_vars = _get_closure_vars_recursively(get_closure_vars_from_function(inp.definition)) + offsets_and_dimensions = _filter_closure_vars_by_type( + _get_closure_vars_recursively(closure_vars), FieldOffset, common.Dimension + ) + grid_type = _deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) + + gt_callables = _filter_closure_vars_by_type(_get_closure_vars_recursively(closure_vars), GTCallable).values() + lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + return ProgramLowering.apply( + inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type + ) + + +class PastToItirFactory(factory.Factory): + class Meta: + model = PastToItir diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 4a65f6d049..258f373d97 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -25,7 +25,7 @@ from gt4py.next import common, config from gt4py.next.iterator import transforms from gt4py.next.iterator.transforms import global_tmps -from gt4py.next.otf import recipes, stages, workflow +from gt4py.next.otf import recipes, stages, workflow, transforms as otf_transforms from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb @@ -180,9 +180,10 @@ class Params: name = factory.LazyAttribute( lambda o: f"run_gtfn_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}" ) + transform_workflow = factory.SubFactory(otf_transforms.PastToItirFactory) executor = factory.LazyAttribute( - lambda o: otf_compile_executor.OTFCompileExecutor(otf_workflow=o.otf_workflow, name=o.name) + lambda o: otf_compile_executor.OTFCompileExecutor(otf_workflow=workflow.StepSequence([o.transform_workflow, o.otf_workflow]), name=o.name) ) allocator = next_allocators.StandardCPUFieldBufferAllocator() From ca4768f1ab79c4fa00d52b0d1e74880020530e8c Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 28 Feb 2024 16:16:25 +0100 Subject: [PATCH 02/47] progress --- src/gt4py/next/ffront/decorator.py | 16 +++++++++------ src/gt4py/next/otf/stages.py | 7 +++++++ src/gt4py/next/otf/transforms.py | 20 ++++++++++++------- .../otf_compile_executor.py | 7 +++++-- .../next/program_processors/runners/gtfn.py | 2 +- 5 files changed, 36 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 1468f7b1ab..377645723c 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -241,12 +241,16 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: ctx.run(self.definition, *rewritten_args, **kwargs) return - elif isinstance(self.backend, otf_compile_executor.OTFBackend): + elif isinstance(self.backend, otf_compile_executor.OTFCompileExecutor): self.backend( - stages.ProgramIRStage( - definition=self.definition, - past_node=self.past_node, - grid_type=self.grid_type + stages.PastClosure( + pirs = stages.ProgramIRStage( + definition=self.definition, + past_node=self.past_node, + grid_type=self.grid_type + ), + args = rewritten_args, + kwargs = kwargs | {"offset_provider": offset_provider, "column_axis": self._column_axis} ), *rewritten_args, *size_args, @@ -337,7 +341,7 @@ def _column_axis(self): # that dimension. only one column axis is allowed, but we can use # this mapping to provide good error messages. scanops_per_axis: dict[Dimension, str] = {} - for name, gt_callable in _filter_closure_vars_by_type( + for name, gt_callable in otf_transforms._filter_closure_vars_by_type( self._all_closure_vars, GTCallable ).items(): if isinstance( diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 01f9447901..9efd0ba638 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -41,6 +41,13 @@ class ProgramIRStage: grid_type: common.GridType +@dataclasses.dataclass(frozen=True) +class PastClosure: + pirs: ProgramIRStage + args: tuple[Any, ...] + kwargs: dict[str, Any] + + @dataclasses.dataclass(frozen=True) class ProgramCall: """Iterator IR representaion of a program together with arguments to be passed to it.""" diff --git a/src/gt4py/next/otf/transforms.py b/src/gt4py/next/otf/transforms.py index 945361065f..6b9a43cf47 100644 --- a/src/gt4py/next/otf/transforms.py +++ b/src/gt4py/next/otf/transforms.py @@ -10,8 +10,10 @@ from gt4py.next.ffront.fbuiltins import FieldOffset from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.iterator import ir as itir -from gt4py.next.otf import stages +from gt4py.next.otf import stages, workflow from gt4py.next.ffront.source_utils import get_closure_vars_from_function +from gt4py.next.otf.stages import ProgramCall + def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any]: all_closure_vars = collections.ChainMap(closure_vars) @@ -76,18 +78,22 @@ def is_cartesian_offset(o: FieldOffset): @dataclasses.dataclass(frozen=True) -class PastToItir: - def __call__(self, inp: stages.ProgramIRStage) -> itir.FencilDefinition: - closure_vars = _get_closure_vars_recursively(get_closure_vars_from_function(inp.definition)) +class PastToItir(workflow.ChainableWorkflowMixin): + def __call__(self, inp: stages.PastClosure) -> itir.FencilDefinition: + closure_vars = _get_closure_vars_recursively(get_closure_vars_from_function(inp.pirs.definition)) offsets_and_dimensions = _filter_closure_vars_by_type( _get_closure_vars_recursively(closure_vars), FieldOffset, common.Dimension ) - grid_type = _deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) + grid_type = _deduce_grid_type(inp.pirs.grid_type, offsets_and_dimensions.values()) gt_callables = _filter_closure_vars_by_type(_get_closure_vars_recursively(closure_vars), GTCallable).values() lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] - return ProgramLowering.apply( - inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type + return ProgramCall( + ProgramLowering.apply( + inp.pirs.past_node, function_definitions=lowered_funcs, grid_type=grid_type + ), + inp.args, + inp.kwargs ) diff --git a/src/gt4py/next/program_processors/otf_compile_executor.py b/src/gt4py/next/program_processors/otf_compile_executor.py index 8dff34a35d..f8d9419c70 100644 --- a/src/gt4py/next/program_processors/otf_compile_executor.py +++ b/src/gt4py/next/program_processors/otf_compile_executor.py @@ -35,10 +35,13 @@ class OTFCompileExecutor(ppi.ProgramExecutor): otf_workflow: recipes.OTFCompileWorkflow name: Optional[str] = None - def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None: - self.otf_workflow(stages.ProgramCall(program, args, kwargs))( + def __call__(self, program: stages.ProgramIRStage, *args, **kwargs: Any) -> None: + self.otf_workflow(program)( *args, offset_provider=kwargs["offset_provider"] ) + # self.otf_workflow(stages.ProgramCall(program, args, kwargs))( + # *args, offset_provider=kwargs["offset_provider"] + # ) @property def __name__(self) -> str: diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 258f373d97..2d099ae82a 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -183,7 +183,7 @@ class Params: transform_workflow = factory.SubFactory(otf_transforms.PastToItirFactory) executor = factory.LazyAttribute( - lambda o: otf_compile_executor.OTFCompileExecutor(otf_workflow=workflow.StepSequence([o.transform_workflow, o.otf_workflow]), name=o.name) + lambda o: otf_compile_executor.OTFCompileExecutor(otf_workflow=o.transform_workflow.chain(o.otf_workflow), name=o.name) ) allocator = next_allocators.StandardCPUFieldBufferAllocator() From 4d05daaa9a29680818b60d811b1bf05af056c520 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 28 Feb 2024 16:43:07 +0100 Subject: [PATCH 03/47] fix for test --- src/gt4py/next/ffront/decorator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 377645723c..7f4da7f413 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -249,7 +249,7 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No past_node=self.past_node, grid_type=self.grid_type ), - args = rewritten_args, + args = [*rewritten_args, *size_args], kwargs = kwargs | {"offset_provider": offset_provider, "column_axis": self._column_axis} ), *rewritten_args, From ee6a764c689dc1608dddd76480a50367f12dfbaa Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 29 Feb 2024 09:22:53 +0100 Subject: [PATCH 04/47] ran pre-commit with ruff --- src/gt4py/next/ffront/decorator.py | 30 +++++++++------ src/gt4py/next/otf/stages.py | 3 +- src/gt4py/next/otf/transforms.py | 38 +++++++++++++++---- .../otf_compile_executor.py | 4 +- .../next/program_processors/runners/gtfn.py | 6 ++- 5 files changed, 55 insertions(+), 26 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 07deb5f042..aa55ce3355 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -18,13 +18,12 @@ from __future__ import annotations -import collections import dataclasses import functools import types import typing import warnings -from collections.abc import Callable, Iterable +from collections.abc import Callable from typing import Generator, Generic, TypeVar from devtools import debug @@ -34,7 +33,7 @@ from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional from gt4py.next import allocators as next_allocators, embedded as next_embedded, errors -from gt4py.next.common import Dimension, DimensionKind, GridType +from gt4py.next.common import Dimension, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( dialect_ast_enums, @@ -61,8 +60,8 @@ ref, sym, ) -from gt4py.next.otf import transforms as otf_transforms, stages -from gt4py.next.program_processors import processor_interface as ppi, otf_compile_executor +from gt4py.next.otf import stages, transforms as otf_transforms +from gt4py.next.program_processors import otf_compile_executor, processor_interface as ppi from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -138,7 +137,9 @@ def from_function( ) def __post_init__(self): - function_closure_vars = otf_transforms._filter_closure_vars_by_type(self.closure_vars, GTCallable) + function_closure_vars = otf_transforms._filter_closure_vars_by_type( + self.closure_vars, GTCallable + ) misnamed_functions = [ f"{name} vs. {func.id}" for name, func in function_closure_vars.items() @@ -221,9 +222,13 @@ def itir(self) -> itir.FencilDefinition: offsets_and_dimensions = otf_transforms._filter_closure_vars_by_type( self._all_closure_vars, FieldOffset, Dimension ) - grid_type = otf_transforms._deduce_grid_type(self.grid_type, offsets_and_dimensions.values()) + grid_type = otf_transforms._deduce_grid_type( + self.grid_type, offsets_and_dimensions.values() + ) - gt_callables = otf_transforms._filter_closure_vars_by_type(self._all_closure_vars, GTCallable).values() + gt_callables = otf_transforms._filter_closure_vars_by_type( + self._all_closure_vars, GTCallable + ).values() lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] return ProgramLowering.apply( self.past_node, function_definitions=lowered_funcs, grid_type=grid_type @@ -245,13 +250,14 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No elif isinstance(self.backend, otf_compile_executor.OTFCompileExecutor): self.backend( stages.PastClosure( - pirs = stages.ProgramIRStage( + pirs=stages.ProgramIRStage( definition=self.definition, past_node=self.past_node, - grid_type=self.grid_type + grid_type=self.grid_type, ), - args = [*rewritten_args, *size_args], - kwargs = kwargs | {"offset_provider": offset_provider, "column_axis": self._column_axis} + args=[*rewritten_args, *size_args], + kwargs=kwargs + | {"offset_provider": offset_provider, "column_axis": self._column_axis}, ), *rewritten_args, *size_args, diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 82f9584917..5e1a94ce81 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -19,8 +19,8 @@ from typing import Any, Generic, Optional, Protocol, TypeVar from gt4py.next import common -from gt4py.next.iterator import ir as itir from gt4py.next.ffront import program_ast as past +from gt4py.next.iterator import ir as itir from gt4py.next.otf import languages from gt4py.next.otf.binding import interface @@ -32,6 +32,7 @@ TgtL_co = TypeVar("TgtL_co", bound=languages.LanguageTag, covariant=True) SettingT_co = TypeVar("SettingT_co", bound=languages.LanguageSettings, covariant=True) + @dataclasses.dataclass(frozen=True) class ProgramIRStage: """Program IR representation of a program together with the DSL function definition for it.""" diff --git a/src/gt4py/next/otf/transforms.py b/src/gt4py/next/otf/transforms.py index 6b9a43cf47..efecdc5663 100644 --- a/src/gt4py/next/otf/transforms.py +++ b/src/gt4py/next/otf/transforms.py @@ -1,17 +1,31 @@ -import dataclasses +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + import collections +import dataclasses +from collections.abc import Iterable import factory from gt4py.eve.extended_typing import Any, Optional from gt4py.next import common -from collections.abc import Iterable -from gt4py.next.ffront.gtcallable import GTCallable from gt4py.next.ffront.fbuiltins import FieldOffset +from gt4py.next.ffront.gtcallable import GTCallable from gt4py.next.ffront.past_to_itir import ProgramLowering +from gt4py.next.ffront.source_utils import get_closure_vars_from_function from gt4py.next.iterator import ir as itir from gt4py.next.otf import stages, workflow -from gt4py.next.ffront.source_utils import get_closure_vars_from_function from gt4py.next.otf.stages import ProgramCall @@ -41,6 +55,7 @@ def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any all_closure_vars = collections.ChainMap(all_closure_vars, all_child_closure_vars) return dict(all_closure_vars) + def _filter_closure_vars_by_type(closure_vars: dict[str, Any], *types: type) -> dict[str, Any]: return {name: value for name, value in closure_vars.items() if isinstance(value, types)} @@ -69,7 +84,10 @@ def is_cartesian_offset(o: FieldOffset): deduced_grid_type = common.GridType.UNSTRUCTURED break - if requested_grid_type == common.GridType.CARTESIAN and deduced_grid_type == common.GridType.UNSTRUCTURED: + if ( + requested_grid_type == common.GridType.CARTESIAN + and deduced_grid_type == common.GridType.UNSTRUCTURED + ): raise ValueError( "'grid_type == GridType.CARTESIAN' was requested, but unstructured 'FieldOffset' or local 'Dimension' was found." ) @@ -80,20 +98,24 @@ def is_cartesian_offset(o: FieldOffset): @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): def __call__(self, inp: stages.PastClosure) -> itir.FencilDefinition: - closure_vars = _get_closure_vars_recursively(get_closure_vars_from_function(inp.pirs.definition)) + closure_vars = _get_closure_vars_recursively( + get_closure_vars_from_function(inp.pirs.definition) + ) offsets_and_dimensions = _filter_closure_vars_by_type( _get_closure_vars_recursively(closure_vars), FieldOffset, common.Dimension ) grid_type = _deduce_grid_type(inp.pirs.grid_type, offsets_and_dimensions.values()) - gt_callables = _filter_closure_vars_by_type(_get_closure_vars_recursively(closure_vars), GTCallable).values() + gt_callables = _filter_closure_vars_by_type( + _get_closure_vars_recursively(closure_vars), GTCallable + ).values() lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] return ProgramCall( ProgramLowering.apply( inp.pirs.past_node, function_definitions=lowered_funcs, grid_type=grid_type ), inp.args, - inp.kwargs + inp.kwargs, ) diff --git a/src/gt4py/next/program_processors/otf_compile_executor.py b/src/gt4py/next/program_processors/otf_compile_executor.py index f8d9419c70..e7f18cbb50 100644 --- a/src/gt4py/next/program_processors/otf_compile_executor.py +++ b/src/gt4py/next/program_processors/otf_compile_executor.py @@ -36,9 +36,7 @@ class OTFCompileExecutor(ppi.ProgramExecutor): name: Optional[str] = None def __call__(self, program: stages.ProgramIRStage, *args, **kwargs: Any) -> None: - self.otf_workflow(program)( - *args, offset_provider=kwargs["offset_provider"] - ) + self.otf_workflow(program)(*args, offset_provider=kwargs["offset_provider"]) # self.otf_workflow(stages.ProgramCall(program, args, kwargs))( # *args, offset_provider=kwargs["offset_provider"] # ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index a4a3c7f9dd..c8fa53036b 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -25,7 +25,7 @@ from gt4py.next import common, config from gt4py.next.iterator import transforms from gt4py.next.iterator.transforms import global_tmps -from gt4py.next.otf import recipes, stages, workflow, transforms as otf_transforms +from gt4py.next.otf import recipes, stages, transforms as otf_transforms, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb @@ -182,7 +182,9 @@ class Params: transform_workflow = factory.SubFactory(otf_transforms.PastToItirFactory) executor = factory.LazyAttribute( - lambda o: otf_compile_executor.OTFCompileExecutor(otf_workflow=o.transform_workflow.chain(o.otf_workflow), name=o.name) + lambda o: otf_compile_executor.OTFCompileExecutor( + otf_workflow=o.transform_workflow.chain(o.otf_workflow), name=o.name + ) ) allocator = next_allocators.StandardCPUFieldBufferAllocator() From 6302c2af7054f8be970000df10feba1f460b753e Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 29 Feb 2024 13:27:16 +0100 Subject: [PATCH 05/47] fixed ruff errors --- src/gt4py/next/otf/transforms.py | 3 +-- src/gt4py/next/program_processors/otf_compile_executor.py | 7 +++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/otf/transforms.py b/src/gt4py/next/otf/transforms.py index efecdc5663..128b1b6cfb 100644 --- a/src/gt4py/next/otf/transforms.py +++ b/src/gt4py/next/otf/transforms.py @@ -24,7 +24,6 @@ from gt4py.next.ffront.gtcallable import GTCallable from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.ffront.source_utils import get_closure_vars_from_function -from gt4py.next.iterator import ir as itir from gt4py.next.otf import stages, workflow from gt4py.next.otf.stages import ProgramCall @@ -97,7 +96,7 @@ def is_cartesian_offset(o: FieldOffset): @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): - def __call__(self, inp: stages.PastClosure) -> itir.FencilDefinition: + def __call__(self, inp: stages.PastClosure) -> ProgramCall: closure_vars = _get_closure_vars_recursively( get_closure_vars_from_function(inp.pirs.definition) ) diff --git a/src/gt4py/next/program_processors/otf_compile_executor.py b/src/gt4py/next/program_processors/otf_compile_executor.py index e7f18cbb50..89c18a5068 100644 --- a/src/gt4py/next/program_processors/otf_compile_executor.py +++ b/src/gt4py/next/program_processors/otf_compile_executor.py @@ -35,11 +35,10 @@ class OTFCompileExecutor(ppi.ProgramExecutor): otf_workflow: recipes.OTFCompileWorkflow name: Optional[str] = None - def __call__(self, program: stages.ProgramIRStage, *args, **kwargs: Any) -> None: + def __call__( + self, program: stages.ProgramIRStage | itir.FencilDefinition, *args, **kwargs: Any + ) -> None: self.otf_workflow(program)(*args, offset_provider=kwargs["offset_provider"]) - # self.otf_workflow(stages.ProgramCall(program, args, kwargs))( - # *args, offset_provider=kwargs["offset_provider"] - # ) @property def __name__(self) -> str: From 1cdbefd716691f888e023b5ad52edcee4fd42c61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rico=20H=C3=A4uselmann?= Date: Thu, 29 Feb 2024 13:57:11 +0100 Subject: [PATCH 06/47] streamline PAST stage --- src/gt4py/next/ffront/decorator.py | 8 +++----- src/gt4py/next/otf/stages.py | 9 +-------- src/gt4py/next/otf/transforms.py | 10 +++++----- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index aa55ce3355..e3c6c2414c 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -250,11 +250,9 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No elif isinstance(self.backend, otf_compile_executor.OTFCompileExecutor): self.backend( stages.PastClosure( - pirs=stages.ProgramIRStage( - definition=self.definition, - past_node=self.past_node, - grid_type=self.grid_type, - ), + definition=self.definition, + past_node=self.past_node, + grid_type=self.grid_type, args=[*rewritten_args, *size_args], kwargs=kwargs | {"offset_provider": offset_provider, "column_axis": self._column_axis}, diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 5e1a94ce81..e690e7c0e6 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -34,17 +34,10 @@ @dataclasses.dataclass(frozen=True) -class ProgramIRStage: - """Program IR representation of a program together with the DSL function definition for it.""" - +class PastClosure: definition: types.FunctionType past_node: past.Program grid_type: common.GridType - - -@dataclasses.dataclass(frozen=True) -class PastClosure: - pirs: ProgramIRStage args: tuple[Any, ...] kwargs: dict[str, Any] diff --git a/src/gt4py/next/otf/transforms.py b/src/gt4py/next/otf/transforms.py index efecdc5663..e9a883787a 100644 --- a/src/gt4py/next/otf/transforms.py +++ b/src/gt4py/next/otf/transforms.py @@ -99,20 +99,20 @@ def is_cartesian_offset(o: FieldOffset): class PastToItir(workflow.ChainableWorkflowMixin): def __call__(self, inp: stages.PastClosure) -> itir.FencilDefinition: closure_vars = _get_closure_vars_recursively( - get_closure_vars_from_function(inp.pirs.definition) + get_closure_vars_from_function(inp.definition) ) offsets_and_dimensions = _filter_closure_vars_by_type( - _get_closure_vars_recursively(closure_vars), FieldOffset, common.Dimension + closure_vars, FieldOffset, common.Dimension ) - grid_type = _deduce_grid_type(inp.pirs.grid_type, offsets_and_dimensions.values()) + grid_type = _deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) gt_callables = _filter_closure_vars_by_type( - _get_closure_vars_recursively(closure_vars), GTCallable + closure_vars, GTCallable ).values() lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] return ProgramCall( ProgramLowering.apply( - inp.pirs.past_node, function_definitions=lowered_funcs, grid_type=grid_type + inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type ), inp.args, inp.kwargs, From 209faa3336d613ccbfb3dd854ac363753ab14e39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rico=20H=C3=A4uselmann?= Date: Thu, 29 Feb 2024 16:53:15 +0100 Subject: [PATCH 07/47] first attempt at as_program function generator --- src/gt4py/next/ffront/decorator.py | 34 ++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index e3c6c2414c..c834944582 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -20,6 +20,7 @@ import dataclasses import functools +import textwrap import types import typing import warnings @@ -30,7 +31,7 @@ from gt4py import eve from gt4py._core import definitions as core_defs -from gt4py.eve import utils as eve_utils +from gt4py.eve import codegen, utils as eve_utils from gt4py.eve.extended_typing import Any, Optional from gt4py.next import allocators as next_allocators, embedded as next_embedded, errors from gt4py.next.common import Dimension, GridType @@ -663,14 +664,23 @@ def as_program( ) untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars) past_node = ProgramTypeDeduction.apply(untyped_past_node) + dims = set( + i for j in [type_info.extract_dims(arg_type) for arg_type in arg_types] for i in j + ) + source_code = ProgamFuncGen.apply(past_node) + local_context = {dim.value: dim for dim in dims} + local_context[self.definition.__name__] = self.definition + exec(source_code, {}, local_context) + function_definition = local_context[past_node.id] self._program_cache[hash_] = Program( past_node=past_node, closure_vars=closure_vars, - definition=None, + definition=function_definition, backend=self.backend, grid_type=self.grid_type, ) + return self._program_cache[hash_] def __call__( @@ -828,3 +838,23 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: ) return scan_operator_inner if definition is None else scan_operator_inner(definition) + + +class ProgamFuncGen(codegen.TemplatedGenerator): + def visit_Program(self, node: past.Program, **kwargs) -> str: + imports = "from gt4py.next import *" + params = self.visit(node.params) + signature = ", ".join(params) + body = textwrap.indent("\n".join(self.visit(node.body)), prefix=" " * 4) + return f"{imports}\n\n\ndef {node.id}({signature}) -> None:\n{body}" + + Symbol = codegen.FormatTemplate("{id}: {type}") + + def visit_Call(self, node: past.Call, **kwargs) -> str: + args = ", ".join(self.visit(node.args)) + kwargs_list = [f"{name}={self.visit(value)}" for name, value in node.kwargs.items()] + kwargs = ", ".join(kwargs_list) + params = ", ".join([args, kwargs]) + return f"{self.visit(node.func)}({params})" + + Name = codegen.FormatTemplate("{id}") From 065fe50cfded0b17469777c1f3dcaa1dba7ef7f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rico=20H=C3=A4uselmann?= Date: Fri, 1 Mar 2024 11:33:11 +0100 Subject: [PATCH 08/47] try linecache fix --- src/gt4py/next/ffront/decorator.py | 31 +++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index dd7e0c2f8d..6f96b790ec 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -668,17 +668,34 @@ def as_program( i for j in [type_info.extract_dims(arg_type) for arg_type in arg_types] for i in j ) source_code = ProgamFuncGen.apply(past_node) + + import linecache + + filename = "" local_context = {dim.value: dim for dim in dims} local_context[self.definition.__name__] = self.definition - exec(source_code, {}, local_context) + code_obj = compile(source_code, filename, "exec") + exec(code_obj, {}, local_context) + lines = [line + "\n" for line in source_code.splitlines()] + linecache.cache[filename] = (len(source_code), None, lines, filename) + function_definition = local_context[past_node.id] - self._program_cache[hash_] = Program( - past_node=past_node, - closure_vars=closure_vars, - definition=function_definition, - backend=self.backend, - grid_type=self.grid_type, + # self._program_cache[hash_] = Program( + # past_node=past_node, + # closure_vars=closure_vars, + # definition=function_definition, + # backend=self.backend, + # grid_type=self.grid_type, + # ) + linecache.cache[filename] = ( + len(source_code), + None, + [line + "\n" for line in source_code.splitlines()], + filename, + ) + self._program_cache[hash_] = Program.from_function( + function_definition, backend=self.backend ) return self._program_cache[hash_] From 35e90b27b7a58385b4ac65d24c4a666963573e05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rico=20H=C3=A4uselmann?= Date: Fri, 1 Mar 2024 13:55:43 +0100 Subject: [PATCH 09/47] fix __globals__ of generated function --- src/gt4py/next/ffront/decorator.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 6f96b790ec..efd24257a7 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -671,15 +671,22 @@ def as_program( import linecache + import gt4py.next as gtx + filename = "" - local_context = {dim.value: dim for dim in dims} - local_context[self.definition.__name__] = self.definition + globalns = {dim.value: dim for dim in dims} + globalns[self.definition.__name__] = self.definition + globalns |= gtx.__dict__ + localns = {} code_obj = compile(source_code, filename, "exec") - exec(code_obj, {}, local_context) + exec(code_obj, globalns, localns) + # exec(code_obj) lines = [line + "\n" for line in source_code.splitlines()] linecache.cache[filename] = (len(source_code), None, lines, filename) - function_definition = local_context[past_node.id] + function_definition = localns[past_node.id] + # function_definition = locals()[past_node.id] + # function_definition.__globals__ = globalns # self._program_cache[hash_] = Program( # past_node=past_node, @@ -859,7 +866,7 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: class ProgamFuncGen(codegen.TemplatedGenerator): def visit_Program(self, node: past.Program, **kwargs) -> str: - imports = "from gt4py.next import *" + imports = "from __future__ import annotations\nfrom gt4py.next import *" params = self.visit(node.params) signature = ", ".join(params) body = textwrap.indent("\n".join(self.visit(node.body)), prefix=" " * 4) From 54ebebd2548b13f4ba18a706a2fddc39078d3292 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rico=20H=C3=A4uselmann?= Date: Fri, 1 Mar 2024 14:05:18 +0100 Subject: [PATCH 10/47] fix as_program (?) --- src/gt4py/next/ffront/decorator.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index efd24257a7..a72aae41f3 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -675,26 +675,15 @@ def as_program( filename = "" globalns = {dim.value: dim for dim in dims} - globalns[self.definition.__name__] = self.definition + globalns[self.definition.__name__] = self globalns |= gtx.__dict__ localns = {} code_obj = compile(source_code, filename, "exec") exec(code_obj, globalns, localns) - # exec(code_obj) lines = [line + "\n" for line in source_code.splitlines()] linecache.cache[filename] = (len(source_code), None, lines, filename) function_definition = localns[past_node.id] - # function_definition = locals()[past_node.id] - # function_definition.__globals__ = globalns - - # self._program_cache[hash_] = Program( - # past_node=past_node, - # closure_vars=closure_vars, - # definition=function_definition, - # backend=self.backend, - # grid_type=self.grid_type, - # ) linecache.cache[filename] = ( len(source_code), None, From df34976996edefc31d8b8bc7f4b03e0707e2c625 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rico=20H=C3=A4uselmann?= Date: Fri, 1 Mar 2024 14:20:35 +0100 Subject: [PATCH 11/47] roundtrip conversion skeleton --- .../program_processors/runners/roundtrip.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 0f07b5519a..191cbf53f9 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -227,7 +227,24 @@ def execute_roundtrip( return fencil(*args, **new_kwargs) -executor = ppi.program_executor(execute_roundtrip) # type: ignore[arg-type] +# executor = ppi.program_executor(execute_roundtrip) # type: ignore[arg-type] +@dataclasses.dataclass +class ExecuteRoundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): + debug: ... + column_axis: Optional[common.Dimension] + lift_mode: ... + dispatch_backend: ... + + def __call__(self, inp: ...): + execute_roundtrip( + inp.program, *inp.args, column_axis=self.column_axis, debug=self.debug, **inp.kwargs + ) + + +executor = modular_executor.ModularExecutor( + otf_workflow=PastToItir().chain(ExecuteRoundtrip()), + name="roundtrip" +) backend = next_backend.Backend( executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator() From c925663a529b2d3029aae356518b6b8f65f91adc Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Fri, 1 Mar 2024 16:53:34 +0100 Subject: [PATCH 12/47] edits --- src/gt4py/next/otf/transforms.py | 8 +-- .../runners/double_roundtrip.py | 15 +++-- .../program_processors/runners/roundtrip.py | 55 ++++++------------- 3 files changed, 27 insertions(+), 51 deletions(-) diff --git a/src/gt4py/next/otf/transforms.py b/src/gt4py/next/otf/transforms.py index 928c401245..cbdc4c999f 100644 --- a/src/gt4py/next/otf/transforms.py +++ b/src/gt4py/next/otf/transforms.py @@ -97,17 +97,13 @@ def is_cartesian_offset(o: FieldOffset): @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): def __call__(self, inp: stages.PastClosure) -> ProgramCall: - closure_vars = _get_closure_vars_recursively( - get_closure_vars_from_function(inp.definition) - ) + closure_vars = _get_closure_vars_recursively(get_closure_vars_from_function(inp.definition)) offsets_and_dimensions = _filter_closure_vars_by_type( closure_vars, FieldOffset, common.Dimension ) grid_type = _deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) - gt_callables = _filter_closure_vars_by_type( - closure_vars, GTCallable - ).values() + gt_callables = _filter_closure_vars_by_type(closure_vars, GTCallable).values() lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] return ProgramCall( ProgramLowering.apply( diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index 3662020200..3e6f7539dc 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -19,17 +19,20 @@ import gt4py.next.program_processors.processor_interface as ppi from gt4py.next import backend as next_backend from gt4py.next.program_processors.runners import roundtrip - +from gt4py.next.program_processors import modular_executor +from gt4py.next.otf import transforms if TYPE_CHECKING: import gt4py.next.iterator.ir as itir -@ppi.program_executor -def executor(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: - roundtrip.execute_roundtrip(program, *args, dispatch_backend=roundtrip.executor, **kwargs) - - +#@ppi.program_executor +def executor() -> None: + #roundtrip.execute_roundtrip(program, *args, dispatch_backend=roundtrip.executor, **kwargs) + modular_executor.ModularExecutor( + otf_workflow=transforms.PastToItir().chain(roundtrip.ExecuteRoundtrip()), + name="roundtrip" + ) backend = next_backend.Backend( executor=executor, allocator=roundtrip.backend.allocator, diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 191cbf53f9..c1b6c429c8 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -20,6 +20,7 @@ import textwrap from collections.abc import Callable, Iterable from typing import Any, Optional +import dataclasses import gt4py.next.allocators as next_allocators import gt4py.next.common as common @@ -31,6 +32,8 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import backend as next_backend +from gt4py.next.otf import workflow, stages, transforms +from gt4py.next.program_processors import modular_executor def _create_tmp(axes, origin, shape, dtype): @@ -200,50 +203,24 @@ def fencil_generator( return fencil -def execute_roundtrip( - ir: itir.Node, - *args, - column_axis: Optional[common.Dimension] = None, - offset_provider: dict[str, embedded.NeighborTableOffsetProvider], - debug: bool = False, - lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, - dispatch_backend: Optional[ppi.ProgramExecutor] = None, -) -> None: - fencil = fencil_generator( - ir, - offset_provider=offset_provider, - debug=debug, - lift_mode=lift_mode, - use_embedded=dispatch_backend is None, - ) - - new_kwargs: dict[str, Any] = { - "offset_provider": offset_provider, - "column_axis": column_axis, - } - if dispatch_backend: - new_kwargs["backend"] = dispatch_backend - - return fencil(*args, **new_kwargs) - - -# executor = ppi.program_executor(execute_roundtrip) # type: ignore[arg-type] @dataclasses.dataclass class ExecuteRoundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): - debug: ... - column_axis: Optional[common.Dimension] - lift_mode: ... - dispatch_backend: ... - - def __call__(self, inp: ...): - execute_roundtrip( - inp.program, *inp.args, column_axis=self.column_axis, debug=self.debug, **inp.kwargs + debug: bool = False + lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE + dispatch_backend: Optional[ppi.ProgramExecutor] = None + + def __call__(self, inp) -> stages.CompiledProgram: + return fencil_generator( + inp.program, + offset_provider=inp.kwargs["offset_provider"], + debug=self.debug, + lift_mode=self.lift_mode, + use_embedded=self.dispatch_backend is None, ) - executor = modular_executor.ModularExecutor( - otf_workflow=PastToItir().chain(ExecuteRoundtrip()), - name="roundtrip" + otf_workflow=transforms.PastToItir().chain(ExecuteRoundtrip()), + name=_BACKEND_NAME ) backend = next_backend.Backend( From aef1438966e9ad8ddfc31b3ae603529c90040f96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rico=20H=C3=A4uselmann?= Date: Mon, 4 Mar 2024 10:55:52 +0100 Subject: [PATCH 13/47] workflowify double roundtrip --- .../runners/double_roundtrip.py | 4 +- .../program_processors/runners/roundtrip.py | 61 ++++++++++++++----- 2 files changed, 49 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index 3662020200..9b611453f3 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -31,6 +31,8 @@ def executor(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: backend = next_backend.Backend( - executor=executor, + executor=roundtrip.RoundtripExecutorFactory( + dispatch_backend=roundtrip.execute_roundtrip, + ), allocator=roundtrip.backend.allocator, ) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 191cbf53f9..26839a7f97 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -14,6 +14,7 @@ from __future__ import annotations +import dataclasses import importlib.util import pathlib import tempfile @@ -21,16 +22,19 @@ from collections.abc import Callable, Iterable from typing import Any, Optional +import factory + import gt4py.next.allocators as next_allocators import gt4py.next.common as common import gt4py.next.iterator.embedded as embedded import gt4py.next.iterator.ir as itir import gt4py.next.iterator.transforms as itir_transforms import gt4py.next.iterator.transforms.global_tmps as gtmps_transform -import gt4py.next.program_processors.processor_interface as ppi +from gt4py.next.program_processors import modular_executor, processor_interface as ppi from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import backend as next_backend +from gt4py.next.otf import stages, transforms as otf_transforms, workflow def _create_tmp(axes, origin, shape, dtype): @@ -228,23 +232,50 @@ def execute_roundtrip( # executor = ppi.program_executor(execute_roundtrip) # type: ignore[arg-type] -@dataclasses.dataclass -class ExecuteRoundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): - debug: ... - column_axis: Optional[common.Dimension] - lift_mode: ... - dispatch_backend: ... - - def __call__(self, inp: ...): - execute_roundtrip( - inp.program, *inp.args, column_axis=self.column_axis, debug=self.debug, **inp.kwargs +@dataclasses.dataclass(frozen=True) +class Roundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): + debug: bool = None + lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE + dispatch_backend: Optional[ppi.ProgramExecutor] = None + + def __call__(self, inp: stages.ProgramCall) -> stages.CompiledProgram: + return fencil_generator( + inp.program, + offset_provider=inp.kwargs.get("offset_provider", None), + debug=self.debug, + lift_mode=self.lift_mode, + use_embedded=self.dispatch_backend is None ) -executor = modular_executor.ModularExecutor( - otf_workflow=PastToItir().chain(ExecuteRoundtrip()), - name="roundtrip" -) +class RoundtripFactory(factory.Factory): + class Meta: + model = Roundtrip + + +@dataclasses.dataclass(frozen=True) +class RoundtripExecutor(modular_executor.ModularExecutor): + dispatch_backend: Optional[ppi.ProgramExecutor] = None + + def __call__(self, program: stages.PastClosure, *args, **kwargs) -> None: + kwargs["backend"] = self.dispatch_backend + super().__call__(program, *args, **kwargs) + + +class RoundtripExecutorFactory(factory.Factory): + class Meta: + model = RoundtripExecutor + + class Params: + transform_workflow = factory.SubFactory(otf_transforms.PastToItirFactory) + roundtrip_workflow = factory.SubFactory(RoundtripFactory) + + otf_workflow = factory.LazyAttribute( + lambda o: o.transform_workflow.chain(o.roundtrip_workflow) + ) + + +executor = RoundtripExecutorFactory(name="roundtrip") backend = next_backend.Backend( executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator() From f9b419823962bf785458019eb49d4f7f0dbfbf2d Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Mon, 4 Mar 2024 11:35:33 +0100 Subject: [PATCH 14/47] edit for dims --- src/gt4py/next/ffront/decorator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index a72aae41f3..fcc01e0907 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -664,8 +664,9 @@ def as_program( ) untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars) past_node = ProgramTypeDeduction.apply(untyped_past_node) + inout_types = [*arg_types, out_sym.type] dims = set( - i for j in [type_info.extract_dims(arg_type) for arg_type in arg_types] for i in j + i for j in [type_info.extract_dims(inout_type) for inout_type in inout_types] for i in j ) source_code = ProgamFuncGen.apply(past_node) From 55e24f27dd6271d419d8ea4714fbf6762f4d2c4a Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Mon, 4 Mar 2024 11:56:07 +0100 Subject: [PATCH 15/47] small fix --- src/gt4py/next/ffront/decorator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index fcc01e0907..fa0fd4e5ab 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -664,7 +664,11 @@ def as_program( ) untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars) past_node = ProgramTypeDeduction.apply(untyped_past_node) - inout_types = [*arg_types, out_sym.type] + + inout_types = [ + *arg_types, + *(out_sym.type.types if isinstance(out_sym.type, ts.TupleType) else [out_sym.type]), + ] dims = set( i for j in [type_info.extract_dims(inout_type) for inout_type in inout_types] for i in j ) From b382ba29baf4af2198b07714a4466b939a5d3b8e Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Mon, 4 Mar 2024 13:55:34 +0100 Subject: [PATCH 16/47] added code for nested tuples --- src/gt4py/next/ffront/decorator.py | 7 ++++++- src/gt4py/next/type_system/type_info.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index fa0fd4e5ab..31684922a2 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -667,8 +667,13 @@ def as_program( inout_types = [ *arg_types, - *(out_sym.type.types if isinstance(out_sym.type, ts.TupleType) else [out_sym.type]), + *( + list(type_info.flatten(out_sym.type.types)) + if isinstance(out_sym.type, ts.TupleType) + else [out_sym.type] + ), ] + list(type_info.flatten(out_sym.type.types)) dims = set( i for j in [type_info.extract_dims(inout_type) for inout_type in inout_types] for i in j ) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index fd11a421c0..3587f66ce5 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -754,3 +754,18 @@ def accepts_args( return True return next(errors, None) is None + + +def flatten(arg: list | tuple | ts.TupleType): + if ( + not isinstance(arg, ts.TupleType) + and not isinstance(arg, list) + and not isinstance(arg, tuple) + ): + yield arg + elif isinstance(arg, list) or isinstance(arg, tuple): + for sub in arg: + yield from flatten(sub) + elif isinstance(arg, ts.TupleType): + for sub in arg.types: + yield from flatten(sub) From b4ee386d9639ff35fbd3973a4b8d12fd79b79978 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Mon, 4 Mar 2024 13:57:43 +0100 Subject: [PATCH 17/47] left over code --- src/gt4py/next/ffront/decorator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 31684922a2..b2624d708c 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -673,7 +673,6 @@ def as_program( else [out_sym.type] ), ] - list(type_info.flatten(out_sym.type.types)) dims = set( i for j in [type_info.extract_dims(inout_type) for inout_type in inout_types] for i in j ) From 27963276a2730c15cef65156610c1bb86557c6b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rico=20H=C3=A4uselmann?= Date: Mon, 4 Mar 2024 14:55:33 +0100 Subject: [PATCH 18/47] separate transforms from compile workflow, fix fendef call --- src/gt4py/next/backend.py | 7 ++++--- src/gt4py/next/iterator/runtime.py | 7 ++++++- src/gt4py/next/program_processors/modular_executor.py | 4 ++-- .../program_processors/runners/dace_iterator/__init__.py | 3 +++ .../next/program_processors/runners/double_roundtrip.py | 2 ++ src/gt4py/next/program_processors/runners/gtfn.py | 6 ++---- src/gt4py/next/program_processors/runners/roundtrip.py | 7 ++++--- 7 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 33fe96e7ad..dd2169f3e4 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -19,17 +19,18 @@ from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators -from gt4py.next.iterator import ir as itir +from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import processor_interface as ppi @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): + transformer: workflow.Workflow[stages.PastClosure, stages.ProgramCall] executor: ppi.ProgramExecutor allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] - def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None: - self.executor.__call__(program, *args, **kwargs) + def __call__(self, program: stages.PastClosure, *args, **kwargs: Any) -> None: + self.executor(self.transformer(program), *args, **kwargs) @property def __name__(self) -> str: diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index 5de4839b55..651d6a3097 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -24,6 +24,7 @@ from gt4py.next.iterator import builtins from gt4py.next.iterator.builtins import BackendNotSelectedError, builtin_dispatch from gt4py.next.iterator.ir import FencilDefinition +from gt4py.next.otf import stages from gt4py.next.program_processors.processor_interface import ( ProgramExecutor, ProgramFormatter, @@ -91,7 +92,11 @@ def __call__(self, *args, backend: Optional[ProgramExecutor] = None, **kwargs): if backend is not None: ensure_processor_kind(backend, ProgramExecutor) - backend(self.itir(*args, **kwargs), *args, **kwargs) + backend.executor( + stages.ProgramCall(program=self.itir(*args, **kwargs), args=args, kwargs=kwargs), + *args, + **kwargs, + ) else: if fendef_embedded is None: raise RuntimeError("Embedded execution is not registered.") diff --git a/src/gt4py/next/program_processors/modular_executor.py b/src/gt4py/next/program_processors/modular_executor.py index 41f19b27a1..7419406f53 100644 --- a/src/gt4py/next/program_processors/modular_executor.py +++ b/src/gt4py/next/program_processors/modular_executor.py @@ -23,10 +23,10 @@ @dataclasses.dataclass(frozen=True) class ModularExecutor(ppi.ProgramExecutor): - otf_workflow: workflow.Workflow[stages.PastClosure, stages.CompiledProgram] + otf_workflow: workflow.Workflow[stages.ProgramCall, stages.CompiledProgram] name: Optional[str] = None - def __call__(self, program: stages.PastClosure, *args, **kwargs: Any) -> None: + def __call__(self, program: stages.ProgramCall, *args, **kwargs: Any) -> None: self.otf_workflow(program)(*args, offset_provider=kwargs["offset_provider"]) @property diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index ae0cfbfee8..2e960d763f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -28,6 +28,7 @@ import gt4py.next.program_processors.processor_interface as ppi from gt4py.next import backend, common from gt4py.next.iterator import transforms as itir_transforms +from gt4py.next.otf import transforms as otf_transforms from gt4py.next.otf.compilation import cache as compilation_cache from gt4py.next.type_system import type_specifications as ts, type_translation @@ -437,6 +438,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_cpu = backend.Backend( + transformer=otf_transforms.PastToItirFactory(), executor=ppi.program_executor(_run_dace_cpu, name="run_dace_cpu"), allocator=next_allocators.StandardCPUFieldBufferAllocator(), ) @@ -460,6 +462,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_gpu = backend.Backend( + transformer=otf_transforms.PastToItirFactory(), executor=ppi.program_executor(_run_dace_gpu, name="run_dace_gpu"), allocator=next_allocators.StandardGPUFieldBufferAllocator(), ) diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index 6a8d7a4c98..1a09ba04e7 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -15,10 +15,12 @@ from __future__ import annotations from gt4py.next import backend as next_backend +from gt4py.next.otf import transforms as otf_transforms from gt4py.next.program_processors.runners import roundtrip backend = next_backend.Backend( + transformer=otf_transforms.PastToItirFactory(), executor=roundtrip.RoundtripExecutorFactory( dispatch_backend=roundtrip.execute_roundtrip, ), diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 76ae2de5a2..c58dd6e080 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -179,12 +179,10 @@ class Params: name = factory.LazyAttribute( lambda o: f"run_gtfn_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}" ) - transform_workflow = factory.SubFactory(otf_transforms.PastToItirFactory) + transformer = factory.SubFactory(otf_transforms.PastToItirFactory) executor = factory.LazyAttribute( - lambda o: modular_executor.ModularExecutor( - otf_workflow=o.transform_workflow.chain(o.otf_workflow), name=o.name - ) + lambda o: modular_executor.ModularExecutor(otf_workflow=o.otf_workflow, name=o.name) ) allocator = next_allocators.StandardCPUFieldBufferAllocator() diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 964f35357a..703a3e4f41 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -266,14 +266,15 @@ class Meta: model = RoundtripExecutor class Params: - transform_workflow = factory.SubFactory(otf_transforms.PastToItirFactory) roundtrip_workflow = factory.SubFactory(RoundtripFactory) - otf_workflow = factory.LazyAttribute(lambda o: o.transform_workflow.chain(o.roundtrip_workflow)) + otf_workflow = factory.LazyAttribute(lambda o: o.roundtrip_workflow) executor = RoundtripExecutorFactory(name="roundtrip") backend = next_backend.Backend( - executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator() + transformer=otf_transforms.PastToItirFactory(), + executor=executor, + allocator=next_allocators.StandardCPUFieldBufferAllocator(), ) From 7ff19a3c61bc037ed4a8a5513f184bb0556f55a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rico=20H=C3=A4uselmann?= Date: Mon, 4 Mar 2024 16:42:40 +0100 Subject: [PATCH 19/47] [wip] refactor backend to split transforms and executor --- src/gt4py/next/ffront/decorator.py | 44 +++++++------- src/gt4py/next/otf/stages.py | 2 +- src/gt4py/next/otf/transforms/__init__.py | 4 ++ src/gt4py/next/otf/transforms/past_to_func.py | 0 src/gt4py/next/otf/transforms/past_to_itir.py | 51 +++++++++++++++++ .../{transforms.py => transforms/utils.py} | 57 ++----------------- 6 files changed, 86 insertions(+), 72 deletions(-) create mode 100644 src/gt4py/next/otf/transforms/__init__.py create mode 100644 src/gt4py/next/otf/transforms/past_to_func.py create mode 100644 src/gt4py/next/otf/transforms/past_to_itir.py rename src/gt4py/next/otf/{transforms.py => transforms/utils.py} (56%) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index b2624d708c..bd0ffd9f15 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -33,7 +33,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import codegen, utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, embedded as next_embedded, errors +from gt4py.next import allocators as next_allocators, backend as next_backend, embedded as next_embedded, errors from gt4py.next.common import Dimension, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( @@ -115,14 +115,14 @@ class Program: past_node: past.Program closure_vars: dict[str, Any] definition: Optional[types.FunctionType] - backend: Optional[ppi.ProgramExecutor] + backend: Optional[next_backend.Backend] grid_type: Optional[GridType] @classmethod def from_function( cls, definition: types.FunctionType, - backend: Optional[ppi.ProgramExecutor], + backend: Optional[next_backend], grid_type: Optional[GridType] = None, ) -> Program: source_def = SourceDefinition.from_function(definition) @@ -138,7 +138,7 @@ def from_function( ) def __post_init__(self): - function_closure_vars = otf_transforms._filter_closure_vars_by_type( + function_closure_vars = otf_transforms.utils._filter_closure_vars_by_type( self.closure_vars, GTCallable ) misnamed_functions = [ @@ -161,6 +161,7 @@ def __post_init__(self): f"The following closure variables are undefined: {', '.join(undefined_symbols)}." ) + @property def __name__(self) -> str: return self.definition.__name__ @@ -216,18 +217,18 @@ def with_bound_args(self, **kwargs) -> ProgramWithBoundArgs: @functools.cached_property def _all_closure_vars(self) -> dict[str, Any]: - return otf_transforms._get_closure_vars_recursively(self.closure_vars) + return otf_transforms.utils._get_closure_vars_recursively(self.closure_vars) @functools.cached_property def itir(self) -> itir.FencilDefinition: - offsets_and_dimensions = otf_transforms._filter_closure_vars_by_type( + offsets_and_dimensions = otf_transforms.utils._filter_closure_vars_by_type( self._all_closure_vars, FieldOffset, Dimension ) - grid_type = otf_transforms._deduce_grid_type( + grid_type = otf_transforms.utils._deduce_grid_type( self.grid_type, offsets_and_dimensions.values() ) - gt_callables = otf_transforms._filter_closure_vars_by_type( + gt_callables = otf_transforms.utils._filter_closure_vars_by_type( self._all_closure_vars, GTCallable ).values() lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] @@ -251,7 +252,7 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No elif isinstance(self.backend, modular_executor.ModularExecutor): self.backend( stages.PastClosure( - definition=self.definition, + closure_vars=self.closure_vars, past_node=self.past_node, grid_type=self.grid_type, args=[*rewritten_args, *size_args], @@ -347,7 +348,7 @@ def _column_axis(self): # that dimension. only one column axis is allowed, but we can use # this mapping to provide good error messages. scanops_per_axis: dict[Dimension, str] = {} - for name, gt_callable in otf_transforms._filter_closure_vars_by_type( + for name, gt_callable in otf_transforms.utils._filter_closure_vars_by_type( self._all_closure_vars, GTCallable ).items(): if isinstance( @@ -665,6 +666,19 @@ def as_program( untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars) past_node = ProgramTypeDeduction.apply(untyped_past_node) + self.past_to_func(arg_types, out_sym, past_node) + + self._program_cache[hash_] = Program( + past_node=past_node, + closure_vars=closure_vars, + definition=None, + backend=self.backend, + grid_type=self.grid_type, + ) + + return self._program_cache[hash_] + + def past_to_func(self, arg_types, out_sym, past_node): inout_types = [ *arg_types, *( @@ -677,11 +691,8 @@ def as_program( i for j in [type_info.extract_dims(inout_type) for inout_type in inout_types] for i in j ) source_code = ProgamFuncGen.apply(past_node) - import linecache - import gt4py.next as gtx - filename = "" globalns = {dim.value: dim for dim in dims} globalns[self.definition.__name__] = self @@ -691,7 +702,6 @@ def as_program( exec(code_obj, globalns, localns) lines = [line + "\n" for line in source_code.splitlines()] linecache.cache[filename] = (len(source_code), None, lines, filename) - function_definition = localns[past_node.id] linecache.cache[filename] = ( len(source_code), @@ -699,11 +709,7 @@ def as_program( [line + "\n" for line in source_code.splitlines()], filename, ) - self._program_cache[hash_] = Program.from_function( - function_definition, backend=self.backend - ) - - return self._program_cache[hash_] + return function_definition def __call__( self, diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index e690e7c0e6..a54ba24610 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -35,7 +35,7 @@ @dataclasses.dataclass(frozen=True) class PastClosure: - definition: types.FunctionType + closure_vars: dict[str, Any] past_node: past.Program grid_type: common.GridType args: tuple[Any, ...] diff --git a/src/gt4py/next/otf/transforms/__init__.py b/src/gt4py/next/otf/transforms/__init__.py new file mode 100644 index 0000000000..896c9d93b4 --- /dev/null +++ b/src/gt4py/next/otf/transforms/__init__.py @@ -0,0 +1,4 @@ +from gt4py.next.otf.transforms.past_to_itir import PastToItir, PastToItirFactory + + +__all__ = ["PastToItir", "PastToItirFactory"] diff --git a/src/gt4py/next/otf/transforms/past_to_func.py b/src/gt4py/next/otf/transforms/past_to_func.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/gt4py/next/otf/transforms/past_to_itir.py b/src/gt4py/next/otf/transforms/past_to_itir.py new file mode 100644 index 0000000000..d4f6a6cf7b --- /dev/null +++ b/src/gt4py/next/otf/transforms/past_to_itir.py @@ -0,0 +1,51 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import dataclasses + +import factory + +from gt4py.next import common +from gt4py.next.ffront.fbuiltins import FieldOffset +from gt4py.next.ffront.gtcallable import GTCallable +from gt4py.next.ffront.past_to_itir import ProgramLowering +from gt4py.next.ffront.source_utils import get_closure_vars_from_function +from gt4py.next.otf import stages, workflow +from gt4py.next.otf.stages import ProgramCall +from . import utils + + +@dataclasses.dataclass(frozen=True) +class PastToItir(workflow.ChainableWorkflowMixin): + def __call__(self, inp: stages.PastClosure) -> ProgramCall: + closure_vars = utils._get_closure_vars_recursively(get_closure_vars_from_function(inp.definition)) + offsets_and_dimensions = utils._filter_closure_vars_by_type( + closure_vars, FieldOffset, common.Dimension + ) + grid_type = utils._deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) + + gt_callables = utils._filter_closure_vars_by_type(closure_vars, GTCallable).values() + lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + return ProgramCall( + ProgramLowering.apply( + inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type + ), + inp.args, + inp.kwargs, + ) + + +class PastToItirFactory(factory.Factory): + class Meta: + model = PastToItir diff --git a/src/gt4py/next/otf/transforms.py b/src/gt4py/next/otf/transforms/utils.py similarity index 56% rename from src/gt4py/next/otf/transforms.py rename to src/gt4py/next/otf/transforms/utils.py index cbdc4c999f..ef4951beba 100644 --- a/src/gt4py/next/otf/transforms.py +++ b/src/gt4py/next/otf/transforms/utils.py @@ -1,31 +1,9 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - import collections -import dataclasses -from collections.abc import Iterable - -import factory +from typing import Any, Optional, Iterable -from gt4py.eve.extended_typing import Any, Optional from gt4py.next import common -from gt4py.next.ffront.fbuiltins import FieldOffset +from gt4py.next.ffront import fbuiltins from gt4py.next.ffront.gtcallable import GTCallable -from gt4py.next.ffront.past_to_itir import ProgramLowering -from gt4py.next.ffront.source_utils import get_closure_vars_from_function -from gt4py.next.otf import stages, workflow -from gt4py.next.otf.stages import ProgramCall def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any]: @@ -61,7 +39,7 @@ def _filter_closure_vars_by_type(closure_vars: dict[str, Any], *types: type) -> def _deduce_grid_type( requested_grid_type: Optional[common.GridType], - offsets_and_dimensions: Iterable[FieldOffset | common.Dimension], + offsets_and_dimensions: Iterable[fbuiltins.FieldOffset | common.Dimension], ) -> common.GridType: """ Derive grid type from actually occurring dimensions and check against optional user request. @@ -71,12 +49,12 @@ def _deduce_grid_type( dimensions and offsets. """ - def is_cartesian_offset(o: FieldOffset): + def is_cartesian_offset(o: fbuiltins.FieldOffset): return len(o.target) == 1 and o.source == o.target[0] deduced_grid_type = common.GridType.CARTESIAN for o in offsets_and_dimensions: - if isinstance(o, FieldOffset) and not is_cartesian_offset(o): + if isinstance(o, fbuiltins.FieldOffset) and not is_cartesian_offset(o): deduced_grid_type = common.GridType.UNSTRUCTURED break if isinstance(o, common.Dimension) and o.kind == common.DimensionKind.LOCAL: @@ -92,28 +70,3 @@ def is_cartesian_offset(o: FieldOffset): ) return deduced_grid_type if requested_grid_type is None else requested_grid_type - - -@dataclasses.dataclass(frozen=True) -class PastToItir(workflow.ChainableWorkflowMixin): - def __call__(self, inp: stages.PastClosure) -> ProgramCall: - closure_vars = _get_closure_vars_recursively(get_closure_vars_from_function(inp.definition)) - offsets_and_dimensions = _filter_closure_vars_by_type( - closure_vars, FieldOffset, common.Dimension - ) - grid_type = _deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) - - gt_callables = _filter_closure_vars_by_type(closure_vars, GTCallable).values() - lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] - return ProgramCall( - ProgramLowering.apply( - inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type - ), - inp.args, - inp.kwargs, - ) - - -class PastToItirFactory(factory.Factory): - class Meta: - model = PastToItir From 8218da63d29d92a485bd0fa91dd7fe64c7c258bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rico=20H=C3=A4uselmann?= Date: Tue, 5 Mar 2024 11:36:23 +0100 Subject: [PATCH 20/47] [wip] refactor backend / executor --- src/gt4py/next/backend.py | 7 +-- src/gt4py/next/ffront/decorator.py | 44 ++++++++----------- src/gt4py/next/otf/transforms/past_to_itir.py | 22 ++++++---- .../program_processors/modular_executor.py | 7 ++- .../program_processors/runners/roundtrip.py | 4 +- tests/next_tests/integration_tests/cases.py | 15 ++++--- 6 files changed, 54 insertions(+), 45 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index dd2169f3e4..fed2686f08 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -15,7 +15,7 @@ from __future__ import annotations import dataclasses -from typing import Any, Generic +from typing import Generic from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators @@ -29,8 +29,9 @@ class Backend(Generic[core_defs.DeviceTypeT]): executor: ppi.ProgramExecutor allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] - def __call__(self, program: stages.PastClosure, *args, **kwargs: Any) -> None: - self.executor(self.transformer(program), *args, **kwargs) + def __call__(self, program: stages.PastClosure) -> None: + program_call = self.transformer(program) + self.executor(program_call.program, *program_call.args, **program_call.kwargs) @property def __name__(self) -> str: diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index bd0ffd9f15..8f02a9036f 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -33,7 +33,12 @@ from gt4py._core import definitions as core_defs from gt4py.eve import codegen, utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, backend as next_backend, embedded as next_embedded, errors +from gt4py.next import ( + allocators as next_allocators, + backend as next_backend, + embedded as next_embedded, + errors, +) from gt4py.next.common import Dimension, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( @@ -62,7 +67,7 @@ sym, ) from gt4py.next.otf import stages, transforms as otf_transforms -from gt4py.next.program_processors import modular_executor, processor_interface as ppi +from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -161,7 +166,6 @@ def __post_init__(self): f"The following closure variables are undefined: {', '.join(undefined_symbols)}." ) - @property def __name__(self) -> str: return self.definition.__name__ @@ -239,7 +243,7 @@ def itir(self) -> itir.FencilDefinition: def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> None: rewritten_args, size_args, kwargs = self._process_args(args, kwargs) - if self.backend is None: + if self.backend is None or self.backend.executor is None: warnings.warn( UserWarning( f"Field View Program '{self.itir.id}': Using Python execution, consider selecting a perfomance backend." @@ -249,30 +253,18 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: ctx.run(self.definition, *rewritten_args, **kwargs) return - elif isinstance(self.backend, modular_executor.ModularExecutor): - self.backend( - stages.PastClosure( - closure_vars=self.closure_vars, - past_node=self.past_node, - grid_type=self.grid_type, - args=[*rewritten_args, *size_args], - kwargs=kwargs - | {"offset_provider": offset_provider, "column_axis": self._column_axis}, - ), - *rewritten_args, - *size_args, - **kwargs, - offset_provider=offset_provider, - column_axis=self._column_axis, - ) - return - ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor) - if "debug" in kwargs: - debug(self.itir) + ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor) self.backend( - self.itir, + stages.PastClosure( + closure_vars=self.closure_vars, + past_node=self.past_node, + grid_type=self.grid_type, + args=[*rewritten_args, *size_args], + kwargs=kwargs + | {"offset_provider": offset_provider, "column_axis": self._column_axis}, + ), *rewritten_args, *size_args, **kwargs, @@ -692,7 +684,9 @@ def past_to_func(self, arg_types, out_sym, past_node): ) source_code = ProgamFuncGen.apply(past_node) import linecache + import gt4py.next as gtx + filename = "" globalns = {dim.value: dim for dim in dims} globalns[self.definition.__name__] = self diff --git a/src/gt4py/next/otf/transforms/past_to_itir.py b/src/gt4py/next/otf/transforms/past_to_itir.py index d4f6a6cf7b..2b95ed61b9 100644 --- a/src/gt4py/next/otf/transforms/past_to_itir.py +++ b/src/gt4py/next/otf/transforms/past_to_itir.py @@ -14,33 +14,39 @@ import dataclasses +import devtools import factory -from gt4py.next import common +from gt4py.next import common, config from gt4py.next.ffront.fbuiltins import FieldOffset from gt4py.next.ffront.gtcallable import GTCallable from gt4py.next.ffront.past_to_itir import ProgramLowering -from gt4py.next.ffront.source_utils import get_closure_vars_from_function from gt4py.next.otf import stages, workflow from gt4py.next.otf.stages import ProgramCall + from . import utils @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): def __call__(self, inp: stages.PastClosure) -> ProgramCall: - closure_vars = utils._get_closure_vars_recursively(get_closure_vars_from_function(inp.definition)) offsets_and_dimensions = utils._filter_closure_vars_by_type( - closure_vars, FieldOffset, common.Dimension + inp.closure_vars, FieldOffset, common.Dimension ) grid_type = utils._deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) - gt_callables = utils._filter_closure_vars_by_type(closure_vars, GTCallable).values() + gt_callables = utils._filter_closure_vars_by_type(inp.closure_vars, GTCallable).values() lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + + itir_program = ProgramLowering.apply( + inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type + ) + + if config.DEBUG or "debug" in inp.kwargs: + devtools.debug(itir_program) + return ProgramCall( - ProgramLowering.apply( - inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type - ), + itir_program, inp.args, inp.kwargs, ) diff --git a/src/gt4py/next/program_processors/modular_executor.py b/src/gt4py/next/program_processors/modular_executor.py index 7419406f53..a5d958803a 100644 --- a/src/gt4py/next/program_processors/modular_executor.py +++ b/src/gt4py/next/program_processors/modular_executor.py @@ -18,6 +18,7 @@ from typing import Any, Optional import gt4py.next.program_processors.processor_interface as ppi +from gt4py.next.iterator import ir as itir from gt4py.next.otf import stages, workflow @@ -26,8 +27,10 @@ class ModularExecutor(ppi.ProgramExecutor): otf_workflow: workflow.Workflow[stages.ProgramCall, stages.CompiledProgram] name: Optional[str] = None - def __call__(self, program: stages.ProgramCall, *args, **kwargs: Any) -> None: - self.otf_workflow(program)(*args, offset_provider=kwargs["offset_provider"]) + def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None: + self.otf_workflow(stages.ProgramCall(program=program, args=args, kwargs=kwargs))( + *args, offset_provider=kwargs["offset_provider"] + ) @property def __name__(self) -> str: diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 703a3e4f41..31f6b0d76d 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -233,7 +233,7 @@ def execute_roundtrip( @dataclasses.dataclass(frozen=True) class Roundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): - debug: bool = None + debug: bool = False lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE dispatch_backend: Optional[ppi.ProgramExecutor] = None @@ -256,7 +256,7 @@ class Meta: class RoundtripExecutor(modular_executor.ModularExecutor): dispatch_backend: Optional[ppi.ProgramExecutor] = None - def __call__(self, program: stages.PastClosure, *args, **kwargs) -> None: + def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> None: kwargs["backend"] = self.dispatch_backend super().__call__(program, *args, **kwargs) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index c11c1ac256..4cb1cfceb4 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -28,9 +28,14 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self -from gt4py.next import allocators as next_allocators, common, constructors, field_utils +from gt4py.next import ( + allocators as next_allocators, + backend as next_backend, + common, + constructors, + field_utils, +) from gt4py.next.ffront import decorator -from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_specifications as ts, type_translation from next_tests import definitions as test_definitions @@ -477,7 +482,7 @@ def cartesian_case( exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor, ): yield Case( - exec_alloc_descriptor.executor, + exec_alloc_descriptor, offset_provider={"Ioff": IDim, "Joff": JDim, "Koff": KDim}, default_sizes={IDim: 10, JDim: 10, KDim: 10}, grid_type=common.GridType.CARTESIAN, @@ -491,7 +496,7 @@ def unstructured_case( exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor, ): yield Case( - exec_alloc_descriptor.executor, + exec_alloc_descriptor, offset_provider=mesh_descriptor.offset_provider, default_sizes={ Vertex: mesh_descriptor.num_vertices, @@ -601,7 +606,7 @@ def get_default_data( class Case: """Parametrizable components for single feature integration tests.""" - executor: Optional[ppi.ProgramProcessor] + executor: Optional[next_backend.Backend] offset_provider: dict[str, common.Connectivity | gtx.Dimension] default_sizes: dict[gtx.Dimension, int] grid_type: common.GridType From 701b3304e8cd030420f62a73686b81ea9f5a2912 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Tue, 5 Mar 2024 17:06:01 +0100 Subject: [PATCH 21/47] edits for past_to_func --- src/gt4py/next/ffront/decorator.py | 84 ++++++------------- src/gt4py/next/otf/stages.py | 1 - src/gt4py/next/otf/transforms/__init__.py | 14 ++++ src/gt4py/next/otf/transforms/past_to_func.py | 73 ++++++++++++++++ 4 files changed, 113 insertions(+), 59 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index bd0ffd9f15..085b45ac85 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -20,7 +20,6 @@ import dataclasses import functools -import textwrap import types import typing import warnings @@ -31,9 +30,14 @@ from gt4py import eve from gt4py._core import definitions as core_defs -from gt4py.eve import codegen, utils as eve_utils +from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, backend as next_backend, embedded as next_embedded, errors +from gt4py.next import ( + allocators as next_allocators, + backend as next_backend, + embedded as next_embedded, + errors, +) from gt4py.next.common import Dimension, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( @@ -62,6 +66,7 @@ sym, ) from gt4py.next.otf import stages, transforms as otf_transforms +from gt4py.next.otf.transforms.past_to_func import past_to_fun_def from gt4py.next.program_processors import modular_executor, processor_interface as ppi from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -161,7 +166,6 @@ def __post_init__(self): f"The following closure variables are undefined: {', '.join(undefined_symbols)}." ) - @property def __name__(self) -> str: return self.definition.__name__ @@ -247,6 +251,15 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No stacklevel=2, ) with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: + if self.definition is None: + self.definition = past_to_fun_def(stages.PastClosure( + closure_vars=self.closure_vars, + past_node=self.past_node, + grid_type=self.grid_type, + args=[*rewritten_args, *size_args], + kwargs=kwargs + | {"offset_provider": offset_provider, "column_axis": self._column_axis}, + )) ctx.run(self.definition, *rewritten_args, **kwargs) return elif isinstance(self.backend, modular_executor.ModularExecutor): @@ -666,7 +679,15 @@ def as_program( untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars) past_node = ProgramTypeDeduction.apply(untyped_past_node) - self.past_to_func(arg_types, out_sym, past_node) + # past_closure = stages.PastClosure( + # closure_vars=closure_vars, + # past_node=past_node, + # grid_type=self.grid_type, + # args=past_node.params, + # kwargs={}, + # ) + # + # past_to_fun_def(past_closure) self._program_cache[hash_] = Program( past_node=past_node, @@ -678,39 +699,6 @@ def as_program( return self._program_cache[hash_] - def past_to_func(self, arg_types, out_sym, past_node): - inout_types = [ - *arg_types, - *( - list(type_info.flatten(out_sym.type.types)) - if isinstance(out_sym.type, ts.TupleType) - else [out_sym.type] - ), - ] - dims = set( - i for j in [type_info.extract_dims(inout_type) for inout_type in inout_types] for i in j - ) - source_code = ProgamFuncGen.apply(past_node) - import linecache - import gt4py.next as gtx - filename = "" - globalns = {dim.value: dim for dim in dims} - globalns[self.definition.__name__] = self - globalns |= gtx.__dict__ - localns = {} - code_obj = compile(source_code, filename, "exec") - exec(code_obj, globalns, localns) - lines = [line + "\n" for line in source_code.splitlines()] - linecache.cache[filename] = (len(source_code), None, lines, filename) - function_definition = localns[past_node.id] - linecache.cache[filename] = ( - len(source_code), - None, - [line + "\n" for line in source_code.splitlines()], - filename, - ) - return function_definition - def __call__( self, *args, @@ -866,23 +854,3 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: ) return scan_operator_inner if definition is None else scan_operator_inner(definition) - - -class ProgamFuncGen(codegen.TemplatedGenerator): - def visit_Program(self, node: past.Program, **kwargs) -> str: - imports = "from __future__ import annotations\nfrom gt4py.next import *" - params = self.visit(node.params) - signature = ", ".join(params) - body = textwrap.indent("\n".join(self.visit(node.body)), prefix=" " * 4) - return f"{imports}\n\n\ndef {node.id}({signature}) -> None:\n{body}" - - Symbol = codegen.FormatTemplate("{id}: {type}") - - def visit_Call(self, node: past.Call, **kwargs) -> str: - args = ", ".join(self.visit(node.args)) - kwargs_list = [f"{name}={self.visit(value)}" for name, value in node.kwargs.items()] - kwargs = ", ".join(kwargs_list) - params = ", ".join([args, kwargs]) - return f"{self.visit(node.func)}({params})" - - Name = codegen.FormatTemplate("{id}") diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index a54ba24610..88c1b44792 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -15,7 +15,6 @@ from __future__ import annotations import dataclasses -import types from typing import Any, Generic, Optional, Protocol, TypeVar from gt4py.next import common diff --git a/src/gt4py/next/otf/transforms/__init__.py b/src/gt4py/next/otf/transforms/__init__.py index 896c9d93b4..32817c283b 100644 --- a/src/gt4py/next/otf/transforms/__init__.py +++ b/src/gt4py/next/otf/transforms/__init__.py @@ -1,3 +1,17 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + from gt4py.next.otf.transforms.past_to_itir import PastToItir, PastToItirFactory diff --git a/src/gt4py/next/otf/transforms/past_to_func.py b/src/gt4py/next/otf/transforms/past_to_func.py index e69de29bb2..02cdb5f7ff 100644 --- a/src/gt4py/next/otf/transforms/past_to_func.py +++ b/src/gt4py/next/otf/transforms/past_to_func.py @@ -0,0 +1,73 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import linecache +import textwrap + +import gt4py.next as gtx +from gt4py.eve import codegen +from gt4py.next.ffront import program_ast as past +from gt4py.next.otf import stages +from gt4py.next.type_system import type_info + + +def past_to_fun_def(past_closure: stages.PastClosure): + node = past_closure.past_node + inout_types_ls = [ + type_info.apply_to_primitive_constituents(arg.type, lambda primitive_type: primitive_type) + for arg in past_closure.args + ] + inout_types = list(type_info.flatten(inout_types_ls)) + dims = set( + i for j in [type_info.extract_dims(inout_type) for inout_type in inout_types] for i in j + ) + source_code = ProgamFuncGen.apply(node) + + filename = "" + globalns = {dim.value: dim for dim in dims} + # globalns[node.id] = node + globalns |= gtx.__dict__ + localns = {} + code_obj = compile(source_code, filename, "exec") + exec(code_obj, globalns, localns) + lines = [line + "\n" for line in source_code.splitlines()] + linecache.cache[filename] = (len(source_code), None, lines, filename) + function_definition = localns[str(node.id)] + linecache.cache[filename] = ( + len(source_code), + None, + [line + "\n" for line in source_code.splitlines()], + filename, + ) + return function_definition + + +class ProgamFuncGen(codegen.TemplatedGenerator): + def visit_Program(self, node: past.Program, **kwargs) -> str: + imports = "from __future__ import annotations\nfrom gt4py.next import *" + params = self.visit(node.params) + signature = ", ".join(params) + body = textwrap.indent("\n".join(self.visit(node.body)), prefix=" " * 4) + return f"{imports}\n\n\ndef {node.id}({signature}) -> None:\n{body}" + + Symbol = codegen.FormatTemplate("{id}: {type}") + + def visit_Call(self, node: past.Call, **kwargs) -> str: + args = ", ".join(self.visit(node.args)) + kwargs_list = [f"{name}={self.visit(value)}" for name, value in node.kwargs.items()] + kwargs = ", ".join(kwargs_list) + params = ", ".join([args, kwargs]) + return f"{self.visit(node.func)}({params})" + + Name = codegen.FormatTemplate("{id}") From 81e6a9914e8944bc7bc1740bed1d22f3ce4e17e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rico=20H=C3=A4uselmann?= Date: Wed, 6 Mar 2024 10:20:43 +0100 Subject: [PATCH 22/47] [wip] refactor backend to start from PAST --- src/gt4py/next/backend.py | 4 ---- src/gt4py/next/ffront/decorator.py | 9 +-------- src/gt4py/next/iterator/runtime.py | 5 ++--- src/gt4py/next/otf/transforms/past_to_itir.py | 5 +++-- .../next/program_processors/modular_executor.py | 4 ++++ .../next/program_processors/processor_interface.py | 4 +++- .../program_processors/runners/double_roundtrip.py | 2 +- .../next/program_processors/runners/roundtrip.py | 14 ++++++++++++-- tests/next_tests/integration_tests/cases.py | 4 ++-- .../feature_tests/ffront_tests/test_execution.py | 2 +- tests/next_tests/unit_tests/conftest.py | 5 +++++ .../test_decorator_domain_deduction.py | 2 +- 12 files changed, 35 insertions(+), 25 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index fed2686f08..80fe7f1043 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -37,10 +37,6 @@ def __call__(self, program: stages.PastClosure) -> None: def __name__(self) -> str: return getattr(self.executor, "__name__", None) or repr(self) - @property - def kind(self) -> type[ppi.ProgramExecutor]: - return self.executor.kind - @property def __gt_allocator__( self, diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 8f02a9036f..a7e63d527e 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -243,7 +243,7 @@ def itir(self) -> itir.FencilDefinition: def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> None: rewritten_args, size_args, kwargs = self._process_args(args, kwargs) - if self.backend is None or self.backend.executor is None: + if self.backend is None: warnings.warn( UserWarning( f"Field View Program '{self.itir.id}': Using Python execution, consider selecting a perfomance backend." @@ -265,11 +265,6 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No kwargs=kwargs | {"offset_provider": offset_provider, "column_axis": self._column_axis}, ), - *rewritten_args, - *size_args, - **kwargs, - offset_provider=offset_provider, - column_axis=self._column_axis, ) def format_itir( @@ -658,8 +653,6 @@ def as_program( untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars) past_node = ProgramTypeDeduction.apply(untyped_past_node) - self.past_to_func(arg_types, out_sym, past_node) - self._program_cache[hash_] = Program( past_node=past_node, closure_vars=closure_vars, diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index 651d6a3097..8209c6dd41 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -24,7 +24,6 @@ from gt4py.next.iterator import builtins from gt4py.next.iterator.builtins import BackendNotSelectedError, builtin_dispatch from gt4py.next.iterator.ir import FencilDefinition -from gt4py.next.otf import stages from gt4py.next.program_processors.processor_interface import ( ProgramExecutor, ProgramFormatter, @@ -92,8 +91,8 @@ def __call__(self, *args, backend: Optional[ProgramExecutor] = None, **kwargs): if backend is not None: ensure_processor_kind(backend, ProgramExecutor) - backend.executor( - stages.ProgramCall(program=self.itir(*args, **kwargs), args=args, kwargs=kwargs), + backend( + self.itir(*args, **kwargs), *args, **kwargs, ) diff --git a/src/gt4py/next/otf/transforms/past_to_itir.py b/src/gt4py/next/otf/transforms/past_to_itir.py index 2b95ed61b9..7e27bd9f3a 100644 --- a/src/gt4py/next/otf/transforms/past_to_itir.py +++ b/src/gt4py/next/otf/transforms/past_to_itir.py @@ -30,12 +30,13 @@ @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): def __call__(self, inp: stages.PastClosure) -> ProgramCall: + all_closure_vars = utils._get_closure_vars_recursively(inp.closure_vars) offsets_and_dimensions = utils._filter_closure_vars_by_type( - inp.closure_vars, FieldOffset, common.Dimension + all_closure_vars, FieldOffset, common.Dimension ) grid_type = utils._deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) - gt_callables = utils._filter_closure_vars_by_type(inp.closure_vars, GTCallable).values() + gt_callables = utils._filter_closure_vars_by_type(all_closure_vars, GTCallable).values() lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] itir_program = ProgramLowering.apply( diff --git a/src/gt4py/next/program_processors/modular_executor.py b/src/gt4py/next/program_processors/modular_executor.py index a5d958803a..b8032c17b8 100644 --- a/src/gt4py/next/program_processors/modular_executor.py +++ b/src/gt4py/next/program_processors/modular_executor.py @@ -35,3 +35,7 @@ def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None @property def __name__(self) -> str: return self.name or repr(self) + + @property + def kind(self) -> type[ppi.ProgramExecutor]: + return ppi.ProgramExecutor diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py index 37921cdd35..36870252c5 100644 --- a/src/gt4py/next/program_processors/processor_interface.py +++ b/src/gt4py/next/program_processors/processor_interface.py @@ -237,8 +237,10 @@ class ProgramBackend( def is_program_backend(obj: Callable) -> TypeGuard[ProgramBackend]: + if not hasattr(obj, "executor"): + return False return is_processor_kind( - obj, + obj.executor, ProgramExecutor, # type: ignore[type-abstract] # ProgramExecutor is abstract ) and next_allocators.is_field_allocator_factory(obj) diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index 1a09ba04e7..a05243bb43 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -22,7 +22,7 @@ backend = next_backend.Backend( transformer=otf_transforms.PastToItirFactory(), executor=roundtrip.RoundtripExecutorFactory( - dispatch_backend=roundtrip.execute_roundtrip, + dispatch_backend=roundtrip.RoundtripExecutorFactory(), ), allocator=roundtrip.backend.allocator, ) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 31f6b0d76d..abdacf2bf3 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -183,6 +183,10 @@ def fencil_generator( source_file.write("\n") source_file.write(program) + import black + + print(black.format_str(program, mode=black.mode.Mode())) + try: spec = importlib.util.spec_from_file_location("module.name", source_file_name) mod = importlib.util.module_from_spec(spec) # type: ignore @@ -204,6 +208,7 @@ def fencil_generator( return fencil +@ppi.program_executor def execute_roundtrip( ir: itir.Node, *args, @@ -258,7 +263,9 @@ class RoundtripExecutor(modular_executor.ModularExecutor): def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> None: kwargs["backend"] = self.dispatch_backend - super().__call__(program, *args, **kwargs) + self.otf_workflow(stages.ProgramCall(program=program, args=args, kwargs=kwargs))( + *args, **kwargs + ) class RoundtripExecutorFactory(factory.Factory): @@ -266,8 +273,11 @@ class Meta: model = RoundtripExecutor class Params: - roundtrip_workflow = factory.SubFactory(RoundtripFactory) + roundtrip_workflow = factory.SubFactory( + RoundtripFactory, dispatch_backend=factory.SelfAttribute("..dispatch_backend") + ) + dispatch_backend = None otf_workflow = factory.LazyAttribute(lambda o: o.roundtrip_workflow) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 4cb1cfceb4..4b50e21260 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -482,7 +482,7 @@ def cartesian_case( exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor, ): yield Case( - exec_alloc_descriptor, + exec_alloc_descriptor if exec_alloc_descriptor.executor else None, offset_provider={"Ioff": IDim, "Joff": JDim, "Koff": KDim}, default_sizes={IDim: 10, JDim: 10, KDim: 10}, grid_type=common.GridType.CARTESIAN, @@ -496,7 +496,7 @@ def unstructured_case( exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor, ): yield Case( - exec_alloc_descriptor, + exec_alloc_descriptor if exec_alloc_descriptor.executor else None, offset_provider=mesh_descriptor.offset_provider, default_sizes={ Vertex: mesh_descriptor.num_vertices, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 5db9886966..deeddc67bd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -637,7 +637,7 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_lift_expressions @pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): - if cartesian_case.executor == gtfn.run_gtfn_with_temporaries.executor: + if cartesian_case.executor == gtfn.run_gtfn_with_temporaries: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 17418a9ca6..6b1fe3d37e 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -88,6 +88,8 @@ def program_processor(request) -> tuple[ppi.ProgramProcessor, bool]: processor = processor_id.load() assert is_backend == ppi.is_program_backend(processor) + if is_backend: + processor = processor.executor for marker, skip_mark, msg in next_tests.definitions.BACKEND_SKIP_TEST_MATRIX.get( processor_id, [] @@ -104,6 +106,9 @@ def run_processor( *args, **kwargs, ) -> None: + import devtools + + devtools.debug(processor) if processor is None or ppi.is_processor_kind(processor, ppi.ProgramExecutor): program(*args, backend=processor, **kwargs) elif ppi.is_processor_kind(processor, ppi.ProgramFormatter): diff --git a/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py index 75e23545de..5620efc0c7 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py @@ -15,7 +15,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.ffront.decorator import _deduce_grid_type +from gt4py.next.otf.transforms.utils import _deduce_grid_type Dim = gtx.Dimension("Dim") From e6410e6122f89181f723d7011c3f5c7147ca2969 Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 6 Mar 2024 11:23:25 +0100 Subject: [PATCH 23/47] fixed some pre-commit errors --- src/gt4py/next/ffront/decorator.py | 31 ++++++++----------- src/gt4py/next/otf/transforms/past_to_func.py | 9 +++--- src/gt4py/next/otf/transforms/past_to_itir.py | 5 ++- src/gt4py/next/otf/transforms/utils.py | 16 +++++++++- .../program_processors/runners/roundtrip.py | 4 +-- 5 files changed, 38 insertions(+), 27 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 085b45ac85..5557b8f9f4 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -252,14 +252,19 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No ) with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: if self.definition is None: - self.definition = past_to_fun_def(stages.PastClosure( - closure_vars=self.closure_vars, - past_node=self.past_node, - grid_type=self.grid_type, - args=[*rewritten_args, *size_args], - kwargs=kwargs - | {"offset_provider": offset_provider, "column_axis": self._column_axis}, - )) + self.definition = past_to_fun_def( + stages.PastClosure( + closure_vars=self.closure_vars, + past_node=self.past_node, + grid_type=self.grid_type, + args=[*rewritten_args, *size_args], + kwargs=kwargs + | { + "offset_provider": offset_provider, + "column_axis": self._column_axis, + }, + ) + ) ctx.run(self.definition, *rewritten_args, **kwargs) return elif isinstance(self.backend, modular_executor.ModularExecutor): @@ -679,16 +684,6 @@ def as_program( untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars) past_node = ProgramTypeDeduction.apply(untyped_past_node) - # past_closure = stages.PastClosure( - # closure_vars=closure_vars, - # past_node=past_node, - # grid_type=self.grid_type, - # args=past_node.params, - # kwargs={}, - # ) - # - # past_to_fun_def(past_closure) - self._program_cache[hash_] = Program( past_node=past_node, closure_vars=closure_vars, diff --git a/src/gt4py/next/otf/transforms/past_to_func.py b/src/gt4py/next/otf/transforms/past_to_func.py index 02cdb5f7ff..915a99cc21 100644 --- a/src/gt4py/next/otf/transforms/past_to_func.py +++ b/src/gt4py/next/otf/transforms/past_to_func.py @@ -36,9 +36,8 @@ def past_to_fun_def(past_closure: stages.PastClosure): filename = "" globalns = {dim.value: dim for dim in dims} - # globalns[node.id] = node globalns |= gtx.__dict__ - localns = {} + localns: dict = {} code_obj = compile(source_code, filename, "exec") exec(code_obj, globalns, localns) lines = [line + "\n" for line in source_code.splitlines()] @@ -64,10 +63,10 @@ def visit_Program(self, node: past.Program, **kwargs) -> str: Symbol = codegen.FormatTemplate("{id}: {type}") def visit_Call(self, node: past.Call, **kwargs) -> str: - args = ", ".join(self.visit(node.args)) + args_joined = ", ".join(self.visit(node.args)) kwargs_list = [f"{name}={self.visit(value)}" for name, value in node.kwargs.items()] - kwargs = ", ".join(kwargs_list) - params = ", ".join([args, kwargs]) + kwargs_joined = ", ".join(kwargs_list) + params = ", ".join([args_joined, kwargs_joined]) return f"{self.visit(node.func)}({params})" Name = codegen.FormatTemplate("{id}") diff --git a/src/gt4py/next/otf/transforms/past_to_itir.py b/src/gt4py/next/otf/transforms/past_to_itir.py index d4f6a6cf7b..296ceaf4fd 100644 --- a/src/gt4py/next/otf/transforms/past_to_itir.py +++ b/src/gt4py/next/otf/transforms/past_to_itir.py @@ -23,13 +23,16 @@ from gt4py.next.ffront.source_utils import get_closure_vars_from_function from gt4py.next.otf import stages, workflow from gt4py.next.otf.stages import ProgramCall + from . import utils @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): def __call__(self, inp: stages.PastClosure) -> ProgramCall: - closure_vars = utils._get_closure_vars_recursively(get_closure_vars_from_function(inp.definition)) + closure_vars = utils._get_closure_vars_recursively( + get_closure_vars_from_function(inp.definition) + ) offsets_and_dimensions = utils._filter_closure_vars_by_type( closure_vars, FieldOffset, common.Dimension ) diff --git a/src/gt4py/next/otf/transforms/utils.py b/src/gt4py/next/otf/transforms/utils.py index ef4951beba..987598b21d 100644 --- a/src/gt4py/next/otf/transforms/utils.py +++ b/src/gt4py/next/otf/transforms/utils.py @@ -1,5 +1,19 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + import collections -from typing import Any, Optional, Iterable +from typing import Any, Iterable, Optional from gt4py.next import common from gt4py.next.ffront import fbuiltins diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 703a3e4f41..68181ea7aa 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -233,7 +233,7 @@ def execute_roundtrip( @dataclasses.dataclass(frozen=True) class Roundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): - debug: bool = None + debug: bool = False lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE dispatch_backend: Optional[ppi.ProgramExecutor] = None @@ -256,7 +256,7 @@ class Meta: class RoundtripExecutor(modular_executor.ModularExecutor): dispatch_backend: Optional[ppi.ProgramExecutor] = None - def __call__(self, program: stages.PastClosure, *args, **kwargs) -> None: + def __call__(self, program: stages.ProgramCall, *args, **kwargs) -> None: kwargs["backend"] = self.dispatch_backend super().__call__(program, *args, **kwargs) From 625c3372ecb330912df4fdbb69a601a3796deeca Mon Sep 17 00:00:00 2001 From: DropD Date: Wed, 6 Mar 2024 11:38:11 +0100 Subject: [PATCH 24/47] update remaining tests and test utils --- .../ffront_tests/ffront_test_utils.py | 12 ++++- .../test_temporaries_with_sizes.py | 7 ++- .../test_cartesian_offset_provider.py | 2 +- .../ffront_tests/test_icon_like_scan.py | 47 ++++++++++++------- .../iterator_tests/test_vertical_advection.py | 17 ++++--- .../test_processor_interface.py | 21 +++++++-- tox.ini | 2 +- 7 files changed, 75 insertions(+), 33 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 00a193be0b..051a9ece7d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -21,7 +21,7 @@ import pytest import gt4py.next as gtx -from gt4py.next import common +from gt4py.next import backend as next_backend, common from gt4py.next.ffront import decorator from gt4py.next.iterator import ir as itir from gt4py.next.program_processors import processor_interface as ppi @@ -40,11 +40,19 @@ @ppi.program_executor -def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: +def no_exec(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: """Temporary default backend to not accidentally test the wrong backend.""" raise ValueError("No backend selected! Backend selection is mandatory in tests.") +class NoBackend(next_backend.Backend): + def __call__(self, program) -> None: + raise ValueError("No backend selected! Backend selection is mandatory in tests.") + + +no_backend = NoBackend(executor=no_exec, transformer=None, allocator=None) + + OPTIONAL_PROCESSORS = [] if dace_iterator: OPTIONAL_PROCESSORS.append(next_tests.definitions.OptionalProgramBackendId.DACE_CPU) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 1bf640f93d..9d4cfc90a9 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -17,6 +17,7 @@ from gt4py import next as gtx from gt4py.next import backend, common +from gt4py.next.otf import transforms as otf_transforms from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.program_processors import modular_executor from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries @@ -39,6 +40,7 @@ @pytest.fixture def run_gtfn_with_temporaries_and_symbolic_sizes(): return backend.Backend( + transformer=otf_transforms.PastToItirFactory(), executor=modular_executor.ModularExecutor( name="run_gtfn_with_temporaries_and_sizes", otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace( @@ -77,7 +79,7 @@ def prog( def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh_descriptor): unstructured_case = Case( - run_gtfn_with_temporaries_and_symbolic_sizes.executor, + run_gtfn_with_temporaries_and_symbolic_sizes, offset_provider=mesh_descriptor.offset_provider, default_sizes={ Vertex: mesh_descriptor.num_vertices, @@ -92,7 +94,8 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh a = cases.allocate(unstructured_case, testee, "a")() out = cases.allocate(unstructured_case, testee, "out")() - first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].table[:, i] for i in [0, 1]) + first_nbs, second_nbs = ( + mesh_descriptor.offset_provider["E2V"].table[:, i] for i in [0, 1]) ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] cases.verify( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py index 5c80d9e415..b82cea4b22 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py @@ -64,5 +64,5 @@ def test_cartesian_offset_provider(): fencil(out, inp, backend=roundtrip.executor) assert out[0][0] == 42 - fencil(out, inp, backend=double_roundtrip.executor) + fencil(out, inp, backend=double_roundtrip.backend.executor) assert out[0][0] == 42 diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index f1a5b41f81..e23199d8ab 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -59,7 +59,8 @@ def _solve_nonhydro_stencil_52_like( z_q: gtx.Field[[Cell, KDim], float], w: gtx.Field[[Cell, KDim], float], ) -> tuple[ - gtx.Field[[Cell, KDim], float], gtx.Field[[Cell, KDim], float], gtx.Field[[Cell, KDim], bool] + gtx.Field[[Cell, KDim], float], gtx.Field[[ + Cell, KDim], float], gtx.Field[[Cell, KDim], bool] ]: """No projector required as we write all output of the scan (including dummy field)""" z_a = z_beta(Koff[-1]) * z_alpha(Koff[-1]) @@ -139,7 +140,8 @@ def solve_nonhydro_stencil_52_like_z_q( w: gtx.Field[[Cell, KDim], float], z_q_out: gtx.Field[[Cell, KDim], float], ): - _solve_nonhydro_stencil_52_like_z_q(z_alpha, z_beta, z_q, w, out=z_q_out[:, 1:]) + _solve_nonhydro_stencil_52_like_z_q( + z_alpha, z_beta, z_q, w, out=z_q_out[:, 1:]) @gtx.field_operator @@ -164,7 +166,8 @@ def solve_nonhydro_stencil_52_like_z_q_tup( w: gtx.Field[[Cell, KDim], float], z_q_out: gtx.Field[[Cell, KDim], float], ): - _solve_nonhydro_stencil_52_like_z_q_tup(z_alpha, z_beta, z_q, w, out=(z_q_out[:, 1:],)) + _solve_nonhydro_stencil_52_like_z_q_tup( + z_alpha, z_beta, z_q, w, out=(z_q_out[:, 1:],)) def reference( @@ -195,7 +198,7 @@ def reference( @pytest.fixture def test_setup(exec_alloc_descriptor): test_case = cases.Case( - exec_alloc_descriptor.executor, + exec_alloc_descriptor if exec_alloc_descriptor.executor else None, offset_provider={"Koff": KDim}, default_sizes={Cell: 14, KDim: 10}, grid_type=common.GridType.UNSTRUCTURED, @@ -208,20 +211,27 @@ class setup: cell_size = test_case.default_sizes[Cell] k_size = test_case.default_sizes[KDim] z_alpha = test_case.as_field( - [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size + 1)) + [Cell, KDim], np.random.default_rng().uniform( + size=(cell_size, k_size + 1)) ) z_beta = test_case.as_field( - [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) + [Cell, KDim], np.random.default_rng().uniform( + size=(cell_size, k_size)) ) z_q = test_case.as_field( - [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) + [Cell, KDim], np.random.default_rng().uniform( + size=(cell_size, k_size)) ) w = test_case.as_field( - [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) + [Cell, KDim], np.random.default_rng().uniform( + size=(cell_size, k_size)) ) - z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) - dummy = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) - z_q_out = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) + z_q_ref, w_ref = reference( + z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) + dummy = test_case.as_field( + [Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) + z_q_out = test_case.as_field( + [Cell, KDim], np.zeros((cell_size, k_size))) return setup() @@ -242,7 +252,8 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): comparison=lambda ref, a: np.allclose(ref[:, 1:], a[:, 1:]), ) - assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:].asnumpy()) + assert np.allclose( + test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:].asnumpy()) @pytest.mark.uses_tuple_returns @@ -256,7 +267,8 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): "again after CollapseTuple." ) if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load().executor: - pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") + pytest.xfail( + "Needs proper handling of tuple[Column] <-> Column[tuple].") cases.verify( test_setup.case, @@ -278,7 +290,8 @@ def test_solve_nonhydro_stencil_52_like(test_setup): test_setup.case.executor == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor ): - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") + pytest.xfail( + "Temporary extraction does not work correctly in combination with scans.") cases.run( test_setup.case, @@ -300,9 +313,11 @@ def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): test_setup.case.executor == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor ): - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") + pytest.xfail( + "Temporary extraction does not work correctly in combination with scans.") if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load().executor: - pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") + pytest.xfail( + "Needs proper handling of tuple[Column] <-> Column[tuple].") cases.run( test_setup.case, diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index b28c98ab38..4a95b0fdeb 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -30,7 +30,8 @@ def tridiag_forward(state, a, b, c, d): return make_tuple( deref(c) / (deref(b) - deref(a) * tuple_get(0, state)), - (deref(d) - deref(a) * tuple_get(1, state)) / (deref(b) - deref(a) * tuple_get(0, state)), + (deref(d) - deref(a) * tuple_get(1, state)) / + (deref(b) - deref(a) * tuple_get(0, state)), ) @@ -119,15 +120,19 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): if ( program_processor in [ - gtfn.run_gtfn, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, + gtfn.run_gtfn.executor, + gtfn.run_gtfn_imperative.executor, + gtfn.run_gtfn_with_temporaries.executor, gtfn_formatters.format_cpp, ] and lift_mode == LiftMode.FORCE_INLINE ): - pytest.skip("gtfn does only support lifted scans when using temporaries") - if program_processor == gtfn.run_gtfn_with_temporaries or lift_mode == LiftMode.USE_TEMPORARIES: + pytest.skip( + "gtfn does only support lifted scans when using temporaries") + if ( + program_processor == gtfn.run_gtfn_with_temporaries.executor + or lift_mode == LiftMode.USE_TEMPORARIES + ): pytest.xfail("tuple_get on columns not supported.") a, b, c, d, x = tridiag_reference shape = a.shape diff --git a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py index 1ba35da7c6..cb80d94c3f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py +++ b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import dataclasses + import pytest import gt4py.next.allocators as next_allocators @@ -48,13 +50,15 @@ def other_func(program: itir.FencilDefinition, *args, **kwargs) -> str: assert processor.__name__ == "new_name" assert processor(None) == other_func(None) assert processor(1, 2, a="A", b="B") == other_func(1, 2, a="A", b="B") - assert processor(1, 2, 3, 4, a="A", b="B", c="C") != other_func(1, 2, 3, 4, a="A", b="B", c="C") + assert processor(1, 2, 3, 4, a="A", b="B", c="C") != other_func( + 1, 2, 3, 4, a="A", b="B", c="C") with pytest.raises(ValueError, match="accepted arguments cannot be a negative number"): make_program_processor(my_func, ProgramFormatter, accept_args=-1) with pytest.raises(ValueError, match="invalid list of keyword argument names"): - make_program_processor(my_func, ProgramFormatter, accept_kwargs=["a", None]) + make_program_processor(my_func, ProgramFormatter, + accept_kwargs=["a", None]) @pytest.fixture @@ -95,8 +99,15 @@ class DummyAllocatorFactory: assert not is_program_backend(DummyAllocatorFactory()) - class DummyBackend(DummyProgramExecutor, DummyAllocatorFactory): - def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> None: - return None + @dataclasses.dataclass + class DummyBackend: + executor: DummyProgramExecutor = dataclasses.field( + default_factory=DummyProgramExecutor) + allocator: DummyAllocatorFactory = dataclasses.field( + default_factory=DummyAllocatorFactory) + + @property + def __gt_allocator__(self): + return self.allocator.__gt_allocator__ assert is_program_backend(DummyBackend()) diff --git a/tox.ini b/tox.ini index 4da10f7d8d..ff67764b05 100644 --- a/tox.ini +++ b/tox.ini @@ -43,7 +43,7 @@ extras = cuda12x: cuda12x package = wheel wheel_build_env = .pkg -pass_env = NUM_PROCESSES +pass_env = NUM_PROCESSES, GT4PY_BUILD_CACHE_LIFETIME, GT4PY_BUILD_CACHE_DIR set_env = PYTHONWARNINGS = {env:PYTHONWARNINGS:ignore:Support for `[tool.setuptools]` in `pyproject.toml` is still *beta*:UserWarning} From 1561581debf3633421b2c9bcdc9e8d2c9f8cc96e Mon Sep 17 00:00:00 2001 From: DropD Date: Wed, 6 Mar 2024 13:44:15 +0100 Subject: [PATCH 25/47] tests: fix xfail conditions in test_icon_like_scan --- .../ffront_tests/test_icon_like_scan.py | 55 +++++++------------ 1 file changed, 20 insertions(+), 35 deletions(-) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index e23199d8ab..8da95712f4 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -59,8 +59,7 @@ def _solve_nonhydro_stencil_52_like( z_q: gtx.Field[[Cell, KDim], float], w: gtx.Field[[Cell, KDim], float], ) -> tuple[ - gtx.Field[[Cell, KDim], float], gtx.Field[[ - Cell, KDim], float], gtx.Field[[Cell, KDim], bool] + gtx.Field[[Cell, KDim], float], gtx.Field[[Cell, KDim], float], gtx.Field[[Cell, KDim], bool] ]: """No projector required as we write all output of the scan (including dummy field)""" z_a = z_beta(Koff[-1]) * z_alpha(Koff[-1]) @@ -140,8 +139,7 @@ def solve_nonhydro_stencil_52_like_z_q( w: gtx.Field[[Cell, KDim], float], z_q_out: gtx.Field[[Cell, KDim], float], ): - _solve_nonhydro_stencil_52_like_z_q( - z_alpha, z_beta, z_q, w, out=z_q_out[:, 1:]) + _solve_nonhydro_stencil_52_like_z_q(z_alpha, z_beta, z_q, w, out=z_q_out[:, 1:]) @gtx.field_operator @@ -166,8 +164,7 @@ def solve_nonhydro_stencil_52_like_z_q_tup( w: gtx.Field[[Cell, KDim], float], z_q_out: gtx.Field[[Cell, KDim], float], ): - _solve_nonhydro_stencil_52_like_z_q_tup( - z_alpha, z_beta, z_q, w, out=(z_q_out[:, 1:],)) + _solve_nonhydro_stencil_52_like_z_q_tup(z_alpha, z_beta, z_q, w, out=(z_q_out[:, 1:],)) def reference( @@ -211,27 +208,20 @@ class setup: cell_size = test_case.default_sizes[Cell] k_size = test_case.default_sizes[KDim] z_alpha = test_case.as_field( - [Cell, KDim], np.random.default_rng().uniform( - size=(cell_size, k_size + 1)) + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size + 1)) ) z_beta = test_case.as_field( - [Cell, KDim], np.random.default_rng().uniform( - size=(cell_size, k_size)) + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) ) z_q = test_case.as_field( - [Cell, KDim], np.random.default_rng().uniform( - size=(cell_size, k_size)) + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) ) w = test_case.as_field( - [Cell, KDim], np.random.default_rng().uniform( - size=(cell_size, k_size)) + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) ) - z_q_ref, w_ref = reference( - z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) - dummy = test_case.as_field( - [Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) - z_q_out = test_case.as_field( - [Cell, KDim], np.zeros((cell_size, k_size))) + z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) + dummy = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) + z_q_out = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) return setup() @@ -252,23 +242,21 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): comparison=lambda ref, a: np.allclose(ref[:, 1:], a[:, 1:]), ) - assert np.allclose( - test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:].asnumpy()) + assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:].asnumpy()) @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): if ( test_setup.case.executor - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor + == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() ): pytest.xfail( "Needs implementation of scan projector. Breaks in type inference as executed" "again after CollapseTuple." ) - if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load().executor: - pytest.xfail( - "Needs proper handling of tuple[Column] <-> Column[tuple].") + if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load(): + pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") cases.verify( test_setup.case, @@ -288,10 +276,9 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): def test_solve_nonhydro_stencil_52_like(test_setup): if ( test_setup.case.executor - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor + == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() ): - pytest.xfail( - "Temporary extraction does not work correctly in combination with scans.") + pytest.xfail("Temporary extraction does not work correctly in combination with scans.") cases.run( test_setup.case, @@ -311,13 +298,11 @@ def test_solve_nonhydro_stencil_52_like(test_setup): def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): if ( test_setup.case.executor - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor + == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() ): - pytest.xfail( - "Temporary extraction does not work correctly in combination with scans.") - if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load().executor: - pytest.xfail( - "Needs proper handling of tuple[Column] <-> Column[tuple].") + pytest.xfail("Temporary extraction does not work correctly in combination with scans.") + if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load(): + pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") cases.run( test_setup.case, From 413aea364717d9aa9e7965d5bd7b3c529f161395 Mon Sep 17 00:00:00 2001 From: DropD Date: Wed, 6 Mar 2024 13:47:29 +0100 Subject: [PATCH 26/47] tests: fix xfail conditions in test_anton_toy --- .../multi_feature_tests/iterator_tests/test_anton_toy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 9a1bc6deb6..bcea9e0901 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -83,9 +83,9 @@ def test_anton_toy(program_processor, lift_mode): program_processor, validate = program_processor if program_processor in [ - gtfn.run_gtfn, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, + gtfn.run_gtfn.executor, + gtfn.run_gtfn_imperative.executor, + gtfn.run_gtfn_with_temporaries.executor, ]: from gt4py.next.iterator import transforms From 86765b49c3443a71d66d428d7ce6e14a50de068a Mon Sep 17 00:00:00 2001 From: DropD Date: Wed, 6 Mar 2024 13:49:04 +0100 Subject: [PATCH 27/47] tests: fix xfail conditions in test_hdiff --- .../multi_feature_tests/iterator_tests/test_hdiff.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 9bba1ab89c..5d369c3a8f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -75,9 +75,9 @@ def hdiff(inp, coeff, out, x, y): def test_hdiff(hdiff_reference, program_processor, lift_mode): program_processor, validate = program_processor if program_processor in [ - gtfn.run_gtfn, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, + gtfn.run_gtfn.executor, + gtfn.run_gtfn_imperative.executor, + gtfn.run_gtfn_with_temporaries.executor, ]: # TODO(tehrengruber): check if still true from gt4py.next.iterator import transforms From 08072082aaa5d9ad8c015671acb6a4ca12353a2f Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 6 Mar 2024 14:25:43 +0100 Subject: [PATCH 28/47] edits and pre-commit --- src/gt4py/next/ffront/decorator.py | 109 +++++++++--------- src/gt4py/next/otf/transforms/past_to_func.py | 13 ++- src/gt4py/next/otf/transforms/past_to_itir.py | 13 +-- .../program_processors/runners/roundtrip.py | 26 ++--- .../test_temporaries_with_sizes.py | 3 +- .../iterator_tests/test_vertical_advection.py | 6 +- .../test_processor_interface.py | 12 +- 7 files changed, 82 insertions(+), 100 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 08bd0896cc..20714c5023 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -177,8 +177,7 @@ def __gt_allocator__( if self.backend: return self.backend.__gt_allocator__ else: - raise RuntimeError( - f"Program '{self}' does not have a backend set.") + raise RuntimeError(f"Program '{self}' does not have a backend set.") def with_backend(self, backend: ppi.ProgramExecutor) -> Program: return dataclasses.replace(self, backend=backend) @@ -213,8 +212,7 @@ def with_bound_args(self, **kwargs) -> ProgramWithBoundArgs: """ for key in kwargs.keys(): if all(key != param.id for param in self.past_node.params): - raise TypeError( - f"Keyword argument '{key}' is not a valid program parameter.") + raise TypeError(f"Keyword argument '{key}' is not a valid program parameter.") return ProgramWithBoundArgs( bound_args=kwargs, @@ -237,8 +235,7 @@ def itir(self) -> itir.FencilDefinition: gt_callables = otf_transforms.utils._filter_closure_vars_by_type( self._all_closure_vars, GTCallable ).values() - lowered_funcs = [gt_callable.__gt_itir__() - for gt_callable in gt_callables] + lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] return ProgramLowering.apply( self.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) @@ -255,12 +252,12 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No ) with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: if self.definition is None: - self.definition = past_to_fun_def( + definition = past_to_fun_def( stages.PastClosure( closure_vars=self.closure_vars, past_node=self.past_node, grid_type=self.grid_type, - args=[*rewritten_args, *size_args], + args=args, kwargs=kwargs | { "offset_provider": offset_provider, @@ -268,20 +265,39 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No }, ) ) - ctx.run(self.definition, *rewritten_args, **kwargs) + else: + definition = self.definition + ctx.run(definition, *rewritten_args, **kwargs) + return + elif isinstance(self.backend, modular_executor.ModularExecutor): + self.backend( + stages.PastClosure( + closure_vars=self.closure_vars, + past_node=self.past_node, + grid_type=self.grid_type, + args=[*rewritten_args, *size_args], + kwargs=kwargs + | {"offset_provider": offset_provider, "column_axis": self._column_axis}, + ), + *rewritten_args, + *size_args, + **kwargs, + offset_provider=offset_provider, + column_axis=self._column_axis, + ) return - ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor) + ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor) + if "debug" in kwargs: + debug(self.itir) self.backend( - stages.PastClosure( - closure_vars=self.closure_vars, - past_node=self.past_node, - grid_type=self.grid_type, - args=[*rewritten_args, *size_args], - kwargs=kwargs - | {"offset_provider": offset_provider, "column_axis": self._column_axis}, - ), + self.itir, + *rewritten_args, + *size_args, + **kwargs, + offset_provider=offset_provider, + column_axis=self._column_axis, ) def format_itir( @@ -305,8 +321,7 @@ def format_itir( def _validate_args(self, *args, **kwargs) -> None: arg_types = [type_translation.from_value(arg) for arg in args] - kwarg_types = {k: type_translation.from_value( - v) for k, v in kwargs.items()} + kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()} try: type_info.accepts_args( @@ -323,8 +338,7 @@ def _validate_args(self, *args, **kwargs) -> None: def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: self._validate_args(*args, **kwargs) - args, kwargs = type_info.canonicalize_arguments( - self.past_node.type, args, kwargs) + args, kwargs = type_info.canonicalize_arguments(self.past_node.type, args, kwargs) implicit_domain = any( isinstance(stmt, past.Call) and "domain" not in stmt.kwargs @@ -336,8 +350,7 @@ def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[s rewritten_args = list(args) for param_idx, param in enumerate(self.past_node.params): if implicit_domain and isinstance(param.type, (ts.FieldType, ts.TupleType)): - shapes_and_dims = [ - *_field_constituents_shape_and_dims(args[param_idx], param.type)] + shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)] shape, dims = shapes_and_dims[0] if not all( el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims @@ -406,16 +419,14 @@ def _process_args(self, args: tuple, kwargs: dict): ) arg_types = [type_translation.from_value(arg) for arg in args] - kwarg_types = {k: type_translation.from_value( - v) for k, v in kwargs.items()} + kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()} try: # This error is also catched using `accepts_args`, but we do it manually here to give # a better error message. for name in self.bound_args.keys(): if name in kwargs: - raise ValueError( - f"Parameter '{name}' already set as a bound argument.") + raise ValueError(f"Parameter '{name}' already set as a bound argument.") type_info.accepts_args( new_type, @@ -424,8 +435,7 @@ def _process_args(self, args: tuple, kwargs: dict): raise_exception=True, ) except ValueError as err: - bound_arg_names = ", ".join( - [f"'{bound_arg}'" for bound_arg in self.bound_args.keys()]) + bound_arg_names = ", ".join([f"'{bound_arg}'" for bound_arg in self.bound_args.keys()]) raise TypeError( f"Invalid argument types in call to program '{self.past_node.id}' with " f"bound arguments '{bound_arg_names}'." @@ -479,8 +489,7 @@ def program( def program( definition=None, *, - # `NOTHING` -> default backend, `None` -> no backend (embedded execution) - backend=eve.NOTHING, + backend=eve.NOTHING, # `NOTHING` -> default backend, `None` -> no backend (embedded execution) grid_type=None, ) -> Program | Callable[[types.FunctionType], Program]: """ @@ -559,12 +568,10 @@ def from_function( source_def = SourceDefinition.from_function(definition) closure_vars = get_closure_vars_from_function(definition) annotations = typing.get_type_hints(definition) - foast_definition_node = FieldOperatorParser.apply( - source_def, closure_vars, annotations) + foast_definition_node = FieldOperatorParser.apply(source_def, closure_vars, annotations) loc = foast_definition_node.location operator_attribute_nodes = { - key: foast.Constant( - value=value, type=type_translation.from_value(value), location=loc) + key: foast.Constant(value=value, type=type_translation.from_value(value), location=loc) for key, value in operator_attributes.items() } untyped_foast_node = operator_node_cls( @@ -602,8 +609,7 @@ def __gt_itir__(self) -> itir.FunctionDefinition: if hasattr(self, "__cached_itir"): return getattr(self, "__cached_itir") - itir_node: itir.FunctionDefinition = FieldOperatorLowering.apply( - self.foast_node) + itir_node: itir.FunctionDefinition = FieldOperatorLowering.apply(self.foast_node) object.__setattr__(self, "__cached_itir", itir_node) @@ -629,8 +635,7 @@ def as_program( pass loc = self.foast_node.location - # use a new UID generator to allow caching - param_sym_uids = eve_utils.UIDGenerator() + param_sym_uids = eve_utils.UIDGenerator() # use a new UID generator to allow caching type_ = self.__gt_type__() params_decl: list[past.Symbol] = [ @@ -642,20 +647,17 @@ def as_program( ) for arg_type in arg_types ] - params_ref = [past.Name(id=pdecl.id, location=loc) - for pdecl in params_decl] + params_ref = [past.Name(id=pdecl.id, location=loc) for pdecl in params_decl] out_sym: past.Symbol = past.DataSymbol( id="out", - type=type_info.return_type( - type_, with_args=arg_types, with_kwargs=kwarg_types), + type=type_info.return_type(type_, with_args=arg_types, with_kwargs=kwarg_types), namespace=dialect_ast_enums.Namespace.LOCAL, location=loc, ) out_ref = past.Name(id="out", location=loc) if self.foast_node.id in self.closure_vars: - raise RuntimeError( - "A closure variable has the same name as the field operator itself.") + raise RuntimeError("A closure variable has the same name as the field operator itself.") closure_vars = {self.foast_node.id: self} closure_symbols = [ past.Symbol( @@ -681,8 +683,7 @@ def as_program( closure_vars=closure_symbols, location=loc, ) - untyped_past_node = ProgramClosureVarTypeDeduction.apply( - untyped_past_node, closure_vars) + untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars) past_node = ProgramTypeDeduction.apply(untyped_past_node) self._program_cache[hash_] = Program( @@ -703,15 +704,13 @@ def __call__( if not next_embedded.context.within_context() and self.backend is not None: # non embedded execution if "offset_provider" not in kwargs: - raise errors.MissingArgumentError( - None, "offset_provider", True) + raise errors.MissingArgumentError(None, "offset_provider", True) offset_provider = kwargs.pop("offset_provider") if "out" not in kwargs: raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") - args, kwargs = type_info.canonicalize_arguments( - self.foast_node.type, args, kwargs) + args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) # TODO(tehrengruber): check all offset providers are given # deduce argument types arg_types = [] @@ -735,8 +734,7 @@ def __call__( forward = self.operator_attributes["forward"] init = self.operator_attributes["init"] axis = self.operator_attributes["axis"] - op = embedded_operators.ScanOperator( - self.definition, forward, init, axis) + op = embedded_operators.ScanOperator(self.definition, forward, init, axis) else: op = embedded_operators.EmbeddedOperator(self.definition) return embedded_operators.field_operator_call(op, args, kwargs) @@ -849,8 +847,7 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type, operator_node_cls=foast.ScanOperator, - operator_attributes={"axis": axis, - "forward": forward, "init": init}, + operator_attributes={"axis": axis, "forward": forward, "init": init}, ) return scan_operator_inner if definition is None else scan_operator_inner(definition) diff --git a/src/gt4py/next/otf/transforms/past_to_func.py b/src/gt4py/next/otf/transforms/past_to_func.py index 915a99cc21..ddaf18addc 100644 --- a/src/gt4py/next/otf/transforms/past_to_func.py +++ b/src/gt4py/next/otf/transforms/past_to_func.py @@ -17,18 +17,20 @@ import gt4py.next as gtx from gt4py.eve import codegen -from gt4py.next.ffront import program_ast as past +from gt4py.next.ffront import program_ast as past, type_translation from gt4py.next.otf import stages from gt4py.next.type_system import type_info def past_to_fun_def(past_closure: stages.PastClosure): node = past_closure.past_node - inout_types_ls = [ - type_info.apply_to_primitive_constituents(arg.type, lambda primitive_type: primitive_type) - for arg in past_closure.args + arg_types = [type_translation.from_value(arg) for arg in past_closure.args] + kwarg_types = [ + type_translation.from_value(v) + for k, v in past_closure.kwargs.items() + if k not in ("offset_provider", "column_axis") ] - inout_types = list(type_info.flatten(inout_types_ls)) + inout_types = list(type_info.flatten(arg_types + kwarg_types)) dims = set( i for j in [type_info.extract_dims(inout_type) for inout_type in inout_types] for i in j ) @@ -37,6 +39,7 @@ def past_to_fun_def(past_closure: stages.PastClosure): filename = "" globalns = {dim.value: dim for dim in dims} globalns |= gtx.__dict__ + globalns |= past_closure.closure_vars localns: dict = {} code_obj = compile(source_code, filename, "exec") exec(code_obj, globalns, localns) diff --git a/src/gt4py/next/otf/transforms/past_to_itir.py b/src/gt4py/next/otf/transforms/past_to_itir.py index 3fcd13ca49..7e27bd9f3a 100644 --- a/src/gt4py/next/otf/transforms/past_to_itir.py +++ b/src/gt4py/next/otf/transforms/past_to_itir.py @@ -14,6 +14,7 @@ import dataclasses +import devtools import factory from gt4py.next import common, config @@ -29,18 +30,14 @@ @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): def __call__(self, inp: stages.PastClosure) -> ProgramCall: - all_closure_vars = utils._get_closure_vars_recursively( - inp.closure_vars) + all_closure_vars = utils._get_closure_vars_recursively(inp.closure_vars) offsets_and_dimensions = utils._filter_closure_vars_by_type( all_closure_vars, FieldOffset, common.Dimension ) - grid_type = utils._deduce_grid_type( - inp.grid_type, offsets_and_dimensions.values()) + grid_type = utils._deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) - gt_callables = utils._filter_closure_vars_by_type( - all_closure_vars, GTCallable).values() - lowered_funcs = [gt_callable.__gt_itir__() - for gt_callable in gt_callables] + gt_callables = utils._filter_closure_vars_by_type(all_closure_vars, GTCallable).values() + lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] itir_program = ProgramLowering.apply( inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 7a18437096..ddef29c181 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -55,8 +55,7 @@ class EmbeddedDSL(codegen.TemplatedGenerator): AxisLiteral = as_fmt("{value}") FunCall = as_fmt("{fun}({','.join(args)})") Lambda = as_mako("(lambda ${','.join(params)}: ${expr})") - StencilClosure = as_mako( - "closure(${domain}, ${stencil}, ${output}, [${','.join(inputs)}])") + StencilClosure = as_mako("closure(${domain}, ${stencil}, ${output}, [${','.join(inputs)}])") FencilDefinition = as_mako( """ ${''.join(function_definitions)} @@ -95,16 +94,13 @@ def visit_Temporary(self, node, **kwargs): "unstructured_domain", ) assert all( - isinstance(r, itir.FunCall) and r.fun == itir.SymRef( - id="named_range") + isinstance(r, itir.FunCall) and r.fun == itir.SymRef(id="named_range") for r in node.domain.args ) domain_ranges = [self.visit(r.args) for r in node.domain.args] axes = ", ".join(label for label, _, _ in domain_ranges) - origin = "{" + ", ".join(f"{label}: -{start}" for label, - start, _ in domain_ranges) + "}" - shape = "(" + ", ".join(f"{stop}-{start}" for _, - start, stop in domain_ranges) + ")" + origin = "{" + ", ".join(f"{label}: -{start}" for label, start, _ in domain_ranges) + "}" + shape = "(" + ", ".join(f"{stop}-{start}" for _, start, stop in domain_ranges) + ")" return f"{node.id} = {_create_tmp(axes, origin, shape, node.dtype)}" @@ -134,8 +130,7 @@ def fencil_generator( """ # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism - cache_key = hash((ir, lift_mode, debug, use_embedded, - tuple(offset_provider.items()))) + cache_key = hash((ir, lift_mode, debug, use_embedded, tuple(offset_provider.items()))) if cache_key in _FENCIL_CACHE: return _FENCIL_CACHE[cache_key] @@ -189,16 +184,14 @@ def fencil_generator( source_file.write(program) try: - spec = importlib.util.spec_from_file_location( - "module.name", source_file_name) + spec = importlib.util.spec_from_file_location("module.name", source_file_name) mod = importlib.util.module_from_spec(spec) # type: ignore spec.loader.exec_module(mod) # type: ignore finally: if not debug: pathlib.Path(source_file_name).unlink(missing_ok=True) - assert isinstance(ir, (itir.FencilDefinition, - gtmps_transform.FencilWithTemporaries)) + assert isinstance(ir, (itir.FencilDefinition, gtmps_transform.FencilWithTemporaries)) fencil_name = ( ir.fencil.id + "_wrapper" if isinstance(ir, gtmps_transform.FencilWithTemporaries) @@ -211,7 +204,7 @@ def fencil_generator( return fencil -@ppi.program_executor +@ppi.program_executor # type: ignore[arg-type] def execute_roundtrip( ir: itir.Node, *args, @@ -277,8 +270,7 @@ class Meta: class Params: roundtrip_workflow = factory.SubFactory( - RoundtripFactory, dispatch_backend=factory.SelfAttribute( - "..dispatch_backend") + RoundtripFactory, dispatch_backend=factory.SelfAttribute("..dispatch_backend") ) dispatch_backend = None diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 9d4cfc90a9..2dd103a3fa 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -94,8 +94,7 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh a = cases.allocate(unstructured_case, testee, "a")() out = cases.allocate(unstructured_case, testee, "out")() - first_nbs, second_nbs = ( - mesh_descriptor.offset_provider["E2V"].table[:, i] for i in [0, 1]) + first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].table[:, i] for i in [0, 1]) ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] cases.verify( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 4a95b0fdeb..46038832d1 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -30,8 +30,7 @@ def tridiag_forward(state, a, b, c, d): return make_tuple( deref(c) / (deref(b) - deref(a) * tuple_get(0, state)), - (deref(d) - deref(a) * tuple_get(1, state)) / - (deref(b) - deref(a) * tuple_get(0, state)), + (deref(d) - deref(a) * tuple_get(1, state)) / (deref(b) - deref(a) * tuple_get(0, state)), ) @@ -127,8 +126,7 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): ] and lift_mode == LiftMode.FORCE_INLINE ): - pytest.skip( - "gtfn does only support lifted scans when using temporaries") + pytest.skip("gtfn does only support lifted scans when using temporaries") if ( program_processor == gtfn.run_gtfn_with_temporaries.executor or lift_mode == LiftMode.USE_TEMPORARIES diff --git a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py index cb80d94c3f..dc09d9fe76 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py +++ b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py @@ -50,15 +50,13 @@ def other_func(program: itir.FencilDefinition, *args, **kwargs) -> str: assert processor.__name__ == "new_name" assert processor(None) == other_func(None) assert processor(1, 2, a="A", b="B") == other_func(1, 2, a="A", b="B") - assert processor(1, 2, 3, 4, a="A", b="B", c="C") != other_func( - 1, 2, 3, 4, a="A", b="B", c="C") + assert processor(1, 2, 3, 4, a="A", b="B", c="C") != other_func(1, 2, 3, 4, a="A", b="B", c="C") with pytest.raises(ValueError, match="accepted arguments cannot be a negative number"): make_program_processor(my_func, ProgramFormatter, accept_args=-1) with pytest.raises(ValueError, match="invalid list of keyword argument names"): - make_program_processor(my_func, ProgramFormatter, - accept_kwargs=["a", None]) + make_program_processor(my_func, ProgramFormatter, accept_kwargs=["a", None]) @pytest.fixture @@ -101,10 +99,8 @@ class DummyAllocatorFactory: @dataclasses.dataclass class DummyBackend: - executor: DummyProgramExecutor = dataclasses.field( - default_factory=DummyProgramExecutor) - allocator: DummyAllocatorFactory = dataclasses.field( - default_factory=DummyAllocatorFactory) + executor: DummyProgramExecutor = dataclasses.field(default_factory=DummyProgramExecutor) + allocator: DummyAllocatorFactory = dataclasses.field(default_factory=DummyAllocatorFactory) @property def __gt_allocator__(self): From c748ad76d42e083b85d9bfcb5c59c22d89ed46ae Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 6 Mar 2024 15:08:17 +0100 Subject: [PATCH 29/47] fixed wrong code merge --- src/gt4py/next/ffront/decorator.py | 33 ++++++++---------------------- 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 20714c5023..8cd072de3b 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -67,7 +67,7 @@ ) from gt4py.next.otf import stages, transforms as otf_transforms from gt4py.next.otf.transforms.past_to_func import past_to_fun_def -from gt4py.next.program_processors import modular_executor, processor_interface as ppi +from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -269,35 +269,20 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No definition = self.definition ctx.run(definition, *rewritten_args, **kwargs) return - elif isinstance(self.backend, modular_executor.ModularExecutor): - self.backend( - stages.PastClosure( - closure_vars=self.closure_vars, - past_node=self.past_node, - grid_type=self.grid_type, - args=[*rewritten_args, *size_args], - kwargs=kwargs - | {"offset_provider": offset_provider, "column_axis": self._column_axis}, - ), - *rewritten_args, - *size_args, - **kwargs, - offset_provider=offset_provider, - column_axis=self._column_axis, - ) - return ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor) if "debug" in kwargs: debug(self.itir) self.backend( - self.itir, - *rewritten_args, - *size_args, - **kwargs, - offset_provider=offset_provider, - column_axis=self._column_axis, + stages.PastClosure( + closure_vars=self.closure_vars, + past_node=self.past_node, + grid_type=self.grid_type, + args=[*rewritten_args, *size_args], + kwargs=kwargs + | {"offset_provider": offset_provider, "column_axis": self._column_axis}, + ) ) def format_itir( From 00bf734145e06f5a9ab3fc287abb6d70afb1c1db Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 6 Mar 2024 15:18:44 +0100 Subject: [PATCH 30/47] small edit --- src/gt4py/next/ffront/decorator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 8cd072de3b..86f9113fe2 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -271,8 +271,6 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No return ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor) - if "debug" in kwargs: - debug(self.itir) self.backend( stages.PastClosure( From 12b90fc11976cd700c1e26a5dbdeadb8bedc47da Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Wed, 6 Mar 2024 15:19:49 +0100 Subject: [PATCH 31/47] small edit --- src/gt4py/next/ffront/decorator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 86f9113fe2..708837ed52 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -270,7 +270,7 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No ctx.run(definition, *rewritten_args, **kwargs) return - ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor) + ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor) self.backend( stages.PastClosure( From 2bd16f51684e10ede9b86970fd47e405c89b184e Mon Sep 17 00:00:00 2001 From: DropD Date: Thu, 7 Mar 2024 15:32:12 +0100 Subject: [PATCH 32/47] start factoring _process_args out from Program --- src/gt4py/next/backend.py | 6 +- src/gt4py/next/ffront/decorator.py | 45 ++++---- src/gt4py/next/otf/transforms/__init__.py | 19 ++- .../next/otf/transforms/past_process_args.py | 109 ++++++++++++++++++ src/gt4py/next/otf/transforms/past_to_itir.py | 17 ++- .../runners/dace_iterator/__init__.py | 7 +- .../runners/double_roundtrip.py | 2 +- .../next/program_processors/runners/gtfn.py | 2 +- .../program_processors/runners/roundtrip.py | 2 +- .../ffront_tests/test_arg_call_interface.py | 6 +- .../test_temporaries_with_sizes.py | 2 +- 11 files changed, 170 insertions(+), 47 deletions(-) create mode 100644 src/gt4py/next/otf/transforms/past_process_args.py diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 80fe7f1043..ee5ab4ebcd 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -19,15 +19,17 @@ from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators -from gt4py.next.otf import stages, workflow +from gt4py.next.otf import stages, transforms as otf_transforms, workflow from gt4py.next.program_processors import processor_interface as ppi @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): - transformer: workflow.Workflow[stages.PastClosure, stages.ProgramCall] executor: ppi.ProgramExecutor allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] + transformer: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = ( + otf_transforms.DEFAULT_TRANSFORMS + ) def __call__(self, program: stages.PastClosure) -> None: program_call = self.transformer(program) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 708837ed52..753c3b0d9a 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -46,7 +46,6 @@ program_ast as past, type_specifications as ts_ffront, ) -from gt4py.next.ffront.fbuiltins import FieldOffset from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering from gt4py.next.ffront.func_to_foast import FieldOperatorParser @@ -56,7 +55,6 @@ ClosureVarTypeDeduction as ProgramClosureVarTypeDeduction, ) from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction -from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils.ir_makers import ( @@ -66,7 +64,6 @@ sym, ) from gt4py.next.otf import stages, transforms as otf_transforms -from gt4py.next.otf.transforms.past_to_func import past_to_fun_def from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -225,24 +222,17 @@ def _all_closure_vars(self) -> dict[str, Any]: @functools.cached_property def itir(self) -> itir.FencilDefinition: - offsets_and_dimensions = otf_transforms.utils._filter_closure_vars_by_type( - self._all_closure_vars, FieldOffset, Dimension - ) - grid_type = otf_transforms.utils._deduce_grid_type( - self.grid_type, offsets_and_dimensions.values() - ) - - gt_callables = otf_transforms.utils._filter_closure_vars_by_type( - self._all_closure_vars, GTCallable - ).values() - lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] - return ProgramLowering.apply( - self.past_node, function_definitions=lowered_funcs, grid_type=grid_type - ) + return otf_transforms.PastToItirFactory()( + stages.PastClosure( + past_node=self.past_node, + closure_vars=self.closure_vars, + grid_type=self.grid_type, + args=[], + kwargs={}, + ) + ).program def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> None: - rewritten_args, size_args, kwargs = self._process_args(args, kwargs) - if self.backend is None: warnings.warn( UserWarning( @@ -251,8 +241,9 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No stacklevel=2, ) with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: + # TODO(ricoh): move into test if self.definition is None: - definition = past_to_fun_def( + definition = otf_transforms.past_to_fun_def( stages.PastClosure( closure_vars=self.closure_vars, past_node=self.past_node, @@ -267,6 +258,8 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No ) else: definition = self.definition + # TODO(ricoh): check if rewriting still needed + rewritten_args, size_args, kwargs = self._process_args(args, kwargs) ctx.run(definition, *rewritten_args, **kwargs) return @@ -277,7 +270,7 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No closure_vars=self.closure_vars, past_node=self.past_node, grid_type=self.grid_type, - args=[*rewritten_args, *size_args], + args=args, kwargs=kwargs | {"offset_provider": offset_provider, "column_axis": self._column_axis}, ) @@ -382,7 +375,7 @@ def _column_axis(self): class ProgramWithBoundArgs(Program): bound_args: dict[str, typing.Union[float, int, bool]] = None - def _process_args(self, args: tuple, kwargs: dict): + def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): type_ = self.past_node.type new_type = ts_ffront.ProgramType( definition=ts.FunctionType( @@ -433,7 +426,7 @@ def _process_args(self, args: tuple, kwargs: dict): else: full_kwargs[str(param.id)] = self.bound_args[param.id] - return super()._process_args(tuple(full_args), full_kwargs) + return super().__call__(*tuple(full_args), offset_provider=offset_provider, **full_kwargs) @functools.cached_property def itir(self): @@ -472,7 +465,8 @@ def program( def program( definition=None, *, - backend=eve.NOTHING, # `NOTHING` -> default backend, `None` -> no backend (embedded execution) + # `NOTHING` -> default backend, `None` -> no backend (embedded execution) + backend=eve.NOTHING, grid_type=None, ) -> Program | Callable[[types.FunctionType], Program]: """ @@ -618,7 +612,8 @@ def as_program( pass loc = self.foast_node.location - param_sym_uids = eve_utils.UIDGenerator() # use a new UID generator to allow caching + # use a new UID generator to allow caching + param_sym_uids = eve_utils.UIDGenerator() type_ = self.__gt_type__() params_decl: list[past.Symbol] = [ diff --git a/src/gt4py/next/otf/transforms/__init__.py b/src/gt4py/next/otf/transforms/__init__.py index 32817c283b..31086499a7 100644 --- a/src/gt4py/next/otf/transforms/__init__.py +++ b/src/gt4py/next/otf/transforms/__init__.py @@ -12,7 +12,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.otf.transforms.past_to_itir import PastToItir, PastToItirFactory +from gt4py.next.otf import stages, workflow +from .past_process_args import past_process_args +from .past_to_func import past_to_fun_def +from .past_to_itir import PastToItir, PastToItirFactory -__all__ = ["PastToItir", "PastToItirFactory"] + +__all__ = [ + "PastToItir", + "PastToItirFactory", + "past_to_fun_def", + "past_process_args", + "DEFAULT_TRANSFORMS", +] + + +DEFAULT_TRANSFORMS: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = ( + past_process_args.chain(PastToItirFactory()) +) diff --git a/src/gt4py/next/otf/transforms/past_process_args.py b/src/gt4py/next/otf/transforms/past_process_args.py new file mode 100644 index 0000000000..c242dc41ce --- /dev/null +++ b/src/gt4py/next/otf/transforms/past_process_args.py @@ -0,0 +1,109 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any, Iterator, Optional + +from gt4py.next import common, errors +from gt4py.next.ffront import program_ast as past, type_specifications as ts_ffront +from gt4py.next.otf import stages, workflow +from gt4py.next.type_system import type_info, type_specifications as ts, type_translation + + +@workflow.make_step +def past_process_args(inp: stages.PastClosure) -> stages.PastClosure: + extra_kwarg_names = ["offset_provider", "column_axis"] + extra_kwargs = {k: v for k, v in inp.kwargs.items() if k in extra_kwarg_names} + kwargs = {k: v for k, v in inp.kwargs.items() if k not in extra_kwarg_names} + rewritten_args, size_args, kwargs = _process_args( + past_node=inp.past_node, args=list(inp.args), kwargs=kwargs + ) + return stages.PastClosure( + past_node=inp.past_node, + closure_vars=inp.closure_vars, + grid_type=inp.grid_type, + args=tuple(*rewritten_args, *size_args), + kwargs=kwargs | extra_kwargs, + ) + + +def _validate_args(past_node: past.Program, args: list, kwargs: dict[str, Any]) -> None: + arg_types = [type_translation.from_value(arg) for arg in args] + kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()} + + if not isinstance(past_node.type, ts_ffront.ProgramType): + raise TypeError("Can not validate arguments for PAST programs prior to type inference.") + + try: + type_info.accepts_args( + past_node.type, + with_args=arg_types, + with_kwargs=kwarg_types, + raise_exception=True, + ) + except ValueError as err: + raise errors.DSLError( + None, f"Invalid argument types in call to '{past_node.id}'.\n{err}" + ) from err + + +def _process_args( + past_node: past.Program, args: list, kwargs: dict[str, Any] +) -> tuple[tuple, tuple, dict[str, Any]]: + if not isinstance(past_node.type, ts_ffront.ProgramType): + raise TypeError("Can not process arguments for PAST programs prior to type inference.") + + _validate_args(past_node=past_node, args=args, kwargs=kwargs) + + args, kwargs = type_info.canonicalize_arguments(past_node.type, args, kwargs) + + implicit_domain = any( + isinstance(stmt, past.Call) and "domain" not in stmt.kwargs for stmt in past_node.body + ) + + # extract size of all field arguments + size_args: list[Optional[int]] = [] + rewritten_args = list(args) + for param_idx, param in enumerate(past_node.params): + if implicit_domain and isinstance(param.type, (ts.FieldType, ts.TupleType)): + shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)] + shape, dims = shapes_and_dims[0] + if not all( + el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims + ): + raise ValueError( + "Constituents of composite arguments (e.g. the elements of a" + " tuple) need to have the same shape and dimensions." + ) + size_args.extend(shape if shape else [None] * len(dims)) + return tuple(rewritten_args), tuple(size_args), kwargs + + +def _field_constituents_shape_and_dims( + arg, arg_type: ts.DataType +) -> Iterator[tuple[tuple[int, ...], list[common.Dimension]]]: + match arg_type: + case ts.TupleType(): + for el, el_type in zip(arg, arg_type.types): + yield from _field_constituents_shape_and_dims(el, el_type) + case ts.FieldType(): + dims = type_info.extract_dims(arg_type) + if hasattr(arg, "shape"): + assert len(arg.shape) == len(dims) + yield (arg.shape, dims) + else: + yield (tuple(), dims) + case ts.ScalarType(): + yield (tuple(), []) + case _: + raise ValueError("Expected 'FieldType' or 'TupleType' thereof.") diff --git a/src/gt4py/next/otf/transforms/past_to_itir.py b/src/gt4py/next/otf/transforms/past_to_itir.py index 7e27bd9f3a..a0a216baa8 100644 --- a/src/gt4py/next/otf/transforms/past_to_itir.py +++ b/src/gt4py/next/otf/transforms/past_to_itir.py @@ -18,35 +18,34 @@ import factory from gt4py.next import common, config -from gt4py.next.ffront.fbuiltins import FieldOffset -from gt4py.next.ffront.gtcallable import GTCallable -from gt4py.next.ffront.past_to_itir import ProgramLowering +from gt4py.next.ffront import fbuiltins, gtcallable, past_to_itir from gt4py.next.otf import stages, workflow -from gt4py.next.otf.stages import ProgramCall from . import utils @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): - def __call__(self, inp: stages.PastClosure) -> ProgramCall: + def __call__(self, inp: stages.PastClosure) -> stages.ProgramCall: all_closure_vars = utils._get_closure_vars_recursively(inp.closure_vars) offsets_and_dimensions = utils._filter_closure_vars_by_type( - all_closure_vars, FieldOffset, common.Dimension + all_closure_vars, fbuiltins.FieldOffset, common.Dimension ) grid_type = utils._deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) - gt_callables = utils._filter_closure_vars_by_type(all_closure_vars, GTCallable).values() + gt_callables = utils._filter_closure_vars_by_type( + all_closure_vars, gtcallable.GTCallable + ).values() lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] - itir_program = ProgramLowering.apply( + itir_program = past_to_itir.ProgramLowering.apply( inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) if config.DEBUG or "debug" in inp.kwargs: devtools.debug(itir_program) - return ProgramCall( + return stages.ProgramCall( itir_program, inp.args, inp.kwargs, diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 2e960d763f..7698629ccd 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -357,7 +357,8 @@ def build_sdfg_from_itir( def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): # build parameters build_cache = kwargs.get("build_cache", None) - compiler_args = kwargs.get("compiler_args", None) # `None` will take default. + # `None` will take default. + compiler_args = kwargs.get("compiler_args", None) build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", _default_on_gpu) auto_optimize = kwargs.get("auto_optimize", True) @@ -438,7 +439,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_cpu = backend.Backend( - transformer=otf_transforms.PastToItirFactory(), + transformer=otf_transforms.DEFAULT_TRANSFORMS, executor=ppi.program_executor(_run_dace_cpu, name="run_dace_cpu"), allocator=next_allocators.StandardCPUFieldBufferAllocator(), ) @@ -462,7 +463,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_gpu = backend.Backend( - transformer=otf_transforms.PastToItirFactory(), + transformer=otf_transforms.DEFAULT_TRANSFORMS, executor=ppi.program_executor(_run_dace_gpu, name="run_dace_gpu"), allocator=next_allocators.StandardGPUFieldBufferAllocator(), ) diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index a05243bb43..f8f7864a6e 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -20,7 +20,7 @@ backend = next_backend.Backend( - transformer=otf_transforms.PastToItirFactory(), + transformer=otf_transforms.DEFAULT_TRANSFORMS, executor=roundtrip.RoundtripExecutorFactory( dispatch_backend=roundtrip.RoundtripExecutorFactory(), ), diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 1b5c90a669..c242b30b4d 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -182,7 +182,7 @@ class Params: lambda o: f"run_gtfn_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}" ) - transformer = factory.SubFactory(otf_transforms.PastToItirFactory) + transformer = otf_transforms.DEFAULT_TRANSFORMS executor = factory.LazyAttribute( lambda o: modular_executor.ModularExecutor(otf_workflow=o.otf_workflow, name=o.name) ) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index ddef29c181..131ea52c54 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -280,7 +280,7 @@ class Params: executor = RoundtripExecutorFactory(name="roundtrip") backend = next_backend.Backend( - transformer=otf_transforms.PastToItirFactory(), + transformer=otf_transforms.DEFAULT_TRANSFORMS, executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator(), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index e5a821de52..b8d9841616 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -286,7 +286,7 @@ def test_call_bound_program_with_wrong_args(cartesian_case, bound_args_testee): out = cases.allocate(cartesian_case, bound_args_testee, "out")() with pytest.raises(TypeError) as exc_info: - program_with_bound_arg(out, offset_provider={}) + program_with_bound_arg.with_backend(cartesian_case.executor)(out, offset_provider={}) assert ( re.search( @@ -302,7 +302,9 @@ def test_call_bound_program_with_already_bound_arg(cartesian_case, bound_args_te out = cases.allocate(cartesian_case, bound_args_testee, "out")() with pytest.raises(TypeError) as exc_info: - program_with_bound_arg(True, out, arg2=True, offset_provider={}) + program_with_bound_arg.with_backend(cartesian_case.executor)( + True, out, arg2=True, offset_provider={} + ) assert ( re.search( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 2dd103a3fa..a16f6fb845 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -40,7 +40,7 @@ @pytest.fixture def run_gtfn_with_temporaries_and_symbolic_sizes(): return backend.Backend( - transformer=otf_transforms.PastToItirFactory(), + transformer=otf_transforms.DEFAULT_TRANSFORMS, executor=modular_executor.ModularExecutor( name="run_gtfn_with_temporaries_and_sizes", otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace( From 260c323b699aea4256631c9ae7ddb4967fa5b15d Mon Sep 17 00:00:00 2001 From: DropD Date: Thu, 7 Mar 2024 15:49:17 +0100 Subject: [PATCH 33/47] fix and remove _process_args etc from decorator --- src/gt4py/next/ffront/decorator.py | 51 +++---------------- src/gt4py/next/otf/transforms/__init__.py | 6 +-- .../next/otf/transforms/past_process_args.py | 4 +- 3 files changed, 11 insertions(+), 50 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 753c3b0d9a..806c8cc987 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -259,7 +259,9 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No else: definition = self.definition # TODO(ricoh): check if rewriting still needed - rewritten_args, size_args, kwargs = self._process_args(args, kwargs) + rewritten_args, size_args, kwargs = otf_transforms.past_process_args._process_args( + self.past_node, args, kwargs + ) ctx.run(definition, *rewritten_args, **kwargs) return @@ -284,7 +286,9 @@ def format_itir( **kwargs, ) -> str: ppi.ensure_processor_kind(formatter, ppi.ProgramFormatter) - rewritten_args, size_args, kwargs = self._process_args(args, kwargs) + rewritten_args, size_args, kwargs = otf_transforms.past_process_args._process_args( + self.past_node, args, kwargs + ) if "debug" in kwargs: debug(self.itir) return formatter( @@ -295,49 +299,6 @@ def format_itir( offset_provider=offset_provider, ) - def _validate_args(self, *args, **kwargs) -> None: - arg_types = [type_translation.from_value(arg) for arg in args] - kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()} - - try: - type_info.accepts_args( - self.past_node.type, - with_args=arg_types, - with_kwargs=kwarg_types, - raise_exception=True, - ) - except ValueError as err: - raise errors.DSLError( - None, f"Invalid argument types in call to '{self.past_node.id}'.\n{err}" - ) from err - - def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: - self._validate_args(*args, **kwargs) - - args, kwargs = type_info.canonicalize_arguments(self.past_node.type, args, kwargs) - - implicit_domain = any( - isinstance(stmt, past.Call) and "domain" not in stmt.kwargs - for stmt in self.past_node.body - ) - - # extract size of all field arguments - size_args: list[Optional[tuple[int, ...]]] = [] - rewritten_args = list(args) - for param_idx, param in enumerate(self.past_node.params): - if implicit_domain and isinstance(param.type, (ts.FieldType, ts.TupleType)): - shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)] - shape, dims = shapes_and_dims[0] - if not all( - el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims - ): - raise ValueError( - "Constituents of composite arguments (e.g. the elements of a" - " tuple) need to have the same shape and dimensions." - ) - size_args.extend(shape if shape else [None] * len(dims)) - return tuple(rewritten_args), tuple(size_args), kwargs - @functools.cached_property def _column_axis(self): # construct mapping from column axis to scan operators defined on diff --git a/src/gt4py/next/otf/transforms/__init__.py b/src/gt4py/next/otf/transforms/__init__.py index 31086499a7..7844c6f04d 100644 --- a/src/gt4py/next/otf/transforms/__init__.py +++ b/src/gt4py/next/otf/transforms/__init__.py @@ -14,7 +14,7 @@ from gt4py.next.otf import stages, workflow -from .past_process_args import past_process_args +from .past_process_args import past_process_args_wf from .past_to_func import past_to_fun_def from .past_to_itir import PastToItir, PastToItirFactory @@ -23,11 +23,11 @@ "PastToItir", "PastToItirFactory", "past_to_fun_def", - "past_process_args", + "past_process_args_wf", "DEFAULT_TRANSFORMS", ] DEFAULT_TRANSFORMS: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = ( - past_process_args.chain(PastToItirFactory()) + past_process_args_wf.chain(PastToItirFactory()) ) diff --git a/src/gt4py/next/otf/transforms/past_process_args.py b/src/gt4py/next/otf/transforms/past_process_args.py index c242dc41ce..a29730c6d0 100644 --- a/src/gt4py/next/otf/transforms/past_process_args.py +++ b/src/gt4py/next/otf/transforms/past_process_args.py @@ -21,7 +21,7 @@ @workflow.make_step -def past_process_args(inp: stages.PastClosure) -> stages.PastClosure: +def past_process_args_wf(inp: stages.PastClosure) -> stages.PastClosure: extra_kwarg_names = ["offset_provider", "column_axis"] extra_kwargs = {k: v for k, v in inp.kwargs.items() if k in extra_kwarg_names} kwargs = {k: v for k, v in inp.kwargs.items() if k not in extra_kwarg_names} @@ -32,7 +32,7 @@ def past_process_args(inp: stages.PastClosure) -> stages.PastClosure: past_node=inp.past_node, closure_vars=inp.closure_vars, grid_type=inp.grid_type, - args=tuple(*rewritten_args, *size_args), + args=tuple([*rewritten_args, *size_args]), kwargs=kwargs | extra_kwargs, ) From 2b1b368a1652dda5c931c5853497edd48ab456bd Mon Sep 17 00:00:00 2001 From: DropD Date: Thu, 7 Mar 2024 16:21:55 +0100 Subject: [PATCH 34/47] refactor otf.transforms -> ffront --- src/gt4py/next/backend.py | 12 +++++++---- src/gt4py/next/ffront/decorator.py | 20 +++++++++++-------- .../past_process_args_wf.py} | 2 +- .../past_to_func_wf.py} | 2 +- .../past_to_itir_wf.py} | 14 ++++++------- .../utils.py => ffront/transform_utils.py} | 0 src/gt4py/next/otf/transforms/__init__.py | 18 +++++------------ .../runners/dace_iterator/__init__.py | 9 +++------ .../runners/double_roundtrip.py | 2 -- .../next/program_processors/runners/gtfn.py | 3 +-- .../program_processors/runners/roundtrip.py | 3 +-- .../test_temporaries_with_sizes.py | 2 -- .../test_decorator_domain_deduction.py | 2 +- 13 files changed, 40 insertions(+), 49 deletions(-) rename src/gt4py/next/{otf/transforms/past_process_args.py => ffront/past_process_args_wf.py} (98%) rename src/gt4py/next/{otf/transforms/past_to_func.py => ffront/past_to_func_wf.py} (98%) rename src/gt4py/next/{otf/transforms/past_to_itir.py => ffront/past_to_itir_wf.py} (81%) rename src/gt4py/next/{otf/transforms/utils.py => ffront/transform_utils.py} (100%) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index ee5ab4ebcd..d9699cfc7a 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -19,17 +19,21 @@ from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators -from gt4py.next.otf import stages, transforms as otf_transforms, workflow +from gt4py.next.ffront import past_process_args_wf, past_to_itir_wf +from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import processor_interface as ppi +DEFAULT_TRANSFORMS: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = ( + past_process_args_wf.past_process_args.chain(past_to_itir_wf.PastToItirFactory()) +) + + @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): executor: ppi.ProgramExecutor allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] - transformer: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = ( - otf_transforms.DEFAULT_TRANSFORMS - ) + transformer: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = DEFAULT_TRANSFORMS def __call__(self, program: stages.PastClosure) -> None: program_call = self.transformer(program) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 806c8cc987..231418eadd 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -43,7 +43,11 @@ from gt4py.next.ffront import ( dialect_ast_enums, field_operator_ast as foast, + past_process_args_wf, + past_to_func_wf, + past_to_itir_wf, program_ast as past, + transform_utils, type_specifications as ts_ffront, ) from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction @@ -63,7 +67,7 @@ ref, sym, ) -from gt4py.next.otf import stages, transforms as otf_transforms +from gt4py.next.otf import stages from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -140,7 +144,7 @@ def from_function( ) def __post_init__(self): - function_closure_vars = otf_transforms.utils._filter_closure_vars_by_type( + function_closure_vars = transform_utils._filter_closure_vars_by_type( self.closure_vars, GTCallable ) misnamed_functions = [ @@ -218,11 +222,11 @@ def with_bound_args(self, **kwargs) -> ProgramWithBoundArgs: @functools.cached_property def _all_closure_vars(self) -> dict[str, Any]: - return otf_transforms.utils._get_closure_vars_recursively(self.closure_vars) + return transform_utils._get_closure_vars_recursively(self.closure_vars) @functools.cached_property def itir(self) -> itir.FencilDefinition: - return otf_transforms.PastToItirFactory()( + return past_to_itir_wf.PastToItirFactory()( stages.PastClosure( past_node=self.past_node, closure_vars=self.closure_vars, @@ -243,7 +247,7 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: # TODO(ricoh): move into test if self.definition is None: - definition = otf_transforms.past_to_fun_def( + definition = past_to_func_wf.past_to_func( stages.PastClosure( closure_vars=self.closure_vars, past_node=self.past_node, @@ -259,7 +263,7 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No else: definition = self.definition # TODO(ricoh): check if rewriting still needed - rewritten_args, size_args, kwargs = otf_transforms.past_process_args._process_args( + rewritten_args, size_args, kwargs = past_process_args_wf._process_args( self.past_node, args, kwargs ) ctx.run(definition, *rewritten_args, **kwargs) @@ -286,7 +290,7 @@ def format_itir( **kwargs, ) -> str: ppi.ensure_processor_kind(formatter, ppi.ProgramFormatter) - rewritten_args, size_args, kwargs = otf_transforms.past_process_args._process_args( + rewritten_args, size_args, kwargs = past_process_args_wf._process_args( self.past_node, args, kwargs ) if "debug" in kwargs: @@ -305,7 +309,7 @@ def _column_axis(self): # that dimension. only one column axis is allowed, but we can use # this mapping to provide good error messages. scanops_per_axis: dict[Dimension, str] = {} - for name, gt_callable in otf_transforms.utils._filter_closure_vars_by_type( + for name, gt_callable in transform_utils._filter_closure_vars_by_type( self._all_closure_vars, GTCallable ).items(): if isinstance( diff --git a/src/gt4py/next/otf/transforms/past_process_args.py b/src/gt4py/next/ffront/past_process_args_wf.py similarity index 98% rename from src/gt4py/next/otf/transforms/past_process_args.py rename to src/gt4py/next/ffront/past_process_args_wf.py index a29730c6d0..22661502e9 100644 --- a/src/gt4py/next/otf/transforms/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args_wf.py @@ -21,7 +21,7 @@ @workflow.make_step -def past_process_args_wf(inp: stages.PastClosure) -> stages.PastClosure: +def past_process_args(inp: stages.PastClosure) -> stages.PastClosure: extra_kwarg_names = ["offset_provider", "column_axis"] extra_kwargs = {k: v for k, v in inp.kwargs.items() if k in extra_kwarg_names} kwargs = {k: v for k, v in inp.kwargs.items() if k not in extra_kwarg_names} diff --git a/src/gt4py/next/otf/transforms/past_to_func.py b/src/gt4py/next/ffront/past_to_func_wf.py similarity index 98% rename from src/gt4py/next/otf/transforms/past_to_func.py rename to src/gt4py/next/ffront/past_to_func_wf.py index ddaf18addc..4aaa45ae44 100644 --- a/src/gt4py/next/otf/transforms/past_to_func.py +++ b/src/gt4py/next/ffront/past_to_func_wf.py @@ -22,7 +22,7 @@ from gt4py.next.type_system import type_info -def past_to_fun_def(past_closure: stages.PastClosure): +def past_to_func(past_closure: stages.PastClosure): node = past_closure.past_node arg_types = [type_translation.from_value(arg) for arg in past_closure.args] kwarg_types = [ diff --git a/src/gt4py/next/otf/transforms/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir_wf.py similarity index 81% rename from src/gt4py/next/otf/transforms/past_to_itir.py rename to src/gt4py/next/ffront/past_to_itir_wf.py index a0a216baa8..a2aed45d90 100644 --- a/src/gt4py/next/otf/transforms/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir_wf.py @@ -18,22 +18,22 @@ import factory from gt4py.next import common, config -from gt4py.next.ffront import fbuiltins, gtcallable, past_to_itir +from gt4py.next.ffront import fbuiltins, gtcallable, past_to_itir, transform_utils from gt4py.next.otf import stages, workflow -from . import utils - @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): def __call__(self, inp: stages.PastClosure) -> stages.ProgramCall: - all_closure_vars = utils._get_closure_vars_recursively(inp.closure_vars) - offsets_and_dimensions = utils._filter_closure_vars_by_type( + all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars) + offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( all_closure_vars, fbuiltins.FieldOffset, common.Dimension ) - grid_type = utils._deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) + grid_type = transform_utils._deduce_grid_type( + inp.grid_type, offsets_and_dimensions.values() + ) - gt_callables = utils._filter_closure_vars_by_type( + gt_callables = transform_utils._filter_closure_vars_by_type( all_closure_vars, gtcallable.GTCallable ).values() lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] diff --git a/src/gt4py/next/otf/transforms/utils.py b/src/gt4py/next/ffront/transform_utils.py similarity index 100% rename from src/gt4py/next/otf/transforms/utils.py rename to src/gt4py/next/ffront/transform_utils.py diff --git a/src/gt4py/next/otf/transforms/__init__.py b/src/gt4py/next/otf/transforms/__init__.py index 7844c6f04d..c3deb135fc 100644 --- a/src/gt4py/next/otf/transforms/__init__.py +++ b/src/gt4py/next/otf/transforms/__init__.py @@ -12,22 +12,14 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.otf import stages, workflow - -from .past_process_args import past_process_args_wf -from .past_to_func import past_to_fun_def -from .past_to_itir import PastToItir, PastToItirFactory +from gt4py.next.ffront.past_process_args_wf import past_process_args +from gt4py.next.ffront.past_to_func_wf import past_to_func +from gt4py.next.ffront.past_to_itir_wf import PastToItir, PastToItirFactory __all__ = [ "PastToItir", "PastToItirFactory", - "past_to_fun_def", - "past_process_args_wf", - "DEFAULT_TRANSFORMS", + "past_to_func", + "past_process_args", ] - - -DEFAULT_TRANSFORMS: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = ( - past_process_args_wf.chain(PastToItirFactory()) -) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 7698629ccd..0b6f18f5d3 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -26,9 +26,8 @@ import gt4py.next.allocators as next_allocators import gt4py.next.iterator.ir as itir import gt4py.next.program_processors.processor_interface as ppi -from gt4py.next import backend, common +from gt4py.next import backend as next_backend, common from gt4py.next.iterator import transforms as itir_transforms -from gt4py.next.otf import transforms as otf_transforms from gt4py.next.otf.compilation import cache as compilation_cache from gt4py.next.type_system import type_specifications as ts, type_translation @@ -438,8 +437,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: ) -run_dace_cpu = backend.Backend( - transformer=otf_transforms.DEFAULT_TRANSFORMS, +run_dace_cpu = next_backend.Backend( executor=ppi.program_executor(_run_dace_cpu, name="run_dace_cpu"), allocator=next_allocators.StandardCPUFieldBufferAllocator(), ) @@ -462,8 +460,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: raise RuntimeError("Missing 'cupy' dependency for GPU execution.") -run_dace_gpu = backend.Backend( - transformer=otf_transforms.DEFAULT_TRANSFORMS, +run_dace_gpu = next_backend.Backend( executor=ppi.program_executor(_run_dace_gpu, name="run_dace_gpu"), allocator=next_allocators.StandardGPUFieldBufferAllocator(), ) diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index f8f7864a6e..e37fb65891 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -15,12 +15,10 @@ from __future__ import annotations from gt4py.next import backend as next_backend -from gt4py.next.otf import transforms as otf_transforms from gt4py.next.program_processors.runners import roundtrip backend = next_backend.Backend( - transformer=otf_transforms.DEFAULT_TRANSFORMS, executor=roundtrip.RoundtripExecutorFactory( dispatch_backend=roundtrip.RoundtripExecutorFactory(), ), diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index c242b30b4d..7bd65d791e 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -25,7 +25,7 @@ from gt4py.next import backend, common, config from gt4py.next.iterator import transforms from gt4py.next.iterator.transforms import global_tmps -from gt4py.next.otf import recipes, stages, transforms as otf_transforms, workflow +from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb @@ -182,7 +182,6 @@ class Params: lambda o: f"run_gtfn_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}" ) - transformer = otf_transforms.DEFAULT_TRANSFORMS executor = factory.LazyAttribute( lambda o: modular_executor.ModularExecutor(otf_workflow=o.otf_workflow, name=o.name) ) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 131ea52c54..97debccc40 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -33,7 +33,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import backend as next_backend -from gt4py.next.otf import stages, transforms as otf_transforms, workflow +from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import modular_executor, processor_interface as ppi @@ -280,7 +280,6 @@ class Params: executor = RoundtripExecutorFactory(name="roundtrip") backend = next_backend.Backend( - transformer=otf_transforms.DEFAULT_TRANSFORMS, executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator(), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index a16f6fb845..191f8ee739 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -17,7 +17,6 @@ from gt4py import next as gtx from gt4py.next import backend, common -from gt4py.next.otf import transforms as otf_transforms from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.program_processors import modular_executor from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries @@ -40,7 +39,6 @@ @pytest.fixture def run_gtfn_with_temporaries_and_symbolic_sizes(): return backend.Backend( - transformer=otf_transforms.DEFAULT_TRANSFORMS, executor=modular_executor.ModularExecutor( name="run_gtfn_with_temporaries_and_sizes", otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace( diff --git a/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py index 5620efc0c7..a586f09038 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py @@ -15,7 +15,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.otf.transforms.utils import _deduce_grid_type +from gt4py.next.ffront.transform_utils import _deduce_grid_type Dim = gtx.Dimension("Dim") From 8f08fdb64166c22892d00d3fa3af4451abcc0514 Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 8 Mar 2024 10:00:27 +0100 Subject: [PATCH 35/47] remove empty `otf.transforms` --- src/gt4py/next/otf/transforms/__init__.py | 25 ----------------------- 1 file changed, 25 deletions(-) delete mode 100644 src/gt4py/next/otf/transforms/__init__.py diff --git a/src/gt4py/next/otf/transforms/__init__.py b/src/gt4py/next/otf/transforms/__init__.py deleted file mode 100644 index c3deb135fc..0000000000 --- a/src/gt4py/next/otf/transforms/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from gt4py.next.ffront.past_process_args_wf import past_process_args -from gt4py.next.ffront.past_to_func_wf import past_to_func -from gt4py.next.ffront.past_to_itir_wf import PastToItir, PastToItirFactory - - -__all__ = [ - "PastToItir", - "PastToItirFactory", - "past_to_func", - "past_process_args", -] From 9e70d69d916b4d77b0620b2bc8c564fde23f82cb Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 8 Mar 2024 10:58:42 +0100 Subject: [PATCH 36/47] minor roundtrip workflow refactor --- .../program_processors/runners/roundtrip.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 97debccc40..eb6c4e9d9e 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -24,15 +24,11 @@ import factory -import gt4py.next.allocators as next_allocators -import gt4py.next.common as common -import gt4py.next.iterator.embedded as embedded -import gt4py.next.iterator.ir as itir -import gt4py.next.iterator.transforms as itir_transforms -import gt4py.next.iterator.transforms.global_tmps as gtmps_transform from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako -from gt4py.next import backend as next_backend +from gt4py.next import allocators as next_allocators, backend as next_backend, common +from gt4py.next.iterator import embedded, ir as itir, transforms as itir_transforms +from gt4py.next.iterator.transforms import global_tmps as gtmps_transform from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import modular_executor, processor_interface as ppi @@ -236,7 +232,7 @@ def execute_roundtrip( class Roundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): debug: bool = False lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE - dispatch_backend: Optional[ppi.ProgramExecutor] = None + use_embedded: bool = True def __call__(self, inp: stages.ProgramCall) -> stages.CompiledProgram: return fencil_generator( @@ -244,7 +240,7 @@ def __call__(self, inp: stages.ProgramCall) -> stages.CompiledProgram: offset_provider=inp.kwargs.get("offset_provider", None), debug=self.debug, lift_mode=self.lift_mode, - use_embedded=self.dispatch_backend is None, + use_embedded=self.use_embedded, ) @@ -269,8 +265,9 @@ class Meta: model = RoundtripExecutor class Params: + use_embedded = factory.LazyAttribute(lambda o: o.dispatch_backend is None) roundtrip_workflow = factory.SubFactory( - RoundtripFactory, dispatch_backend=factory.SelfAttribute("..dispatch_backend") + RoundtripFactory, use_embedded=factory.SelfAttribute("..use_embedded") ) dispatch_backend = None From 33c3ebc3f8151dd5b500c1165c77e153f251180a Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 8 Mar 2024 11:30:18 +0100 Subject: [PATCH 37/47] remove debug lines, use flatten instead of comprehension --- src/gt4py/next/ffront/past_to_func_wf.py | 6 +++--- src/gt4py/next/type_system/type_info.py | 2 +- tests/next_tests/unit_tests/conftest.py | 3 --- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_func_wf.py b/src/gt4py/next/ffront/past_to_func_wf.py index 4aaa45ae44..16362bc0b3 100644 --- a/src/gt4py/next/ffront/past_to_func_wf.py +++ b/src/gt4py/next/ffront/past_to_func_wf.py @@ -15,8 +15,8 @@ import linecache import textwrap -import gt4py.next as gtx -from gt4py.eve import codegen +from gt4py import next as gtx +from gt4py.eve import codegen, utils as eve_utils from gt4py.next.ffront import program_ast as past, type_translation from gt4py.next.otf import stages from gt4py.next.type_system import type_info @@ -32,7 +32,7 @@ def past_to_func(past_closure: stages.PastClosure): ] inout_types = list(type_info.flatten(arg_types + kwarg_types)) dims = set( - i for j in [type_info.extract_dims(inout_type) for inout_type in inout_types] for i in j + eve_utils.flatten_iter(type_info.extract_dims(inout_type) for inout_type in inout_types) ) source_code = ProgamFuncGen.apply(node) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 3587f66ce5..68eb0df7af 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -667,7 +667,7 @@ def function_signature_incompatibilities_func( and not is_concretizable(a_arg, to_type=b_arg) ): if i < len(func_type.pos_only_args): - arg_repr = f"{_number_to_ordinal_number(i+1)} argument" + arg_repr = f"{_number_to_ordinal_number(i + 1)} argument" else: arg_repr = f"argument '{list(func_type.pos_or_kw_args.keys())[i - len(func_type.pos_only_args)]}'" yield f"Expected {arg_repr} to be of type '{a_arg}', got '{b_arg}'." diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 6b1fe3d37e..c9406884e6 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -106,9 +106,6 @@ def run_processor( *args, **kwargs, ) -> None: - import devtools - - devtools.debug(processor) if processor is None or ppi.is_processor_kind(processor, ppi.ProgramExecutor): program(*args, backend=processor, **kwargs) elif ppi.is_processor_kind(processor, ppi.ProgramFormatter): From 30cff5990872c8718a5f0524c4bd95ae05a92e1b Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Fri, 8 Mar 2024 11:43:44 +0100 Subject: [PATCH 38/47] test --- .../feature_tests/ffront_tests/test_execution.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 56c7f04182..0206ff3cdf 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -75,6 +75,18 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> tuple[cases.IJKField, cases. cases.verify_with_default_data(cartesian_case, testee, ref=lambda a, b: (a, b)) +def test_as_program(cartesian_case): + from gt4py.next.type_system import type_specifications as ts + @gtx.field_operator(backend=None) + def testee(a: cases.IJField) -> cases.IJField: + return a + + a = cases.allocate(cartesian_case, testee, "a")() + + t_pr = testee.as_program(arg_types=[ts.FieldType(dims=[IDim, JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))], + kwarg_types={}) + t_pr(a, out=a, offset_provider={}) + @pytest.mark.uses_cartesian_shift def test_cartesian_shift(cartesian_case): From 6ed1dea2acfd4a782dffe127905f5218f1c59da7 Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 8 Mar 2024 14:43:21 +0100 Subject: [PATCH 39/47] run formatting --- .../feature_tests/ffront_tests/test_execution.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 0206ff3cdf..2858f06834 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -75,16 +75,20 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> tuple[cases.IJKField, cases. cases.verify_with_default_data(cartesian_case, testee, ref=lambda a, b: (a, b)) + def test_as_program(cartesian_case): from gt4py.next.type_system import type_specifications as ts + @gtx.field_operator(backend=None) def testee(a: cases.IJField) -> cases.IJField: return a a = cases.allocate(cartesian_case, testee, "a")() - t_pr = testee.as_program(arg_types=[ts.FieldType(dims=[IDim, JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))], - kwarg_types={}) + t_pr = testee.as_program( + arg_types=[ts.FieldType(dims=[IDim, JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))], + kwarg_types={}, + ) t_pr(a, out=a, offset_provider={}) From f0659a90015b394cdcc2a8d426fe04eef5a71c85 Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 8 Mar 2024 16:27:10 +0100 Subject: [PATCH 40/47] refactor type_info.flatten -> eve.utils.flatten_iter --- src/gt4py/next/ffront/past_to_func_wf.py | 7 ++- src/gt4py/next/type_system/type_info.py | 15 ------ .../next/type_system/type_specifications.py | 5 +- .../ffront_tests/test_execution.py | 46 +++++++++++++++---- 4 files changed, 45 insertions(+), 28 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_func_wf.py b/src/gt4py/next/ffront/past_to_func_wf.py index 16362bc0b3..9bd39cc0c8 100644 --- a/src/gt4py/next/ffront/past_to_func_wf.py +++ b/src/gt4py/next/ffront/past_to_func_wf.py @@ -30,7 +30,7 @@ def past_to_func(past_closure: stages.PastClosure): for k, v in past_closure.kwargs.items() if k not in ("offset_provider", "column_axis") ] - inout_types = list(type_info.flatten(arg_types + kwarg_types)) + inout_types = eve_utils.flatten_iter(arg_types + kwarg_types) dims = set( eve_utils.flatten_iter(type_info.extract_dims(inout_type) for inout_type in inout_types) ) @@ -38,7 +38,7 @@ def past_to_func(past_closure: stages.PastClosure): filename = "" globalns = {dim.value: dim for dim in dims} - globalns |= gtx.__dict__ + globalns |= {k: v for k, v in gtx.__dict__.items() if not k.startswith("__")} globalns |= past_closure.closure_vars localns: dict = {} code_obj = compile(source_code, filename, "exec") @@ -57,11 +57,10 @@ def past_to_func(past_closure: stages.PastClosure): class ProgamFuncGen(codegen.TemplatedGenerator): def visit_Program(self, node: past.Program, **kwargs) -> str: - imports = "from __future__ import annotations\nfrom gt4py.next import *" params = self.visit(node.params) signature = ", ".join(params) body = textwrap.indent("\n".join(self.visit(node.body)), prefix=" " * 4) - return f"{imports}\n\n\ndef {node.id}({signature}) -> None:\n{body}" + return f"def {node.id}({signature}) -> None:\n{body}" Symbol = codegen.FormatTemplate("{id}: {type}") diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 68eb0df7af..88c1d31e9d 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -754,18 +754,3 @@ def accepts_args( return True return next(errors, None) is None - - -def flatten(arg: list | tuple | ts.TupleType): - if ( - not isinstance(arg, ts.TupleType) - and not isinstance(arg, list) - and not isinstance(arg, tuple) - ): - yield arg - elif isinstance(arg, list) or isinstance(arg, tuple): - for sub in arg: - yield from flatten(sub) - elif isinstance(arg, ts.TupleType): - for sub in arg.types: - yield from flatten(sub) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 3456a3cede..f178a5752f 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later from dataclasses import dataclass -from typing import Optional +from typing import Iterator, Optional from gt4py.eve.type_definitions import IntEnum from gt4py.next import common as func_common @@ -99,6 +99,9 @@ class TupleType(DataType): def __str__(self): return f"tuple[{', '.join(map(str, self.types))}]" + def __iter__(self) -> Iterator[DataType]: + yield from self.types + @dataclass(frozen=True) class FieldType(DataType, CallableType): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 2858f06834..0820bfd70c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -22,7 +22,9 @@ astype, broadcast, common, + constructors, errors, + field_utils, float32, float64, int32, @@ -33,6 +35,7 @@ ) from gt4py.next.ffront.experimental import as_offset from gt4py.next.program_processors.runners import gtfn +from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -47,6 +50,7 @@ Koff, V2EDim, Vertex, + Edge, cartesian_case, unstructured_case, ) @@ -76,20 +80,22 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> tuple[cases.IJKField, cases. cases.verify_with_default_data(cartesian_case, testee, ref=lambda a, b: (a, b)) -def test_as_program(cartesian_case): - from gt4py.next.type_system import type_specifications as ts - - @gtx.field_operator(backend=None) - def testee(a: cases.IJField) -> cases.IJField: - return a +@pytest.mark.uses_tuple_returns +def test_as_program_tuple_return(cartesian_case): + @gtx.field_operator(backend=cartesian_case.executor) + def testee(a: cases.IJField) -> tuple[cases.IJField, cases.IJField]: + return a + 1, a a = cases.allocate(cartesian_case, testee, "a")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() - t_pr = testee.as_program( + testee_program = testee.as_program( arg_types=[ts.FieldType(dims=[IDim, JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))], kwarg_types={}, ) - t_pr(a, out=a, offset_provider={}) + testee_program(a, out=out, offset_provider={}) + assert np.all(a + 1 == out[0]) + assert np.all(a == out[1]) @pytest.mark.uses_cartesian_shift @@ -117,6 +123,30 @@ def testee(a: cases.VField) -> cases.EField: ) +@pytest.mark.uses_unstructured_shift +def test_as_program_with_unstructured_shift_no_return_hint(unstructured_case): + @gtx.field_operator(backend=unstructured_case.executor) + def testee(a: cases.VField): + return a(E2V[0]) + + testee_program = testee.as_program( + arg_types=[ts.FieldType(dims=[Vertex], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))], + kwarg_types={}, + ) + a = cases.allocate(unstructured_case, testee, "a")() + out = constructors.zeros( + domain=common.domain({Edge: unstructured_case.default_sizes[Edge]}), + allocator=unstructured_case.allocator, + dtype=int32, + ) + + testee_program(a, out=out, offset_provider=unstructured_case.offset_provider) + + ref = field_utils.asnumpy(a)[unstructured_case.offset_provider["E2V"].table[:, 0]] + + assert np.all(ref == field_utils.asnumpy(out)) + + @pytest.mark.uses_unstructured_shift def test_composed_unstructured_shift(unstructured_case): @gtx.field_operator From c907da0fbe6a4a09072891c4f747c373a11c614d Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 8 Mar 2024 16:32:30 +0100 Subject: [PATCH 41/47] split PAST -> function into separate branch --- src/gt4py/next/ffront/decorator.py | 20 +---- src/gt4py/next/ffront/past_to_func_wf.py | 74 ------------------- .../ffront_tests/test_execution.py | 42 ----------- 3 files changed, 1 insertion(+), 135 deletions(-) delete mode 100644 src/gt4py/next/ffront/past_to_func_wf.py diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 231418eadd..6235abc52b 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -44,7 +44,6 @@ dialect_ast_enums, field_operator_ast as foast, past_process_args_wf, - past_to_func_wf, past_to_itir_wf, program_ast as past, transform_utils, @@ -245,28 +244,11 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No stacklevel=2, ) with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: - # TODO(ricoh): move into test - if self.definition is None: - definition = past_to_func_wf.past_to_func( - stages.PastClosure( - closure_vars=self.closure_vars, - past_node=self.past_node, - grid_type=self.grid_type, - args=args, - kwargs=kwargs - | { - "offset_provider": offset_provider, - "column_axis": self._column_axis, - }, - ) - ) - else: - definition = self.definition # TODO(ricoh): check if rewriting still needed rewritten_args, size_args, kwargs = past_process_args_wf._process_args( self.past_node, args, kwargs ) - ctx.run(definition, *rewritten_args, **kwargs) + ctx.run(self.definition, *rewritten_args, **kwargs) return ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor) diff --git a/src/gt4py/next/ffront/past_to_func_wf.py b/src/gt4py/next/ffront/past_to_func_wf.py deleted file mode 100644 index 9bd39cc0c8..0000000000 --- a/src/gt4py/next/ffront/past_to_func_wf.py +++ /dev/null @@ -1,74 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import linecache -import textwrap - -from gt4py import next as gtx -from gt4py.eve import codegen, utils as eve_utils -from gt4py.next.ffront import program_ast as past, type_translation -from gt4py.next.otf import stages -from gt4py.next.type_system import type_info - - -def past_to_func(past_closure: stages.PastClosure): - node = past_closure.past_node - arg_types = [type_translation.from_value(arg) for arg in past_closure.args] - kwarg_types = [ - type_translation.from_value(v) - for k, v in past_closure.kwargs.items() - if k not in ("offset_provider", "column_axis") - ] - inout_types = eve_utils.flatten_iter(arg_types + kwarg_types) - dims = set( - eve_utils.flatten_iter(type_info.extract_dims(inout_type) for inout_type in inout_types) - ) - source_code = ProgamFuncGen.apply(node) - - filename = "" - globalns = {dim.value: dim for dim in dims} - globalns |= {k: v for k, v in gtx.__dict__.items() if not k.startswith("__")} - globalns |= past_closure.closure_vars - localns: dict = {} - code_obj = compile(source_code, filename, "exec") - exec(code_obj, globalns, localns) - lines = [line + "\n" for line in source_code.splitlines()] - linecache.cache[filename] = (len(source_code), None, lines, filename) - function_definition = localns[str(node.id)] - linecache.cache[filename] = ( - len(source_code), - None, - [line + "\n" for line in source_code.splitlines()], - filename, - ) - return function_definition - - -class ProgamFuncGen(codegen.TemplatedGenerator): - def visit_Program(self, node: past.Program, **kwargs) -> str: - params = self.visit(node.params) - signature = ", ".join(params) - body = textwrap.indent("\n".join(self.visit(node.body)), prefix=" " * 4) - return f"def {node.id}({signature}) -> None:\n{body}" - - Symbol = codegen.FormatTemplate("{id}: {type}") - - def visit_Call(self, node: past.Call, **kwargs) -> str: - args_joined = ", ".join(self.visit(node.args)) - kwargs_list = [f"{name}={self.visit(value)}" for name, value in node.kwargs.items()] - kwargs_joined = ", ".join(kwargs_list) - params = ", ".join([args_joined, kwargs_joined]) - return f"{self.visit(node.func)}({params})" - - Name = codegen.FormatTemplate("{id}") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 0820bfd70c..d571c61590 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -80,24 +80,6 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> tuple[cases.IJKField, cases. cases.verify_with_default_data(cartesian_case, testee, ref=lambda a, b: (a, b)) -@pytest.mark.uses_tuple_returns -def test_as_program_tuple_return(cartesian_case): - @gtx.field_operator(backend=cartesian_case.executor) - def testee(a: cases.IJField) -> tuple[cases.IJField, cases.IJField]: - return a + 1, a - - a = cases.allocate(cartesian_case, testee, "a")() - out = cases.allocate(cartesian_case, testee, cases.RETURN)() - - testee_program = testee.as_program( - arg_types=[ts.FieldType(dims=[IDim, JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))], - kwarg_types={}, - ) - testee_program(a, out=out, offset_provider={}) - assert np.all(a + 1 == out[0]) - assert np.all(a == out[1]) - - @pytest.mark.uses_cartesian_shift def test_cartesian_shift(cartesian_case): @gtx.field_operator @@ -123,30 +105,6 @@ def testee(a: cases.VField) -> cases.EField: ) -@pytest.mark.uses_unstructured_shift -def test_as_program_with_unstructured_shift_no_return_hint(unstructured_case): - @gtx.field_operator(backend=unstructured_case.executor) - def testee(a: cases.VField): - return a(E2V[0]) - - testee_program = testee.as_program( - arg_types=[ts.FieldType(dims=[Vertex], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))], - kwarg_types={}, - ) - a = cases.allocate(unstructured_case, testee, "a")() - out = constructors.zeros( - domain=common.domain({Edge: unstructured_case.default_sizes[Edge]}), - allocator=unstructured_case.allocator, - dtype=int32, - ) - - testee_program(a, out=out, offset_provider=unstructured_case.offset_provider) - - ref = field_utils.asnumpy(a)[unstructured_case.offset_provider["E2V"].table[:, 0]] - - assert np.all(ref == field_utils.asnumpy(out)) - - @pytest.mark.uses_unstructured_shift def test_composed_unstructured_shift(unstructured_case): @gtx.field_operator From 5f4a4d4405dd2576c7736f4d8cc3669340dc3669 Mon Sep 17 00:00:00 2001 From: DropD Date: Thu, 7 Mar 2024 17:06:40 +0100 Subject: [PATCH 42/47] remove Program.format_itir --- src/gt4py/next/ffront/decorator.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 6235abc52b..01cb890b8b 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -26,8 +26,6 @@ from collections.abc import Callable from typing import Generator, Generic, TypeVar -from devtools import debug - from gt4py import eve from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils @@ -95,10 +93,6 @@ def _field_constituents_shape_and_dims( # TODO(tehrengruber): Decide if and how programs can call other programs. As a # result Program could become a GTCallable. -# TODO(ricoh): factor out the generated ITIR together with arguments rewriting -# so that using fencil processors on `some_program.itir` becomes trivial without -# prior knowledge of the fencil signature rewriting done by `Program`. -# After that, drop the `.format_itir()` method, since it won't be needed. @dataclasses.dataclass(frozen=True) class Program: """ @@ -264,27 +258,6 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No ) ) - def format_itir( - self, - *args, - formatter: ppi.ProgramFormatter, - offset_provider: dict[str, Dimension], - **kwargs, - ) -> str: - ppi.ensure_processor_kind(formatter, ppi.ProgramFormatter) - rewritten_args, size_args, kwargs = past_process_args_wf._process_args( - self.past_node, args, kwargs - ) - if "debug" in kwargs: - debug(self.itir) - return formatter( - self.itir, - *rewritten_args, - *size_args, - **kwargs, - offset_provider=offset_provider, - ) - @functools.cached_property def _column_axis(self): # construct mapping from column axis to scan operators defined on From ef077d6a212e7bb693189cc2be8a02c077edf04a Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 8 Mar 2024 09:50:16 +0100 Subject: [PATCH 43/47] [wip] refactor column_axis out of decorator --- src/gt4py/next/ffront/decorator.py | 35 +------------------ src/gt4py/next/ffront/past_to_itir_wf.py | 43 ++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 01cb890b8b..4eeb71e5dd 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -253,43 +253,10 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No past_node=self.past_node, grid_type=self.grid_type, args=args, - kwargs=kwargs - | {"offset_provider": offset_provider, "column_axis": self._column_axis}, + kwargs=kwargs | {"offset_provider": offset_provider}, ) ) - @functools.cached_property - def _column_axis(self): - # construct mapping from column axis to scan operators defined on - # that dimension. only one column axis is allowed, but we can use - # this mapping to provide good error messages. - scanops_per_axis: dict[Dimension, str] = {} - for name, gt_callable in transform_utils._filter_closure_vars_by_type( - self._all_closure_vars, GTCallable - ).items(): - if isinstance( - (type_ := gt_callable.__gt_type__()), - ts_ffront.ScanOperatorType, - ): - scanops_per_axis.setdefault(type_.axis, []).append(name) - - if len(scanops_per_axis.values()) == 0: - return None - - if len(scanops_per_axis.values()) != 1: - scanops_per_axis_strs = [ - f"- {dim.value}: {', '.join(scanops)}" for dim, scanops in scanops_per_axis.items() - ] - - raise TypeError( - "Only 'ScanOperator's defined on the same axis " - + "can be used in a 'Program', found:\n" - + "\n".join(scanops_per_axis_strs) - + "." - ) - - return iter(scanops_per_axis.keys()).__next__() - @dataclasses.dataclass(frozen=True) class ProgramWithBoundArgs(Program): diff --git a/src/gt4py/next/ffront/past_to_itir_wf.py b/src/gt4py/next/ffront/past_to_itir_wf.py index a2aed45d90..7fbe852d93 100644 --- a/src/gt4py/next/ffront/past_to_itir_wf.py +++ b/src/gt4py/next/ffront/past_to_itir_wf.py @@ -13,12 +13,19 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses +from typing import Any import devtools import factory from gt4py.next import common, config -from gt4py.next.ffront import fbuiltins, gtcallable, past_to_itir, transform_utils +from gt4py.next.ffront import ( + fbuiltins, + gtcallable, + past_to_itir, + transform_utils, + type_specifications as ts_ffront, +) from gt4py.next.otf import stages, workflow @@ -48,10 +55,42 @@ def __call__(self, inp: stages.PastClosure) -> stages.ProgramCall: return stages.ProgramCall( itir_program, inp.args, - inp.kwargs, + inp.kwargs | {"column_axis": _column_axis(all_closure_vars)}, ) class PastToItirFactory(factory.Factory): class Meta: model = PastToItir + + +def _column_axis(all_closure_vars: dict[str, Any]) -> common.Dimension: + # construct mapping from column axis to scan operators defined on + # that dimension. only one column axis is allowed, but we can use + # this mapping to provide good error messages. + scanops_per_axis: dict[common.Dimension, str] = {} + for name, gt_callable in transform_utils._filter_closure_vars_by_type( + all_closure_vars, gtcallable.GTCallable + ).items(): + if isinstance( + (type_ := gt_callable.__gt_type__()), + ts_ffront.ScanOperatorType, + ): + scanops_per_axis.setdefault(type_.axis, []).append(name) + + if len(scanops_per_axis.values()) == 0: + return None + + if len(scanops_per_axis.values()) != 1: + scanops_per_axis_strs = [ + f"- {dim.value}: {', '.join(scanops)}" for dim, scanops in scanops_per_axis.items() + ] + + raise TypeError( + "Only 'ScanOperator's defined on the same axis " + + "can be used in a 'Program', found:\n" + + "\n".join(scanops_per_axis_strs) + + "." + ) + + return iter(scanops_per_axis.keys()).__next__() From 1b72c843294c73cc42603d914453d3d268487ad9 Mon Sep 17 00:00:00 2001 From: DropD Date: Wed, 13 Mar 2024 11:12:25 +0100 Subject: [PATCH 44/47] refactor: past_to_itir_wf -> past_to_itir --- src/gt4py/next/backend.py | 4 +- src/gt4py/next/ffront/decorator.py | 4 +- src/gt4py/next/ffront/past_to_itir.py | 111 +++++++++++++++++++---- src/gt4py/next/ffront/past_to_itir_wf.py | 96 -------------------- 4 files changed, 99 insertions(+), 116 deletions(-) delete mode 100644 src/gt4py/next/ffront/past_to_itir_wf.py diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index d9699cfc7a..3b5db57647 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -19,13 +19,13 @@ from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators -from gt4py.next.ffront import past_process_args_wf, past_to_itir_wf +from gt4py.next.ffront import past_process_args_wf, past_to_itir from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import processor_interface as ppi DEFAULT_TRANSFORMS: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = ( - past_process_args_wf.past_process_args.chain(past_to_itir_wf.PastToItirFactory()) + past_process_args_wf.past_process_args.chain(past_to_itir.PastToItirFactory()) ) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 4eeb71e5dd..cb68223d08 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -42,7 +42,7 @@ dialect_ast_enums, field_operator_ast as foast, past_process_args_wf, - past_to_itir_wf, + past_to_itir, program_ast as past, transform_utils, type_specifications as ts_ffront, @@ -219,7 +219,7 @@ def _all_closure_vars(self) -> dict[str, Any]: @functools.cached_property def itir(self) -> itir.FencilDefinition: - return past_to_itir_wf.PastToItirFactory()( + return past_to_itir.PastToItirFactory()( stages.PastClosure( past_node=self.past_node, closure_vars=self.closure_vars, diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 99534b4e61..45953cdd50 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -14,16 +14,95 @@ from __future__ import annotations -from typing import Optional, cast +import dataclasses +from typing import Any, Optional, cast + +import devtools +import factory from gt4py.eve import NodeTranslator, concepts, traits -from gt4py.next.common import Dimension, DimensionKind, GridType -from gt4py.next.ffront import lowering_utils, program_ast as past, type_specifications as ts_ffront +from gt4py.next import common, config +from gt4py.next.ffront import ( + fbuiltins, + gtcallable, + lowering_utils, + program_ast as past, + transform_utils, + type_specifications as ts_ffront, +) from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.otf import stages, workflow from gt4py.next.type_system import type_info, type_specifications as ts +@dataclasses.dataclass(frozen=True) +class PastToItir(workflow.ChainableWorkflowMixin): + def __call__(self, inp: stages.PastClosure) -> stages.ProgramCall: + all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars) + offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( + all_closure_vars, fbuiltins.FieldOffset, common.Dimension + ) + grid_type = transform_utils._deduce_grid_type( + inp.grid_type, offsets_and_dimensions.values() + ) + + gt_callables = transform_utils._filter_closure_vars_by_type( + all_closure_vars, gtcallable.GTCallable + ).values() + lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + + itir_program = ProgramLowering.apply( + inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type + ) + + if config.DEBUG or "debug" in inp.kwargs: + devtools.debug(itir_program) + + return stages.ProgramCall( + itir_program, + inp.args, + inp.kwargs | {"column_axis": _column_axis(all_closure_vars)}, + ) + + +class PastToItirFactory(factory.Factory): + class Meta: + model = PastToItir + + +def _column_axis(all_closure_vars: dict[str, Any]) -> common.Dimension: + # construct mapping from column axis to scan operators defined on + # that dimension. only one column axis is allowed, but we can use + # this mapping to provide good error messages. + scanops_per_axis: dict[common.Dimension, str] = {} + for name, gt_callable in transform_utils._filter_closure_vars_by_type( + all_closure_vars, gtcallable.GTCallable + ).items(): + if isinstance( + (type_ := gt_callable.__gt_type__()), + ts_ffront.ScanOperatorType, + ): + scanops_per_axis.setdefault(type_.axis, []).append(name) + + if len(scanops_per_axis.values()) == 0: + return None + + if len(scanops_per_axis.values()) != 1: + scanops_per_axis_strs = [ + f"- {dim.value}: {', '.join(scanops)}" for dim, scanops in scanops_per_axis.items() + ] + + raise TypeError( + "Only 'ScanOperator's defined on the same axis " + + "can be used in a 'Program', found:\n" + + "\n".join(scanops_per_axis_strs) + + "." + ) + + return iter(scanops_per_axis.keys()).__next__() + + def _size_arg_from_field(field_name: str, dim: int) -> str: return f"__{field_name}_size_{dim}" @@ -51,10 +130,10 @@ class ProgramLowering( -------- >>> from gt4py.next.ffront.func_to_past import ProgramParser >>> from gt4py.next.iterator import ir - >>> from gt4py.next import Dimension, Field + >>> from gt4py.next import common.Dimension, Field >>> >>> float64 = float - >>> IDim = Dimension("IDim") + >>> IDim = common.Dimension("IDim") >>> >>> def fieldop(inp: Field[[IDim], "float64"]) -> Field[[IDim], "float64"]: ... >>> def program(inp: Field[[IDim], "float64"], out: Field[[IDim], "float64"]): @@ -67,7 +146,7 @@ class ProgramLowering( ... expr=ir.FunCall(fun=ir.SymRef(id="deref"), pos_only_args=[ir.SymRef(id="inp")]), ... ) # doctest: +SKIP >>> lowered = ProgramLowering.apply( - ... parsed, [fieldop_def], grid_type=GridType.CARTESIAN + ... parsed, [fieldop_def], grid_type=common.GridType.CARTESIAN ... ) # doctest: +SKIP >>> type(lowered) # doctest: +SKIP @@ -85,7 +164,7 @@ def apply( cls, node: past.Program, function_definitions: list[itir.FunctionDefinition], - grid_type: GridType, + grid_type: common.GridType, ) -> itir.FencilDefinition: return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions) @@ -97,7 +176,7 @@ def _gen_size_params_from_program(self, node: past.Program): size_params = [] for param in node.params: if type_info.is_type_or_tuple_of_type(param.type, ts.FieldType): - fields_dims: list[list[Dimension]] = ( + fields_dims: list[list[common.Dimension]] = ( type_info.primitive_constituents(param.type).getattr("dims").to_list() ) assert all(field_dims == fields_dims[0] for field_dims in fields_dims) @@ -263,8 +342,8 @@ def _construct_itir_domain_arg( slices[dim_i].upper if slices else None, dim_size, dim_size ) - if dim.kind == DimensionKind.LOCAL: - raise ValueError(f"Dimension '{dim.value}' must not be local.") + if dim.kind == common.DimensionKind.LOCAL: + raise ValueError(f"common.Dimension '{dim.value}' must not be local.") domain_args.append( itir.FunCall( fun=itir.SymRef(id="named_range"), @@ -273,14 +352,14 @@ def _construct_itir_domain_arg( ) domain_args_kind.append(dim.kind) - if self.grid_type == GridType.CARTESIAN: + if self.grid_type == common.GridType.CARTESIAN: domain_builtin = "cartesian_domain" - elif self.grid_type == GridType.UNSTRUCTURED: + elif self.grid_type == common.GridType.UNSTRUCTURED: domain_builtin = "unstructured_domain" # for no good reason, the domain arguments for unstructured need to be in order (horizontal, vertical) - if domain_args_kind[0] == DimensionKind.VERTICAL: + if domain_args_kind[0] == common.DimensionKind.VERTICAL: assert len(domain_args) == 2 - assert domain_args_kind[1] == DimensionKind.HORIZONTAL + assert domain_args_kind[1] == common.DimensionKind.HORIZONTAL domain_args[0], domain_args[1] = domain_args[1], domain_args[0] else: raise AssertionError() @@ -294,14 +373,14 @@ def _construct_itir_domain_arg( def _construct_itir_initialized_domain_arg( self, dim_i: int, - dim: Dimension, + dim: common.Dimension, node_domain: past.Dict, ) -> list[itir.FunCall]: assert len(node_domain.values_[dim_i].elts) == 2 keys_dims_types = cast(ts.DimensionType, node_domain.keys_[dim_i].type).dim if keys_dims_types != dim: raise ValueError( - "Dimensions in out field and field domain are not equivalent:" + "common.Dimensions in out field and field domain are not equivalent:" f"expected '{dim}', got '{keys_dims_types}'." ) diff --git a/src/gt4py/next/ffront/past_to_itir_wf.py b/src/gt4py/next/ffront/past_to_itir_wf.py deleted file mode 100644 index 7fbe852d93..0000000000 --- a/src/gt4py/next/ffront/past_to_itir_wf.py +++ /dev/null @@ -1,96 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import dataclasses -from typing import Any - -import devtools -import factory - -from gt4py.next import common, config -from gt4py.next.ffront import ( - fbuiltins, - gtcallable, - past_to_itir, - transform_utils, - type_specifications as ts_ffront, -) -from gt4py.next.otf import stages, workflow - - -@dataclasses.dataclass(frozen=True) -class PastToItir(workflow.ChainableWorkflowMixin): - def __call__(self, inp: stages.PastClosure) -> stages.ProgramCall: - all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars) - offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( - all_closure_vars, fbuiltins.FieldOffset, common.Dimension - ) - grid_type = transform_utils._deduce_grid_type( - inp.grid_type, offsets_and_dimensions.values() - ) - - gt_callables = transform_utils._filter_closure_vars_by_type( - all_closure_vars, gtcallable.GTCallable - ).values() - lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] - - itir_program = past_to_itir.ProgramLowering.apply( - inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type - ) - - if config.DEBUG or "debug" in inp.kwargs: - devtools.debug(itir_program) - - return stages.ProgramCall( - itir_program, - inp.args, - inp.kwargs | {"column_axis": _column_axis(all_closure_vars)}, - ) - - -class PastToItirFactory(factory.Factory): - class Meta: - model = PastToItir - - -def _column_axis(all_closure_vars: dict[str, Any]) -> common.Dimension: - # construct mapping from column axis to scan operators defined on - # that dimension. only one column axis is allowed, but we can use - # this mapping to provide good error messages. - scanops_per_axis: dict[common.Dimension, str] = {} - for name, gt_callable in transform_utils._filter_closure_vars_by_type( - all_closure_vars, gtcallable.GTCallable - ).items(): - if isinstance( - (type_ := gt_callable.__gt_type__()), - ts_ffront.ScanOperatorType, - ): - scanops_per_axis.setdefault(type_.axis, []).append(name) - - if len(scanops_per_axis.values()) == 0: - return None - - if len(scanops_per_axis.values()) != 1: - scanops_per_axis_strs = [ - f"- {dim.value}: {', '.join(scanops)}" for dim, scanops in scanops_per_axis.items() - ] - - raise TypeError( - "Only 'ScanOperator's defined on the same axis " - + "can be used in a 'Program', found:\n" - + "\n".join(scanops_per_axis_strs) - + "." - ) - - return iter(scanops_per_axis.keys()).__next__() From e148c8623a940403dfb46df0812ea94ce659ff1e Mon Sep 17 00:00:00 2001 From: DropD Date: Wed, 13 Mar 2024 11:12:55 +0100 Subject: [PATCH 45/47] rename: ffront.*_wf -> ffront.* --- src/gt4py/next/backend.py | 4 ++-- src/gt4py/next/ffront/decorator.py | 4 ++-- .../ffront/{past_process_args_wf.py => past_process_args.py} | 0 3 files changed, 4 insertions(+), 4 deletions(-) rename src/gt4py/next/ffront/{past_process_args_wf.py => past_process_args.py} (100%) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 3b5db57647..fc01123b62 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -19,13 +19,13 @@ from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators -from gt4py.next.ffront import past_process_args_wf, past_to_itir +from gt4py.next.ffront import past_process_args, past_to_itir from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import processor_interface as ppi DEFAULT_TRANSFORMS: workflow.Workflow[stages.PastClosure, stages.ProgramCall] = ( - past_process_args_wf.past_process_args.chain(past_to_itir.PastToItirFactory()) + past_process_args.past_process_args.chain(past_to_itir.PastToItirFactory()) ) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index cb68223d08..4dd7c6e399 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -41,7 +41,7 @@ from gt4py.next.ffront import ( dialect_ast_enums, field_operator_ast as foast, - past_process_args_wf, + past_process_args, past_to_itir, program_ast as past, transform_utils, @@ -239,7 +239,7 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No ) with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: # TODO(ricoh): check if rewriting still needed - rewritten_args, size_args, kwargs = past_process_args_wf._process_args( + rewritten_args, size_args, kwargs = past_process_args._process_args( self.past_node, args, kwargs ) ctx.run(self.definition, *rewritten_args, **kwargs) diff --git a/src/gt4py/next/ffront/past_process_args_wf.py b/src/gt4py/next/ffront/past_process_args.py similarity index 100% rename from src/gt4py/next/ffront/past_process_args_wf.py rename to src/gt4py/next/ffront/past_process_args.py From 57cb2edd5837872883e8a453015eedeae4901aa0 Mon Sep 17 00:00:00 2001 From: DropD Date: Wed, 13 Mar 2024 14:33:23 +0100 Subject: [PATCH 46/47] fix: undo false positive replacement from refactoring --- src/gt4py/next/ffront/past_to_itir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 45953cdd50..a5d266c4c0 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -130,10 +130,10 @@ class ProgramLowering( -------- >>> from gt4py.next.ffront.func_to_past import ProgramParser >>> from gt4py.next.iterator import ir - >>> from gt4py.next import common.Dimension, Field + >>> from gt4py.next import Dimension, Field >>> >>> float64 = float - >>> IDim = common.Dimension("IDim") + >>> IDim = Dimension("IDim") >>> >>> def fieldop(inp: Field[[IDim], "float64"]) -> Field[[IDim], "float64"]: ... >>> def program(inp: Field[[IDim], "float64"], out: Field[[IDim], "float64"]): From 2cff0203bb0157c26bb0c12e25d66c0602b66704 Mon Sep 17 00:00:00 2001 From: DropD Date: Wed, 13 Mar 2024 16:00:10 +0100 Subject: [PATCH 47/47] typing: fix issues in `_column_axis` --- src/gt4py/next/ffront/past_to_itir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index a5d266c4c0..0fc9a6280d 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -71,11 +71,11 @@ class Meta: model = PastToItir -def _column_axis(all_closure_vars: dict[str, Any]) -> common.Dimension: +def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]: # construct mapping from column axis to scan operators defined on # that dimension. only one column axis is allowed, but we can use # this mapping to provide good error messages. - scanops_per_axis: dict[common.Dimension, str] = {} + scanops_per_axis: dict[common.Dimension, list[str]] = {} for name, gt_callable in transform_utils._filter_closure_vars_by_type( all_closure_vars, gtcallable.GTCallable ).items():