From d31dd7b2990396ac6f76a5cbaa34c131372b54b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Sun, 15 Sep 2024 08:58:17 +0200 Subject: [PATCH] Maps With Zero Parameters (#1649) Before a map without any parameter was considered not invalid, it would pass validation, but most likly compilation would fail (except it is a serial map). This PR adds: - Disallows such maps. - Fixes a small bug in the constructor of the `Map` object. - It updates `TrivialMapElimination` such that it correctly handles the case if it has dynamic map ranges. - It removes the `TrivialMapRangeElimination` transformation as it is redundant and contained a bug. --------- Co-authored-by: Tal Ben-Nun --- dace/sdfg/nodes.py | 9 +- dace/transformation/dataflow/__init__.py | 1 - .../dataflow/trivial_map_elimination.py | 106 ++++++++++++------ .../dataflow/trivial_map_range_elimination.py | 48 -------- tests/trivial_map_elimination_test.py | 67 ++++++++++- tests/trivial_map_range_elimination_test.py | 58 ---------- 6 files changed, 142 insertions(+), 147 deletions(-) delete mode 100644 dace/transformation/dataflow/trivial_map_range_elimination.py delete mode 100644 tests/trivial_map_range_elimination_test.py diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 143b60a30f..409d30c57a 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -932,7 +932,7 @@ def __init__(self, self.label = label self.schedule = schedule self.unroll = unroll - self.collapse = 1 + self.collapse = collapse self.params = params self.range = ndrange self.debuginfo = debuginfo @@ -948,7 +948,12 @@ def __repr__(self): def validate(self, sdfg, state, node): if not dtypes.validate_name(self.label): - raise NameError('Invalid map name "%s"' % self.label) + raise NameError(f'Invalid map name "{self.label}"') + if self.get_param_num() == 0: + raise ValueError('There must be at least one parameter in a map.') + if self.get_param_num() != self.range.dims(): + raise ValueError(f'There are {self.get_param_num()} parameters but the range' + f' has {self.range.dims()} dimensions.') def get_param_num(self): """ Returns the number of map dimension parameters/symbols. """ diff --git a/dace/transformation/dataflow/__init__.py b/dace/transformation/dataflow/__init__.py index db4c928481..4ed7fd6283 100644 --- a/dace/transformation/dataflow/__init__.py +++ b/dace/transformation/dataflow/__init__.py @@ -12,7 +12,6 @@ from .map_fission import MapFission from .map_unroll import MapUnroll from .trivial_map_elimination import TrivialMapElimination -from .trivial_map_range_elimination import TrivialMapRangeElimination from .otf_map_fusion import OTFMapFusion # Data movement diff --git a/dace/transformation/dataflow/trivial_map_elimination.py b/dace/transformation/dataflow/trivial_map_elimination.py index 9387cfce23..69f445fd96 100644 --- a/dace/transformation/dataflow/trivial_map_elimination.py +++ b/dace/transformation/dataflow/trivial_map_elimination.py @@ -1,6 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """ Contains classes that implement the trivial-map-elimination transformation. """ +import dace from dace.sdfg import nodes from dace.sdfg import utils as sdutil from dace.transformation import transformation @@ -10,12 +11,17 @@ @make_properties class TrivialMapElimination(transformation.SingleStateTransformation): - """ Implements the Trivial-Map Elimination pattern. + """Implements the Trivial-Map Elimination pattern. - Trivial-Map Elimination removes all dimensions containing only one - element from a map. If this applies to all ranges the map is removed. - Example: Map[i=0:I,j=7] -> Map[i=0:I] - Example: Map[i=0 ,j=7] -> nothing + Trivial-Map Elimination removes all dimensions containing only one + element from a map. If this applies to all ranges the map is removed. + Example: Map[i=0:I,j=7] -> Map[i=0:I] + Example: Map[i=0 ,j=7] -> nothing + + There are some special cases: + - GPU maps are ignored as they are syntactically needed. + - If all map ranges are trivial and the map has dynamic map ranges, + the map is not removed, and one map parameter is retained. """ map_entry = transformation.PatternNode(nodes.MapEntry) @@ -26,52 +32,78 @@ def expressions(cls): def can_be_applied(self, graph, expr_index, sdfg, permissive=False): map_entry = self.map_entry - return any(r[0] == r[1] for r in map_entry.map.range) + + if map_entry.map.schedule in (dace.dtypes.GPU_SCHEDULES + [dace.ScheduleType.GPU_Default]): + return False + if not any(r[0] == r[1] for r in map_entry.map.range): + return False + if (map_entry.map.get_param_num()) == 1 and ( + any(not e.dst_conn.startswith("IN_") for e in graph.in_edges(map_entry) if not e.data.is_empty()) + ): + # There is only one map parameter and there are dynamic map ranges, this can not be resolved. + return False + return True def apply(self, graph, sdfg): map_entry = self.map_entry - map_exit = graph.exit_node(map_entry) remaining_ranges = [] remaining_params = [] + scope = graph.scope_subgraph(map_entry) for map_param, ranges in zip(map_entry.map.params, map_entry.map.range.ranges): map_from, map_to, _ = ranges if map_from == map_to: # Replace the map index variable with the value it obtained - scope = graph.scope_subgraph(map_entry) scope.replace(map_param, map_from) else: remaining_ranges.append(ranges) remaining_params.append(map_param) - map_entry.map.range.ranges = remaining_ranges + map_entry.map.range = remaining_ranges map_entry.map.params = remaining_params - if len(remaining_ranges) == 0: - # Redirect map entry's out edges - write_only_map = True - for edge in graph.out_edges(map_entry): - path = graph.memlet_path(edge) - index = path.index(edge) - - if not edge.data.is_empty(): - # Add an edge directly from the previous source connector to the destination - graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data) - write_only_map = False - - # Redirect map exit's in edges. - for edge in graph.in_edges(map_exit): - path = graph.memlet_path(edge) - index = path.index(edge) - - # Add an edge directly from the source to the next destination connector - if len(path) > index + 1: - graph.add_edge(edge.src, edge.src_conn, path[index + 1].dst, path[index + 1].dst_conn, edge.data) - if write_only_map: - outer_exit = path[index+1].dst - outer_entry = graph.entry_node(outer_exit) - if outer_entry is not None: - graph.add_edge(outer_entry, None, edge.src, None, Memlet()) - - # Remove map - graph.remove_nodes_from([map_entry, map_exit]) + if len(remaining_params) != 0: + # There are still some dimensions left, so no need to remove the map + pass + + elif any(not e.dst_conn.startswith("IN_") for e in graph.in_edges(map_entry) if not e.data.is_empty()): + # The map has dynamic map ranges, thus we can not remove the map. + # Instead we add one dimension back to keep the SDFG valid. + map_entry.map.params = [map_param] + map_entry.map.range = [ranges] + + else: + # The map is empty and there are no dynamic map ranges. + self.remove_empty_map(graph, sdfg) + + def remove_empty_map(self, graph, sdfg): + map_entry = self.map_entry + map_exit = graph.exit_node(map_entry) + + # Redirect map entry's out edges + write_only_map = True + for edge in graph.out_edges(map_entry): + if edge.data.is_empty(): + continue + # Add an edge directly from the previous source connector to the destination + path = graph.memlet_path(edge) + index = path.index(edge) + graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data) + write_only_map = False + + # Redirect map exit's in edges. + for edge in graph.in_edges(map_exit): + path = graph.memlet_path(edge) + index = path.index(edge) + + # Add an edge directly from the source to the next destination connector + if len(path) > index + 1: + graph.add_edge(edge.src, edge.src_conn, path[index + 1].dst, path[index + 1].dst_conn, edge.data) + if write_only_map: + outer_exit = path[index+1].dst + outer_entry = graph.entry_node(outer_exit) + if outer_entry is not None: + graph.add_edge(outer_entry, None, edge.src, None, Memlet()) + + # Remove map + graph.remove_nodes_from([map_entry, map_exit]) diff --git a/dace/transformation/dataflow/trivial_map_range_elimination.py b/dace/transformation/dataflow/trivial_map_range_elimination.py deleted file mode 100644 index 1de1f0de90..0000000000 --- a/dace/transformation/dataflow/trivial_map_range_elimination.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" Contains classes that implement the trivial map range elimination transformation. """ - -from dace import registry -from dace.sdfg import nodes -from dace.sdfg import utils as sdutil -from dace.transformation import transformation -from dace.properties import make_properties - - -@make_properties -class TrivialMapRangeElimination(transformation.SingleStateTransformation): - """ Implements the Trivial Map Range Elimination pattern. - - Trivial Map Range Elimination takes a multi-dimensional map with - a range containing one element and removes the corresponding dimension. - Example: Map[i=0:I,j=0] -> Map[i=0:I] - """ - - map_entry = transformation.PatternNode(nodes.MapEntry) - - @classmethod - def expressions(cls): - return [sdutil.node_path_graph(cls.map_entry)] - - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - map_entry = self.map_entry - if len(map_entry.map.range) <= 1: - return False # only acts on multi-dimensional maps - return any(frm == to for frm, to, _ in map_entry.map.range) - - def apply(self, graph, sdfg): - map_entry = self.map_entry - - remaining_ranges = [] - remaining_params = [] - for map_param, ranges in zip(map_entry.map.params, map_entry.map.range.ranges): - map_from, map_to, _ = ranges - if map_from == map_to: - # Replace the map index variable with the value it obtained - scope = graph.scope_subgraph(map_entry) - scope.replace(map_param, map_from) - else: - remaining_ranges.append(ranges) - remaining_params.append(map_param) - - map_entry.map.range.ranges = remaining_ranges - map_entry.map.params = remaining_params diff --git a/tests/trivial_map_elimination_test.py b/tests/trivial_map_elimination_test.py index 52ab4c1557..f159dc6e6a 100644 --- a/tests/trivial_map_elimination_test.py +++ b/tests/trivial_map_elimination_test.py @@ -52,6 +52,37 @@ def trivial_map_init_sdfg(): return sdfg +def trivial_map_with_dynamic_map_range_sdfg(): + sdfg = dace.SDFG("trivial_map_with_dynamic_map_range") + state = sdfg.add_state("state1", is_start_block=True) + + for name in "ABC": + sdfg.add_scalar(name, dtype=dace.float32, transient=False) + A, B, C = (state.add_access(name) for name in "ABC") + + _, me, _ = state.add_mapped_tasklet( + name="MAP", + map_ranges=[("__i", "0:1"), ("__j", "10:11")], + inputs={"__in": dace.Memlet("A[0]")}, + input_nodes={"A": A}, + code="__out = __in + 1", + outputs={"__out": dace.Memlet("B[0]")}, + output_nodes={"B": B}, + external_edges=True, + ) + state.add_edge( + C, + None, + me, + "dynamic_variable", + dace.Memlet("C[0]"), + ) + me.add_in_connector("dynamic_variable") + sdfg.validate() + + return sdfg + + def trivial_map_pseudo_init_sdfg(): sdfg = dace.SDFG('trivial_map_range_expanded') sdfg.add_array('A', [5, 1], dace.float64) @@ -160,7 +191,6 @@ def test_can_be_applied(self): count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False) graph.validate() - #graph.view() self.assertGreater(count, 0) @@ -188,5 +218,40 @@ def test_reconnects_edges(self): self.assertEqual(len(state.out_edges(map_entries[0])), 1) +class TrivialMapEliminationWithDynamicMapRangesTest(unittest.TestCase): + """ + Tests the case where the map has trivial ranges and dynamic map ranges. + """ + + def test_can_be_applied(self): + graph = trivial_map_with_dynamic_map_range_sdfg() + + count = graph.apply_transformations(TrivialMapElimination) + graph.validate() + + self.assertEqual(count, 1) + + + def test_removes_map(self): + graph = trivial_map_with_dynamic_map_range_sdfg() + + graph.apply_transformations(TrivialMapElimination) + + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) + self.assertEqual(state.in_degree(map_entries[0]), 2) + self.assertTrue(any(e.dst_conn.startswith("IN_") for e in state.in_edges(map_entries[0]))) + self.assertTrue(any(not e.dst_conn.startswith("IN_") for e in state.in_edges(map_entries[0]))) + + def test_not_remove_dynamic_map_range(self): + graph = trivial_map_with_dynamic_map_range_sdfg() + + count1 = graph.apply_transformations(TrivialMapElimination) + self.assertEqual(count1, 1) + + count2 = graph.apply_transformations(TrivialMapElimination) + self.assertEqual(count2, 0) + if __name__ == '__main__': unittest.main() diff --git a/tests/trivial_map_range_elimination_test.py b/tests/trivial_map_range_elimination_test.py deleted file mode 100644 index 5be1e6a2bf..0000000000 --- a/tests/trivial_map_range_elimination_test.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import dace -from dace.sdfg import nodes -from dace.transformation.dataflow import TrivialMapRangeElimination -import unittest - - -def trivial_map_range_sdfg(): - sdfg = dace.SDFG('trivial_map_range') - sdfg.add_array('A', [5], dace.float64) - sdfg.add_array('B', [5], dace.float64) - state = sdfg.add_state() - - # Nodes - read = state.add_read('A') - map_entry, map_exit = state.add_map('map', dict(i='0:1', j='0:5')) - tasklet = state.add_tasklet('tasklet', {'a'}, {'b'}, 'b = a') - write = state.add_write('B') - - # Edges - state.add_memlet_path(read, map_entry, tasklet, memlet=dace.Memlet.simple('A', '0'), dst_conn='a') - state.add_memlet_path(tasklet, map_exit, write, memlet=dace.Memlet.simple('B', 'i'), src_conn='b') - - sdfg.validate() - return sdfg - - -class TrivialMapRangeEliminationTest(unittest.TestCase): - def test_can_be_applied(self): - graph = trivial_map_range_sdfg() - - count = graph.apply_transformations(TrivialMapRangeElimination) - - self.assertGreater(count, 0) - - def test_transforms_map(self): - graph = trivial_map_range_sdfg() - - graph.apply_transformations(TrivialMapRangeElimination) - - state = graph.nodes()[0] - map_entry = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)][0] - self.assertEqual(map_entry.map.params, ['j']) - self.assertEqual(map_entry.map.range, dace.subsets.Range([(0, 4, 1)])) - - def test_raplaces_map_params_in_scope(self): - graph = trivial_map_range_sdfg() - - graph.apply_transformations(TrivialMapRangeElimination) - - state = graph.nodes()[0] - map_exit = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapExit)][0] - out_memlet = state.in_edges(map_exit)[0] - self.assertEqual(out_memlet.data.subset, dace.subsets.Range([(0, 0, 1)])) - - -if __name__ == '__main__': - unittest.main()