Skip to content

Commit

Permalink
Optimize lightning.tensor by adding direct MPS sites data set (#983)
Browse files Browse the repository at this point in the history
**Context:**
Optimize `lightning.tensor` by adding direct MPS sites data set

**Description of the Change:**
Adding the `MPSPrep` gate to be able to pass an MPS directly to the
Tensor Network.
The `MPSPrep` gate frontend was developed on this
[PR](PennyLaneAI/pennylane#6431)

**Benefits:**
Avoid the decomposition from state vector to MPS sites which are
expensive and inefficient

**Possible Drawbacks:**

**Related GitHub Issues:**
[sc-74709]

---------

Co-authored-by: ringo-but-quantum <github-ringo-but-quantum@xanadu.ai>
  • Loading branch information
LuisAlfredoNu and ringo-but-quantum authored Dec 20, 2024
1 parent 11b03d4 commit 87bcd10
Show file tree
Hide file tree
Showing 12 changed files with 281 additions and 36 deletions.
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@

### Improvements

* Optimize lightning.tensor by adding direct MPS sites data set with `qml.MPSPrep`.
[(#983)](https://github.com/PennyLaneAI/pennylane-lightning/pull/983)

* Replace the `dummy_tensor_update` method with the `cutensornetStateCaptureMPS`API to ensure that further gates apply is allowed after the `cutensornetStateCompute` call.
[(#1028)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1028/)

Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.40.0-dev41"
__version__ = "0.40.0-dev42"
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,6 @@ class TNCuda : public TNCudaBase<PrecisionT, Derived> {
using ComplexT = std::complex<PrecisionT>;
using BaseType = TNCudaBase<PrecisionT, Derived>;

protected:
// Note both maxBondDim_ and bondDims_ are used for both MPS and Exact
// Tensor Network. Per Exact Tensor Network, maxBondDim_ is 1 and bondDims_
// is {1}. Per Exact Tensor Network, setting bondDims_ allows call to
// appendInitialMPSState_() to append the initial state to the Exact Tensor
// Network state.
const std::size_t
maxBondDim_; // maxBondDim_ default is 1 for Exact Tensor Network
const std::vector<std::size_t>
bondDims_; // bondDims_ default is {1} for Exact Tensor Network

private:
const std::vector<std::vector<std::size_t>> sitesModes_;
const std::vector<std::vector<std::size_t>> sitesExtents_;
const std::vector<std::vector<int64_t>> sitesExtents_int64_;

SharedCublasCaller cublascaller_;

std::shared_ptr<TNCudaGateCache<PrecisionT>> gate_cache_;
std::set<int64_t> gate_ids_;

std::vector<std::size_t> identiy_gate_ids_;

std::vector<TensorCuda<PrecisionT>> tensors_;
std::vector<TensorCuda<PrecisionT>> tensors_out_;

public:
TNCuda() = delete;

Expand Down Expand Up @@ -499,7 +473,27 @@ class TNCuda : public TNCudaBase<PrecisionT, Derived> {
projected_mode_values, numHyperSamples);
}

/**
* @brief Get a const vector reference of sitesExtents_.
*
* @return const std::vector<std::vector<std::size_t>>
*/
[[nodiscard]] auto getSitesExtents() const
-> const std::vector<std::vector<std::size_t>> & {
return sitesExtents_;
}

protected:
// Note both maxBondDim_ and bondDims_ are used for both MPS and Exact
// Tensor Network. For Exact Tensor Network, maxBondDim_ is 1 and bondDims_
// is {1}. For Exact Tensor Network, setting bondDims_ allows call to
// appendInitialMPSState_() to append the initial state to the Exact Tensor
// Network state.
const std::size_t
maxBondDim_; // maxBondDim_ default is 1 for Exact Tensor Network
const std::vector<std::size_t>
bondDims_; // bondDims_ default is {1} for Exact Tensor Network

/**
* @brief Get a vector of pointers to tensor data of each site.
*
Expand Down Expand Up @@ -578,6 +572,20 @@ class TNCuda : public TNCudaBase<PrecisionT, Derived> {
}

private:
const std::vector<std::vector<std::size_t>> sitesModes_;
const std::vector<std::vector<std::size_t>> sitesExtents_;
const std::vector<std::vector<int64_t>> sitesExtents_int64_;

SharedCublasCaller cublascaller_;

std::shared_ptr<TNCudaGateCache<PrecisionT>> gate_cache_;
std::set<int64_t> gate_ids_;

std::vector<std::size_t> identiy_gate_ids_;

std::vector<TensorCuda<PrecisionT>> tensors_;
std::vector<TensorCuda<PrecisionT>> tensors_out_;

/**
* @brief Get accessor of a state tensor
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
#include "TypeList.hpp"
#include "Util.hpp"
#include "cuda_helpers.hpp"
#include "tncuda_helpers.hpp"

/// @cond DEV
namespace {
using namespace Pennylane;
using namespace Pennylane::Bindings;
using namespace Pennylane::LightningGPU::Util;
using Pennylane::LightningTensor::TNCuda::MPSTNCuda;
using namespace Pennylane::LightningTensor::TNCuda::Util;
} // namespace
/// @endcond

Expand Down Expand Up @@ -137,6 +138,20 @@ void registerBackendClassSpecificBindingsMPS(PyClass &pyclass) {
.def(
"updateMPSSitesData",
[](TensorNet &tensor_network, std::vector<np_arr_c> &tensors) {
// Extract the incoming MPS shape
std::vector<std::vector<std::size_t>> MPS_shape_source;
for (std::size_t idx = 0; idx < tensors.size(); idx++) {
py::buffer_info numpyArrayInfo = tensors[idx].request();
auto MPS_site_source_shape = numpyArrayInfo.shape;
std::vector<std::size_t> MPS_site_source(
MPS_site_source_shape.begin(),
MPS_site_source_shape.end());
MPS_shape_source.emplace_back(std::move(MPS_site_source));
}

const auto &MPS_shape_dest = tensor_network.getSitesExtents();
MPSShapeCheck(MPS_shape_dest, MPS_shape_source);

for (std::size_t idx = 0; idx < tensors.size(); idx++) {
py::buffer_info numpyArrayInfo = tensors[idx].request();
auto *data_ptr = static_cast<std::complex<PrecisionT> *>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ TEMPLATE_TEST_CASE("MPSTNCuda::getDataVector()", "[MPSTNCuda]", float, double) {

TEMPLATE_TEST_CASE("MPOTNCuda::getBondDims()", "[MPOTNCuda]", float, double) {
using cp_t = std::complex<TestType>;
SECTION("Check if bondDims is correctly set") {
SECTION("Check if bondDims is correct set") {
const std::size_t num_qubits = 3;
const std::size_t maxBondDim = 128;
const DevTag<int> dev_tag{0, 0};
Expand Down Expand Up @@ -323,4 +323,64 @@ TEMPLATE_TEST_CASE("MPOTNCuda::getBondDims()", "[MPOTNCuda]", float, double) {

CHECK(bondDims == expected_bondDims);
}
}
}

TEMPLATE_TEST_CASE("MPSTNCuda::getSitesExtents()", "[MPSTNCuda]", float,
double) {
SECTION("Check if sitesExtents retrun is correct with 3 qubits") {
const std::size_t num_qubits = 3;
const std::size_t maxBondDim = 128;
const DevTag<int> dev_tag{0, 0};

const std::vector<std::vector<std::size_t>> reference{
{{2, 2}, {2, 2, 2}, {2, 2}}};

MPSTNCuda<TestType> mps{num_qubits, maxBondDim, dev_tag};

const auto &sitesExtents = mps.getSitesExtents();

CHECK(reference == sitesExtents);
}

SECTION("Check if sitesExtents retrun is correct with 8 qubits") {
const std::size_t num_qubits = 8;
const std::size_t maxBondDim = 128;
const DevTag<int> dev_tag{0, 0};

const std::vector<std::vector<std::size_t>> reference{{{2, 2},
{2, 2, 4},
{4, 2, 8},
{8, 2, 16},
{16, 2, 8},
{8, 2, 4},
{4, 2, 2},
{2, 2}}};

MPSTNCuda<TestType> mps{num_qubits, maxBondDim, dev_tag};

const auto &sitesExtents = mps.getSitesExtents();

CHECK(reference == sitesExtents);
}
SECTION("Check if sitesExtents retrun is correct with 8 qubits and "
"maxBondDim=8") {
const std::size_t num_qubits = 8;
const std::size_t maxBondDim = 8;
const DevTag<int> dev_tag{0, 0};

const std::vector<std::vector<std::size_t>> reference{{{2, 2},
{2, 2, 4},
{4, 2, 8},
{8, 2, 8},
{8, 2, 8},
{8, 2, 4},
{4, 2, 2},
{2, 2}}};

MPSTNCuda<TestType> mps{num_qubits, maxBondDim, dev_tag};

const auto &sitesExtents = mps.getSitesExtents();

CHECK(reference == sitesExtents);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,52 @@ TEST_CASE("swap_op_wires_queue", "[TNCuda_utils]") {
REQUIRE(swap_wires_queue[1] == swap_wires_queue_ref1);
}
}

TEST_CASE("MPSShapeCheck", "[TNCuda_utils]") {
SECTION("Correct incoming MPS shape") {
std::vector<std::vector<std::size_t>> MPS_shape_dest{
{2, 2}, {2, 2, 4}, {4, 2, 2}, {2, 2}};

std::vector<std::vector<std::size_t>> MPS_shape_source{
{2, 2}, {2, 2, 4}, {4, 2, 2}, {2, 2}};

REQUIRE_NOTHROW(MPSShapeCheck(MPS_shape_dest, MPS_shape_source));
}

SECTION("Incorrect incoming MPS shape, bond dimension") {
std::vector<std::vector<std::size_t>> MPS_shape_dest{
{2, 2}, {2, 2, 4}, {4, 2, 2}, {2, 2}};

std::vector<std::vector<std::size_t>> incorrect_MPS_shape{
{2, 2}, {2, 2, 2}, {2, 2, 2}, {2, 2}};

REQUIRE_THROWS_WITH(
MPSShapeCheck(MPS_shape_dest, incorrect_MPS_shape),
Catch::Matchers::Contains("The incoming MPS does not have the "
"correct layout for lightning.tensor"));
}
SECTION("Incorrect incoming MPS shape, physical dimension") {
std::vector<std::vector<std::size_t>> MPS_shape_dest{
{2, 2}, {2, 2, 4}, {4, 2, 2}, {2, 2}};

std::vector<std::vector<std::size_t>> incorrect_shape{
{4, 2}, {2, 4, 4}, {4, 4, 2}, {2, 4}};

REQUIRE_THROWS_WITH(
MPSShapeCheck(MPS_shape_dest, incorrect_shape),
Catch::Matchers::Contains("The incoming MPS does not have the "
"correct layout for lightning.tensor"));
}
SECTION("Incorrect incoming MPS shape, number sites") {
std::vector<std::vector<std::size_t>> MPS_shape_dest{
{2, 2}, {2, 2, 4}, {4, 2, 2}, {2, 2}};

std::vector<std::vector<std::size_t>> incorrect_shape{
{2, 2}, {2, 2, 2}, {2, 2}};

REQUIRE_THROWS_WITH(
MPSShapeCheck(MPS_shape_dest, incorrect_shape),
Catch::Matchers::Contains("The incoming MPS does not have the "
"correct layout for lightning.tensor"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,19 @@ inline auto create_swap_wire_pair_queue(const std::vector<std::size_t> &wires)
return {local_wires, swap_wires_queue};
}

/**
* @brief Check if the provided MPS has the correct dimension for C++
* backend.
*
* @param MPS_shape_dest Dimension list of destination MPS.
* @param MPS_shape_source Dimension list of incoming MPS.
*/
inline void
MPSShapeCheck(const std::vector<std::vector<std::size_t>> &MPS_shape_dest,
const std::vector<std::vector<std::size_t>> &MPS_shape_source) {
PL_ABORT_IF_NOT(MPS_shape_dest == MPS_shape_source,
"The incoming MPS does not have the correct layout for "
"lightning.tensor.")
}

} // namespace Pennylane::LightningTensor::TNCuda::Util
1 change: 1 addition & 0 deletions pennylane_lightning/core/src/utils/Util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,4 +592,5 @@ bool areVecsDisjoint(const std::vector<T> &v1, const std::vector<T> &v2) {
}
return true;
}

} // namespace Pennylane::Util
15 changes: 11 additions & 4 deletions pennylane_lightning/lightning_tensor/_tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import numpy as np
import pennylane as qml
from pennylane import BasisState, DeviceError, StatePrep
from pennylane import BasisState, DeviceError, MPSPrep, StatePrep
from pennylane.ops.op_math import Adjoint
from pennylane.tape import QuantumScript
from pennylane.wires import Wires
Expand Down Expand Up @@ -433,17 +433,24 @@ def apply_operations(self, operations):
# State preparation is currently done in Python
if operations: # make sure operations[0] exists
if isinstance(operations[0], StatePrep):
if self.method == "tn":
raise DeviceError("Exact Tensor Network does not support StatePrep")

if self.method == "mps":
self._apply_state_vector(
operations[0].parameters[0].copy(), operations[0].wires
)
operations = operations[1:]
if self.method == "tn":
raise DeviceError("Exact Tensor Network does not support StatePrep")
elif isinstance(operations[0], BasisState):
self._apply_basis_state(operations[0].parameters[0], operations[0].wires)
operations = operations[1:]
elif isinstance(operations[0], MPSPrep):
if self.method == "mps":
mps = operations[0].mps
self._tensornet.updateMPSSitesData(mps)
operations = operations[1:]

if self.method == "tn":
raise DeviceError("Exact Tensor Network does not support MPSPrep")

self._apply_lightning(operations)

Expand Down
4 changes: 4 additions & 0 deletions pennylane_lightning/lightning_tensor/lightning_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
{
"Identity",
"BasisState",
"MPSPrep",
"QubitUnitary",
"ControlledQubitUnitary",
"DiagonalQubitUnitary",
Expand Down Expand Up @@ -169,6 +170,9 @@ def stopping_condition(op: Operator) -> bool:
if isinstance(op, qml.ControlledQubitUnitary):
return True

if isinstance(op, qml.MPSPrep):
return True

return op.has_matrix and op.name in _operations


Expand Down
Loading

0 comments on commit 87bcd10

Please sign in to comment.