Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Add 'groupby' mode to MapNodes visitor #2502

Merged
merged 2 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,7 @@ def default_retval(cls):
the nodes of type ``child_types`` retrieved by the search. This behaviour
can be changed through this parameter. Accepted values are:
- 'immediate': only the closest matching ancestor is mapped.
- 'groupby': the matching ancestors are grouped together as a single key.
"""

def __init__(self, parent_type=None, child_types=None, mode=None):
Expand All @@ -885,7 +886,7 @@ def __init__(self, parent_type=None, child_types=None, mode=None):
assert issubclass(parent_type, Node)
self.parent_type = parent_type
self.child_types = as_tuple(child_types) or (Call, Expression)
assert mode in (None, 'immediate')
assert mode in (None, 'immediate', 'groupby')
self.mode = mode

def visit_object(self, o, ret=None, **kwargs):
Expand All @@ -902,7 +903,9 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
if parents is None:
parents = []
if isinstance(o, self.child_types):
if self.mode == 'immediate':
if self.mode == 'groupby':
ret.setdefault(as_tuple(parents), []).append(o)
elif self.mode == 'immediate':
if in_parent:
ret.setdefault(parents[-1], []).append(o)
else:
Expand Down
21 changes: 20 additions & 1 deletion tests/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from devito.ir.equations import DummyEq
from devito.ir.iet import (Block, Expression, Callable, FindNodes, FindSections,
FindSymbols, IsPerfectIteration, Transformer,
Conditional, printAST, Iteration)
Conditional, printAST, Iteration, MapNodes, Call)
from devito.types import SpaceDimension, Array


Expand Down Expand Up @@ -376,3 +376,22 @@ def test_find_symbols_with_duplicates():
# So we expect FindSymbols to catch five Indexeds in total
symbols = FindSymbols('indexeds').visit(op)
assert len(symbols) == 5


def test_map_nodes(block1):
"""
Tests MapNodes visitor. When MapNodes is created with mode='groupby',
matching ancestors are grouped together under a single key.
This can be useful, for example, when applying transformations to the
outermost Iteration containing a specific node.
"""
map_nodes = MapNodes(Iteration, Expression, mode='groupby').visit(block1)

assert len(map_nodes.keys()) == 1

for iters, (expr,) in map_nodes.items():
# Replace the outermost `Iteration` with a `Call`
callback = Callable('solver', iters[0], 'void', ())
processed = Transformer({iters[0]: Call(callback.name)}).visit(block1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is fine!
Nipick: Maybe you could add:

+assert len(iters) == 3
assert str(processed) == 'solver();'

but sure it is not needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have assert str(processed) == 'solver();' just below

assert str(processed) == 'solver();'
Loading