diff --git a/AUTHORS b/AUTHORS index 3275288476..9b7763593e 100644 --- a/AUTHORS +++ b/AUTHORS @@ -34,5 +34,6 @@ Cliff Hodel Tiancheng Chen Reid Wahl Yihang Luo +Alexandru Calotoiu and other contributors listed in https://github.com/spcl/dace/graphs/contributors diff --git a/dace/sdfg/analysis/cutout.py b/dace/sdfg/analysis/cutout.py index b557cb185e..a72a6d7e54 100644 --- a/dace/sdfg/analysis/cutout.py +++ b/dace/sdfg/analysis/cutout.py @@ -3,6 +3,9 @@ Functionality that allows users to "cut out" parts of an SDFG in a smart way (i.e., memory preserving) for localized testing or optimization. """ +import networkx as nx +from networkx.algorithms.flow import edmondskarp +import sympy as sp from collections import deque import copy from typing import Deque, Dict, List, Set, Tuple, Union, Optional, Any @@ -15,6 +18,7 @@ PatternTransformation, SubgraphTransformation, SingleStateTransformation) +from dace.transformation.interstate.loop_detection import DetectLoop from dace.transformation.passes.analysis import StateReachability @@ -77,8 +81,10 @@ def translate_transformation_into(self, transformation: Union[PatternTransformat # Ignore. pass elif isinstance(transformation, MultiStateTransformation): - transformation._sdfg = self - transformation.sdfg_id = 0 + new_sdfg_id = self._in_translation[transformation.sdfg_id] + new_sdfg = self.sdfg_list[new_sdfg_id] + transformation._sdfg = new_sdfg + transformation.sdfg_id = new_sdfg_id for k in transformation.subgraph.keys(): old_state = self._base_sdfg.node(transformation.subgraph[k]) try: @@ -111,18 +117,40 @@ def from_json(cls, json_obj, context_info=None): @classmethod def from_transformation( cls, sdfg: SDFG, transformation: Union[PatternTransformation, SubgraphTransformation], - make_side_effects_global = True, use_alibi_nodes: bool = True + make_side_effects_global = True, use_alibi_nodes: bool = True, reduce_input_config = True, + symbols_map: Optional[Dict[str, Any]] = None ) -> Union['SDFGCutout', SDFG]: + """ + Create a cutout from a transformation's set of affected graph elements. + + :param sdfg: The SDFG to create the cutout from. + :param transformation: The transformation to create the cutout from. + :param make_side_effects_global: Whether to make side effect data containers global, i.e. non-transient. + :param use_alibi_nodes: Whether to use alibi nodes for the cutout across scope borders. + :param reduce_input_config: Whether to reduce the input configuration where possible in singlestate cutouts. + :param symbols_map: A mapping of symbols to values to use for the cutout. Optional, only used when reducing the + input configuration. + :return: The cutout. + """ affected_nodes = _transformation_determine_affected_nodes(sdfg, transformation) + if len(affected_nodes) == 0: + cut_sdfg = copy.deepcopy(sdfg) + transformation._sdfg = cut_sdfg + return cut_sdfg + target_sdfg = sdfg if transformation.sdfg_id >= 0 and target_sdfg.sdfg_list is not None: target_sdfg = target_sdfg.sdfg_list[transformation.sdfg_id] - if isinstance(transformation, (SubgraphTransformation, SingleStateTransformation)): - state = target_sdfg.node(transformation.state_id) + if (all(isinstance(n, nd.Node) for n in affected_nodes) or + isinstance(transformation, (SubgraphTransformation, SingleStateTransformation))): + state = target_sdfg.parent + if transformation.state_id >= 0: + state = target_sdfg.node(transformation.state_id) cutout = cls.singlestate_cutout(state, *affected_nodes, make_side_effects_global=make_side_effects_global, - use_alibi_nodes=use_alibi_nodes) + use_alibi_nodes=use_alibi_nodes, reduce_input_config=reduce_input_config, + symbols_map=symbols_map) cutout.translate_transformation_into(transformation) return cutout elif isinstance(transformation, MultiStateTransformation): @@ -132,14 +160,16 @@ def from_transformation( cutout.translate_transformation_into(transformation) return cutout raise Exception('Unsupported transformation type: {}'.format(type(transformation))) - + @classmethod def singlestate_cutout(cls, state: SDFGState, *nodes: nd.Node, make_copy: bool = True, make_side_effects_global: bool = True, - use_alibi_nodes: bool = True) -> 'SDFGCutout': + use_alibi_nodes: bool = True, + reduce_input_config: bool = False, + symbols_map: Optional[Dict[str, Any]] = None) -> 'SDFGCutout': """ Cut out a subgraph of a state from an SDFG to run separately for localized testing or optimization. The subgraph defined by the list of nodes will be extended to include access nodes of data containers necessary @@ -155,8 +185,13 @@ def singlestate_cutout(cls, inside the cutout but may be read _after_ the cutout, are made global. :param use_alibi_nodes: If True, do not extend the cutout with access nodes that span outside of a scope, but introduce alibi nodes instead that represent only the accesses subset. + :param reduce_input_config: Whether to reduce the input configuration where possible in singlestate cutouts. + :param symbols_map: A mapping of symbols to values to use for the cutout. Optional, only used when reducing the + input configuration. :return: The created SDFGCutout. """ + if reduce_input_config: + nodes = _reduce_in_configuration(state, nodes, use_alibi_nodes, symbols_map) create_element = copy.deepcopy if make_copy else (lambda x: x) sdfg = state.parent subgraph: StateSubgraphView = StateSubgraphView(state, nodes) @@ -272,6 +307,15 @@ def singlestate_cutout(cls, cutout._in_translation = in_translation cutout._out_translation = out_translation + # Translate in nested SDFG nodes and their SDFGs (their list id, specifically). + cutout.reset_sdfg_list() + outers = set(in_translation.keys()) + for outer in outers: + if isinstance(outer, nd.NestedSDFG): + inner: nd.NestedSDFG = in_translation[outer] + cutout._in_translation[outer.sdfg.sdfg_id] = inner.sdfg.sdfg_id + _recursively_set_nsdfg_parents(cutout) + return cutout @classmethod @@ -319,14 +363,14 @@ def multistate_cutout(cls, frontier, frontier_edges = bfs_queue.popleft() if len(frontier_edges) == 0: # No explicit start state, but also no frontier to select from. - return sdfg + return copy.deepcopy(sdfg) elif len(frontier_edges) == 1: # If there is only one predecessor frontier edge, its destination must be the start state. start_state = list(frontier_edges)[0].dst else: if len(frontier) == 0: # No explicit start state, but also no frontier to select from. - return sdfg + return copy.deepcopy(sdfg) if len(frontier) == 1: # For many frontier edges but only one frontier state, the frontier state is the new start state # and is included in the cutout. @@ -349,8 +393,18 @@ def multistate_cutout(cls, state_defined_symbols = state.defined_symbols() for sym in state_defined_symbols: defined_symbols[sym] = state_defined_symbols[sym] + for edge in subgraph.edges(): + is_edge: InterstateEdge = edge.data + available_symbols = sdfg.symbols.keys() + free_symbols |= (is_edge.free_symbols & available_symbols) + for rmem in is_edge.get_read_memlets(sdfg.arrays): + if rmem.data in cutout.arrays: + continue + new_desc = sdfg.arrays[rmem.data].clone() + cutout.add_datadesc(rmem.data, new_desc) for sym in free_symbols: - cutout.add_symbol(sym, defined_symbols[sym]) + if not sym in cutout.symbols: + cutout.add_symbol(sym, defined_symbols[sym]) for state in cutout_states: for dnode in state.data_nodes(): @@ -413,6 +467,9 @@ def multistate_cutout(cls, cutout._in_translation = in_translation cutout._out_translation = out_translation + cutout.reset_sdfg_list() + _recursively_set_nsdfg_parents(cutout) + return cutout @@ -447,6 +504,27 @@ def _transformation_determine_affected_nodes( except KeyError: # Ignored. pass + + # Transformations that modify a loop in any way must also include the loop init node, i.e. the state directly + # before the loop guard. Also make sure that ALL loop body states are part of the set of affected nodes. + # TODO: This is hacky and should be replaced with a more general mechanism - this is something that + # transformation intents / transactions will need to solve. + if isinstance(transformation, DetectLoop): + if transformation.loop_guard is not None and transformation.loop_guard in target_sdfg.nodes(): + for iedge in target_sdfg.in_edges(transformation.loop_guard): + affected_nodes.add(iedge.src) + if transformation.loop_begin is not None and transformation.loop_begin in target_sdfg.nodes(): + to_visit = [transformation.loop_begin] + while to_visit: + state = to_visit.pop(0) + for _, dst, _ in target_sdfg.out_edges(state): + if dst not in affected_nodes and dst is not transformation.loop_guard: + to_visit.append(dst) + affected_nodes.add(state) + + if len(affected_nodes) == 0 and transformation.state_id < 0 and target_sdfg.parent_nsdfg_node is not None: + # This is a transformation that affects a nested SDFG node, grab that NSDFG node. + affected_nodes.add(target_sdfg.parent_nsdfg_node) else: if transformation.sdfg_id >= 0 and target_sdfg.sdfg_list: target_sdfg = target_sdfg.sdfg_list[transformation.sdfg_id] @@ -478,6 +556,202 @@ def _transformation_determine_affected_nodes( return affected_nodes +def _reduce_in_configuration(state: SDFGState, affected_nodes: Set[nd.Node], use_alibi_nodes: bool = False, + symbols_map: Optional[Dict[str, Any]] = None) -> Set[nd.Node]: + """ + For a given set of nodes that should be cut out in a single state cutout, try to reduce the size of the input + configuration as much as possible by adding more nodes to find a S-T minimum 2-cut in the state. + + :param state: The state in which to cut out. + :param affected_nodes: The set of nodes that should be cut out. + :param use_alibi_nodes: If True, use alibi nodes across scope borders. + :param symbols_map: A map of symbols to values. An assumption will be made about symbol values if None is provided. + :return: A new set of node greater than or equal to the initial cutout nodes, which makes up a minimized cutout. + """ + subgraph: StateSubgraphView = StateSubgraphView(state, affected_nodes) + subgraph = _extend_subgraph_with_access_nodes(state, subgraph, use_alibi_nodes) + subgraph_nodes = set(subgraph.nodes()) + + # For the given state, determine what should count as the input configuration if we were to cut out the entire + # state. + state_reachability_dict = StateReachability().apply_pass(state.parent, None) + state_reach = state_reachability_dict[state.parent.sdfg_id] + reaching_cutout: Set[SDFGState] = set() + for k, v in state_reach.items(): + if state in v: + reaching_cutout.add(k) + state_input_configuration = set() + check_for_write_before = set() + for dn in state.data_nodes(): + if state.out_degree(dn) > 0: + # This is read from, add to the system state if it is written anywhere else in the graph. + # Except if it is also written to at the same time and is scalar or of size 1. + array = state.parent.arrays[dn.data] + if state.in_degree(dn) > 0 and (array.total_size == 1 or isinstance(array, data.Scalar)): + continue + elif not array.transient: + # Non-transients are always part of the input config if they are read and not overwritten anyway. + state_input_configuration.add(dn.data) + else: + check_for_write_before.add(dn.data) + for pre_state in reaching_cutout: + for dn in pre_state.data_nodes(): + if pre_state.in_degree(dn) > 0: + # For any writes, check if they are reads from the cutout that need to be checked. If they are, they're + # part of the system state. + if dn.data in check_for_write_before: + state_input_configuration.add(dn.data) + + # If no explicit symbol map was provided, we have to make an assumption about symbol values to determine a minimum + # cut. + # TODO: This is a hack. Ideally, we should be able to determine the minimum cut without having to make assumptions + # about symbol values. Not sure how to do that yet. + if symbols_map is None: + symbols_map = dict() + consts = state.parent.constants + for s in state.parent.symbols: + if s in consts: + symbols_map[s] = consts[s] + else: + symbols_map[s] = 20 + + # Use a proxy graph to compute the minium cut. + proxy_graph = nx.DiGraph() + + # By expanding over the borders of a scope (e.g. over the entry of a map), we know that we universally can only + # increase the size of the input configuration. Consequently, we can use the outer-most scope entry node as our + # source node for the minimum cut, if there is such a unique outer entry node. + source_candidates = set() + for n in subgraph_nodes: + source_candidates.add(state.entry_node(n)) + + source = None + scope_children = state.scope_children() + transitive_scope_children: Dict[SDFGState, Set[SDFGState]] = dict() + for k, v in scope_children.items(): + queue = deque(v) + k_children = set(v) + while queue: + child = queue.popleft() + if child in scope_children: + n_children = set(scope_children[child]) + queue.extend(n_children) + k_children.update(n_children) + transitive_scope_children[k] = k_children + if len(source_candidates) > 1: + for cand in source_candidates: + if all(other_cand in transitive_scope_children[cand] for other_cand in source_candidates): + source = cand + break + elif len(source_candidates) == 1: + source = list(source_candidates)[0] + + # If there is no unique outer entry node, we use a proxy node as the source. + scope_nodes: Set[nd.Node] = set() + if source == None: + source = nd.Node() + scope_nodes = set(scope_children[None]) + else: + scope_nodes = set(scope_children[source]) + scope_nodes.add(source) + expand_with = set() + for n in scope_nodes: + if isinstance(n, nd.EntryNode): + exit = state.exit_node(n) + expand_with.add(exit) + scope_nodes.update(expand_with) + scope_subgraph = StateSubgraphView(state, scope_nodes) + + # Add the source and a proxy sink to the proxy graph. + proxy_graph.add_node(source) + sink = nd.Node() + proxy_graph.add_node(sink) + + # Build up the proxy graph. + for edge in scope_subgraph.edges(): + proxy_edge_src = edge.src + proxy_edge_dst = edge.dst + + vol = 0 + memlet: Memlet = edge.data + if memlet.data: + vol = memlet.volume + if isinstance(vol, sp.Expr): + vol = vol.subs(symbols_map) + + remain_free = False + if edge.src in subgraph_nodes and edge.dst in subgraph_nodes: + # Edge completely in subgraph, don't do anything. Unless the destination is an access node which is in the + # state input configuration, in which case we add an edge from the source to the sink with that volume. + if isinstance(edge.dst, nd.AccessNode) and memlet.data in state_input_configuration: + if proxy_graph.has_edge(source, sink): + proxy_graph[source][sink]['capacity'] += vol + else: + proxy_graph.add_node(source) + proxy_graph.add_node(sink) + proxy_graph.add_edge(source, sink, capacity=vol) + continue + elif edge.src in subgraph_nodes: + # Edge starts in subgraph, ends outside. + # If there's no path back inside, it's source is the proxy sink. Otherwise, it's source is set to the proxy + # source and the volume is made 0, since the value will already be part of the cutout. + if any([n in nx.descendants(state.nx, proxy_edge_src) for n in subgraph_nodes]): + proxy_edge_src = source + vol = 0 + remain_free = True + else: + proxy_edge_src = sink + elif edge.dst in subgraph_nodes: + # Edge starts outside, ends in the subgraph. It's destination thus is the proxy sink. + proxy_edge_dst = sink + + if isinstance(proxy_edge_dst, nd.AccessNode) and memlet.data in state_input_configuration: + # If the destination is an access node that is part of the state input configuration, we add an edge from + # the source with that volume. + if proxy_graph.has_edge(source, proxy_edge_dst): + proxy_graph[source][proxy_edge_dst]['capacity'] += vol + else: + proxy_graph.add_edge(source, proxy_edge_dst, capacity=vol) + # The actual edge between src and dst is set to have infinite capacity. + vol = float('inf') + elif isinstance(proxy_edge_src, nd.AccessNode) and not remain_free: + # All outgoing edges from access nodes (with data) are set to have infinite capacity. + vol = float('inf') + + if isinstance(proxy_edge_src, nd.ExitNode): + proxy_edge_src = state.entry_node(proxy_edge_src) + + if proxy_graph.has_edge(proxy_edge_src, proxy_edge_dst): + proxy_graph[proxy_edge_src][proxy_edge_dst]['capacity'] += vol + else: + proxy_graph.add_node(proxy_edge_src) + proxy_graph.add_node(proxy_edge_dst) + proxy_graph.add_edge(proxy_edge_src, proxy_edge_dst, capacity=vol) + + for node in scope_nodes: + if isinstance(node, nd.AccessNode) and node.data in state_input_configuration: + if not proxy_graph.has_edge(source, node) and node.data in state.parent.arrays: + vol = state.parent.arrays[node.data].total_size + if isinstance(vol, sp.Expr): + vol = vol.subs(symbols_map) + proxy_graph.add_edge(source, node, capacity=vol) + + _, (_, non_reachable) = nx.minimum_cut(proxy_graph, + source, + sink, + flow_func=edmondskarp.edmonds_karp) + + non_reachable -= {sink} + if len(non_reachable) > 0: + subscope_expansions = set() + for n in non_reachable: + if isinstance(n, nd.EntryNode): + subscope_expansions.update(transitive_scope_children[n]) + elif isinstance(n, nd.ExitNode): + subscope_expansions.update(transitive_scope_children[state.entry_node(n)]) + return subgraph_nodes.union(non_reachable.union(subscope_expansions)) + return subgraph_nodes + def _stateset_predecessor_frontier(states: Set[SDFGState]) -> Tuple[Set[SDFGState], Set[Edge[InterstateEdge]]]: """ For a set of states, return their predecessor frontier. @@ -667,14 +941,17 @@ def _cutout_determine_input_config(ct: SDFG, inverse_cutout_reach: Set[SDFGState for state in cutout_states: for dn in state.data_nodes(): noded_descriptors.add(dn.data) - - array = ct.arrays[dn.data] - if not array.transient: - # Non-transients are always part of the system state. - input_configuration.add(dn.data) - elif state.out_degree(dn) > 0: + if state.out_degree(dn) > 0: # This is read from, add to the system state if it is written anywhere else in the graph. - check_for_write_before.add(dn.data) + # Except if it is also written to at the same time and is scalar or of size 1. + array = ct.arrays[dn.data] + if state.in_degree(dn) > 0 and (array.total_size == 1 or isinstance(array, data.Scalar)): + continue + elif not array.transient: + # Non-transients are always part of the input config if they are read and not overwritten anyway. + input_configuration.add(dn.data) + else: + check_for_write_before.add(dn.data) original_state: Optional[SDFGState] = None try: @@ -759,3 +1036,11 @@ def _cutout_determine_output_configuration(ct: SDFG, cutout_reach: Set[SDFGState system_state.add(dn.data) return system_state + + +def _recursively_set_nsdfg_parents(target: SDFG): + for state in target.states(): + for n in state.nodes(): + if isinstance(n, nd.NestedSDFG): + n.sdfg.parent_sdfg = target + _recursively_set_nsdfg_parents(n.sdfg) diff --git a/dace/transformation/interstate/__init__.py b/dace/transformation/interstate/__init__.py index 8af9e901ab..0bd168751c 100644 --- a/dace/transformation/interstate/__init__.py +++ b/dace/transformation/interstate/__init__.py @@ -2,6 +2,7 @@ """ This module initializes the inter-state transformations package.""" from .state_fusion import StateFusion +from .state_fusion_with_happens_before import StateFusionExtended from .state_elimination import (EndStateElimination, StartStateElimination, StateAssignElimination, SymbolAliasPromotion, HoistState) from .fpga_transform_state import FPGATransformState diff --git a/dace/transformation/interstate/state_fusion_with_happens_before.py b/dace/transformation/interstate/state_fusion_with_happens_before.py new file mode 100644 index 0000000000..4c6ad3c992 --- /dev/null +++ b/dace/transformation/interstate/state_fusion_with_happens_before.py @@ -0,0 +1,590 @@ +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +""" State fusion transformation """ + +from typing import Dict, List, Set + +import networkx as nx + +from dace import data as dt, dtypes, registry, sdfg, subsets, memlet +from dace.config import Config +from dace.sdfg import nodes +from dace.sdfg import utils as sdutil +from dace.sdfg.state import SDFGState +from dace.transformation import transformation + + +# Helper class for finding connected component correspondences +class CCDesc: + def __init__(self, first_input_nodes: Set[nodes.AccessNode], first_output_nodes: Set[nodes.AccessNode], + second_input_nodes: Set[nodes.AccessNode], second_output_nodes: Set[nodes.AccessNode]) -> None: + self.first_inputs = {n.data for n in first_input_nodes} + self.first_input_nodes = first_input_nodes + self.first_outputs = {n.data for n in first_output_nodes} + self.first_output_nodes = first_output_nodes + self.second_inputs = {n.data for n in second_input_nodes} + self.second_input_nodes = second_input_nodes + self.second_outputs = {n.data for n in second_output_nodes} + self.second_output_nodes = second_output_nodes + + +def top_level_nodes(state: SDFGState): + return state.scope_children()[None] + + +class StateFusionExtended(transformation.MultiStateTransformation): + """ Implements the state-fusion transformation extended to fuse states with RAW and WAW dependencies. + An empty memlet is used to represent a dependency between two subgraphs with RAW and WAW dependencies. + The merge is made by identifying the source in the first state and the sink in the second state, + and linking the bottom of the appropriate source subgraph in the first state with the top of the + appropriate sink subgraph in the second state. + + State-fusion takes two states that are connected through a single edge, + and fuses them into one state. If permissive, also applies if potential memory + access hazards are created. + """ + connections_to_make = [] + first_state = transformation.PatternNode(sdfg.SDFGState) + second_state = transformation.PatternNode(sdfg.SDFGState) + + @staticmethod + def annotates_memlets(): + return False + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.first_state, cls.second_state)] + + @staticmethod + def find_fused_components(first_cc_input, first_cc_output, second_cc_input, second_cc_output) -> List[CCDesc]: + # Make a bipartite graph out of the first and second components + g = nx.DiGraph() + g.add_nodes_from((0, i) for i in range(len(first_cc_output))) + g.add_nodes_from((1, i) for i in range(len(second_cc_output))) + # Find matching nodes in second state + for i, cc1 in enumerate(first_cc_output): + outnames1 = {n.data for n in cc1} + for j, cc2 in enumerate(second_cc_input): + inpnames2 = {n.data for n in cc2} + if len(outnames1 & inpnames2) > 0: + g.add_edge((0, i), (1, j)) + + # Construct result out of connected components of the bipartite graph + result = [] + for cc in nx.weakly_connected_components(g): + input1, output1, input2, output2 = set(), set(), set(), set() + for gind, cind in cc: + if gind == 0: + input1 |= first_cc_input[cind] + output1 |= first_cc_output[cind] + else: + input2 |= second_cc_input[cind] + output2 |= second_cc_output[cind] + result.append(CCDesc(input1, output1, input2, output2)) + + return result + + @staticmethod + def memlets_intersect(graph_a: SDFGState, group_a: List[nodes.AccessNode], inputs_a: bool, graph_b: SDFGState, + group_b: List[nodes.AccessNode], inputs_b: bool) -> bool: + """ + Performs an all-pairs check for subset intersection on two + groups of nodes. If group intersects or result is indeterminate, + returns True as a precaution. + + :param graph_a: The graph in which the first set of nodes reside. + :param group_a: The first set of nodes to check. + :param inputs_a: If True, checks inputs of the first group. + :param graph_b: The graph in which the second set of nodes reside. + :param group_b: The second set of nodes to check. + :param inputs_b: If True, checks inputs of the second group. + :return: True if subsets intersect or result is indeterminate. + """ + # Set traversal functions + src_subset = lambda e: (e.data.src_subset if e.data.src_subset is not None else e.data.dst_subset) + dst_subset = lambda e: (e.data.dst_subset if e.data.dst_subset is not None else e.data.src_subset) + if inputs_a: + edges_a = [e for n in group_a for e in graph_a.out_edges(n)] + subset_a = src_subset + else: + edges_a = [e for n in group_a for e in graph_a.in_edges(n)] + subset_a = dst_subset + if inputs_b: + edges_b = [e for n in group_b for e in graph_b.out_edges(n)] + subset_b = src_subset + else: + edges_b = [e for n in group_b for e in graph_b.in_edges(n)] + subset_b = dst_subset + + # Simple all-pairs check + for ea in edges_a: + for eb in edges_b: + result = subsets.intersects(subset_a(ea), subset_b(eb)) + if result is True or result is None: + return True + return False + + def has_path(self, first_state: SDFGState, second_state: SDFGState, + match_nodes: Dict[nodes.AccessNode, nodes.AccessNode], node_a: nodes.Node, node_b: nodes.Node) -> bool: + """ Check for paths between the two states if they are fused. """ + for match_a, match_b in match_nodes.items(): + if nx.has_path(first_state._nx, node_a, match_a) and nx.has_path(second_state._nx, match_b, node_b): + return True + return False + + def _check_all_paths(self, first_state: SDFGState, second_state: SDFGState, + match_nodes: Dict[nodes.AccessNode, nodes.AccessNode], nodes_first: List[nodes.AccessNode], + nodes_second: List[nodes.AccessNode], first_read: bool, second_read: bool) -> bool: + for node_a in nodes_first: + succ_a = first_state.successors(node_a) + for node_b in nodes_second: + if all(self.has_path(first_state, second_state, match_nodes, sa, node_b) for sa in succ_a): + return True + # Path not found, check memlets + if StateFusionExtended.memlets_intersect(first_state, nodes_first, first_read, second_state, nodes_second, + second_read): + return False + return True + + def _check_paths(self, first_state: SDFGState, second_state: SDFGState, match_nodes: Dict[nodes.AccessNode, + nodes.AccessNode], + nodes_first: List[nodes.AccessNode], nodes_second: List[nodes.AccessNode], + second_input: Set[nodes.AccessNode], first_read: bool, second_read: bool) -> bool: + fail = False + path_found = False + for match in match_nodes: + for node in nodes_first: + path_to = nx.has_path(first_state._nx, node, match) + if not path_to: + continue + path_found = True + node2 = next(n for n in second_input if n.data == match.data) + if not all(nx.has_path(second_state._nx, node2, n) for n in nodes_second): + fail = True + break + if fail or path_found: + break + + # Check for intersection (if None, fusion is ok) + if fail or not path_found: + if StateFusionExtended.memlets_intersect(first_state, nodes_first, first_read, second_state, nodes_second, + second_read): + return False + return True + + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + first_state: SDFGState = self.first_state + second_state: SDFGState = self.second_state + + out_edges = graph.out_edges(first_state) + in_edges = graph.in_edges(first_state) + + # First state must have only one output edge (with dst the second + # state). + if len(out_edges) != 1: + return False + # If both states have more than one incoming edge, some control flow + # may become ambiguous + if len(in_edges) > 1 and graph.in_degree(second_state) > 1: + return False + # The interstate edge must not have a condition. + if not out_edges[0].data.is_unconditional(): + return False + # The interstate edge may have assignments, as long as there are input + # edges to the first state that can absorb them. + if out_edges[0].data.assignments: + if not in_edges: + return False + # Fail if symbol is set before the state to fuse + new_assignments = set(out_edges[0].data.assignments.keys()) + if any((new_assignments & set(e.data.assignments.keys())) for e in in_edges): + return False + # Fail if symbol is used in the dataflow of that state + if len(new_assignments & first_state.free_symbols) > 0: + return False + # Fail if assignments have free symbols that are updated in the + # first state + freesyms = out_edges[0].data.free_symbols + if freesyms and any(n.data in freesyms for n in first_state.nodes() + if isinstance(n, nodes.AccessNode) and first_state.in_degree(n) > 0): + return False + # Fail if symbols assigned on the first edge are free symbols on the + # second edge + symbols_used = set(out_edges[0].data.free_symbols) + for e in in_edges: + if e.data.assignments.keys() & symbols_used: + return False + # Also fail in the inverse; symbols assigned on the second edge are free symbols on the first edge + if new_assignments & set(e.data.free_symbols): + return False + + # There can be no state that have output edges pointing to both the + # first and the second state. Such a case will produce a multi-graph. + for src, _, _ in in_edges: + for _, dst, _ in graph.out_edges(src): + if dst == second_state: + return False + + if not permissive: + # Strict mode that inhibits state fusion if Python callbacks are involved + if Config.get_bool('frontend', 'dont_fuse_callbacks'): + for node in (first_state.data_nodes() + second_state.data_nodes()): + if node.data == '__pystate': + return False + + # NOTE: This is quick fix for MPI Waitall (probably also needed for + # Wait), until we have a better SDFG representation of the buffer + # dependencies. + try: + next(node for node in first_state.nodes() + if (isinstance(node, nodes.LibraryNode) and type(node).__name__ == 'Waitall') + or node.label == '_Waitall_') + return False + except StopIteration: + pass + try: + next(node for node in second_state.nodes() + if (isinstance(node, nodes.LibraryNode) and type(node).__name__ == 'Waitall') + or node.label == '_Waitall_') + return False + except StopIteration: + pass + + # If second state has other input edges, there might be issues + # Exceptions are when none of the states contain dataflow, unless + # the first state is an initial state (in which case the new initial + # state would be ambiguous). + first_in_edges = graph.in_edges(first_state) + second_in_edges = graph.in_edges(second_state) + if ((not second_state.is_empty() or not first_state.is_empty() or len(first_in_edges) == 0) + and len(second_in_edges) != 1): + return False + + # Get connected components. + first_cc = [cc_nodes for cc_nodes in nx.weakly_connected_components(first_state._nx)] + second_cc = [cc_nodes for cc_nodes in nx.weakly_connected_components(second_state._nx)] + + # Find source/sink (data) nodes + first_input = {node for node in first_state.source_nodes() if isinstance(node, nodes.AccessNode)} + first_output = { + node + for node in first_state.scope_children()[None] + if isinstance(node, nodes.AccessNode) and node not in first_input + } + second_input = {node for node in second_state.source_nodes() if isinstance(node, nodes.AccessNode)} + second_output = { + node + for node in second_state.scope_children()[None] + if isinstance(node, nodes.AccessNode) and node not in second_input + } + + # Find source/sink (data) nodes by connected component + first_cc_input = [cc.intersection(first_input) for cc in first_cc] + first_cc_output = [cc.intersection(first_output) for cc in first_cc] + second_cc_input = [cc.intersection(second_input) for cc in second_cc] + second_cc_output = [cc.intersection(second_output) for cc in second_cc] + + # Apply transformation in case all paths to the second state's + # nodes go through the same access node, which implies sequential + # behavior in SDFG semantics. + first_output_names = {node.data for node in first_output} + second_input_names = {node.data for node in second_input} + + # If any second input appears more than once, fail + if len(second_input) > len(second_input_names): + return False + + # If any first output that is an input to the second state + # appears in more than one CC, fail + matches = first_output_names & second_input_names + for match in matches: + cc_appearances = 0 + for cc in first_cc_output: + if len([n for n in cc if n.data == match]) > 0: + cc_appearances += 1 + if cc_appearances > 1: + return False + + # Recreate fused connected component correspondences, and then + # check for hazards + resulting_ccs: List[CCDesc] = StateFusionExtended.find_fused_components(first_cc_input, first_cc_output, + second_cc_input, second_cc_output) + + # Check for data races + for fused_cc in resulting_ccs: + # Write-Write hazard - data is output of both first and second + # states, without a read in between + write_write_candidates = ((fused_cc.first_outputs & fused_cc.second_outputs) - fused_cc.second_inputs) + + # Find the leaf (topological) instances of the matches + order = [ + x for x in reversed(list(nx.topological_sort(first_state._nx))) + if isinstance(x, nodes.AccessNode) and x.data in fused_cc.first_outputs + ] + # Those nodes will be the connection points upon fusion + match_nodes: Dict[nodes.AccessNode, nodes.AccessNode] = { + next(n for n in order + if n.data == match): next(n for n in fused_cc.second_input_nodes if n.data == match) + for match in (fused_cc.first_outputs + & fused_cc.second_inputs) + } + + # If we have potential candidates, check if there is a + # path from the first write to the second write (in that + # case, there is no hazard): + for cand in write_write_candidates: + nodes_first = [n for n in first_output if n.data == cand] + nodes_second = [n for n in second_output if n.data == cand] + + # If there is a path for the candidate that goes through + # the match nodes in both states, there is no conflict + if not self._check_paths(first_state, second_state, match_nodes, nodes_first, nodes_second, + second_input, False, False): + return False + # End of write-write hazard check + + first_inout = fused_cc.first_inputs | fused_cc.first_outputs + for other_cc in resulting_ccs: + # NOTE: Special handling for `other_cc is fused_cc` + if other_cc is fused_cc: + # Checking for potential Read-Write data races + for d in first_inout: + if d in other_cc.second_outputs: + nodes_second = [n for n in second_output if n.data == d] + # Read-Write race + if d in fused_cc.first_inputs: + nodes_first = [n for n in first_input if n.data == d] + else: + nodes_first = [] + for n2 in nodes_second: + for e in second_state.in_edges(n2): + path = second_state.memlet_path(e) + src = path[0].src + if src in second_input and src.data in fused_cc.first_outputs: + for n1 in fused_cc.first_output_nodes: + if n1.data == src.data: + for n0 in nodes_first: + if not nx.has_path(first_state._nx, n0, n1): + return False + # Read-write hazard where an access node is connected + # to more than one output at once: (a) -> (b) | (d) -> [code] -> (d) + # \-> (c) | + # in the first state, and the same memory is inout in the second state + # All paths need to lead to `src` + if not self._check_all_paths(first_state, second_state, match_nodes, nodes_first, + nodes_second, True, False): + return False + + continue + # If an input/output of a connected component in the first + # state is an output of another connected component in the + # second state, we have a potential data race (Read-Write + # or Write-Write) + for d in first_inout: + if d in other_cc.second_outputs: + # Check for intersection (if None, fusion is ok) + nodes_second = [n for n in second_output if n.data == d] + # Read-Write race + if d in fused_cc.first_inputs: + nodes_first = [n for n in first_input if n.data == d] + if StateFusionExtended.memlets_intersect(first_state, nodes_first, True, second_state, + nodes_second, False): + self.connections_to_make.append([nodes_first, nodes_second]) + #return False + # Write-Write race + if d in fused_cc.first_outputs: + nodes_first = [n for n in first_output if n.data == d] + if StateFusionExtended.memlets_intersect(first_state, nodes_first, False, second_state, + nodes_second, False): + self.connections_to_make.append([nodes_first, nodes_second]) + #return False + # End of data race check + + # Read-after-write dependencies: if there is an output of the + # second state that is an input of the first, ensure all paths + # from the input of the first state lead to the output. + # Otherwise, there may be a RAW due to topological sort or + # concurrency. + second_inout = ((fused_cc.first_inputs | fused_cc.first_outputs) & fused_cc.second_outputs) + for inout in second_inout: + nodes_first = [n for n in match_nodes if n.data == inout] + if any(first_state.out_degree(n) > 0 for n in nodes_first): + return False + + # If we have potential candidates, check if there is a + # path from the first read to the second write (in that + # case, there is no hazard): + nodes_first = { + n + for n in fused_cc.first_input_nodes + | fused_cc.first_output_nodes if n.data == inout + } + nodes_second = {n for n in fused_cc.second_output_nodes if n.data == inout} + + # If there is a path for the candidate that goes through + # the match nodes in both states, there is no conflict + if not self._check_paths(first_state, second_state, match_nodes, nodes_first, nodes_second, + second_input, True, False): + return False + + # End of read-write hazard check + + # Read-after-write dependencies: if there is more than one first + # output with the same data, make sure it can be unambiguously + # connected to the second state + if (len(fused_cc.first_output_nodes) > len(fused_cc.first_outputs)): + for inpnode in fused_cc.second_input_nodes: + found = None + for outnode in fused_cc.first_output_nodes: + if outnode.data != inpnode.data: + continue + if StateFusionExtended.memlets_intersect(first_state, [outnode], False, second_state, + [inpnode], True): + # If found more than once, either there is a + # path from one to another or it is ambiguous + if found is not None: + if nx.has_path(first_state.nx, outnode, found): + # Found is a descendant, continue + continue + elif nx.has_path(first_state.nx, found, outnode): + # New node is a descendant, set as found + found = outnode + else: + # No path: ambiguous match + return False + found = outnode + + # Do not fuse FPGA and NON-FPGA states (unless one of them is empty) + if first_state.number_of_nodes() > 0 and second_state.number_of_nodes() > 0 and sdutil.is_fpga_kernel( + sdfg, first_state) != sdutil.is_fpga_kernel(sdfg, second_state): + return False + + return True + + def apply(self, _, sdfg): + first_state: SDFGState = self.first_state + second_state: SDFGState = self.second_state + + # Remove interstate edge(s) + edges = sdfg.edges_between(first_state, second_state) + for edge in edges: + if edge.data.assignments: + for src, dst, other_data in sdfg.in_edges(first_state): + other_data.assignments.update(edge.data.assignments) + sdfg.remove_edge(edge) + + # Special case 1: first state is empty + if first_state.is_empty(): + sdutil.change_edge_dest(sdfg, first_state, second_state) + sdfg.remove_node(first_state) + if sdfg.start_state == first_state: + sdfg.start_state = sdfg.node_id(second_state) + return + + # Special case 2: second state is empty + if second_state.is_empty(): + sdutil.change_edge_src(sdfg, second_state, first_state) + sdutil.change_edge_dest(sdfg, second_state, first_state) + sdfg.remove_node(second_state) + if sdfg.start_state == second_state: + sdfg.start_state = sdfg.node_id(first_state) + return + + # Normal case: both states are not empty + + # Find source/sink (data) nodes + first_input = [node for node in first_state.source_nodes() if isinstance(node, nodes.AccessNode)] + first_output = [node for node in first_state.sink_nodes() if isinstance(node, nodes.AccessNode)] + second_input = [node for node in second_state.source_nodes() if isinstance(node, nodes.AccessNode)] + + top2 = top_level_nodes(second_state) + + # first input = first input - first output + first_input = [ + node for node in first_input if next((x for x in first_output if x.data == node.data), None) is None + ] + + # NOTE: We exclude Views from the process of merging common data nodes because it may lead to double edges. + second_mid = [ + x for x in list(nx.topological_sort(second_state._nx)) if isinstance(x, nodes.AccessNode) + and second_state.out_degree(x) > 0 and not isinstance(sdfg.arrays[x.data], dt.View) + ] + + # Merge second state to first state + # First keep a backup of the topological sorted order of the nodes + sdict = first_state.scope_dict() + order = [ + x for x in reversed(list(nx.topological_sort(first_state._nx))) + if isinstance(x, nodes.AccessNode) and sdict[x] is None + ] + for node in second_state.nodes(): + if isinstance(node, nodes.NestedSDFG): + # update parent information + node.sdfg.parent = first_state + + #The node could have been added when adding connections by add_nedge hence the need to check + if node not in first_state.nodes(): + first_state.add_node(node) + + for conn in self.connections_to_make: + if node in conn[1]: + for i in top2: + if i not in [nodex for nodex in second_state.source_nodes()]: + continue + paths = second_state.all_nodes_between(i, node) + direct_edges = second_state.edges_between(i, node) + + if ((paths != None and len(paths) > 0) or len(direct_edges) > 0): + for j in conn[0]: + if j in first_output: + first_state.add_nedge(j, i, memlet.Memlet()) + for src, src_conn, dst, dst_conn, data in second_state.edges(): + first_state.add_edge(src, src_conn, dst, dst_conn, data) + + top = top_level_nodes(first_state) + + # Merge common (data) nodes + merged_nodes = set() + for node in second_mid: + + # merge only top level nodes, skip everything else + if node not in top2: + continue + + candidates = [x for x in order if x.data == node.data and x in top and x not in merged_nodes] + source_node = first_state.in_degree(node) == 0 + + # If not source node, try to connect every memlet-intersecting candidate + if not source_node: + for cand in candidates: + if StateFusionExtended.memlets_intersect(first_state, [cand], False, second_state, [node], True): + if nx.has_path(first_state._nx, cand, node): # Do not create cycles + continue + sdutil.change_edge_src(first_state, cand, node) + sdutil.change_edge_dest(first_state, cand, node) + first_state.remove_node(cand) + continue + + if len(candidates) == 0: + continue + elif len(candidates) == 1: + n = candidates[0] + else: + # Choose first candidate that intersects memlets + for cand in candidates: + if StateFusionExtended.memlets_intersect(first_state, [cand], False, second_state, [node], True): + n = cand + break + else: + # No node intersects, use topologically-last node + n = candidates[0] + + sdutil.change_edge_src(first_state, node, n) + sdutil.change_edge_dest(first_state, node, n) + first_state.remove_node(node) + merged_nodes.add(n) + + # Redirect edges and remove second state + sdutil.change_edge_src(sdfg, second_state, first_state) + sdfg.remove_node(second_state) + if sdfg.start_state == second_state: + sdfg.start_state = sdfg.node_id(first_state) diff --git a/dace/transformation/subgraph/temporal_vectorization.py b/dace/transformation/subgraph/temporal_vectorization.py index 32fb98ec2d..0ed87f56b3 100644 --- a/dace/transformation/subgraph/temporal_vectorization.py +++ b/dace/transformation/subgraph/temporal_vectorization.py @@ -46,7 +46,7 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: src_nodes = subgraph.source_nodes() dst_nodes = subgraph.sink_nodes() srcdst_nodes = src_nodes + dst_nodes - srcdst_arrays = [sdfg.arrays[node.data] for node in srcdst_nodes] + srcdst_arrays = [sdfg.arrays[node.data] for node in srcdst_nodes if isinstance(node, nodes.AccessNode)] access_nodes = [ node for node in subgraph.nodes() if isinstance(node, nodes.AccessNode) and not node in srcdst_nodes ] diff --git a/tests/sdfg/cutout_test.py b/tests/sdfg/cutout_test.py index 9ac338b3da..151c3cab47 100644 --- a/tests/sdfg/cutout_test.py +++ b/tests/sdfg/cutout_test.py @@ -1,7 +1,7 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. import numpy as np import dace -from dace.sdfg.analysis.cutout import SDFGCutout +from dace.sdfg.analysis.cutout import SDFGCutout, _reduce_in_configuration import pytest @@ -312,6 +312,106 @@ def test_input_output_configuration(): assert len(ct.arrays) == 4 +def test_minimum_cut_simple_no_further_input_config(): + sdfg = dace.SDFG('mincut') + N = dace.symbol('N') + sdfg.add_array('A', [N], dace.float64) + sdfg.add_array('B', [N], dace.float64) + sdfg.add_array('C', [N, N], dace.float64) + sdfg.add_array('tmp1', [1], dace.float64, transient=True) + sdfg.add_array('tmp2', [1], dace.float64, transient=True) + sdfg.add_array('tmp3', [1], dace.float64, transient=True) + sdfg.add_array('tmp4', [1], dace.float64, transient=True) + sdfg.add_array('tmp5', [1], dace.float64, transient=True) + sdfg.add_array('tmp6', [1], dace.float64, transient=True) + state = sdfg.add_state('state') + mi, mo = state.add_map('map', dict(i='0:N', j='0:N')) + t1 = state.add_tasklet('t1', {'a', 'b'}, {'t'}, 't = a + b') + t2 = state.add_tasklet( + 't2', {'tin'}, {'t1', 't2', 't3', 't4'}, 't1 = tin + 2\nt2 = tin * 2\nt3 = tin / 2\nt4 = tin + 1' + ) + t3 = state.add_tasklet('t3', {'a', 'b'}, {'t'}, 't = a + b') + t4 = state.add_tasklet('t4', {'a', 'b', 'c'}, {'t'}, 't = (a - b) * c') + a_access = state.add_access('A') + b_access = state.add_access('B') + c_access = state.add_access('C') + tmp1_access = state.add_access('tmp1') + tmp2_access = state.add_access('tmp2') + tmp3_access = state.add_access('tmp3') + tmp4_access = state.add_access('tmp4') + tmp5_access = state.add_access('tmp5') + tmp6_access = state.add_access('tmp6') + state.add_memlet_path(a_access, mi, t1, dst_conn='a', memlet=dace.Memlet('A[i]')) + state.add_memlet_path(b_access, mi, t1, dst_conn='b', memlet=dace.Memlet('B[j]')) + state.add_edge(t1, 't', tmp1_access, None, dace.Memlet('tmp1[0]')) + state.add_edge(tmp1_access, None, t2, 'tin', dace.Memlet('tmp1[0]')) + state.add_edge(t2, 't1', tmp2_access, None, dace.Memlet('tmp2[0]')) + state.add_edge(t2, 't2', tmp3_access, None, dace.Memlet('tmp3[0]')) + state.add_edge(t2, 't3', tmp4_access, None, dace.Memlet('tmp4[0]')) + state.add_edge(t2, 't4', tmp5_access, None, dace.Memlet('tmp5[0]')) + state.add_edge(tmp2_access, None, t3, 'a', dace.Memlet('tmp2[0]')) + state.add_edge(tmp3_access, None, t3, 'b', dace.Memlet('tmp3[0]')) + state.add_edge(tmp4_access, None, t4, 'a', dace.Memlet('tmp4[0]')) + state.add_edge(tmp5_access, None, t4, 'b', dace.Memlet('tmp5[0]')) + state.add_edge(t3, 't', tmp6_access, None, dace.Memlet('tmp6[0]')) + state.add_edge(tmp6_access, None, t4, 'c', dace.Memlet('tmp6[0]')) + state.add_memlet_path(t4, mo, c_access, src_conn='t', memlet=dace.Memlet('C[i, j]')) + + cutout = SDFGCutout.singlestate_cutout(state, t3, t4, tmp6_access, reduce_input_config=True) + + c_state = cutout.nodes()[0] + c_nodes = set(c_state.nodes()) + o_nodes = {t2, t3, t4, tmp6_access, tmp4_access, tmp5_access, tmp2_access, tmp3_access, tmp1_access, c_access} + assert len(c_nodes) == 10 + for n in o_nodes: + assert cutout._in_translation[n] in c_nodes + for n in c_nodes: + assert cutout._out_translation[n] in o_nodes + + +def test_minimum_cut_map_scopes(): + sdfg = dace.SDFG('mincut') + sdfg.add_array('A', [10, 10], dace.float64) + sdfg.add_array('B', [10, 10], dace.float64) + sdfg.add_array('tmp_1', [10, 10], dace.float64, transient=True) + sdfg.add_array('tmp_2', [10, 10], dace.float64, transient=True) + sdfg.add_array('C', [10, 10], dace.float64) + + state = sdfg.add_state('state') + t1 = state.add_tasklet('t1', {'in1', 'in2'}, {'out1'}, 'out1 = in1 + in2') + t2 = state.add_tasklet('t2', {'in1'}, {'out1'}, 'out1 = in1 * 2') + t3 = state.add_tasklet('t3', {'in1', 'in2'}, {'out1'}, 'out1 = in1 + in2') + m1_i, m1_o = state.add_map('m1', dict(i='0:10', j='0:10')) + m2_i, m2_o = state.add_map('m2', dict(i='0:10', j='0:10')) + m3_i, m3_o = state.add_map('m3', dict(i='0:10', j='0:10')) + + a_access = state.add_access('A') + b_access = state.add_access('B') + c_access = state.add_access('C') + tmp1_access = state.add_access('tmp_1') + tmp2_access = state.add_access('tmp_2') + + state.add_memlet_path(a_access, m1_i, t1, dst_conn='in1', memlet=dace.Memlet('A[i, j]')) + state.add_memlet_path(b_access, m1_i, t1, dst_conn='in2', memlet=dace.Memlet('B[i, j]')) + state.add_memlet_path(t1, m1_o, tmp1_access, src_conn='out1', memlet=dace.Memlet('tmp_1[i, j]')) + state.add_memlet_path(tmp1_access, m2_i, t2, dst_conn='in1', memlet=dace.Memlet('tmp_1[i, j]')) + state.add_memlet_path(t2, m2_o, tmp2_access, src_conn='out1', memlet=dace.Memlet('tmp_2[i, j]')) + state.add_memlet_path(tmp1_access, m3_i, t3, dst_conn='in1', memlet=dace.Memlet('tmp_1[i, j]')) + state.add_memlet_path(tmp2_access, m3_i, t3, dst_conn='in2', memlet=dace.Memlet('tmp_2[i, j]')) + state.add_memlet_path(t3, m3_o, c_access, src_conn='out1', memlet=dace.Memlet('C[i, j]')) + + cutout = SDFGCutout.singlestate_cutout(state, t3, m3_i, m3_o, reduce_input_config=True) + + c_state = cutout.nodes()[0] + c_nodes = set(c_state.nodes()) + o_nodes = {t2, t3, tmp1_access, tmp2_access, c_access, m2_i, m2_o, m3_i, m3_o} + assert len(c_nodes) == 9 + for n in o_nodes: + assert cutout._in_translation[n] in c_nodes + for n in c_nodes: + assert cutout._out_translation[n] in o_nodes + + if __name__ == '__main__': test_cutout_onenode() test_cutout_multinode() @@ -322,3 +422,5 @@ def test_input_output_configuration(): test_multistate_cutout_simple_expand() test_multistate_cutout_complex_expand() test_input_output_configuration() + test_minimum_cut_simple_no_further_input_config() + test_minimum_cut_map_scopes() diff --git a/tests/transformations/state_fusion_extended_test.py b/tests/transformations/state_fusion_extended_test.py new file mode 100644 index 0000000000..97ba8da2b9 --- /dev/null +++ b/tests/transformations/state_fusion_extended_test.py @@ -0,0 +1,65 @@ +from dace import SDFG, InterstateEdge,Memlet +from dace import dtypes +from dace.transformation.interstate import StateFusionExtended + + +def test_extended_fusion(): + """ + Test the extended state fusion transformation. + It should fuse the two states into one and add a dependency between the two uses of tmp. + """ + sdfg=SDFG('extended_state_fusion_test') + sdfg.add_array('A', [20, 20], dtypes.float64) + sdfg.add_array('B', [20, 20], dtypes.float64) + sdfg.add_array('C', [20, 20], dtypes.float64) + sdfg.add_array('D', [20, 20], dtypes.float64) + sdfg.add_array('E', [20, 20], dtypes.float64) + sdfg.add_array('F', [20, 20], dtypes.float64) + + sdfg.add_scalar('tmp', dtypes.float64) + + strt = sdfg.add_state("start") + mid = sdfg.add_state("middle") + + sdfg.add_edge(strt, mid, InterstateEdge()) + + acc_a = strt.add_read('A') + acc_b = strt.add_read('B') + acc_c = strt.add_write('C') + acc_tmp = strt.add_access('tmp') + + acc2_d = mid.add_read('D') + acc2_e = mid.add_read('E') + acc2_f = mid.add_write('F') + acc2_tmp = mid.add_access('tmp') + + t1 = strt.add_tasklet('t1', {'a', 'b'}, { + 'c', + }, 'c[1,1] = a[1,1] + b[1,1]') + t2 = strt.add_tasklet('t2', {}, { + 'tmpa', + }, 'tmpa=4') + + t3 = mid.add_tasklet('t3', {'d', 'e'}, { + 'f', + }, 'f[1,1] = e[1,1] + d[1,1]') + t4 = mid.add_tasklet('t4', {}, { + 'tmpa', + }, 'tmpa=7') + + strt.add_edge(acc_a, None, t1, 'a', Memlet.simple('A', '1,1')) + strt.add_edge(acc_b, None, t1, 'b', Memlet.simple('B', '1,1')) + strt.add_edge(t1, 'c', acc_c, None, Memlet.simple('C', '1,1')) + strt.add_edge(t2, 'tmpa', acc_tmp, None, Memlet.simple('tmp', '0')) + + mid.add_edge(acc2_d, None, t3, 'd', Memlet.simple('D', '1,1')) + mid.add_edge(acc2_e, None, t3, 'e', Memlet.simple('E', '1,1')) + mid.add_edge(t3, 'f', acc2_f, None, Memlet.simple('F', '1,1')) + mid.add_edge(t4, 'tmpa', acc2_tmp, None, Memlet.simple('tmp', '0')) + sdfg.simplify() + sdfg.apply_transformations_repeated(StateFusionExtended) + assert sdfg.number_of_nodes()==1 + + +if __name__ == '__main__': + test_extended_fusion()