From 7ce6aca46cd70760c0dc66efe56a6a791a23fdf5 Mon Sep 17 00:00:00 2001 From: Thomas Loke Date: Tue, 2 Mar 2021 06:40:11 +0800 Subject: [PATCH] Re-add dispatch table (#68) * Re-add dispatch table * Update changelog --- CHANGELOG.md | 8 ++ pennylane_lightning/src/rework/Apply.cpp | 2 +- .../src/rework/GateFactory.cpp | 90 ------------------- .../src/rework/GateFactory.hpp | 37 -------- pennylane_lightning/src/rework/Gates.cpp | 51 ++++++++++- pennylane_lightning/src/rework/Gates.hpp | 11 +++ setup.py | 2 - 7 files changed, 69 insertions(+), 132 deletions(-) delete mode 100644 pennylane_lightning/src/rework/GateFactory.cpp delete mode 100644 pennylane_lightning/src/rework/GateFactory.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bef4d08e0..a64ad77d05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,14 +6,22 @@ ### Improvements +* Add new lightweight backend with performance improvements + [(#57)](https://github.com/PennyLaneAI/pennylane-lightning/pull/57) + ### Documentation ### Bug fixes +* Re-add dispatch table after fixing static initialisation order issue + [(#68)](https://github.com/PennyLaneAI/pennylane-lightning/pull/68) + ### Contributors This release contains contributions from (in alphabetical order): +Thomas Loke + --- # Release 0.14.1 diff --git a/pennylane_lightning/src/rework/Apply.cpp b/pennylane_lightning/src/rework/Apply.cpp index f8a439b2aa..18095d268c 100644 --- a/pennylane_lightning/src/rework/Apply.cpp +++ b/pennylane_lightning/src/rework/Apply.cpp @@ -14,7 +14,7 @@ #include #include "Apply.hpp" -#include "GateFactory.hpp" +#include "Gates.hpp" #include "StateVector.hpp" #include "Util.hpp" diff --git a/pennylane_lightning/src/rework/GateFactory.cpp b/pennylane_lightning/src/rework/GateFactory.cpp deleted file mode 100644 index ac67054cce..0000000000 --- a/pennylane_lightning/src/rework/GateFactory.cpp +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2021 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "GateFactory.hpp" - -using std::string; -using std::unique_ptr; -using std::vector; - -// FIXME: This should be reworked to use a function dispatch table -unique_ptr Pennylane::constructGate(const string& label, const vector& parameters) { - std::unique_ptr gate; - - if (Pennylane::XGate::label == label) { - gate = std::make_unique(Pennylane::XGate::create(parameters)); - } - else if (Pennylane::YGate::label == label) { - gate = std::make_unique(Pennylane::YGate::create(parameters)); - } - else if (Pennylane::ZGate::label == label) { - gate = std::make_unique(Pennylane::ZGate::create(parameters)); - } - else if (Pennylane::HadamardGate::label == label) { - gate = std::make_unique(Pennylane::HadamardGate::create(parameters)); - } - else if (Pennylane::SGate::label == label) { - gate = std::make_unique(Pennylane::SGate::create(parameters)); - } - else if (Pennylane::TGate::label == label) { - gate = std::make_unique(Pennylane::TGate::create(parameters)); - } - else if (Pennylane::RotationXGate::label == label) { - gate = std::make_unique(Pennylane::RotationXGate::create(parameters)); - } - else if (Pennylane::RotationYGate::label == label) { - gate = std::make_unique(Pennylane::RotationYGate::create(parameters)); - } - else if (Pennylane::RotationZGate::label == label) { - gate = std::make_unique(Pennylane::RotationZGate::create(parameters)); - } - else if (Pennylane::PhaseShiftGate::label == label) { - gate = std::make_unique(Pennylane::PhaseShiftGate::create(parameters)); - } - else if (Pennylane::GeneralRotationGate::label == label) { - gate = std::make_unique(Pennylane::GeneralRotationGate::create(parameters)); - } - else if (Pennylane::CNOTGate::label == label) { - gate = std::make_unique(Pennylane::CNOTGate::create(parameters)); - } - else if (Pennylane::SWAPGate::label == label) { - gate = std::make_unique(Pennylane::SWAPGate::create(parameters)); - } - else if (Pennylane::CZGate::label == label) { - gate = std::make_unique(Pennylane::CZGate::create(parameters)); - } - else if (Pennylane::CRotationXGate::label == label) { - gate = std::make_unique(Pennylane::CRotationXGate::create(parameters)); - } - else if (Pennylane::CRotationYGate::label == label) { - gate = std::make_unique(Pennylane::CRotationYGate::create(parameters)); - } - else if (Pennylane::CRotationZGate::label == label) { - gate = std::make_unique(Pennylane::CRotationZGate::create(parameters)); - } - else if (Pennylane::CGeneralRotationGate::label == label) { - gate = std::make_unique(Pennylane::CGeneralRotationGate::create(parameters)); - } - else if (Pennylane::ToffoliGate::label == label) { - gate = std::make_unique(Pennylane::ToffoliGate::create(parameters)); - } - else if (Pennylane::CSWAPGate::label == label) { - gate = std::make_unique(Pennylane::CSWAPGate::create(parameters)); - } - else { - throw std::invalid_argument(label + " is not a valid gate type"); - } - - return gate; -} diff --git a/pennylane_lightning/src/rework/GateFactory.hpp b/pennylane_lightning/src/rework/GateFactory.hpp deleted file mode 100644 index e5873fbf40..0000000000 --- a/pennylane_lightning/src/rework/GateFactory.hpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2021 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -/** - * @file - * Contains methods that produce gates from the requisite parameters. - */ -#pragma once - -#include -#include - -#include "Gates.hpp" - -namespace Pennylane { - - /** - * Produces the requested gate, defined by a label and the list of parameters - * - * @param label unique string corresponding to a gate type - * @param parameters defines the gate parameterisation (may be zero-length for some gates) - * @return the gate wrapped in std::unique_ptr - * @throws std::invalid_argument thrown if the gate type is not defined, or if the number of parameters to the gate is incorrect - */ - std::unique_ptr constructGate(const std::string& label, const std::vector& parameters); - -} diff --git a/pennylane_lightning/src/rework/Gates.cpp b/pennylane_lightning/src/rework/Gates.cpp index f34a15fed0..f08a1915d3 100644 --- a/pennylane_lightning/src/rework/Gates.cpp +++ b/pennylane_lightning/src/rework/Gates.cpp @@ -15,15 +15,20 @@ #define _USE_MATH_DEFINES #include +#include +#include #include "Gates.hpp" -#include "typedefs.hpp" #include "Util.hpp" -using Pennylane::CplxType; +using std::function; +using std::map; using std::string; +using std::unique_ptr; using std::vector; +using Pennylane::CplxType; + template static void validateLength(const string& errorPrefix, const vector& vec, int requiredLength) { if (vec.size() != requiredLength) @@ -375,3 +380,45 @@ const std::vector Pennylane::CSWAPGate::matrix{ 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }; + +// ------------------------------------------------------------------------------------------------------------- + +template +static void addToDispatchTable(map(const vector&)>>& dispatchTable) { + dispatchTable.emplace(GateType::label, [](const vector& parameters) { return std::make_unique(GateType::create(parameters)); }); +} + +static map(const vector&)>> createDispatchTable() { + map(const vector&)>> dispatchTable; + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + addToDispatchTable(dispatchTable); + return dispatchTable; +} + +static const map(const vector&)>> dispatchTable = createDispatchTable(); + +unique_ptr Pennylane::constructGate(const string& label, const vector& parameters) { + auto dispatchTableIterator = dispatchTable.find(label); + if (dispatchTableIterator == dispatchTable.end()) + throw std::invalid_argument(label + " is not a supported gate type"); + + return dispatchTableIterator->second(parameters); +} diff --git a/pennylane_lightning/src/rework/Gates.hpp b/pennylane_lightning/src/rework/Gates.hpp index c73e7559d4..f0673defb8 100644 --- a/pennylane_lightning/src/rework/Gates.hpp +++ b/pennylane_lightning/src/rework/Gates.hpp @@ -17,6 +17,7 @@ */ #pragma once +#include #include #include "typedefs.hpp" @@ -298,4 +299,14 @@ namespace Pennylane { } }; + /** + * Produces the requested gate, defined by a label and the list of parameters + * + * @param label unique string corresponding to a gate type + * @param parameters defines the gate parameterisation (may be zero-length for some gates) + * @return the gate wrapped in std::unique_ptr + * @throws std::invalid_argument thrown if the gate type is not defined, or if the number of parameters to the gate is incorrect + */ + std::unique_ptr constructGate(const std::string& label, const std::vector& parameters); + } diff --git a/setup.py b/setup.py index 164eb7274b..db34f3f7c3 100644 --- a/setup.py +++ b/setup.py @@ -167,13 +167,11 @@ def build_extensions(self): "lightning_qubit_new_ops", sources=[ "pennylane_lightning/src/rework/Apply.cpp", - "pennylane_lightning/src/rework/GateFactory.cpp", "pennylane_lightning/src/rework/Gates.cpp", "pennylane_lightning/src/rework/StateVector.cpp", ], depends=[ "pennylane_lightning/src/rework/Apply.hpp", - "pennylane_lightning/src/rework/GateFactory.hpp", "pennylane_lightning/src/rework/Gates.hpp", "pennylane_lightning/src/rework/StateVector.hpp", "pennylane_lightning/src/rework/typedefs.hpp",