Skip to content

Commit

Permalink
🔀 Merge pull request #61 from DVLab-NTU/feature/manual-rule
Browse files Browse the repository at this point in the history
Adjust ZX coords
  • Loading branch information
JoshuaLau0220 authored Feb 26, 2024
2 parents d6061db + ee6e64d commit 6226205
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 27 deletions.
6 changes: 3 additions & 3 deletions src/qcir/qcir_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace qsyn::qcir {

std::function<bool(size_t const&)> valid_qcir_id(QCirMgr const& qcir_mgr) {
return [&](size_t const& id) {
if (qcir_mgr.is_id(id)) return true;
if (qcir_mgr.get() && qcir_mgr.is_id(id)) return true;
spdlog::error("QCir {} does not exist!!", id);
return false;
};
Expand All @@ -45,7 +45,7 @@ std::function<bool(size_t const&)> valid_qcir_id(QCirMgr const& qcir_mgr) {
std::function<bool(size_t const&)> valid_qcir_gate_id(QCirMgr const& qcir_mgr) {
return [&](size_t const& id) {
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return false;
if (qcir_mgr.get()->get_gate(id) != nullptr) return true;
if (qcir_mgr.get() && qcir_mgr.get()->get_gate(id) != nullptr) return true;
spdlog::error("Gate ID {} does not exist!!", id);
return false;
};
Expand All @@ -54,7 +54,7 @@ std::function<bool(size_t const&)> valid_qcir_gate_id(QCirMgr const& qcir_mgr) {
std::function<bool(QubitIdType const&)> valid_qcir_qubit_id(QCirMgr const& qcir_mgr) {
return [&](QubitIdType const& id) {
if (!dvlab::utils::mgr_has_data(qcir_mgr)) return false;
if (qcir_mgr.get()->get_qubit(id) != nullptr) return true;
if (qcir_mgr.get() && qcir_mgr.get()->get_qubit(id) != nullptr) return true;
spdlog::error("Qubit ID {} does not exist!!", id);
return false;
};
Expand Down
2 changes: 1 addition & 1 deletion src/util/data_structure_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class DataStructureManager { // NOLINT(hicpp-special-member-functions, cppcoreg

size_t get_next_id() const { return _next_id; }

T* get() const { return _list.at(_focused_id).get(); }
T* get() const { return size() ? _list.at(_focused_id).get() : nullptr; }

void set_by_id(size_t id, std::unique_ptr<T> t) {
if (_list.contains(id)) {
Expand Down
36 changes: 36 additions & 0 deletions src/zx/simplifier/rules/pivot_boundary_rule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
****************************************************************************/

#include "./zx_rules_template.hpp"
#include "zx/zxgraph.hpp"

using namespace qsyn::zx;

Expand Down Expand Up @@ -83,3 +84,38 @@ void PivotBoundaryRule::apply(ZXGraph& graph, std::vector<MatchType> const& matc

PivotRuleInterface::apply(graph, matches);
}

bool PivotBoundaryRule::is_candidate(ZXGraph& graph, ZXVertex* vb, ZXVertex* vn) {
if (!graph.is_graph_like()) {
spdlog::error("The graph is not graph like!");
return false;
}
if (!vb->is_z()) {
spdlog::error("Vertex {} is not a Z vertex", vb->get_id());
return false;
}
bool has_boundary = false;
for (const auto& [nb, etype] : graph.get_neighbors(vb)) {
if (nb->is_boundary()) {
has_boundary = true;
break;
}
}
if (!has_boundary) {
spdlog::error("Vertex {} is not connected to a boundary", vb->get_id());
return false;
}
if (!vn->has_n_pi_phase()) {
spdlog::error("Vertex {} is not a Z vertex with phase n π", vn->get_id());
return false;
}
if (!graph.is_neighbor(vb, vn)) {
spdlog::error("Vertices {} and {} are not connected", vb->get_id(), vn->get_id());
return false;
}
// if (graph.has_dangling_neighbors(vn)) {
// spdlog::error("Vertex {} is the axel of a phase gadget", vn->get_id());
// return false;
// }
return true;
}
1 change: 1 addition & 0 deletions src/zx/simplifier/rules/zx_rules_template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class PivotBoundaryRule : public PivotRuleInterface {

std::vector<MatchType> find_matches(ZXGraph const& graph) const override;
void apply(ZXGraph& graph, std::vector<MatchType> const& matches) const override;
bool is_candidate(ZXGraph& graph, ZXVertex* v0, ZXVertex* v1);
};

class SpiderFusionRule : public ZXRuleTemplate<std::pair<ZXVertex*, ZXVertex*>> {
Expand Down
41 changes: 41 additions & 0 deletions src/zx/simplifier/simp_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

#include "./simp_cmd.hpp"

#include <fmt/core.h>

#include <cstddef>
#include <string>
#include <vector>

#include "./simplify.hpp"
#include "argparse/arg_parser.hpp"
Expand Down Expand Up @@ -177,4 +180,42 @@ Command zxgraph_rule_cmd(zx::ZXGraphMgr &zxgraph_mgr) {
}};
}

// REVIEW - Logic of check function is not completed
Command zxgraph_manual_apply_cmd(zx::ZXGraphMgr &zxgraph_mgr) {
return Command{
"manual",
[&](ArgumentParser &parser) {
parser.description("apply simplification rules on specific candidates");

auto mutex = parser.add_mutually_exclusive_group().required(true);
mutex.add_argument<bool>("--pivot")
.action(store_true)
.help("applies pivot rules to vertex pairs with phase 0 or π");
mutex.add_argument<bool>("--pivot-boundary")
.action(store_true)
.help("applies pivot rules to vertex pairs connected to the boundary");
mutex.add_argument<bool>("--pivot-gadget")
.action(store_true)
.help("unfuses the phase and applies pivot rules to form gadgets");

parser.add_argument<size_t>("vertices")
.nargs(2)
.constraint(valid_zxvertex_id(zxgraph_mgr))
.help("the vertices on which the rule applies");
},
[&](ArgumentParser const &parser) {
if (!dvlab::utils::mgr_has_data(zxgraph_mgr)) return dvlab::CmdExecResult::error;
auto vertices = parser.get<std::vector<size_t>>("vertices");
ZXVertex *bound = zxgraph_mgr.get()->find_vertex_by_id(vertices[0]);
ZXVertex *vert = zxgraph_mgr.get()->find_vertex_by_id(vertices[1]);

const bool is_cand = PivotBoundaryRule().is_candidate(*zxgraph_mgr.get(), bound, vert);
if (!is_cand) return CmdExecResult::error;

std::vector<std::pair<ZXVertex *, ZXVertex *>> match;
match.emplace_back(bound, vert);
PivotBoundaryRule().apply(*zxgraph_mgr.get(), match);
return CmdExecResult::done;
}};
}
} // namespace qsyn::zx
1 change: 1 addition & 0 deletions src/zx/simplifier/simp_cmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ namespace qsyn::zx {

dvlab::Command zxgraph_optimize_cmd(ZXGraphMgr &zxgraph_mgr);
dvlab::Command zxgraph_rule_cmd(ZXGraphMgr &zxgraph_mgr);
dvlab::Command zxgraph_manual_apply_cmd(ZXGraphMgr &zxgraph_mgr);

} // namespace qsyn::zx
3 changes: 2 additions & 1 deletion src/zx/simplifier/simplify.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class Simplifier {
hadamard_rule_simp();
}
~Simplifier() {
_simp_graph->adjustVertexCoordinates();
// REVIEW - Whether to adjust
// _simp_graph->adjust_vertex_coordinates();
}
Simplifier(Simplifier const& other) = default;
Simplifier(Simplifier&& other) = default;
Expand Down
16 changes: 15 additions & 1 deletion src/zx/zx_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,18 @@ namespace qsyn::zx {

std::function<bool(size_t const&)> valid_zxvertex_id(ZXGraphMgr const& zxgraph_mgr) {
return [&](size_t const& id) {
if (zxgraph_mgr.get()->is_v_id(id)) return true;
if (zxgraph_mgr.get() && zxgraph_mgr.get()->is_v_id(id)) return true;
spdlog::error("Cannot find vertex with ID {} in the ZXGraph!!", id);
return false;
};
}

std::function<bool(size_t const&)> zxgraph_id_not_exist(ZXGraphMgr const& zxgraph_mgr) {
return [&](size_t const& id) {
if (!zxgraph_mgr.get()) {
spdlog::error("ZXGraphMgr does not exist!!");
return true;
}
if (!zxgraph_mgr.is_id(id)) return true;
spdlog::error("ZXGraph {} already exists!!", id);
spdlog::info("Use `-Replace` if you want to overwrite it.");
Expand All @@ -43,6 +47,10 @@ std::function<bool(size_t const&)> zxgraph_id_not_exist(ZXGraphMgr const& zxgrap

std::function<bool(int const&)> zxgraph_input_qubit_not_exist(ZXGraphMgr const& zxgraph_mgr) {
return [&](int const& qid) {
if (!zxgraph_mgr.get()) {
spdlog::error("ZXGraphMgr does not exist!!");
return true;
}
if (!zxgraph_mgr.get()->is_input_qubit(qid)) return true;
spdlog::error("This qubit's input already exists!!");
return false;
Expand All @@ -51,6 +59,10 @@ std::function<bool(int const&)> zxgraph_input_qubit_not_exist(ZXGraphMgr const&

std::function<bool(int const&)> zxgraph_output_qubit_not_exist(ZXGraphMgr const& zxgraph_mgr) {
return [&](int const& qid) {
if (!zxgraph_mgr.get()) {
spdlog::error("ZXGraphMgr does not exist!!");
return true;
}
if (!zxgraph_mgr.get()->is_output_qubit(qid)) return true;
spdlog::error("This qubit's output already exists!!");
return false;
Expand Down Expand Up @@ -242,6 +254,7 @@ Command zxgraph_draw_cmd(ZXGraphMgr const& zxgraph_mgr) {
[&](ArgumentParser const& parser) {
if (!dvlab::utils::mgr_has_data(zxgraph_mgr)) return CmdExecResult::error;
if (parser.parsed("filepath")) {
zxgraph_mgr.get()->adjust_vertex_coordinates();
if (!zxgraph_mgr.get()->write_pdf(parser.get<std::string>("filepath"))) return CmdExecResult::error;
}
if (parser.parsed("-cli")) {
Expand Down Expand Up @@ -661,6 +674,7 @@ Command zxgraph_cmd(ZXGraphMgr& zxgraph_mgr) {
cmd.add_subcommand(zxgraph_gflow_cmd(zxgraph_mgr));
cmd.add_subcommand(zxgraph_optimize_cmd(zxgraph_mgr));
cmd.add_subcommand(zxgraph_rule_cmd(zxgraph_mgr));
cmd.add_subcommand(zxgraph_manual_apply_cmd(zxgraph_mgr));
cmd.add_subcommand(zxgraph_vertex_cmd(zxgraph_mgr));
cmd.add_subcommand(zxgraph_edge_cmd(zxgraph_mgr));
return cmd;
Expand Down
4 changes: 4 additions & 0 deletions src/zx/zx_cmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

namespace qsyn::zx {

std::function<bool(size_t const&)> valid_zxvertex_id(ZXGraphMgr const& zxgraph_mgr);
std::function<bool(size_t const&)> zxgraph_id_not_exist(ZXGraphMgr const& zxgraph_mgr);
std::function<bool(int const&)> zxgraph_input_qubit_not_exist(ZXGraphMgr const& zxgraph_mgr);
std::function<bool(int const&)> zxgraph_output_qubit_not_exist(ZXGraphMgr const& zxgraph_mgr);
bool add_zx_cmds(dvlab::CommandLineInterface& cli, qsyn::zx::ZXGraphMgr& zxgraph_mgr);

} // namespace qsyn::zx
2 changes: 1 addition & 1 deletion src/zx/zxgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class ZXGraph { // NOLINT(cppcoreguidelines-special-member-functions) : copy-sw
void add_gadget(Phase p, std::vector<ZXVertex*> const& vertices);
void remove_gadget(ZXVertex* v);
std::unordered_map<size_t, ZXVertex*> create_id_to_vertex_map() const;
void adjustVertexCoordinates();
void adjust_vertex_coordinates();

// Print functions (zxGraphPrint.cpp)
void print_graph(spdlog::level::level_enum lvl = spdlog::level::level_enum::off) const;
Expand Down
61 changes: 41 additions & 20 deletions src/zx/zxgraph_action.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Copyright [ Copyright(c) 2023 DVLab, GIEE, NTU, Taiwan ]
****************************************************************************/

#include <fmt/core.h>

#include <cstddef>
#include <gsl/narrow>
#include <queue>
Expand Down Expand Up @@ -237,46 +239,65 @@ std::unordered_map<size_t, ZXVertex*> ZXGraph::create_id_to_vertex_map() const {
* @brief Rearrange vertices on each qubit so that each vertex can be seperated in the printed graph.
*
*/
void ZXGraph::adjustVertexCoordinates() {
void ZXGraph::adjust_vertex_coordinates() {
// FIXME - QubitId -> RowId
std::unordered_map<QubitIdType, std::vector<ZXVertex*>> qubit_id_to_vertices_map;
std::unordered_set<QubitIdType> visited_qubit_ids;
std::queue<ZXVertex*> vertex_queue;
std::vector<ZXVertex*> vertex_queue;
// NOTE - Check Gadgets
// FIXME - When replacing QubitId with RowId, add 0.5 on it
for (auto const& i : _vertices) {
if (i->get_qubit() == -2 && get_num_neighbors(i) > 1) {
std::unordered_map<QubitIdType, size_t> num_neighbor_qubits;
for (auto const& [nb, _] : get_neighbors(i)) {
if (num_neighbor_qubits.contains(nb->get_qubit())) {
num_neighbor_qubits[nb->get_qubit()]++;
fmt::println("add qb: {}", nb->get_qubit());
} else
num_neighbor_qubits[nb->get_qubit()] = 1;
}
fmt::println("move to {}", (*max_element(num_neighbor_qubits.begin(), num_neighbor_qubits.end(), [](const std::pair<QubitIdType, size_t>& p1, const std::pair<QubitIdType, size_t>& p2) { return p1.second < p2.second; })).first);
i->set_qubit((*max_element(num_neighbor_qubits.begin(), num_neighbor_qubits.end(), [](const std::pair<QubitIdType, size_t>& p1, const std::pair<QubitIdType, size_t>& p2) { return p1.second < p2.second; })).first);
}
}

// REVIEW - Whether to move the vertex from row -2 when it is no longer a gadget
// for (auto const& i : _vertices) {
// if (i->get_qubit() == -2 && get_num_neighbors(i) > 1) {
// std::unordered_map<QubitIdType, size_t> num_neighbor_qubits;
// for (auto const& [nb, _] : get_neighbors(i)) {
// if (num_neighbor_qubits.contains(nb->get_qubit())) {
// num_neighbor_qubits[nb->get_qubit()]++;
// } else
// num_neighbor_qubits[nb->get_qubit()] = 1;
// }
// // fmt::println("move to {}", (*max_element(num_neighbor_qubits.begin(), num_neighbor_qubits.end(), [](const std::pair<QubitIdType, size_t>& p1, const std::pair<QubitIdType, size_t>& p2) { return p1.second < p2.second; })).first);
// i->set_qubit((*max_element(num_neighbor_qubits.begin(), num_neighbor_qubits.end(), [](const std::pair<QubitIdType, size_t>& p1, const std::pair<QubitIdType, size_t>& p2) { return p1.second < p2.second; })).first);
// }
// }

for (auto const& i : _inputs) {
vertex_queue.push(i);
vertex_queue.emplace_back(i);
visited_qubit_ids.insert(gsl::narrow<QubitIdType>(i->get_id()));
}
while (!vertex_queue.empty()) {
ZXVertex* v = vertex_queue.front();
vertex_queue.pop();
vertex_queue.erase(vertex_queue.begin());
qubit_id_to_vertices_map[v->get_qubit()].emplace_back(v);
for (auto const& nb : get_neighbors(v) | std::views::keys) {
if (visited_qubit_ids.find(gsl::narrow<QubitIdType>(nb->get_id())) == visited_qubit_ids.end()) {
vertex_queue.push(nb);
vertex_queue.emplace_back(nb);
visited_qubit_ids.insert(gsl::narrow<QubitIdType>(nb->get_id()));
}
}
}
std::vector<ZXVertex*> gadgets;
double non_gadget = 0;
for (size_t i = 0; i < qubit_id_to_vertices_map[-2].size(); i++) {
if (get_num_neighbors(qubit_id_to_vertices_map[-2][i]) == 1) { // Not Gadgets
gadgets.emplace_back(qubit_id_to_vertices_map[-2][i]);
} else
non_gadget++;
}
auto end_it = std::remove_if(
qubit_id_to_vertices_map[-2].begin(),
qubit_id_to_vertices_map[-2].end(),
[this](ZXVertex* v) {
return this->get_num_neighbors(v) == 1;
});
qubit_id_to_vertices_map[-2].erase(end_it, qubit_id_to_vertices_map[-2].end());

qubit_id_to_vertices_map[-2].insert(qubit_id_to_vertices_map[-2].end(), gadgets.begin(), gadgets.end());
double max_col = 0.0;
for (auto& i : qubit_id_to_vertices_map) {
double col = i.first < 0 ? 0.5 : 0.0;
double col = i.first == -2 ? 0.5 : i.first == -1 ? 0.5 + non_gadget
: 0.0;
for (auto& v : i.second) {
v->set_col(col);
col++;
Expand Down

0 comments on commit 6226205

Please sign in to comment.