Skip to content

Commit

Permalink
Update the Runtime CacheManager (#502)
Browse files Browse the repository at this point in the history
This PR,
- [X] Fixes the issue with not caching `QubitUnitary` ops in
LightningQubit
- [X] Update the cache-manager to support `QubitUnitrary` with respect
to the recent changes in the Lightning C++ API
- [X] Update the cache-manager to also cache controlled wires and values
to be later used in the computation of adjoint-jacobian.
  • Loading branch information
maliasadi authored Feb 15, 2024
1 parent d1a5ce5 commit b009321
Show file tree
Hide file tree
Showing 11 changed files with 209 additions and 50 deletions.
26 changes: 26 additions & 0 deletions frontend/test/pytest/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,5 +1065,31 @@ def f(x, y):
assert np.allclose(flatten_res_jax, flatten_res_catalyst)


@pytest.mark.xfail(reason="QubitUnitrary is not support with catalyst.grad")
@pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)])
def test_adj_qubitunitary(inp, backend):
"""Test the adjoint method."""

def f(x):
qml.RX(x, wires=0)
U1 = 1 / np.sqrt(2) * np.array([[1.0, 1.0], [1.0, -1.0]], dtype=complex)
qml.QubitUnitary(U1, wires=0)
return qml.expval(qml.PauliY(0))

@qjit()
def compiled(x: float):
g = qml.qnode(qml.device(backend, wires=1), diff_method="adjoint")(f)
h = grad(g, method="auto")
return h(x)

def interpreted(x):
device = qml.device("default.qubit", wires=1)
g = qml.QNode(f, device, diff_method="backprop")
h = qml.grad(g, argnum=0)
return h(x)

assert np.allclose(compiled(inp), interpreted(inp))


if __name__ == "__main__":
pytest.main(["-x", __file__])
4 changes: 2 additions & 2 deletions runtime/include/QuantumDevice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ struct QuantumDevice {
/**
* @brief Start recording a quantum tape if provided.
*
* @note This is backed by the `Catalyst::Runtime::CacheManager` property in
* @note This is backed by the `Catalyst::Runtime::CacheManager<ComplexT>` property in
* the device implementation.
*/
virtual void StartTapeRecording() = 0;

/**
* @brief Stop recording a quantum tape if provided.
*
* @note This is backed by the `Catalyst::Runtime::CacheManager` property in
* @note This is backed by the `Catalyst::Runtime::CacheManager<ComplexT>` property in
* the device implementation.
*/
virtual void StopTapeRecording() = 0;
Expand Down
100 changes: 67 additions & 33 deletions runtime/lib/backend/common/CacheManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <complex>
#include <string>
#include <vector>

Expand All @@ -29,13 +30,16 @@ namespace Catalyst::Runtime {
* of a circuit with taking advantage of gradient methods provided by
* simulators.
*/
class CacheManager {
template <typename ComplexT = std::complex<double>> class CacheManager {
protected:
// Operations Data
std::vector<std::string> ops_names_{};
std::vector<std::vector<double>> ops_params_{};
std::vector<std::vector<size_t>> ops_wires_{};
std::vector<bool> ops_inverses_{};
std::vector<std::vector<ComplexT>> ops_matrixs_{};
std::vector<std::vector<size_t>> ops_controlled_wires_{};
std::vector<std::vector<bool>> ops_controlled_values_{};

// Observables Data
std::vector<ObsIdType> obs_keys_{};
Expand All @@ -58,15 +62,18 @@ class CacheManager {
*/
void Reset()
{
this->ops_names_.clear();
this->ops_params_.clear();
this->ops_wires_.clear();
this->ops_inverses_.clear();

this->obs_keys_.clear();
this->obs_callees_.clear();

this->num_params_ = 0;
ops_names_.clear();
ops_params_.clear();
ops_wires_.clear();
ops_inverses_.clear();
ops_matrixs_.clear();
ops_controlled_wires_.clear();
ops_controlled_values_.clear();

obs_keys_.clear();
obs_callees_.clear();

num_params_ = 0;
}

/**
Expand All @@ -76,22 +83,31 @@ class CacheManager {
* @param params Parameters of the gate
* @param wires Wires the gate acts on
* @param inverse If true, inverse of the gate is applied
* @param matrix Unitary matrix for the 'MatrixOp' operations
* @param controlled_wires Control wires
* @param controlled_values Control values
*/
void addOperation(const std::string &name, const std::vector<double> &params,
const std::vector<size_t> &dev_wires, bool inverse)
const std::vector<size_t> &dev_wires, bool inverse,
const std::vector<ComplexT> &matrix = {},
const std::vector<size_t> &controlled_wires = {},
const std::vector<bool> &controlled_values = {})
{
this->ops_names_.push_back(name);
this->ops_params_.push_back(params);
ops_names_.push_back(name);
ops_params_.push_back(params);

std::vector<size_t> wires_ul;
wires_ul.reserve(dev_wires.size());
std::transform(dev_wires.begin(), dev_wires.end(), std::back_inserter(wires_ul),
[](auto w) { return static_cast<size_t>(w); });

this->ops_wires_.push_back(wires_ul);
this->ops_inverses_.push_back(inverse);
ops_wires_.push_back(wires_ul);
ops_inverses_.push_back(inverse);
ops_matrixs_.push_back(matrix);
ops_controlled_wires_.push_back(controlled_wires);
ops_controlled_values_.push_back(controlled_values);

this->num_params_ += params.size();
num_params_ += params.size();
}

/**
Expand All @@ -102,70 +118,88 @@ class CacheManager {
*/
void addObservable(const ObsIdType id, const MeasurementsT &callee = MeasurementsT::None)
{
this->obs_keys_.push_back(id);
this->obs_callees_.push_back(callee);
obs_keys_.push_back(id);
obs_callees_.push_back(callee);
}

/**
* @brief Get a reference to observables keys.
*/
auto getObservablesKeys() -> const std::vector<ObsIdType> & { return this->obs_keys_; }
auto getObservablesKeys() -> const std::vector<ObsIdType> & { return obs_keys_; }

/**
* @brief Get a reference to observables callees.
*/
auto getObservablesCallees() -> const std::vector<MeasurementsT> &
{
return this->obs_callees_;
}
auto getObservablesCallees() -> const std::vector<MeasurementsT> & { return obs_callees_; }

/**
* @brief Get a reference to operations names.
*/
auto getOperationsNames() -> const std::vector<std::string> & { return this->ops_names_; }
auto getOperationsNames() -> const std::vector<std::string> & { return ops_names_; }

/**
* @brief Get a a reference to operations parameters.
*/
auto getOperationsParameters() -> const std::vector<std::vector<double>> &
{
return this->ops_params_;
return ops_params_;
}

/**
* @brief Get a a reference to operations wires.
*/
auto getOperationsWires() -> const std::vector<std::vector<size_t>> &
auto getOperationsWires() -> const std::vector<std::vector<size_t>> & { return ops_wires_; }

/**
* @brief Get a reference to operations inverses.
*/
auto getOperationsInverses() -> const std::vector<bool> & { return ops_inverses_; }

/**
* @brief Get a reference to operations matrices.
*/
auto getOperationsMatrices() -> const std::vector<std::vector<ComplexT>> &
{
return this->ops_wires_;
return ops_matrixs_;
}

/**
* @brief Get a reference to operations inverses.
* @brief Get a reference to operations controlled wires.
*/
auto getOperationsInverses() -> const std::vector<bool> & { return this->ops_inverses_; }
auto getOperationsControlledWires() -> const std::vector<std::vector<size_t>> &
{
return ops_controlled_wires_;
}

/**
* @brief Get a reference to operations controlled values.
*/
auto getOperationsControlledValues() -> const std::vector<std::vector<bool>> &
{
return ops_controlled_values_;
}

/**
* @brief Get total number of cached gates.
*/
[[nodiscard]] auto getNumGates() const -> size_t
{
return this->ops_names_.size() + this->obs_keys_.size();
return ops_names_.size() + obs_keys_.size();
}

/**
* @brief Get number of operations.
*/
[[nodiscard]] auto getNumOperations() const -> size_t { return this->ops_names_.size(); }
[[nodiscard]] auto getNumOperations() const -> size_t { return ops_names_.size(); }

/**
* @brief Get number of observables.
*/
[[nodiscard]] auto getNumObservables() const -> size_t { return this->obs_keys_.size(); }
[[nodiscard]] auto getNumObservables() const -> size_t { return obs_keys_.size(); }

/**
* @brief Get total number of cached gates.
*/
[[nodiscard]] auto getNumParams() const -> size_t { return this->num_params_; }
[[nodiscard]] auto getNumParams() const -> size_t { return num_params_; }
};
} // namespace Catalyst::Runtime
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ auto LightningSimulator::One() const -> Result
void LightningSimulator::NamedOperation(const std::string &name, const std::vector<double> &params,
const std::vector<QubitIdType> &wires, bool inverse)
{
// Check the validity of qubits
RT_FAIL_IF(wires.empty(), "Invalid number of qubits");
RT_FAIL_IF(!isValidQubits(wires), "Invalid given wires");

// First, check if operation `name` is supported by the simulator
auto &&[op_num_wires, op_num_params] =
Lightning::lookup_gates(Lightning::simulator_gate_info, name);
Expand All @@ -131,19 +135,30 @@ void LightningSimulator::NamedOperation(const std::string &name, const std::vect

// Update tape caching if required
if (this->tape_recording) {
this->cache_manager.addOperation(name, params, dev_wires, inverse);
this->cache_manager.addOperation(name, params, dev_wires, inverse, {},
{/*controlled_wires*/}, {/*controlled_values*/});
}
}

void LightningSimulator::MatrixOperation(const std::vector<std::complex<double>> &matrix,
const std::vector<QubitIdType> &wires, bool inverse)
{
// Check the validity of qubits
RT_FAIL_IF(wires.empty(), "Invalid number of qubits");
RT_FAIL_IF(!isValidQubits(wires), "Invalid given wires");

// Convert wires to device wires
// with checking validity of wires
auto &&dev_wires = getDeviceWires(wires);

// Update the state-vector
this->device_sv->applyMatrix(matrix.data(), dev_wires, inverse);

// Update tape caching if required
if (this->tape_recording) {
this->cache_manager.addOperation("QubitUnitary", {}, dev_wires, inverse, matrix,
{/*controlled_wires*/}, {/*controlled_values*/});
}
}

auto LightningSimulator::Observable(ObsId id, const std::vector<std::complex<double>> &matrix,
Expand Down Expand Up @@ -470,10 +485,14 @@ void LightningSimulator::Gradient(std::vector<DataView<double, 1>> &gradients,
auto &&ops_names = this->cache_manager.getOperationsNames();
auto &&ops_params = this->cache_manager.getOperationsParameters();
auto &&ops_wires = this->cache_manager.getOperationsWires();

auto &&ops_inverses = this->cache_manager.getOperationsInverses();
const auto &&ops = Pennylane::Algorithms::OpsData<StateVectorT>(ops_names, ops_params,
ops_wires, ops_inverses);
auto &&ops_matrices = this->cache_manager.getOperationsMatrices();
auto &&ops_controlled_wires = this->cache_manager.getOperationsControlledWires();
auto &&ops_controlled_values = this->cache_manager.getOperationsControlledValues();

const auto &&ops = Pennylane::Algorithms::OpsData<StateVectorT>(
ops_names, ops_params, ops_wires, ops_inverses, ops_matrices, ops_controlled_wires,
ops_controlled_values);

// create the vector of observables
auto &&obs_keys = this->cache_manager.getObservablesKeys();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class LightningSimulator final : public Catalyst::Runtime::QuantumDevice {
"Local"}; // tidy: readability-magic-numbers

Catalyst::Runtime::QubitManager<QubitIdType, size_t> qubit_manager{};
Catalyst::Runtime::CacheManager cache_manager{};
Catalyst::Runtime::CacheManager<std::complex<double>> cache_manager{};
bool tape_recording{false};
size_t device_shots;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ void LightningKokkosSimulator::NamedOperation(const std::string &name,
const std::vector<double> &params,
const std::vector<QubitIdType> &wires, bool inverse)
{
// Check the validity of qubits
RT_FAIL_IF(wires.empty(), "Invalid number of qubits");
RT_FAIL_IF(!isValidQubits(wires), "Invalid given wires");

// First, check if operation `name` is supported by the simulator
auto &&[op_num_wires, op_num_params] =
Lightning::lookup_gates(Lightning::simulator_gate_info, name);
Expand All @@ -154,19 +158,21 @@ void LightningKokkosSimulator::NamedOperation(const std::string &name,

// Update tape caching if required
if (this->tape_recording) {
this->cache_manager.addOperation(name, params, dev_wires, inverse);
this->cache_manager.addOperation(name, params, dev_wires, inverse, {},
{/*controlled_wires*/}, {/*controlled_values*/});
}
}

void LightningKokkosSimulator::MatrixOperation(const std::vector<std::complex<double>> &matrix,
const std::vector<QubitIdType> &wires, bool inverse)
{
// Check the validity of qubits
RT_FAIL_IF(wires.empty(), "Invalid number of qubits");
RT_FAIL_IF(!isValidQubits(wires), "Invalid given wires");

using UnmanagedComplexHostView = Kokkos::View<Kokkos::complex<double> *, Kokkos::HostSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

// Check the validity of number of qubits and parameters
RT_FAIL_IF(!wires.size(), "Invalid number of qubits");

// Convert wires to device wires
auto &&dev_wires = getDeviceWires(wires);

Expand All @@ -183,7 +189,8 @@ void LightningKokkosSimulator::MatrixOperation(const std::vector<std::complex<do

// Update tape caching if required
if (this->tape_recording) {
this->cache_manager.addOperation("MatrixOp", {}, dev_wires, inverse);
this->cache_manager.addOperation("QubitUnitary", {}, dev_wires, inverse, matrix_kok,
{/*controlled_wires*/}, {/*controlled_values*/});
}
}

Expand Down Expand Up @@ -519,9 +526,13 @@ void LightningKokkosSimulator::Gradient(std::vector<DataView<double, 1>> &gradie
auto &&ops_params = this->cache_manager.getOperationsParameters();
auto &&ops_wires = this->cache_manager.getOperationsWires();
auto &&ops_inverses = this->cache_manager.getOperationsInverses();
auto &&ops_matrices = this->cache_manager.getOperationsMatrices();
auto &&ops_controlled_wires = this->cache_manager.getOperationsControlledWires();
auto &&ops_controlled_values = this->cache_manager.getOperationsControlledValues();

const auto &&ops = Pennylane::Algorithms::OpsData<StateVectorT>(ops_names, ops_params,
ops_wires, ops_inverses);
const auto &&ops = Pennylane::Algorithms::OpsData<StateVectorT>(
ops_names, ops_params, ops_wires, ops_inverses, ops_matrices, ops_controlled_wires,
ops_controlled_values);

// Create the vector of observables
auto &&obs_keys = this->cache_manager.getObservablesKeys();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class LightningKokkosSimulator final : public Catalyst::Runtime::QuantumDevice {
static constexpr bool GLOBAL_RESULT_FALSE_CONST = false;

Catalyst::Runtime::QubitManager<QubitIdType, size_t> qubit_manager{};
Catalyst::Runtime::CacheManager cache_manager{};
Catalyst::Runtime::CacheManager<Kokkos::complex<double>> cache_manager{};
bool tape_recording{false};

size_t device_shots;
Expand Down
Loading

0 comments on commit b009321

Please sign in to comment.