Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[next]: make past_to_itir cached in default transforms #1555

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
5 changes: 3 additions & 2 deletions src/gt4py/next/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class FieldopTransformWorkflow(workflow.NamedStepSequenceWithArgs):
dataclasses.field(default=past_process_args.past_process_args)
)
past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = (
dataclasses.field(default_factory=past_to_itir.PastToItirFactory)
dataclasses.field(default_factory=lambda: past_to_itir.PastToItirFactory(cached=True))
)

foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] = (
Expand Down Expand Up @@ -123,7 +123,7 @@ class ProgramTransformWorkflow(workflow.NamedStepSequenceWithArgs):
)
)
past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = (
dataclasses.field(default_factory=past_to_itir.PastToItirFactory)
dataclasses.field(default_factory=lambda: past_to_itir.PastToItirFactory(cached=True))
)


Expand Down Expand Up @@ -167,3 +167,4 @@ def __gt_allocator__(
self,
) -> next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]:
return self.allocator
return self.allocator
54 changes: 36 additions & 18 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,38 +31,56 @@
from gt4py.next.type_system import type_info, type_specifications as ts


@workflow.make_step
def past_to_itir(inp: ffront_stages.PastProgramDefinition) -> stages.ProgramCall:
all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars)
offsets_and_dimensions = transform_utils._filter_closure_vars_by_type(
all_closure_vars, fbuiltins.FieldOffset, common.Dimension
)
grid_type = transform_utils._deduce_grid_type(inp.grid_type, offsets_and_dimensions.values())

gt_callables = transform_utils._filter_closure_vars_by_type(
all_closure_vars, gtcallable.GTCallable
).values()
lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables]

itir_program = ProgramLowering.apply(
inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type
)

return stages.ProgramCall(
program=itir_program, args=tuple(), kwargs={"column_axis": _column_axis(all_closure_vars)}
)


@dataclasses.dataclass(frozen=True)
class PastToItir(workflow.ChainableWorkflowMixin):
def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall:
all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars)
offsets_and_dimensions = transform_utils._filter_closure_vars_by_type(
all_closure_vars, fbuiltins.FieldOffset, common.Dimension
)
grid_type = transform_utils._deduce_grid_type(
inp.grid_type, offsets_and_dimensions.values()
)

gt_callables = transform_utils._filter_closure_vars_by_type(
all_closure_vars, gtcallable.GTCallable
).values()
lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables]
inner: workflow.Workflow[ffront_stages.PastProgramDefinition, stages.ProgramCall] = past_to_itir

itir_program = ProgramLowering.apply(
inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type
def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall:
program_call = self.inner(
ffront_stages.PastProgramDefinition(inp.past_node, inp.closure_vars, inp.grid_type)
)

if config.DEBUG or "debug" in inp.kwargs:
devtools.debug(itir_program)
devtools.debug(program_call.program)

return stages.ProgramCall(
itir_program, inp.args, inp.kwargs | {"column_axis": _column_axis(all_closure_vars)}
return dataclasses.replace(
program_call, args=inp.args, kwargs=inp.kwargs | program_call.kwargs
)


class PastToItirFactory(factory.Factory):
class Meta:
model = PastToItir

class Params:
cached = factory.Trait(
inner=workflow.CachedStep(past_to_itir, hash_function=ffront_stages.fingerprint_stage)
)

inner = past_to_itir


def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]:
# construct mapping from column axis to scan operators defined on
Expand Down
15 changes: 15 additions & 0 deletions src/gt4py/next/ffront/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,21 @@ class PastClosure:
kwargs: dict[str, Any]


def fingerprint_past_closure_noargs(
past_closure: PastClosure, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None
) -> str:
return fingerprint_stage(
obj={
"closure_vars": past_closure.closure_vars,
"past_node": past_closure.past_node,
"grid_type": past_closure.grid_type,
"args": past_closure.args,
"kwargs": past_closure.kwargs,
},
algorithm=algorithm,
)


def fingerprint_stage(obj: Any, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None) -> str:
hasher: xtyping.HashlibAlgorithm
if not algorithm:
Expand Down
Loading