diff --git a/qiskit/visualization/circuit/matplotlib.py b/qiskit/visualization/circuit/matplotlib.py index b1b04ca6602c..965dec412337 100644 --- a/qiskit/visualization/circuit/matplotlib.py +++ b/qiskit/visualization/circuit/matplotlib.py @@ -55,6 +55,8 @@ PORDER_GRAY = 3 PORDER_TEXT = 6 +INFINITE_FOLD = 10000000 + @_optionals.HAS_MATPLOTLIB.require_in_instance @_optionals.HAS_PYLATEX.require_in_instance @@ -79,29 +81,13 @@ def __init__( cregbundle=None, with_layout=False, ): - from matplotlib import patches - from matplotlib import pyplot as plt - - self._patches_mod = patches - self._plt_mod = plt - self._circuit = circuit self._qubits = qubits self._clbits = clbits - self._qubits_dict = {} - self._clbits_dict = {} - self._q_anchors = {} - self._c_anchors = {} - self._wire_map = {} - self._nodes = nodes self._scale = 1.0 if scale is None else scale - self._style, def_font_ratio = load_style(style) - - # If font/subfont ratio changes from default, have to scale width calculations for - # subfont. Font change is auto scaled in the self._figure.set_size_inches call in draw() - self._subfont_factor = self._style["sfs"] * def_font_ratio / self._style["fs"] + self._style = style self._plot_barriers = plot_barriers self._reverse_bits = reverse_bits @@ -117,18 +103,7 @@ def __init__( if self._fold < 2: self._fold = -1 - if ax is None: - self._user_ax = False - self._figure = plt.figure() - self._figure.patch.set_facecolor(color=self._style["bg"]) - self._ax = self._figure.add_subplot(111) - else: - self._user_ax = True - self._ax = ax - self._figure = ax.get_figure() - self._ax.axis("off") - self._ax.set_aspect("equal") - self._ax.tick_params(labelbottom=False, labeltop=False, labelleft=False, labelright=False) + self._ax = ax self._initial_state = initial_state self._global_phase = self._circuit.global_phase @@ -148,18 +123,9 @@ def __init__( else: self._cregbundle = True if cregbundle is None else cregbundle - self._fs = self._style["fs"] - self._sfs = self._style["sfs"] self._lwidth1 = 1.0 self._lwidth15 = 1.5 self._lwidth2 = 2.0 - self._x_offset = 0.0 - - # _data per node with 'width', 'gate_text', 'raw_gate_text', - # 'ctrl_text', 'param', q_xy', 'c_xy', and 'c_indxs' - # and colors 'fc', 'ec', 'lc', 'sc', 'gt', and 'tc' - self._data = {} - self._layer_widths = [] # _char_list for finding text_width of names, labels, and params self._char_list = { @@ -263,31 +229,76 @@ def draw(self, filename=None, verbose=False): """Main entry point to 'matplotlib' ('mpl') drawer. Called from ``visualization.circuit_drawer`` and from ``QuantumCircuit.draw`` through circuit_drawer. """ - # All information for the drawing is first loaded into self._data for the gates and into - # self._qubits_dict and self._clbits_dict for the qubits, clbits, and wires, + + # Import matplotlib and load all the figure, window, and style info + from matplotlib import patches + from matplotlib import pyplot as plt + + # glob_data contains global values used throughout, "n_lines", "x_offset", "next_x_index", + # "patches_mod", subfont_factor" + glob_data = {} + + glob_data["patches_mod"] = patches + plt_mod = plt + + self._style, def_font_ratio = load_style(self._style) + + # If font/subfont ratio changes from default, have to scale width calculations for + # subfont. Font change is auto scaled in the mpl_figure.set_size_inches call in draw() + glob_data["subfont_factor"] = self._style["sfs"] * def_font_ratio / self._style["fs"] + + # if no user ax, setup default figure. Else use the user figure. + if self._ax is None: + is_user_ax = False + mpl_figure = plt.figure() + mpl_figure.patch.set_facecolor(color=self._style["bg"]) + self._ax = mpl_figure.add_subplot(111) + else: + is_user_ax = True + mpl_figure = self._ax.get_figure() + self._ax.axis("off") + self._ax.set_aspect("equal") + self._ax.tick_params(labelbottom=False, labeltop=False, labelleft=False, labelright=False) + + # All information for the drawing is first loaded into node_data for the gates and into + # qubits_dict, clbits_dict, and wire_map for the qubits, clbits, and wires, # followed by the coordinates for each gate. - # get layer widths - self._get_layer_widths() + # load the wire map + wire_map = get_wire_map(self._circuit, self._qubits + self._clbits, self._cregbundle) + + # node_data per node with 'width', 'gate_text', 'raw_gate_text', + # 'ctrl_text', 'param_text', q_xy', and 'c_xy', + # and colors 'fc', 'ec', 'lc', 'sc', 'gt', and 'tc' + node_data = {} + + # dicts for the names and locations of register/bit labels + qubits_dict = {} + clbits_dict = {} # load the _qubit_dict and _clbit_dict with register info - n_lines = self._set_bit_reg_info() + self._set_bit_reg_info(wire_map, qubits_dict, clbits_dict, glob_data) + + # get layer widths + layer_widths = self._get_layer_widths(node_data, glob_data) # load the coordinates for each gate and compute number of folds - max_anc = self._get_coords(n_lines) - num_folds = max(0, max_anc - 1) // self._fold if self._fold > 0 else 0 + max_x_index = self._get_coords( + node_data, wire_map, layer_widths, qubits_dict, clbits_dict, glob_data + ) + num_folds = max(0, max_x_index - 1) // self._fold if self._fold > 0 else 0 # The window size limits are computed, followed by one of the four possible ways # of scaling the drawing. # compute the window size - if max_anc > self._fold > 0: - xmax = self._fold + self._x_offset + 0.1 - ymax = (num_folds + 1) * (n_lines + 1) - 1 + if max_x_index > self._fold > 0: + xmax = self._fold + glob_data["x_offset"] + 0.1 + ymax = (num_folds + 1) * (glob_data["n_lines"] + 1) - 1 else: x_incr = 0.4 if not self._nodes else 0.9 - xmax = max_anc + 1 + self._x_offset - x_incr - ymax = n_lines + xmax = max_x_index + 1 + glob_data["x_offset"] - x_incr + ymax = glob_data["n_lines"] xl = -self._style["margin"][0] xr = xmax + self._style["margin"][1] @@ -297,36 +308,36 @@ def draw(self, filename=None, verbose=False): self._ax.set_ylim(yb, yt) # update figure size and, for backward compatibility, - # need to scale by a default value equal to (self._fs * 3.01 / 72 / 0.65) + # need to scale by a default value equal to (self._style["fs"] * 3.01 / 72 / 0.65) base_fig_w = (xr - xl) * 0.8361111 base_fig_h = (yt - yb) * 0.8361111 scale = self._scale # if user passes in an ax, this size takes priority over any other settings - if self._user_ax: + if is_user_ax: # from stackoverflow #19306510, get the bbox size for the ax and then reset scale - bbox = self._ax.get_window_extent().transformed(self._figure.dpi_scale_trans.inverted()) + bbox = self._ax.get_window_extent().transformed(mpl_figure.dpi_scale_trans.inverted()) scale = bbox.width / base_fig_w / 0.8361111 # if scale not 1.0, use this scale factor elif self._scale != 1.0: - self._figure.set_size_inches(base_fig_w * self._scale, base_fig_h * self._scale) + mpl_figure.set_size_inches(base_fig_w * self._scale, base_fig_h * self._scale) # if "figwidth" style param set, use this to scale elif self._style["figwidth"] > 0.0: # in order to get actual inches, need to scale by factor adj_fig_w = self._style["figwidth"] * 1.282736 - self._figure.set_size_inches(adj_fig_w, adj_fig_w * base_fig_h / base_fig_w) + mpl_figure.set_size_inches(adj_fig_w, adj_fig_w * base_fig_h / base_fig_w) scale = adj_fig_w / base_fig_w # otherwise, display default size else: - self._figure.set_size_inches(base_fig_w, base_fig_h) + mpl_figure.set_size_inches(base_fig_w, base_fig_h) # drawing will scale with 'set_size_inches', but fonts and linewidths do not if scale != 1.0: - self._fs *= scale - self._sfs *= scale + self._style["fs"] *= scale + self._style["sfs"] *= scale self._lwidth1 = 1.0 * scale self._lwidth15 = 1.5 * scale self._lwidth2 = 2.0 * scale @@ -334,46 +345,54 @@ def draw(self, filename=None, verbose=False): # Once the scaling factor has been determined, the global phase, register names # and numbers, wires, and gates are drawn if self._global_phase: - self._plt_mod.text( - xl, yt, "Global Phase: %s" % pi_check(self._global_phase, output="mpl") - ) - self._draw_regs_wires(num_folds, xmax, n_lines, max_anc) - self._draw_ops(verbose) + plt_mod.text(xl, yt, "Global Phase: %s" % pi_check(self._global_phase, output="mpl")) + self._draw_regs_wires(num_folds, xmax, max_x_index, qubits_dict, clbits_dict, glob_data) + self._draw_ops( + self._nodes, node_data, wire_map, layer_widths, clbits_dict, glob_data, verbose + ) if filename: - self._figure.savefig( + mpl_figure.savefig( filename, dpi=self._style["dpi"], bbox_inches="tight", - facecolor=self._figure.get_facecolor(), + facecolor=mpl_figure.get_facecolor(), ) - if not self._user_ax: - matplotlib_close_if_inline(self._figure) - return self._figure + if not is_user_ax: + matplotlib_close_if_inline(mpl_figure) + return mpl_figure - def _get_layer_widths(self): + def _get_layer_widths(self, node_data, glob_data): """Compute the layer_widths for the layers""" - for layer in self._nodes: + + layer_widths = {} + for layer_num, layer in enumerate(self._nodes): widest_box = WID - for node in layer: + for i, node in enumerate(layer): + # Put the layer_num in the first node in the layer and put -1 in the rest + # so that layer widths are not counted more than once + if i != 0: + layer_num = -1 + layer_widths[node] = [1, layer_num] + op = node.op - self._data[node] = {} - self._data[node]["width"] = WID + node_data[node] = {} + node_data[node]["width"] = WID num_ctrl_qubits = 0 if not hasattr(op, "num_ctrl_qubits") else op.num_ctrl_qubits if ( getattr(op, "_directive", False) and (not op.label or not self._plot_barriers) ) or isinstance(op, Measure): - self._data[node]["raw_gate_text"] = op.name + node_data[node]["raw_gate_text"] = op.name continue base_type = None if not hasattr(op, "base_gate") else op.base_gate gate_text, ctrl_text, raw_gate_text = get_gate_ctrl_text( op, "mpl", style=self._style, calibrations=self._calibrations ) - self._data[node]["gate_text"] = gate_text - self._data[node]["ctrl_text"] = ctrl_text - self._data[node]["raw_gate_text"] = raw_gate_text - self._data[node]["param"] = "" + node_data[node]["gate_text"] = gate_text + node_data[node]["ctrl_text"] = ctrl_text + node_data[node]["raw_gate_text"] = raw_gate_text + node_data[node]["param_text"] = "" # if single qubit, no params, and no labels, layer_width is 1 if ( @@ -389,7 +408,9 @@ def _get_layer_widths(self): # small increments at end of the 3 _get_text_width calls are for small # spacing adjustments between gates - ctrl_width = self._get_text_width(ctrl_text, fontsize=self._sfs) - 0.05 + ctrl_width = ( + self._get_text_width(ctrl_text, glob_data, fontsize=self._style["sfs"]) - 0.05 + ) # get param_width, but 0 for gates with array params if ( @@ -397,11 +418,13 @@ def _get_layer_widths(self): and len(op.params) > 0 and not any(isinstance(param, np.ndarray) for param in op.params) ): - param = get_param_str(op, "mpl", ndigits=3) + param_text = get_param_str(op, "mpl", ndigits=3) if isinstance(op, Initialize): - param = f"$[{param.replace('$', '')}]$" - self._data[node]["param"] = param - raw_param_width = self._get_text_width(param, fontsize=self._sfs, param=True) + param_text = f"$[{param_text.replace('$', '')}]$" + node_data[node]["param_text"] = param_text + raw_param_width = self._get_text_width( + param_text, glob_data, fontsize=self._style["sfs"], param=True + ) param_width = raw_param_width + 0.08 else: param_width = raw_param_width = 0.0 @@ -411,14 +434,18 @@ def _get_layer_widths(self): if isinstance(base_type, PhaseGate): gate_text = "P" raw_gate_width = ( - self._get_text_width(gate_text + " ()", fontsize=self._sfs) + self._get_text_width( + gate_text + " ()", glob_data, fontsize=self._style["sfs"] + ) + raw_param_width ) gate_width = (raw_gate_width + 0.08) * 1.58 - # otherwise, standard gate or multiqubit gate + # Otherwise, standard gate or multiqubit gate else: - raw_gate_width = self._get_text_width(gate_text, fontsize=self._fs) + raw_gate_width = self._get_text_width( + gate_text, glob_data, fontsize=self._style["fs"] + ) gate_width = raw_gate_width + 0.10 # add .21 for the qubit numbers on the left of the multibit gates if len(node.qargs) - num_ctrl_qubits > 1: @@ -427,26 +454,28 @@ def _get_layer_widths(self): box_width = max(gate_width, ctrl_width, param_width, WID) if box_width > widest_box: widest_box = box_width - self._data[node]["width"] = max(raw_gate_width, raw_param_width) + node_data[node]["width"] = max(raw_gate_width, raw_param_width) + for node in layer: + layer_widths[node][0] = int(widest_box) + 1 - self._layer_widths.append(int(widest_box) + 1) + return layer_widths - def _set_bit_reg_info(self): + def _set_bit_reg_info(self, wire_map, qubits_dict, clbits_dict, glob_data): """Get all the info for drawing bit/reg names and numbers""" - self._wire_map = get_wire_map(self._circuit, self._qubits + self._clbits, self._cregbundle) longest_wire_label_width = 0 - n_lines = 0 + glob_data["n_lines"] = 0 initial_qbit = " |0>" if self._initial_state else "" initial_cbit = " 0" if self._initial_state else "" idx = 0 pos = y_off = -len(self._qubits) + 1 - for ii, wire in enumerate(self._wire_map): + for ii, wire in enumerate(wire_map): + # if it's a creg, register is the key and just load the index if isinstance(wire, ClassicalRegister): register = wire - index = self._wire_map[wire] + index = wire_map[wire] # otherwise, get the register from find_bit and use bit_index if # it's a bit, or the index of the bit in the register if it's a reg @@ -470,83 +499,87 @@ def _set_bit_reg_info(self): ) reg_remove_under = 0 if reg_size < 2 else 1 text_width = ( - self._get_text_width(wire_label, self._fs, reg_remove_under=reg_remove_under) * 1.15 + self._get_text_width( + wire_label, glob_data, self._style["fs"], reg_remove_under=reg_remove_under + ) + * 1.15 ) if text_width > longest_wire_label_width: longest_wire_label_width = text_width if isinstance(wire, Qubit): pos = -ii - self._qubits_dict[ii] = { + qubits_dict[ii] = { "y": pos, "wire_label": wire_label, - "index": bit_index, - "register": register, } - n_lines += 1 + glob_data["n_lines"] += 1 else: if ( not self._cregbundle or register is None or (self._cregbundle and isinstance(wire, ClassicalRegister)) ): - n_lines += 1 + glob_data["n_lines"] += 1 idx += 1 pos = y_off - idx - self._clbits_dict[ii] = { + clbits_dict[ii] = { "y": pos, "wire_label": wire_label, - "index": bit_index, "register": register, } + glob_data["x_offset"] = -1.2 + longest_wire_label_width - self._x_offset = -1.2 + longest_wire_label_width - return n_lines - - def _get_coords(self, n_lines): - """Load all the coordinate info needed to place the gates on the drawing""" + def _get_coords(self, node_data, wire_map, layer_widths, qubits_dict, clbits_dict, glob_data): + """Load all the coordinate info needed to place the gates on the drawing.""" - # create the anchor arrays - for key, qubit in self._qubits_dict.items(): - self._q_anchors[key] = Anchor(num_wires=n_lines, y_index=qubit["y"], fold=self._fold) - for key, clbit in self._clbits_dict.items(): - self._c_anchors[key] = Anchor(num_wires=n_lines, y_index=clbit["y"], fold=self._fold) - - # get all the necessary coordinates for placing gates on the wires prev_x_index = -1 - for i, layer in enumerate(self._nodes): - layer_width = self._layer_widths[i] - anc_x_index = prev_x_index + 1 + for layer in self._nodes: + curr_x_index = prev_x_index + 1 + l_width = [] for node in layer: - # get qubit index + + # get qubit indexes q_indxs = [] for qarg in node.qargs: if qarg in self._qubits: - q_indxs.append(self._wire_map[qarg]) + q_indxs.append(wire_map[qarg]) + # get clbit indexes c_indxs = [] for carg in node.cargs: if carg in self._clbits: register = get_bit_register(self._circuit, carg) if register is not None and self._cregbundle: - c_indxs.append(self._wire_map[register]) + c_indxs.append(wire_map[register]) else: - c_indxs.append(self._wire_map[carg]) - - # qubit coordinate - self._data[node]["q_xy"] = [ - self._q_anchors[ii].plot_coord(anc_x_index, layer_width, self._x_offset) + c_indxs.append(wire_map[carg]) + + # qubit coordinates + node_data[node]["q_xy"] = [ + self._plot_coord( + curr_x_index, + qubits_dict[ii]["y"], + layer_widths[node][0], + glob_data, + ) for ii in q_indxs ] - # clbit coordinate - self._data[node]["c_xy"] = [ - self._c_anchors[ii].plot_coord(anc_x_index, layer_width, self._x_offset) + # clbit coordinates + node_data[node]["c_xy"] = [ + self._plot_coord( + curr_x_index, + clbits_dict[ii]["y"], + layer_widths[node][0], + glob_data, + ) for ii in c_indxs ] + # update index based on the value from plotting - anc_x_index = self._q_anchors[q_indxs[0]].get_x_index() - self._data[node]["c_indxs"] = c_indxs + curr_x_index = glob_data["next_x_index"] + l_width.append(layer_widths[node][0]) # adjust the column if there have been barriers encountered, but not plotted barrier_offset = 0 @@ -555,12 +588,13 @@ def _get_coords(self, n_lines): barrier_offset = ( -1 if all(getattr(nd.op, "_directive", False) for nd in layer) else 0 ) - prev_x_index = anc_x_index + layer_width + barrier_offset - 1 + prev_x_index = curr_x_index + max(l_width) + barrier_offset - 1 return prev_x_index + 1 - def _get_text_width(self, text, fontsize, param=False, reg_remove_under=None): + def _get_text_width(self, text, glob_data, fontsize, param=False, reg_remove_under=None): """Compute the width of a string in the default font""" + from pylatexenc.latex2text import LatexNodes2Text if not text: @@ -592,7 +626,7 @@ def _get_text_width(self, text, fontsize, param=False, reg_remove_under=None): if param: text = text.replace("-", "+") - f = 0 if fontsize == self._fs else 1 + f = 0 if fontsize == self._style["fs"] else 1 sum_text = 0.0 for c in text: try: @@ -601,42 +635,40 @@ def _get_text_width(self, text, fontsize, param=False, reg_remove_under=None): # if non-ASCII char, use width of 'c', an average size sum_text += self._char_list["c"][f] if f == 1: - sum_text *= self._subfont_factor + sum_text *= glob_data["subfont_factor"] return sum_text - def _draw_regs_wires(self, num_folds, xmax, n_lines, max_anc): + def _draw_regs_wires(self, num_folds, xmax, max_x_index, qubits_dict, clbits_dict, glob_data): """Draw the register names and numbers, wires, and vertical lines at the ends""" for fold_num in range(num_folds + 1): # quantum registers - for qubit in self._qubits_dict.values(): + for qubit in qubits_dict.values(): qubit_label = qubit["wire_label"] - y = qubit["y"] - fold_num * (n_lines + 1) + y = qubit["y"] - fold_num * (glob_data["n_lines"] + 1) self._ax.text( - self._x_offset - 0.2, + glob_data["x_offset"] - 0.2, y, qubit_label, ha="right", va="center", - fontsize=1.25 * self._fs, + fontsize=1.25 * self._style["fs"], color=self._style["tc"], clip_on=True, zorder=PORDER_TEXT, ) # draw the qubit wire - self._line([self._x_offset, y], [xmax, y], zorder=PORDER_REGLINE) + self._line([glob_data["x_offset"], y], [xmax, y], zorder=PORDER_REGLINE) # classical registers this_clbit_dict = {} - for clbit in self._clbits_dict.values(): - clbit_label = clbit["wire_label"] - clbit_reg = clbit["register"] - y = clbit["y"] - fold_num * (n_lines + 1) + for clbit in clbits_dict.values(): + y = clbit["y"] - fold_num * (glob_data["n_lines"] + 1) if y not in this_clbit_dict.keys(): this_clbit_dict[y] = { "val": 1, - "wire_label": clbit_label, - "register": clbit_reg, + "wire_label": clbit["wire_label"], + "register": clbit["register"], } else: this_clbit_dict[y]["val"] += 1 @@ -645,36 +677,36 @@ def _draw_regs_wires(self, num_folds, xmax, n_lines, max_anc): # cregbundle if self._cregbundle and this_clbit["register"] is not None: self._ax.plot( - [self._x_offset + 0.2, self._x_offset + 0.3], + [glob_data["x_offset"] + 0.2, glob_data["x_offset"] + 0.3], [y - 0.1, y + 0.1], color=self._style["cc"], zorder=PORDER_LINE, ) self._ax.text( - self._x_offset + 0.1, + glob_data["x_offset"] + 0.1, y + 0.1, str(this_clbit["register"].size), ha="left", va="bottom", - fontsize=0.8 * self._fs, + fontsize=0.8 * self._style["fs"], color=self._style["tc"], clip_on=True, zorder=PORDER_TEXT, ) self._ax.text( - self._x_offset - 0.2, + glob_data["x_offset"] - 0.2, y, this_clbit["wire_label"], ha="right", va="center", - fontsize=1.25 * self._fs, + fontsize=1.25 * self._style["fs"], color=self._style["tc"], clip_on=True, zorder=PORDER_TEXT, ) # draw the clbit wire self._line( - [self._x_offset, y], + [glob_data["x_offset"], y], [xmax, y], lc=self._style["cc"], ls=self._style["cline"], @@ -685,10 +717,10 @@ def _draw_regs_wires(self, num_folds, xmax, n_lines, max_anc): feedline_r = num_folds > 0 and num_folds > fold_num feedline_l = fold_num > 0 if feedline_l or feedline_r: - xpos_l = self._x_offset - 0.01 - xpos_r = self._fold + self._x_offset + 0.1 - ypos1 = -fold_num * (n_lines + 1) - ypos2 = -(fold_num + 1) * (n_lines) - fold_num + 1 + xpos_l = glob_data["x_offset"] - 0.01 + xpos_r = self._fold + glob_data["x_offset"] + 0.1 + ypos1 = -fold_num * (glob_data["n_lines"] + 1) + ypos2 = -(fold_num + 1) * (glob_data["n_lines"]) - fold_num + 1 if feedline_l: self._ax.plot( [xpos_l, xpos_l], @@ -706,14 +738,14 @@ def _draw_regs_wires(self, num_folds, xmax, n_lines, max_anc): zorder=PORDER_LINE, ) - # draw anchor index number + # draw index number if self._style["index"]: - for layer_num in range(max_anc): + for layer_num in range(max_x_index): if self._fold > 0: - x_coord = layer_num % self._fold + self._x_offset + 0.53 - y_coord = -(layer_num // self._fold) * (n_lines + 1) + 0.65 + x_coord = layer_num % self._fold + glob_data["x_offset"] + 0.53 + y_coord = -(layer_num // self._fold) * (glob_data["n_lines"] + 1) + 0.65 else: - x_coord = layer_num + self._x_offset + 0.53 + x_coord = layer_num + glob_data["x_offset"] + 0.53 y_coord = 0.65 self._ax.text( x_coord, @@ -721,23 +753,27 @@ def _draw_regs_wires(self, num_folds, xmax, n_lines, max_anc): str(layer_num + 1), ha="center", va="center", - fontsize=self._sfs, + fontsize=self._style["sfs"], color=self._style["tc"], clip_on=True, zorder=PORDER_TEXT, ) - def _draw_ops(self, verbose=False): + def _draw_ops( + self, nodes, node_data, wire_map, layer_widths, clbits_dict, glob_data, verbose=False + ): """Draw the gates in the circuit""" + prev_x_index = -1 - for i, layer in enumerate(self._nodes): - layer_width = self._layer_widths[i] - anc_x_index = prev_x_index + 1 + for layer in nodes: + l_width = [] + curr_x_index = prev_x_index + 1 # draw the gates in this layer for node in layer: op = node.op - self._get_colors(node) + + self._get_colors(node, node_data) if verbose: print(op) @@ -745,35 +781,40 @@ def _draw_ops(self, verbose=False): # add conditional if getattr(op, "condition", None): cond_xy = [ - self._c_anchors[ii].plot_coord(anc_x_index, layer_width, self._x_offset) - for ii in self._clbits_dict - ] - if self._clbits_dict: - anc_x_index = max( - anc_x_index, next(iter(self._c_anchors.items()))[1].get_x_index() + self._plot_coord( + curr_x_index, + clbits_dict[ii]["y"], + layer_widths[node][0], + glob_data, ) - self._condition(node, cond_xy) + for ii in clbits_dict + ] + if clbits_dict: + curr_x_index = max(curr_x_index, glob_data["next_x_index"]) + self._condition(node, node_data, wire_map, cond_xy, glob_data) # draw measure if isinstance(op, Measure): - self._measure(node) + self._measure(node, node_data, glob_data) # draw barriers, snapshots, etc. elif getattr(op, "_directive", False): if self._plot_barriers: - self._barrier(node) + self._barrier(node, node_data, glob_data) # draw single qubit gates - elif len(self._data[node]["q_xy"]) == 1 and not node.cargs: - self._gate(node) + elif len(node_data[node]["q_xy"]) == 1 and not node.cargs: + self._gate(node, node_data, glob_data) # draw controlled gates elif isinstance(op, ControlledGate): - self._control_gate(node) + self._control_gate(node, node_data, glob_data) # draw multi-qubit gate as final default else: - self._multiqubit_gate(node) + self._multiqubit_gate(node, node_data, glob_data) + + l_width.append(layer_widths[node][0]) # adjust the column if there have been barriers encountered, but not plotted barrier_offset = 0 @@ -783,15 +824,16 @@ def _draw_ops(self, verbose=False): -1 if all(getattr(nd.op, "_directive", False) for nd in layer) else 0 ) - prev_x_index = anc_x_index + layer_width + barrier_offset - 1 + prev_x_index = curr_x_index + max(l_width) + barrier_offset - 1 - def _get_colors(self, node): + def _get_colors(self, node, node_data): """Get all the colors needed for drawing the circuit""" + op = node.op base_name = None if not hasattr(op, "base_gate") else op.base_gate.name color = None - if self._data[node]["raw_gate_text"] in self._style["dispcol"]: - color = self._style["dispcol"][self._data[node]["raw_gate_text"]] + if node_data[node]["raw_gate_text"] in self._style["dispcol"]: + color = self._style["dispcol"][node_data[node]["raw_gate_text"]] elif op.name in self._style["dispcol"]: color = self._style["dispcol"][op.name] if color is not None: @@ -825,15 +867,16 @@ def _get_colors(self, node): lc = fc # Subtext needs to be same color as gate text sc = gt - self._data[node]["fc"] = fc - self._data[node]["ec"] = ec - self._data[node]["gt"] = gt - self._data[node]["tc"] = self._style["tc"] - self._data[node]["sc"] = sc - self._data[node]["lc"] = lc - - def _condition(self, node, cond_xy): + node_data[node]["fc"] = fc + node_data[node]["ec"] = ec + node_data[node]["gt"] = gt + node_data[node]["tc"] = self._style["tc"] + node_data[node]["sc"] = sc + node_data[node]["lc"] = lc + + def _condition(self, node, node_data, wire_map, cond_xy, glob_data): """Add a conditional to a gate""" + label, val_bits = get_condition_label_val( node.op.condition, self._circuit, self._cregbundle ) @@ -847,17 +890,17 @@ def _condition(self, node, cond_xy): # other cases, only one bit is shown. if not self._cregbundle and isinstance(cond_bit_reg, ClassicalRegister): for idx in range(cond_bit_reg.size): - cond_pos.append(cond_xy[self._wire_map[cond_bit_reg[idx]] - first_clbit]) + cond_pos.append(cond_xy[wire_map[cond_bit_reg[idx]] - first_clbit]) # If it's a register bit and cregbundle, need to use the register to find the location elif self._cregbundle and isinstance(cond_bit_reg, Clbit): register = get_bit_register(self._circuit, cond_bit_reg) if register is not None: - cond_pos.append(cond_xy[self._wire_map[register] - first_clbit]) + cond_pos.append(cond_xy[wire_map[register] - first_clbit]) else: - cond_pos.append(cond_xy[self._wire_map[cond_bit_reg] - first_clbit]) + cond_pos.append(cond_xy[wire_map[cond_bit_reg] - first_clbit]) else: - cond_pos.append(cond_xy[self._wire_map[cond_bit_reg] - first_clbit]) + cond_pos.append(cond_xy[wire_map[cond_bit_reg] - first_clbit]) xy_plot = [] for idx, xy in enumerate(cond_pos): @@ -869,7 +912,7 @@ def _condition(self, node, cond_xy): fc = self._style["lc"] else: fc = self._style["bg"] - box = self._patches_mod.Circle( + box = glob_data["patches_mod"].Circle( xy=xy, radius=WID * 0.15, fc=fc, @@ -879,7 +922,8 @@ def _condition(self, node, cond_xy): ) self._ax.add_patch(box) xy_plot.append(xy) - qubit_b = min(self._data[node]["q_xy"], key=lambda xy: xy[1]) + + qubit_b = min(node_data[node]["q_xy"], key=lambda xy: xy[1]) clbit_b = min(xy_plot, key=lambda xy: xy[1]) # display the label at the bottom of the lowest conditional and draw the double line @@ -892,31 +936,31 @@ def _condition(self, node, cond_xy): label, ha="center", va="top", - fontsize=self._sfs, + fontsize=self._style["sfs"], color=self._style["tc"], clip_on=True, zorder=PORDER_TEXT, ) self._line(qubit_b, clbit_b, lc=self._style["cc"], ls=self._style["cline"]) - def _measure(self, node): + def _measure(self, node, node_data, glob_data): """Draw the measure symbol and the line to the clbit""" - qx, qy = self._data[node]["q_xy"][0] - cx, cy = self._data[node]["c_xy"][0] + qx, qy = node_data[node]["q_xy"][0] + cx, cy = node_data[node]["c_xy"][0] register, _, reg_index = get_bit_reg_index(self._circuit, node.cargs[0]) # draw gate box - self._gate(node) + self._gate(node, node_data, glob_data) # add measure symbol - arc = self._patches_mod.Arc( + arc = glob_data["patches_mod"].Arc( xy=(qx, qy - 0.15 * HIG), width=WID * 0.7, height=HIG * 0.7, theta1=0, theta2=180, fill=False, - ec=self._data[node]["gt"], + ec=node_data[node]["gt"], linewidth=self._lwidth2, zorder=PORDER_GATE, ) @@ -924,18 +968,18 @@ def _measure(self, node): self._ax.plot( [qx, qx + 0.35 * WID], [qy - 0.15 * HIG, qy + 0.20 * HIG], - color=self._data[node]["gt"], + color=node_data[node]["gt"], linewidth=self._lwidth2, zorder=PORDER_GATE, ) # arrow self._line( - self._data[node]["q_xy"][0], + node_data[node]["q_xy"][0], [cx, cy + 0.35 * WID], lc=self._style["cc"], ls=self._style["cline"], ) - arrowhead = self._patches_mod.Polygon( + arrowhead = glob_data["patches_mod"].Polygon( ( (cx - 0.20 * WID, cy + 0.35 * WID), (cx + 0.20 * WID, cy + 0.35 * WID), @@ -953,15 +997,15 @@ def _measure(self, node): str(reg_index), ha="left", va="bottom", - fontsize=0.8 * self._fs, + fontsize=0.8 * self._style["fs"], color=self._style["tc"], clip_on=True, zorder=PORDER_TEXT, ) - def _barrier(self, node): + def _barrier(self, node, node_data, glob_data): """Draw a barrier""" - for i, xy in enumerate(self._data[node]["q_xy"]): + for i, xy in enumerate(node_data[node]["q_xy"]): xpos, ypos = xy # For the topmost barrier, reduce the rectangle if there's a label to allow for the text. if i == 0 and node.op.label is not None: @@ -976,7 +1020,7 @@ def _barrier(self, node): color=self._style["lc"], zorder=PORDER_TEXT, ) - box = self._patches_mod.Rectangle( + box = glob_data["patches_mod"].Rectangle( xy=(xpos - (0.3 * WID), ypos - 0.5), width=0.6 * WID, height=1.0 + ypos_adj, @@ -997,74 +1041,74 @@ def _barrier(self, node): node.op.label, ha="center", va="top", - fontsize=self._fs, - color=self._data[node]["tc"], + fontsize=self._style["fs"], + color=node_data[node]["tc"], clip_on=True, zorder=PORDER_TEXT, ) - def _gate(self, node, xy=None): + def _gate(self, node, node_data, glob_data, xy=None): """Draw a 1-qubit gate""" if xy is None: - xy = self._data[node]["q_xy"][0] + xy = node_data[node]["q_xy"][0] xpos, ypos = xy - wid = max(self._data[node]["width"], WID) + wid = max(node_data[node]["width"], WID) - box = self._patches_mod.Rectangle( + box = glob_data["patches_mod"].Rectangle( xy=(xpos - 0.5 * wid, ypos - 0.5 * HIG), width=wid, height=HIG, - fc=self._data[node]["fc"], - ec=self._data[node]["ec"], + fc=node_data[node]["fc"], + ec=node_data[node]["ec"], linewidth=self._lwidth15, zorder=PORDER_GATE, ) self._ax.add_patch(box) - if "gate_text" in self._data[node]: + if "gate_text" in node_data[node]: gate_ypos = ypos - if "param" in self._data[node] and self._data[node]["param"] != "": + if "param_text" in node_data[node] and node_data[node]["param_text"] != "": gate_ypos = ypos + 0.15 * HIG self._ax.text( xpos, ypos - 0.3 * HIG, - self._data[node]["param"], + node_data[node]["param_text"], ha="center", va="center", - fontsize=self._sfs, - color=self._data[node]["sc"], + fontsize=self._style["sfs"], + color=node_data[node]["sc"], clip_on=True, zorder=PORDER_TEXT, ) self._ax.text( xpos, gate_ypos, - self._data[node]["gate_text"], + node_data[node]["gate_text"], ha="center", va="center", - fontsize=self._fs, - color=self._data[node]["gt"], + fontsize=self._style["fs"], + color=node_data[node]["gt"], clip_on=True, zorder=PORDER_TEXT, ) - def _multiqubit_gate(self, node, xy=None): + def _multiqubit_gate(self, node, node_data, glob_data, xy=None): """Draw a gate covering more than one qubit""" op = node.op if xy is None: - xy = self._data[node]["q_xy"] + xy = node_data[node]["q_xy"] # Swap gate if isinstance(op, SwapGate): - self._swap(xy, node, self._data[node]["lc"]) + self._swap(xy, node, node_data, node_data[node]["lc"]) return # RZZ Gate elif isinstance(op, RZZGate): - self._symmetric_gate(node, RZZGate) + self._symmetric_gate(node, node_data, RZZGate, glob_data) return - c_xy = self._data[node]["c_xy"] + c_xy = node_data[node]["c_xy"] xpos = min(x[0] for x in xy) ypos = min(y[1] for y in xy) ypos_max = max(y[1] for y in xy) @@ -1073,16 +1117,17 @@ def _multiqubit_gate(self, node, xy=None): cypos = min(y[1] for y in c_xy) ypos = min(ypos, cypos) - wid = max(self._data[node]["width"] + 0.21, WID) + wid = max(node_data[node]["width"] + 0.21, WID) - qubit_span = abs(ypos) - abs(ypos_max) + 1 - height = HIG + (qubit_span - 1) - box = self._patches_mod.Rectangle( + qubit_span = abs(ypos) - abs(ypos_max) + height = HIG + qubit_span + + box = glob_data["patches_mod"].Rectangle( xy=(xpos - 0.5 * wid, ypos - 0.5 * HIG), width=wid, height=height, - fc=self._data[node]["fc"], - ec=self._data[node]["ec"], + fc=node_data[node]["fc"], + ec=node_data[node]["ec"], linewidth=self._lwidth15, zorder=PORDER_GATE, ) @@ -1096,8 +1141,8 @@ def _multiqubit_gate(self, node, xy=None): str(bit), ha="left", va="center", - fontsize=self._fs, - color=self._data[node]["gt"], + fontsize=self._style["fs"], + color=node_data[node]["gt"], clip_on=True, zorder=PORDER_TEXT, ) @@ -1110,43 +1155,43 @@ def _multiqubit_gate(self, node, xy=None): str(bit), ha="left", va="center", - fontsize=self._fs, - color=self._data[node]["gt"], + fontsize=self._style["fs"], + color=node_data[node]["gt"], clip_on=True, zorder=PORDER_TEXT, ) - if "gate_text" in self._data[node] and self._data[node]["gate_text"] != "": - gate_ypos = ypos + 0.5 * (qubit_span - 1) - if "param" in self._data[node] and self._data[node]["param"] != "": + if "gate_text" in node_data[node] and node_data[node]["gate_text"] != "": + gate_ypos = ypos + 0.5 * qubit_span + if "param_text" in node_data[node] and node_data[node]["param_text"] != "": gate_ypos = ypos + 0.4 * height self._ax.text( xpos + 0.11, ypos + 0.2 * height, - self._data[node]["param"], + node_data[node]["param_text"], ha="center", va="center", - fontsize=self._sfs, - color=self._data[node]["sc"], + fontsize=self._style["sfs"], + color=node_data[node]["sc"], clip_on=True, zorder=PORDER_TEXT, ) self._ax.text( xpos + 0.11, gate_ypos, - self._data[node]["gate_text"], + node_data[node]["gate_text"], ha="center", va="center", - fontsize=self._fs, - color=self._data[node]["gt"], + fontsize=self._style["fs"], + color=node_data[node]["gt"], clip_on=True, zorder=PORDER_TEXT, ) - def _control_gate(self, node): + def _control_gate(self, node, node_data, glob_data): """Draw a controlled gate""" op = node.op + xy = node_data[node]["q_xy"] base_type = None if not hasattr(op, "base_gate") else op.base_gate - xy = self._data[node]["q_xy"] qubit_b = min(xy, key=lambda xy: xy[1]) qubit_t = max(xy, key=lambda xy: xy[1]) num_ctrl_qubits = op.num_ctrl_qubits @@ -1155,32 +1200,33 @@ def _control_gate(self, node): op.ctrl_state, num_ctrl_qubits, xy, - ec=self._data[node]["ec"], - tc=self._data[node]["tc"], - text=self._data[node]["ctrl_text"], + glob_data, + ec=node_data[node]["ec"], + tc=node_data[node]["tc"], + text=node_data[node]["ctrl_text"], qargs=node.qargs, ) - self._line(qubit_b, qubit_t, lc=self._data[node]["lc"]) + self._line(qubit_b, qubit_t, lc=node_data[node]["lc"]) if isinstance(op, RZZGate) or isinstance(base_type, (U1Gate, PhaseGate, ZGate, RZZGate)): - self._symmetric_gate(node, base_type) + self._symmetric_gate(node, node_data, base_type, glob_data) elif num_qargs == 1 and isinstance(base_type, XGate): tgt_color = self._style["dispcol"]["target"] tgt = tgt_color if isinstance(tgt_color, str) else tgt_color[0] - self._x_tgt_qubit(xy[num_ctrl_qubits], ec=self._data[node]["ec"], ac=tgt) + self._x_tgt_qubit(xy[num_ctrl_qubits], glob_data, ec=node_data[node]["ec"], ac=tgt) elif num_qargs == 1: - self._gate(node, xy[num_ctrl_qubits:][0]) + self._gate(node, node_data, glob_data, xy[num_ctrl_qubits:][0]) elif isinstance(base_type, SwapGate): - self._swap(xy[num_ctrl_qubits:], node, self._data[node]["lc"]) + self._swap(xy[num_ctrl_qubits:], node, node_data, node_data[node]["lc"]) else: - self._multiqubit_gate(node, xy[num_ctrl_qubits:]) + self._multiqubit_gate(node, node_data, glob_data, xy[num_ctrl_qubits:]) def _set_ctrl_bits( - self, ctrl_state, num_ctrl_qubits, qbit, ec=None, tc=None, text="", qargs=None + self, ctrl_state, num_ctrl_qubits, qbit, glob_data, ec=None, tc=None, text="", qargs=None ): """Determine which qubits are controls and whether they are open or closed""" # place the control label at the top or bottom of controls @@ -1202,12 +1248,14 @@ def _set_ctrl_bits( text_top = True elif not top and qlist[i] == max_ctbit: text_top = False - self._ctrl_qubit(qbit[i], fc=fc_open_close, ec=ec, tc=tc, text=text, text_top=text_top) + self._ctrl_qubit( + qbit[i], glob_data, fc=fc_open_close, ec=ec, tc=tc, text=text, text_top=text_top + ) - def _ctrl_qubit(self, xy, fc=None, ec=None, tc=None, text="", text_top=None): + def _ctrl_qubit(self, xy, glob_data, fc=None, ec=None, tc=None, text="", text_top=None): """Draw a control circle and if top or bottom control, draw control label""" xpos, ypos = xy - box = self._patches_mod.Circle( + box = glob_data["patches_mod"].Circle( xy=(xpos, ypos), radius=WID * 0.15, fc=fc, @@ -1236,17 +1284,17 @@ def _ctrl_qubit(self, xy, fc=None, ec=None, tc=None, text="", text_top=None): text, ha="center", va="top", - fontsize=self._sfs, + fontsize=self._style["sfs"], color=tc, clip_on=True, zorder=PORDER_TEXT, ) - def _x_tgt_qubit(self, xy, ec=None, ac=None): + def _x_tgt_qubit(self, xy, glob_data, ec=None, ac=None): """Draw the cnot target symbol""" linewidth = self._lwidth2 xpos, ypos = xy - box = self._patches_mod.Circle( + box = glob_data["patches_mod"].Circle( xy=(xpos, ypos), radius=HIG * 0.35, fc=ec, @@ -1272,44 +1320,50 @@ def _x_tgt_qubit(self, xy, ec=None, ac=None): zorder=PORDER_GATE + 1, ) - def _symmetric_gate(self, node, base_type): + def _symmetric_gate(self, node, node_data, base_type, glob_data): """Draw symmetric gates for cz, cu1, cp, and rzz""" op = node.op - xy = self._data[node]["q_xy"] + xy = node_data[node]["q_xy"] qubit_b = min(xy, key=lambda xy: xy[1]) qubit_t = max(xy, key=lambda xy: xy[1]) base_type = None if not hasattr(op, "base_gate") else op.base_gate - ec = self._data[node]["ec"] - tc = self._data[node]["tc"] - lc = self._data[node]["lc"] + ec = node_data[node]["ec"] + tc = node_data[node]["tc"] + lc = node_data[node]["lc"] # cz and mcz gates if not isinstance(op, ZGate) and isinstance(base_type, ZGate): num_ctrl_qubits = op.num_ctrl_qubits - self._ctrl_qubit(xy[-1], fc=ec, ec=ec, tc=tc) + self._ctrl_qubit(xy[-1], glob_data, fc=ec, ec=ec, tc=tc) self._line(qubit_b, qubit_t, lc=lc, zorder=PORDER_LINE + 1) # cu1, cp, rzz, and controlled rzz gates (sidetext gates) elif isinstance(op, RZZGate) or isinstance(base_type, (U1Gate, PhaseGate, RZZGate)): num_ctrl_qubits = 0 if isinstance(op, RZZGate) else op.num_ctrl_qubits - gate_text = "P" if isinstance(base_type, PhaseGate) else self._data[node]["gate_text"] + gate_text = "P" if isinstance(base_type, PhaseGate) else node_data[node]["gate_text"] - self._ctrl_qubit(xy[num_ctrl_qubits], fc=ec, ec=ec, tc=tc) + self._ctrl_qubit(xy[num_ctrl_qubits], glob_data, fc=ec, ec=ec, tc=tc) if not isinstance(base_type, (U1Gate, PhaseGate)): - self._ctrl_qubit(xy[num_ctrl_qubits + 1], fc=ec, ec=ec, tc=tc) - - self._sidetext(node, qubit_b, tc=tc, text=f"{gate_text} ({self._data[node]['param']})") + self._ctrl_qubit(xy[num_ctrl_qubits + 1], glob_data, fc=ec, ec=ec, tc=tc) + + self._sidetext( + node, + node_data, + qubit_b, + tc=tc, + text=f"{gate_text} ({node_data[node]['param_text']})", + ) self._line(qubit_b, qubit_t, lc=lc) - def _swap(self, xy, node, color=None): + def _swap(self, xy, node, node_data, color=None): """Draw a Swap gate""" self._swap_cross(xy[0], color=color) self._swap_cross(xy[1], color=color) self._line(xy[0], xy[1], lc=color) # add calibration text - gate_text = self._data[node]["gate_text"].split("\n")[-1] - if self._data[node]["raw_gate_text"] in self._calibrations: + gate_text = node_data[node]["gate_text"].split("\n")[-1] + if node_data[node]["raw_gate_text"] in self._calibrations: xpos, ypos = xy[0] self._ax.text( xpos, @@ -1342,19 +1396,19 @@ def _swap_cross(self, xy, color=None): zorder=PORDER_LINE + 1, ) - def _sidetext(self, node, xy, tc=None, text=""): + def _sidetext(self, node, node_data, xy, tc=None, text=""): """Draw the sidetext for symmetric gates""" xpos, ypos = xy # 0.11 = the initial gap, add 1/2 text width to place on the right - xp = xpos + 0.11 + self._data[node]["width"] / 2 + xp = xpos + 0.11 + node_data[node]["width"] / 2 self._ax.text( xp, ypos + HIG, text, ha="center", va="top", - fontsize=self._sfs, + fontsize=self._style["sfs"], color=tc, clip_on=True, zorder=PORDER_TEXT, @@ -1397,33 +1451,18 @@ def _line(self, xy0, xy1, lc=None, ls=None, zorder=PORDER_LINE): zorder=zorder, ) - -class Anchor: - """Locate the anchors for the gates""" - - def __init__(self, num_wires, y_index, fold): - self._num_wires = num_wires - self._fold = fold - self._y_index = y_index - self._x_index = 0 - - def plot_coord(self, x_index, gate_width, x_offset): + def _plot_coord(self, x_index, y_index, gate_width, glob_data): """Get the coord positions for an index""" - h_pos = x_index % self._fold + 1 # check folding - if self._fold > 0: - if h_pos + (gate_width - 1) > self._fold: - x_index += self._fold - (h_pos - 1) - x_pos = x_index % self._fold + 0.5 * gate_width + 0.04 - y_pos = self._y_index - (x_index // self._fold) * (self._num_wires + 1) - else: - x_pos = x_index + 0.5 * gate_width + 0.04 - y_pos = self._y_index + fold = self._fold if self._fold > 0 else INFINITE_FOLD + h_pos = x_index % fold + 1 - # could have been updated, so need to store - self._x_index = x_index - return x_pos + x_offset, y_pos + if h_pos + (gate_width - 1) > fold: + x_index += fold - (h_pos - 1) + x_pos = x_index % fold + glob_data["x_offset"] + 0.04 + x_pos += 0.5 * gate_width + y_pos = y_index - (x_index // fold) * (glob_data["n_lines"] + 1) - def get_x_index(self): - """Getter for the x index""" - return self._x_index + # could have been updated, so need to store + glob_data["next_x_index"] = x_index + return x_pos, y_pos