Skip to content

Commit

Permalink
Remove recursion from ConstrainedReschedule pass (#10051)
Browse files Browse the repository at this point in the history
* Remove recurssion from ConstrainedReschedule pass

The ConstrainedReschedule pass previosuly was using a recursive depth
first traversal to push back overlapping gates after aligning
operations. This however would cause a failure for a sufficiently large
circuit when the recursion depth could potentially exceed the maximum
stack depth allowed in python. To address this, this commit rewrites the
depth first traversal to be iterative instead of recursive. This removes
the stack depth limitation and should let the pass run with any size
circuit.

However, the performance of this pass is poor for large circuits. One
thing we can look at using to try and speed it up is rustworkx's
dfs_search() function which will let us shift the traversal to rust and
call back to python to do the timing offsets. If this is insufficient
we'll have to investigate a different algorithm for adjusting the time
that doesn't require multiple iterations like the current approach.

Fixes #10049

* Use rustworkx's dfs_search instead of manual dfs implementation

This commit rewrites the pass to leverage rustworkx's dfs_search
function which provides a way to have rustworkx traverse the graph in a
depth first manner and then provides hook points to execute code at
different named portions of the DFS. By leveraging this function we're
able to speed up the search by leveraging rust to perform the actual
graph traversal.

* Revert "Use rustworkx's dfs_search instead of manual dfs implementation"

This made performance of the pass worse so reverting this for now. We
can investigate this at a later date.

This reverts commit bd3cbb2.

* Remove visited node check from DFS

This commit removes the visited node check and skip logic from the DFS
traversal. To ensure this code behaves identically to the recursive
version before this PR this logic is removed because there wasn't a
similar check in that version.

(cherry picked from commit 112bd6e)
  • Loading branch information
mtreinish authored and mergify[bot] committed May 2, 2023
1 parent 5f77ee8 commit 3de2e5c
Showing 1 changed file with 56 additions and 54 deletions.
110 changes: 56 additions & 54 deletions qiskit/transpiler/passes/scheduling/alignments/reschedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,9 @@ def _get_next_gate(cls, dag: DAGCircuit, node: DAGOpNode) -> List[DAGOpNode]:
Returns:
A list of non-delay successors.
"""
op_nodes = []
for next_node in dag.successors(node):
if isinstance(next_node, DAGOutNode):
continue
op_nodes.append(next_node)

return op_nodes
if not isinstance(next_node, DAGOutNode):
yield next_node

def _push_node_back(self, dag: DAGCircuit, node: DAGOpNode, shift: int):
"""Update start time of current node. Successors are also shifted to avoid overlap.
Expand All @@ -114,57 +110,63 @@ def _push_node_back(self, dag: DAGCircuit, node: DAGOpNode, shift: int):
conditional_latency = self.property_set.get("conditional_latency", 0)
clbit_write_latency = self.property_set.get("clbit_write_latency", 0)

# Compute shifted t1 of this node separately for qreg and creg
this_t0 = node_start_time[node]
new_t1q = this_t0 + node.op.duration + shift
this_qubits = set(node.qargs)
if isinstance(node.op, Measure):
# creg access ends at the end of instruction
new_t1c = new_t1q
this_clbits = set(node.cargs)
else:
if node.op.condition_bits:
# conditional access ends at the beginning of node start time
new_t1c = this_t0 + shift
this_clbits = set(node.op.condition_bits)
else:
new_t1c = None
this_clbits = set()

# Check successors for overlap
for next_node in self._get_next_gate(dag, node):
# Compute next node start time separately for qreg and creg
next_t0q = node_start_time[next_node]
next_qubits = set(next_node.qargs)
if isinstance(next_node.op, Measure):
# creg access starts after write latency
next_t0c = next_t0q + clbit_write_latency
next_clbits = set(next_node.cargs)
nodes_with_overlap = [(node, shift)]
shift_stack = []
while nodes_with_overlap:
node, shift = nodes_with_overlap.pop()
shift_stack.append((node, shift))
# Compute shifted t1 of this node separately for qreg and creg
this_t0 = node_start_time[node]
new_t1q = this_t0 + node.op.duration + shift
this_qubits = set(node.qargs)
if isinstance(node.op, Measure):
# creg access ends at the end of instruction
new_t1c = new_t1q
this_clbits = set(node.cargs)
else:
if next_node.op.condition_bits:
# conditional access starts before node start time
next_t0c = next_t0q - conditional_latency
next_clbits = set(next_node.op.condition_bits)
if node.op.condition_bits:
# conditional access ends at the beginning of node start time
new_t1c = this_t0 + shift
this_clbits = set(node.op.condition_bits)
else:
next_t0c = None
next_clbits = set()
# Compute overlap if there is qubits overlap
if any(this_qubits & next_qubits):
qreg_overlap = new_t1q - next_t0q
else:
qreg_overlap = 0
# Compute overlap if there is clbits overlap
if any(this_clbits & next_clbits):
creg_overlap = new_t1c - next_t0c
else:
creg_overlap = 0
# Shift next node if there is finite overlap in either in qubits or clbits
overlap = max(qreg_overlap, creg_overlap)
if overlap > 0:
self._push_node_back(dag, next_node, overlap)

new_t1c = None
this_clbits = set()

# Check successors for overlap
for next_node in self._get_next_gate(dag, node):
# Compute next node start time separately for qreg and creg
next_t0q = node_start_time[next_node]
next_qubits = set(next_node.qargs)
if isinstance(next_node.op, Measure):
# creg access starts after write latency
next_t0c = next_t0q + clbit_write_latency
next_clbits = set(next_node.cargs)
else:
if next_node.op.condition_bits:
# conditional access starts before node start time
next_t0c = next_t0q - conditional_latency
next_clbits = set(next_node.op.condition_bits)
else:
next_t0c = None
next_clbits = set()
# Compute overlap if there is qubits overlap
if any(this_qubits & next_qubits):
qreg_overlap = new_t1q - next_t0q
else:
qreg_overlap = 0
# Compute overlap if there is clbits overlap
if any(this_clbits & next_clbits):
creg_overlap = new_t1c - next_t0c
else:
creg_overlap = 0
# Shift next node if there is finite overlap in either in qubits or clbits
overlap = max(qreg_overlap, creg_overlap)
if overlap > 0:
nodes_with_overlap.append((next_node, overlap))
# Update start time of this node after all overlaps are resolved
node_start_time[node] += shift
while shift_stack:
node, shift = shift_stack.pop()
node_start_time[node] += shift

def run(self, dag: DAGCircuit):
"""Run rescheduler.
Expand Down

0 comments on commit 3de2e5c

Please sign in to comment.