From 3de2e5cfacb491cf051ed586d6bf36501354adfe Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Tue, 2 May 2023 10:29:05 -0400 Subject: [PATCH] Remove recursion from ConstrainedReschedule pass (#10051) * Remove recurssion from ConstrainedReschedule pass The ConstrainedReschedule pass previosuly was using a recursive depth first traversal to push back overlapping gates after aligning operations. This however would cause a failure for a sufficiently large circuit when the recursion depth could potentially exceed the maximum stack depth allowed in python. To address this, this commit rewrites the depth first traversal to be iterative instead of recursive. This removes the stack depth limitation and should let the pass run with any size circuit. However, the performance of this pass is poor for large circuits. One thing we can look at using to try and speed it up is rustworkx's dfs_search() function which will let us shift the traversal to rust and call back to python to do the timing offsets. If this is insufficient we'll have to investigate a different algorithm for adjusting the time that doesn't require multiple iterations like the current approach. Fixes #10049 * Use rustworkx's dfs_search instead of manual dfs implementation This commit rewrites the pass to leverage rustworkx's dfs_search function which provides a way to have rustworkx traverse the graph in a depth first manner and then provides hook points to execute code at different named portions of the DFS. By leveraging this function we're able to speed up the search by leveraging rust to perform the actual graph traversal. * Revert "Use rustworkx's dfs_search instead of manual dfs implementation" This made performance of the pass worse so reverting this for now. We can investigate this at a later date. This reverts commit bd3cbb271030b3d88c1ee36691b623a1714348c1. * Remove visited node check from DFS This commit removes the visited node check and skip logic from the DFS traversal. To ensure this code behaves identically to the recursive version before this PR this logic is removed because there wasn't a similar check in that version. (cherry picked from commit 112bd6ea29f8605aea79a191ea9e0c40fb1ba5ea) --- .../scheduling/alignments/reschedule.py | 110 +++++++++--------- 1 file changed, 56 insertions(+), 54 deletions(-) diff --git a/qiskit/transpiler/passes/scheduling/alignments/reschedule.py b/qiskit/transpiler/passes/scheduling/alignments/reschedule.py index af28fed1adb6..350ee31882c8 100644 --- a/qiskit/transpiler/passes/scheduling/alignments/reschedule.py +++ b/qiskit/transpiler/passes/scheduling/alignments/reschedule.py @@ -88,13 +88,9 @@ def _get_next_gate(cls, dag: DAGCircuit, node: DAGOpNode) -> List[DAGOpNode]: Returns: A list of non-delay successors. """ - op_nodes = [] for next_node in dag.successors(node): - if isinstance(next_node, DAGOutNode): - continue - op_nodes.append(next_node) - - return op_nodes + if not isinstance(next_node, DAGOutNode): + yield next_node def _push_node_back(self, dag: DAGCircuit, node: DAGOpNode, shift: int): """Update start time of current node. Successors are also shifted to avoid overlap. @@ -114,57 +110,63 @@ def _push_node_back(self, dag: DAGCircuit, node: DAGOpNode, shift: int): conditional_latency = self.property_set.get("conditional_latency", 0) clbit_write_latency = self.property_set.get("clbit_write_latency", 0) - # Compute shifted t1 of this node separately for qreg and creg - this_t0 = node_start_time[node] - new_t1q = this_t0 + node.op.duration + shift - this_qubits = set(node.qargs) - if isinstance(node.op, Measure): - # creg access ends at the end of instruction - new_t1c = new_t1q - this_clbits = set(node.cargs) - else: - if node.op.condition_bits: - # conditional access ends at the beginning of node start time - new_t1c = this_t0 + shift - this_clbits = set(node.op.condition_bits) - else: - new_t1c = None - this_clbits = set() - - # Check successors for overlap - for next_node in self._get_next_gate(dag, node): - # Compute next node start time separately for qreg and creg - next_t0q = node_start_time[next_node] - next_qubits = set(next_node.qargs) - if isinstance(next_node.op, Measure): - # creg access starts after write latency - next_t0c = next_t0q + clbit_write_latency - next_clbits = set(next_node.cargs) + nodes_with_overlap = [(node, shift)] + shift_stack = [] + while nodes_with_overlap: + node, shift = nodes_with_overlap.pop() + shift_stack.append((node, shift)) + # Compute shifted t1 of this node separately for qreg and creg + this_t0 = node_start_time[node] + new_t1q = this_t0 + node.op.duration + shift + this_qubits = set(node.qargs) + if isinstance(node.op, Measure): + # creg access ends at the end of instruction + new_t1c = new_t1q + this_clbits = set(node.cargs) else: - if next_node.op.condition_bits: - # conditional access starts before node start time - next_t0c = next_t0q - conditional_latency - next_clbits = set(next_node.op.condition_bits) + if node.op.condition_bits: + # conditional access ends at the beginning of node start time + new_t1c = this_t0 + shift + this_clbits = set(node.op.condition_bits) else: - next_t0c = None - next_clbits = set() - # Compute overlap if there is qubits overlap - if any(this_qubits & next_qubits): - qreg_overlap = new_t1q - next_t0q - else: - qreg_overlap = 0 - # Compute overlap if there is clbits overlap - if any(this_clbits & next_clbits): - creg_overlap = new_t1c - next_t0c - else: - creg_overlap = 0 - # Shift next node if there is finite overlap in either in qubits or clbits - overlap = max(qreg_overlap, creg_overlap) - if overlap > 0: - self._push_node_back(dag, next_node, overlap) - + new_t1c = None + this_clbits = set() + + # Check successors for overlap + for next_node in self._get_next_gate(dag, node): + # Compute next node start time separately for qreg and creg + next_t0q = node_start_time[next_node] + next_qubits = set(next_node.qargs) + if isinstance(next_node.op, Measure): + # creg access starts after write latency + next_t0c = next_t0q + clbit_write_latency + next_clbits = set(next_node.cargs) + else: + if next_node.op.condition_bits: + # conditional access starts before node start time + next_t0c = next_t0q - conditional_latency + next_clbits = set(next_node.op.condition_bits) + else: + next_t0c = None + next_clbits = set() + # Compute overlap if there is qubits overlap + if any(this_qubits & next_qubits): + qreg_overlap = new_t1q - next_t0q + else: + qreg_overlap = 0 + # Compute overlap if there is clbits overlap + if any(this_clbits & next_clbits): + creg_overlap = new_t1c - next_t0c + else: + creg_overlap = 0 + # Shift next node if there is finite overlap in either in qubits or clbits + overlap = max(qreg_overlap, creg_overlap) + if overlap > 0: + nodes_with_overlap.append((next_node, overlap)) # Update start time of this node after all overlaps are resolved - node_start_time[node] += shift + while shift_stack: + node, shift = shift_stack.pop() + node_start_time[node] += shift def run(self, dag: DAGCircuit): """Run rescheduler.