diff --git a/src/qcir/qcir_cmd.cpp b/src/qcir/qcir_cmd.cpp index fe83e079..97564bda 100644 --- a/src/qcir/qcir_cmd.cpp +++ b/src/qcir/qcir_cmd.cpp @@ -36,7 +36,7 @@ namespace qsyn::qcir { std::function valid_qcir_id(QCirMgr const& qcir_mgr) { return [&](size_t const& id) { - if (qcir_mgr.is_id(id)) return true; + if (qcir_mgr.get() && qcir_mgr.is_id(id)) return true; spdlog::error("QCir {} does not exist!!", id); return false; }; @@ -45,7 +45,7 @@ std::function valid_qcir_id(QCirMgr const& qcir_mgr) { std::function valid_qcir_gate_id(QCirMgr const& qcir_mgr) { return [&](size_t const& id) { if (!dvlab::utils::mgr_has_data(qcir_mgr)) return false; - if (qcir_mgr.get()->get_gate(id) != nullptr) return true; + if (qcir_mgr.get() && qcir_mgr.get()->get_gate(id) != nullptr) return true; spdlog::error("Gate ID {} does not exist!!", id); return false; }; @@ -54,7 +54,7 @@ std::function valid_qcir_gate_id(QCirMgr const& qcir_mgr) { std::function valid_qcir_qubit_id(QCirMgr const& qcir_mgr) { return [&](QubitIdType const& id) { if (!dvlab::utils::mgr_has_data(qcir_mgr)) return false; - if (qcir_mgr.get()->get_qubit(id) != nullptr) return true; + if (qcir_mgr.get() && qcir_mgr.get()->get_qubit(id) != nullptr) return true; spdlog::error("Qubit ID {} does not exist!!", id); return false; }; diff --git a/src/util/data_structure_manager.hpp b/src/util/data_structure_manager.hpp index 1edb07a7..cb49836a 100644 --- a/src/util/data_structure_manager.hpp +++ b/src/util/data_structure_manager.hpp @@ -69,7 +69,7 @@ class DataStructureManager { // NOLINT(hicpp-special-member-functions, cppcoreg size_t get_next_id() const { return _next_id; } - T* get() const { return _list.at(_focused_id).get(); } + T* get() const { return size() ? _list.at(_focused_id).get() : nullptr; } void set_by_id(size_t id, std::unique_ptr t) { if (_list.contains(id)) { diff --git a/src/zx/simplifier/rules/pivot_boundary_rule.cpp b/src/zx/simplifier/rules/pivot_boundary_rule.cpp index c2632b55..6ee164af 100644 --- a/src/zx/simplifier/rules/pivot_boundary_rule.cpp +++ b/src/zx/simplifier/rules/pivot_boundary_rule.cpp @@ -6,6 +6,7 @@ ****************************************************************************/ #include "./zx_rules_template.hpp" +#include "zx/zxgraph.hpp" using namespace qsyn::zx; @@ -83,3 +84,38 @@ void PivotBoundaryRule::apply(ZXGraph& graph, std::vector const& matc PivotRuleInterface::apply(graph, matches); } + +bool PivotBoundaryRule::is_candidate(ZXGraph& graph, ZXVertex* vb, ZXVertex* vn) { + if (!graph.is_graph_like()) { + spdlog::error("The graph is not graph like!"); + return false; + } + if (!vb->is_z()) { + spdlog::error("Vertex {} is not a Z vertex", vb->get_id()); + return false; + } + bool has_boundary = false; + for (const auto& [nb, etype] : graph.get_neighbors(vb)) { + if (nb->is_boundary()) { + has_boundary = true; + break; + } + } + if (!has_boundary) { + spdlog::error("Vertex {} is not connected to a boundary", vb->get_id()); + return false; + } + if (!vn->has_n_pi_phase()) { + spdlog::error("Vertex {} is not a Z vertex with phase n π", vn->get_id()); + return false; + } + if (!graph.is_neighbor(vb, vn)) { + spdlog::error("Vertices {} and {} are not connected", vb->get_id(), vn->get_id()); + return false; + } + // if (graph.has_dangling_neighbors(vn)) { + // spdlog::error("Vertex {} is the axel of a phase gadget", vn->get_id()); + // return false; + // } + return true; +} diff --git a/src/zx/simplifier/rules/zx_rules_template.hpp b/src/zx/simplifier/rules/zx_rules_template.hpp index dc5f5f8b..78dd4873 100644 --- a/src/zx/simplifier/rules/zx_rules_template.hpp +++ b/src/zx/simplifier/rules/zx_rules_template.hpp @@ -172,6 +172,7 @@ class PivotBoundaryRule : public PivotRuleInterface { std::vector find_matches(ZXGraph const& graph) const override; void apply(ZXGraph& graph, std::vector const& matches) const override; + bool is_candidate(ZXGraph& graph, ZXVertex* v0, ZXVertex* v1); }; class SpiderFusionRule : public ZXRuleTemplate> { diff --git a/src/zx/simplifier/simp_cmd.cpp b/src/zx/simplifier/simp_cmd.cpp index f9448ff2..3d81ef89 100644 --- a/src/zx/simplifier/simp_cmd.cpp +++ b/src/zx/simplifier/simp_cmd.cpp @@ -7,8 +7,11 @@ #include "./simp_cmd.hpp" +#include + #include #include +#include #include "./simplify.hpp" #include "argparse/arg_parser.hpp" @@ -177,4 +180,42 @@ Command zxgraph_rule_cmd(zx::ZXGraphMgr &zxgraph_mgr) { }}; } +// REVIEW - Logic of check function is not completed +Command zxgraph_manual_apply_cmd(zx::ZXGraphMgr &zxgraph_mgr) { + return Command{ + "manual", + [&](ArgumentParser &parser) { + parser.description("apply simplification rules on specific candidates"); + + auto mutex = parser.add_mutually_exclusive_group().required(true); + mutex.add_argument("--pivot") + .action(store_true) + .help("applies pivot rules to vertex pairs with phase 0 or π"); + mutex.add_argument("--pivot-boundary") + .action(store_true) + .help("applies pivot rules to vertex pairs connected to the boundary"); + mutex.add_argument("--pivot-gadget") + .action(store_true) + .help("unfuses the phase and applies pivot rules to form gadgets"); + + parser.add_argument("vertices") + .nargs(2) + .constraint(valid_zxvertex_id(zxgraph_mgr)) + .help("the vertices on which the rule applies"); + }, + [&](ArgumentParser const &parser) { + if (!dvlab::utils::mgr_has_data(zxgraph_mgr)) return dvlab::CmdExecResult::error; + auto vertices = parser.get>("vertices"); + ZXVertex *bound = zxgraph_mgr.get()->find_vertex_by_id(vertices[0]); + ZXVertex *vert = zxgraph_mgr.get()->find_vertex_by_id(vertices[1]); + + const bool is_cand = PivotBoundaryRule().is_candidate(*zxgraph_mgr.get(), bound, vert); + if (!is_cand) return CmdExecResult::error; + + std::vector> match; + match.emplace_back(bound, vert); + PivotBoundaryRule().apply(*zxgraph_mgr.get(), match); + return CmdExecResult::done; + }}; +} } // namespace qsyn::zx diff --git a/src/zx/simplifier/simp_cmd.hpp b/src/zx/simplifier/simp_cmd.hpp index 56898885..c50aadf1 100644 --- a/src/zx/simplifier/simp_cmd.hpp +++ b/src/zx/simplifier/simp_cmd.hpp @@ -14,5 +14,6 @@ namespace qsyn::zx { dvlab::Command zxgraph_optimize_cmd(ZXGraphMgr &zxgraph_mgr); dvlab::Command zxgraph_rule_cmd(ZXGraphMgr &zxgraph_mgr); +dvlab::Command zxgraph_manual_apply_cmd(ZXGraphMgr &zxgraph_mgr); } // namespace qsyn::zx diff --git a/src/zx/simplifier/simplify.hpp b/src/zx/simplifier/simplify.hpp index 8387601a..a9bd9569 100644 --- a/src/zx/simplifier/simplify.hpp +++ b/src/zx/simplifier/simplify.hpp @@ -25,7 +25,8 @@ class Simplifier { hadamard_rule_simp(); } ~Simplifier() { - _simp_graph->adjustVertexCoordinates(); + // REVIEW - Whether to adjust + // _simp_graph->adjust_vertex_coordinates(); } Simplifier(Simplifier const& other) = default; Simplifier(Simplifier&& other) = default; diff --git a/src/zx/zx_cmd.cpp b/src/zx/zx_cmd.cpp index 0c07286e..724a6c27 100644 --- a/src/zx/zx_cmd.cpp +++ b/src/zx/zx_cmd.cpp @@ -26,7 +26,7 @@ namespace qsyn::zx { std::function valid_zxvertex_id(ZXGraphMgr const& zxgraph_mgr) { return [&](size_t const& id) { - if (zxgraph_mgr.get()->is_v_id(id)) return true; + if (zxgraph_mgr.get() && zxgraph_mgr.get()->is_v_id(id)) return true; spdlog::error("Cannot find vertex with ID {} in the ZXGraph!!", id); return false; }; @@ -34,6 +34,10 @@ std::function valid_zxvertex_id(ZXGraphMgr const& zxgraph_m std::function zxgraph_id_not_exist(ZXGraphMgr const& zxgraph_mgr) { return [&](size_t const& id) { + if (!zxgraph_mgr.get()) { + spdlog::error("ZXGraphMgr does not exist!!"); + return true; + } if (!zxgraph_mgr.is_id(id)) return true; spdlog::error("ZXGraph {} already exists!!", id); spdlog::info("Use `-Replace` if you want to overwrite it."); @@ -43,6 +47,10 @@ std::function zxgraph_id_not_exist(ZXGraphMgr const& zxgrap std::function zxgraph_input_qubit_not_exist(ZXGraphMgr const& zxgraph_mgr) { return [&](int const& qid) { + if (!zxgraph_mgr.get()) { + spdlog::error("ZXGraphMgr does not exist!!"); + return true; + } if (!zxgraph_mgr.get()->is_input_qubit(qid)) return true; spdlog::error("This qubit's input already exists!!"); return false; @@ -51,6 +59,10 @@ std::function zxgraph_input_qubit_not_exist(ZXGraphMgr const& std::function zxgraph_output_qubit_not_exist(ZXGraphMgr const& zxgraph_mgr) { return [&](int const& qid) { + if (!zxgraph_mgr.get()) { + spdlog::error("ZXGraphMgr does not exist!!"); + return true; + } if (!zxgraph_mgr.get()->is_output_qubit(qid)) return true; spdlog::error("This qubit's output already exists!!"); return false; @@ -242,6 +254,7 @@ Command zxgraph_draw_cmd(ZXGraphMgr const& zxgraph_mgr) { [&](ArgumentParser const& parser) { if (!dvlab::utils::mgr_has_data(zxgraph_mgr)) return CmdExecResult::error; if (parser.parsed("filepath")) { + zxgraph_mgr.get()->adjust_vertex_coordinates(); if (!zxgraph_mgr.get()->write_pdf(parser.get("filepath"))) return CmdExecResult::error; } if (parser.parsed("-cli")) { @@ -661,6 +674,7 @@ Command zxgraph_cmd(ZXGraphMgr& zxgraph_mgr) { cmd.add_subcommand(zxgraph_gflow_cmd(zxgraph_mgr)); cmd.add_subcommand(zxgraph_optimize_cmd(zxgraph_mgr)); cmd.add_subcommand(zxgraph_rule_cmd(zxgraph_mgr)); + cmd.add_subcommand(zxgraph_manual_apply_cmd(zxgraph_mgr)); cmd.add_subcommand(zxgraph_vertex_cmd(zxgraph_mgr)); cmd.add_subcommand(zxgraph_edge_cmd(zxgraph_mgr)); return cmd; diff --git a/src/zx/zx_cmd.hpp b/src/zx/zx_cmd.hpp index 39567555..54b95cef 100644 --- a/src/zx/zx_cmd.hpp +++ b/src/zx/zx_cmd.hpp @@ -14,6 +14,10 @@ namespace qsyn::zx { +std::function valid_zxvertex_id(ZXGraphMgr const& zxgraph_mgr); +std::function zxgraph_id_not_exist(ZXGraphMgr const& zxgraph_mgr); +std::function zxgraph_input_qubit_not_exist(ZXGraphMgr const& zxgraph_mgr); +std::function zxgraph_output_qubit_not_exist(ZXGraphMgr const& zxgraph_mgr); bool add_zx_cmds(dvlab::CommandLineInterface& cli, qsyn::zx::ZXGraphMgr& zxgraph_mgr); } // namespace qsyn::zx diff --git a/src/zx/zxgraph.hpp b/src/zx/zxgraph.hpp index 5558cffb..2d312295 100644 --- a/src/zx/zxgraph.hpp +++ b/src/zx/zxgraph.hpp @@ -232,7 +232,7 @@ class ZXGraph { // NOLINT(cppcoreguidelines-special-member-functions) : copy-sw void add_gadget(Phase p, std::vector const& vertices); void remove_gadget(ZXVertex* v); std::unordered_map create_id_to_vertex_map() const; - void adjustVertexCoordinates(); + void adjust_vertex_coordinates(); // Print functions (zxGraphPrint.cpp) void print_graph(spdlog::level::level_enum lvl = spdlog::level::level_enum::off) const; diff --git a/src/zx/zxgraph_action.cpp b/src/zx/zxgraph_action.cpp index 333b2786..98c30d3f 100644 --- a/src/zx/zxgraph_action.cpp +++ b/src/zx/zxgraph_action.cpp @@ -5,6 +5,8 @@ Copyright [ Copyright(c) 2023 DVLab, GIEE, NTU, Taiwan ] ****************************************************************************/ +#include + #include #include #include @@ -237,46 +239,65 @@ std::unordered_map ZXGraph::create_id_to_vertex_map() const { * @brief Rearrange vertices on each qubit so that each vertex can be seperated in the printed graph. * */ -void ZXGraph::adjustVertexCoordinates() { +void ZXGraph::adjust_vertex_coordinates() { // FIXME - QubitId -> RowId std::unordered_map> qubit_id_to_vertices_map; std::unordered_set visited_qubit_ids; - std::queue vertex_queue; + std::vector vertex_queue; // NOTE - Check Gadgets // FIXME - When replacing QubitId with RowId, add 0.5 on it - for (auto const& i : _vertices) { - if (i->get_qubit() == -2 && get_num_neighbors(i) > 1) { - std::unordered_map num_neighbor_qubits; - for (auto const& [nb, _] : get_neighbors(i)) { - if (num_neighbor_qubits.contains(nb->get_qubit())) { - num_neighbor_qubits[nb->get_qubit()]++; - fmt::println("add qb: {}", nb->get_qubit()); - } else - num_neighbor_qubits[nb->get_qubit()] = 1; - } - fmt::println("move to {}", (*max_element(num_neighbor_qubits.begin(), num_neighbor_qubits.end(), [](const std::pair& p1, const std::pair& p2) { return p1.second < p2.second; })).first); - i->set_qubit((*max_element(num_neighbor_qubits.begin(), num_neighbor_qubits.end(), [](const std::pair& p1, const std::pair& p2) { return p1.second < p2.second; })).first); - } - } + + // REVIEW - Whether to move the vertex from row -2 when it is no longer a gadget + // for (auto const& i : _vertices) { + // if (i->get_qubit() == -2 && get_num_neighbors(i) > 1) { + // std::unordered_map num_neighbor_qubits; + // for (auto const& [nb, _] : get_neighbors(i)) { + // if (num_neighbor_qubits.contains(nb->get_qubit())) { + // num_neighbor_qubits[nb->get_qubit()]++; + // } else + // num_neighbor_qubits[nb->get_qubit()] = 1; + // } + // // fmt::println("move to {}", (*max_element(num_neighbor_qubits.begin(), num_neighbor_qubits.end(), [](const std::pair& p1, const std::pair& p2) { return p1.second < p2.second; })).first); + // i->set_qubit((*max_element(num_neighbor_qubits.begin(), num_neighbor_qubits.end(), [](const std::pair& p1, const std::pair& p2) { return p1.second < p2.second; })).first); + // } + // } for (auto const& i : _inputs) { - vertex_queue.push(i); + vertex_queue.emplace_back(i); visited_qubit_ids.insert(gsl::narrow(i->get_id())); } while (!vertex_queue.empty()) { ZXVertex* v = vertex_queue.front(); - vertex_queue.pop(); + vertex_queue.erase(vertex_queue.begin()); qubit_id_to_vertices_map[v->get_qubit()].emplace_back(v); for (auto const& nb : get_neighbors(v) | std::views::keys) { if (visited_qubit_ids.find(gsl::narrow(nb->get_id())) == visited_qubit_ids.end()) { - vertex_queue.push(nb); + vertex_queue.emplace_back(nb); visited_qubit_ids.insert(gsl::narrow(nb->get_id())); } } } + std::vector gadgets; + double non_gadget = 0; + for (size_t i = 0; i < qubit_id_to_vertices_map[-2].size(); i++) { + if (get_num_neighbors(qubit_id_to_vertices_map[-2][i]) == 1) { // Not Gadgets + gadgets.emplace_back(qubit_id_to_vertices_map[-2][i]); + } else + non_gadget++; + } + auto end_it = std::remove_if( + qubit_id_to_vertices_map[-2].begin(), + qubit_id_to_vertices_map[-2].end(), + [this](ZXVertex* v) { + return this->get_num_neighbors(v) == 1; + }); + qubit_id_to_vertices_map[-2].erase(end_it, qubit_id_to_vertices_map[-2].end()); + + qubit_id_to_vertices_map[-2].insert(qubit_id_to_vertices_map[-2].end(), gadgets.begin(), gadgets.end()); double max_col = 0.0; for (auto& i : qubit_id_to_vertices_map) { - double col = i.first < 0 ? 0.5 : 0.0; + double col = i.first == -2 ? 0.5 : i.first == -1 ? 0.5 + non_gadget + : 0.0; for (auto& v : i.second) { v->set_col(col); col++;