Skip to content

Commit

Permalink
Make circuit drawers *not crash* on Expr nodes (#10504)
Browse files Browse the repository at this point in the history
* Make circuit drawers *not crash* on `Expr` nodes

This at least causes the circuit visualisers to not crash when
encountering an `Expr` node, and instead emit a warning and make a
best-effort attempt (except for LaTeX) to output _something_.  We intend
to extend the capabilities of these drawers in the future.

* Soften warnings about unsupported `Expr` nodes
  • Loading branch information
jakelishman committed Jul 27, 2023
1 parent 9ae6164 commit c8552f6
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 38 deletions.
7 changes: 7 additions & 0 deletions qiskit/circuit/quantumcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1794,6 +1794,13 @@ def draw(
**latex_source**: raw uncompiled latex output.
.. warning::
Support for :class:`~.expr.Expr` nodes in conditions and :attr:`.SwitchCaseOp.target`
fields is preliminary and incomplete. The ``text`` and ``mpl`` drawers will make a
best-effort attempt to show data dependencies, but the LaTeX-based drawers will skip
these completely.
Args:
output (str): select the output method to use for drawing the circuit.
Valid choices are ``text``, ``mpl``, ``latex``, ``latex_source``.
Expand Down
16 changes: 6 additions & 10 deletions qiskit/visualization/circuit/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Instruction,
Measure,
)
from qiskit.circuit.controlflow import condition_resources
from qiskit.circuit.library import PauliEvolutionGate
from qiskit.circuit import ClassicalRegister, QuantumCircuit, Qubit, ControlFlowOp
from qiskit.circuit.tools import pi_check
Expand Down Expand Up @@ -549,16 +550,11 @@ def slide_from_left(self, node, index):
curr_index = index
last_insertable_index = -1
index_stop = -1
if getattr(node.op, "condition", None):
if isinstance(node.op.condition[0], Clbit):
cond_bit = [clbit for clbit in self.clbits if node.op.condition[0] == clbit]
index_stop = self.measure_map[cond_bit[0]]
else:
for bit in node.op.condition[0]:
max_index = -1
if bit in self.measure_map:
if self.measure_map[bit] > max_index:
index_stop = max_index = self.measure_map[bit]
if (condition := getattr(node.op, "condition", None)) is not None:
index_stop = max(
(self.measure_map[bit] for bit in condition_resources(condition).clbits),
default=index_stop,
)
if node.cargs:
for carg in node.cargs:
try:
Expand Down
6 changes: 6 additions & 0 deletions qiskit/visualization/circuit/circuit_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def circuit_drawer(
**latex_source**: raw uncompiled latex output.
.. warning::
Support for :class:`~.expr.Expr` nodes in conditions and :attr:`.SwitchCaseOp.target` fields
is preliminary and incomplete. The ``text`` and ``mpl`` drawers will make a best-effort
attempt to show data dependencies, but the LaTeX-based drawers will skip these completely.
Args:
circuit (QuantumCircuit): the quantum circuit to draw
scale (float): scale of image to draw (shrink if < 1.0). Only used by
Expand Down
7 changes: 5 additions & 2 deletions qiskit/visualization/circuit/latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
from qiskit.circuit import Clbit, Qubit, ClassicalRegister, QuantumRegister, QuantumCircuit
from qiskit.circuit.classical import expr
from qiskit.circuit.controlledgate import ControlledGate
from qiskit.circuit.library.standard_gates import SwapGate, XGate, ZGate, RZZGate, U1Gate, PhaseGate
from qiskit.circuit.measure import Measure
Expand Down Expand Up @@ -416,7 +417,10 @@ def _build_latex_array(self):
num_cols_op = 1
wire_list = [self._wire_map[qarg] for qarg in node.qargs if qarg in self._qubits]
if getattr(op, "condition", None):
self._add_condition(op, wire_list, column)
if isinstance(op.condition, expr.Expr):
warn("ignoring expression condition, which is not supported yet")
else:
self._add_condition(op, wire_list, column)

if isinstance(op, Measure):
self._build_measure(node, column)
Expand Down Expand Up @@ -619,7 +623,6 @@ def _add_condition(self, op, wire_list, col):
# cwire - the wire number for the first wire for the condition register
# or if cregbundle, wire number of the condition register itself
# gap - the number of wires from cwire to the bottom gate qubit

label, val_bits = get_condition_label_val(op.condition, self._circuit, self._cregbundle)
cond_is_bit = isinstance(op.condition[0], Clbit)
cond_reg = op.condition[0]
Expand Down
76 changes: 50 additions & 26 deletions qiskit/visualization/circuit/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""mpl circuit visualization backend."""

import collections
import itertools
import re
from warnings import warn
Expand All @@ -33,6 +34,8 @@
ForLoopOp,
SwitchCaseOp,
)
from qiskit.circuit.controlflow import condition_resources
from qiskit.circuit.classical import expr
from qiskit.circuit.library.standard_gates import (
SwapGate,
RZZGate,
Expand Down Expand Up @@ -1090,45 +1093,66 @@ def _condition(self, node, node_data, wire_map, cond_xy, glob_data):
# For SwitchCaseOp convert the target to a fully closed Clbit or register
# in condition format
if isinstance(node.op, SwitchCaseOp):
if isinstance(node.op.target, Clbit):
if isinstance(node.op.target, expr.Expr):
condition = node.op.target
elif isinstance(node.op.target, Clbit):
condition = (node.op.target, 1)
else:
condition = (node.op.target, 2 ** (node.op.target.size) - 1)
else:
condition = node.op.condition
label, val_bits = get_condition_label_val(condition, self._circuit, self._cregbundle)
cond_bit_reg = condition[0]
cond_bit_val = int(condition[1])

override_fc = False
first_clbit = len(self._qubits)
cond_pos = []

# In the first case, multiple bits are indicated on the drawing. In all
# other cases, only one bit is shown.
if not self._cregbundle and isinstance(cond_bit_reg, ClassicalRegister):
for idx in range(cond_bit_reg.size):
cond_pos.append(cond_xy[wire_map[cond_bit_reg[idx]] - first_clbit])

# If it's a register bit and cregbundle, need to use the register to find the location
elif self._cregbundle and isinstance(cond_bit_reg, Clbit):
register = get_bit_register(self._circuit, cond_bit_reg)
if register is not None:
cond_pos.append(cond_xy[wire_map[register] - first_clbit])
if isinstance(condition, expr.Expr):
# If fixing this, please update the docstrings of `QuantumCircuit.draw` and
# `visualization.circuit_drawer` to remove warnings.
condition_bits = condition_resources(condition).clbits
label = "[expression]"
override_fc = True
registers = collections.defaultdict(list)
for bit in condition_bits:
registers[get_bit_register(self._circuit, bit)].append(bit)
# Registerless bits don't care whether cregbundle is set.
cond_pos.extend(cond_xy[wire_map[bit] - first_clbit] for bit in registers.pop(None, ()))
if self._cregbundle:
cond_pos.extend(
cond_xy[wire_map[register[0]] - first_clbit] for register in registers
)
else:
cond_pos.append(cond_xy[wire_map[cond_bit_reg] - first_clbit])
cond_pos.extend(
cond_xy[wire_map[bit] - first_clbit]
for register, bits in registers.items()
for bit in bits
)
val_bits = ["1"] * len(cond_pos)
else:
cond_pos.append(cond_xy[wire_map[cond_bit_reg] - first_clbit])
label, val_bits = get_condition_label_val(condition, self._circuit, self._cregbundle)
cond_bit_reg = condition[0]
cond_bit_val = int(condition[1])
override_fc = cond_bit_val != 0

# In the first case, multiple bits are indicated on the drawing. In all
# other cases, only one bit is shown.
if not self._cregbundle and isinstance(cond_bit_reg, ClassicalRegister):
for idx in range(cond_bit_reg.size):
cond_pos.append(cond_xy[wire_map[cond_bit_reg[idx]] - first_clbit])

# If it's a register bit and cregbundle, need to use the register to find the location
elif self._cregbundle and isinstance(cond_bit_reg, Clbit):
register = get_bit_register(self._circuit, cond_bit_reg)
if register is not None:
cond_pos.append(cond_xy[wire_map[register] - first_clbit])
else:
cond_pos.append(cond_xy[wire_map[cond_bit_reg] - first_clbit])
else:
cond_pos.append(cond_xy[wire_map[cond_bit_reg] - first_clbit])

xy_plot = []
for idx, xy in enumerate(cond_pos):
if val_bits[idx] == "1" or (
isinstance(cond_bit_reg, ClassicalRegister)
and cond_bit_val != 0
and self._cregbundle
):
fc = self._style["lc"]
else:
fc = self._style["bg"]
for val_bit, xy in zip(val_bits, cond_pos):
fc = self._style["lc"] if override_fc or val_bit == "1" else self._style["bg"]
box = glob_data["patches_mod"].Circle(
xy=xy,
radius=WID * 0.15,
Expand Down
24 changes: 24 additions & 0 deletions qiskit/visualization/circuit/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

from warnings import warn
from shutil import get_terminal_size
import collections
import itertools
import sys

from qiskit.circuit import Qubit, Clbit, ClassicalRegister
from qiskit.circuit import ControlledGate
from qiskit.circuit import Reset
from qiskit.circuit import Measure
from qiskit.circuit.classical import expr
from qiskit.circuit.controlflow import node_resources
from qiskit.circuit.library.standard_gates import IGate, RZZGate, SwapGate, SXGate, SXdgGate
from qiskit.circuit.tools.pi_check import pi_check

Expand Down Expand Up @@ -1344,6 +1347,27 @@ def set_cl_multibox(self, condition, top_connect="┴"):
Returns:
List: list of tuples of connections between clbits for multi-bit conditions
"""
if isinstance(condition, expr.Expr):
# If fixing this, please update the docstrings of `QuantumCircuit.draw` and
# `visualization.circuit_drawer` to remove warnings.
label = "<expression>"
out = []
condition_bits = node_resources(condition).clbits
registers = collections.defaultdict(list)
for bit in condition_bits:
registers[get_bit_register(self._circuit, bit)].append(bit)
if registerless := registers.pop(None, ()):
out.extend(self.set_cond_bullets(label, ["1"] * len(registerless), registerless))
if self.cregbundle:
# It's hard to do something properly sensible here without more major rewrites, so
# as a minimum to *not crash* we'll just treat a condition that touches part of a
# register like it touched the whole register.
for register in registers:
self.set_clbit(register[0], BoxOnClWire(label=label, top_connect=top_connect))
else:
for register, bits in registers.items():
out.extend(self.set_cond_bullets(label, ["1"] * len(bits), bits))
return out
label, val_bits = get_condition_label_val(condition, self._circuit, self.cregbundle)
if isinstance(condition[0], ClassicalRegister):
cond_reg = condition[0]
Expand Down

0 comments on commit c8552f6

Please sign in to comment.