Skip to content

Commit

Permalink
Improvements for Trace Computation (#628)
Browse files Browse the repository at this point in the history
## Description

This PR introduces improvements for the trace computation: 

1. Adds a compute table for storing results of eliminated nodes during
(partial) trace computation
- this allows the full trace computation to scale linearly with the
number of nodes
- however, the partial trace computation still scales with the number of
paths in the DD if the bottom qubits are to be eliminated:
non-eliminated nodes cannot be stored in the compute table, as this
would prevent their proper elimination in subsequent trace calls
2. Normalizes the result

Fixes #336 

## Checklist:

<!---
This checklist serves as a reminder of a couple of things that ensure
your pull request will be merged swiftly.
-->

- [x] The pull request only contains commits that are related to it.
- [x] I have added appropriate tests and documentation.
- [x] I have made sure that all CI jobs on GitHub pass.
- [x] The pull request introduces no new warnings and follows the
project's style guidelines.

---------

Co-authored-by: burgholzer <burgholzer@me.com>
  • Loading branch information
TeWas and burgholzer authored Jun 19, 2024
1 parent 8b55942 commit 35e06ca
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 17 deletions.
6 changes: 4 additions & 2 deletions include/mqt-core/dd/DDpackageConfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

namespace dd {
struct DDPackageConfig {
// Note the order of parameters here must be the *same* as in the template
// definition.
static constexpr std::size_t UT_VEC_NBUCKET = 32768U;
static constexpr std::size_t UT_VEC_INITIAL_ALLOCATION_SIZE = 2048U;
static constexpr std::size_t UT_MAT_NBUCKET = 32768U;
Expand All @@ -22,6 +20,8 @@ struct DDPackageConfig {
static constexpr std::size_t CT_MAT_MAT_MULT_NBUCKET = 16384U;
static constexpr std::size_t CT_VEC_KRON_NBUCKET = 4096U;
static constexpr std::size_t CT_MAT_KRON_NBUCKET = 4096U;
static constexpr std::size_t CT_DM_TRACE_NBUCKET = 1U;
static constexpr std::size_t CT_MAT_TRACE_NBUCKET = 4096U;
static constexpr std::size_t CT_VEC_INNER_PROD_NBUCKET = 4096U;
static constexpr std::size_t CT_DM_NOISE_NBUCKET = 1U;
static constexpr std::size_t UT_DM_NBUCKET = 1U;
Expand Down Expand Up @@ -63,6 +63,8 @@ struct DensityMatrixSimulatorDDPackageConfig : public dd::DDPackageConfig {
static constexpr std::size_t UT_MAT_INITIAL_ALLOCATION_SIZE = 1U;
static constexpr std::size_t CT_VEC_KRON_NBUCKET = 1U;
static constexpr std::size_t CT_MAT_KRON_NBUCKET = 1U;
static constexpr std::size_t CT_DM_TRACE_NBUCKET = 4096U;
static constexpr std::size_t CT_MAT_TRACE_NBUCKET = 1U;
static constexpr std::size_t CT_VEC_INNER_PROD_NBUCKET = 1U;
static constexpr std::size_t STOCHASTIC_CACHE_OPS = 1U;
static constexpr std::size_t CT_VEC_ADD_MAG_NBUCKET = 1U;
Expand Down
87 changes: 76 additions & 11 deletions include/mqt-core/dd/Package.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "dd/CachedEdge.hpp"
#include "dd/Complex.hpp"
#include "dd/ComplexNumbers.hpp"
#include "dd/ComplexValue.hpp"
#include "dd/ComputeTable.hpp"
#include "dd/DDDefinitions.hpp"
#include "dd/DDpackageConfig.hpp"
Expand Down Expand Up @@ -256,16 +257,22 @@ template <class Config> class Package {
}
// invalidate all compute tables involving matrices if any matrix node has
// been collected
if (mCollect > 0 || dCollect > 0) {
if (mCollect > 0) {
matrixAdd.clear();
conjugateMatrixTranspose.clear();
matrixKronecker.clear();
matrixTrace.clear();
matrixVectorMultiplication.clear();
matrixMatrixMultiplication.clear();
stochasticNoiseOperationCache.clear();
}
// invalidate all compute tables involving density matrices if any density
// matrix node has been collected
if (dCollect > 0) {
densityAdd.clear();
densityDensityMultiplication.clear();
densityNoise.clear();
densityTrace.clear();
}
// invalidate all compute tables where any component of the entry contains
// numbers from the complex table if any complex numbers were collected
Expand All @@ -276,10 +283,12 @@ template <class Config> class Package {
vectorInnerProduct.clear();
vectorKronecker.clear();
matrixKronecker.clear();
matrixTrace.clear();
stochasticNoiseOperationCache.clear();
densityAdd.clear();
densityDensityMultiplication.clear();
densityNoise.clear();
densityTrace.clear();
}
return vCollect > 0 || mCollect > 0 || cCollect > 0;
}
Expand Down Expand Up @@ -885,11 +894,13 @@ template <class Config> class Package {
vectorInnerProduct.clear();
vectorKronecker.clear();
matrixKronecker.clear();
matrixTrace.clear();

stochasticNoiseOperationCache.clear();
densityAdd.clear();
densityDensityMultiplication.clear();
densityNoise.clear();
densityTrace.clear();
}

///
Expand Down Expand Up @@ -1939,6 +1950,19 @@ template <class Config> class Package {
/// (Partial) trace
///
public:
UnaryComputeTable<dNode*, dCachedEdge, Config::CT_DM_TRACE_NBUCKET>
densityTrace{};
UnaryComputeTable<mNode*, mCachedEdge, Config::CT_MAT_TRACE_NBUCKET>
matrixTrace{};

template <class Node> [[nodiscard]] auto& getTraceComputeTable() {
if constexpr (std::is_same_v<Node, mNode>) {
return matrixTrace;
} else {
return densityTrace;
}
}

mEdge partialTrace(const mEdge& a, const std::vector<bool>& eliminate) {
auto r = trace(a, eliminate, eliminate.size());
return {r.p, cn.lookup(r.w)};
Expand All @@ -1947,7 +1971,7 @@ template <class Config> class Package {
template <class Node>
ComplexValue trace(const Edge<Node>& a, const std::size_t numQubits) {
if (a.isIdentity()) {
return a.w * std::pow(2, numQubits);
return static_cast<ComplexValue>(a.w);
}
const auto eliminate = std::vector<bool>(numQubits, true);
return trace(a, eliminate, numQubits).w;
Expand Down Expand Up @@ -1975,7 +1999,24 @@ template <class Config> class Package {
}

private:
/// TODO: introduce a compute table for the trace?
/**
* @brief Computes the normalized (partial) trace using a compute table to
* store results for eliminated nodes.
* @details At each level, perform a lookup and store results in the compute
* table only if all lower-level qubits are eliminated as well.
*
* This optimization allows the full trace
* computation to scale linearly with respect to the number of nodes.
* However, the partial trace computation still scales with the number of
* paths to the lowest level in the DD that should be traced out.
*
* For matrices, normalization is continuously applied, dividing by two at
* each level marked for elimination, thereby ensuring that the result is
* mapped to the interval [0,1] (as opposed to the interval [0,2^N]).
*
* For density matrices, such normalization is not applied as the trace of
* density matrices is always 1 by definition.
*/
template <class Node>
CachedEdge<Node> trace(const Edge<Node>& a,
const std::vector<bool>& eliminate, std::size_t level,
Expand All @@ -1985,24 +2026,48 @@ template <class Config> class Package {
return CachedEdge<Node>::zero();
}

if (std::none_of(eliminate.begin(), eliminate.end(),
// If `a` is the identity matrix or there is nothing left to eliminate,
// then simply return `a`
if (a.isIdentity() ||
std::none_of(eliminate.begin(),
eliminate.begin() +
static_cast<std::vector<bool>::difference_type>(level),
[](bool v) { return v; })) {
return CachedEdge<Node>{a.p, aWeight};
}

if (a.isIdentity()) {
const auto elims =
std::count(eliminate.begin(),
eliminate.begin() + static_cast<int64_t>(level), true);
return CachedEdge<Node>{a.p, aWeight * std::pow(2, elims)};
}

const auto v = a.p->v;
if (eliminate[v]) {
// Lookup nodes marked for elimination in the compute table if all
// lower-level qubits are eliminated as well: if the trace has already
// been computed, return the result
const auto eliminateAll = std::all_of(
eliminate.begin(),
eliminate.begin() +
static_cast<std::vector<bool>::difference_type>(level),
[](bool e) { return e; });
if (eliminateAll) {
if (const auto* r = getTraceComputeTable<Node>().lookup(a.p);
r != nullptr) {
return {r->p, r->w * aWeight};
}
}

const auto elims = alreadyEliminated + 1;
auto r = add2(trace(a.p->e[0], eliminate, level - 1, elims),
trace(a.p->e[3], eliminate, level - 1, elims), v - 1);

// The resulting weight is continuously normalized to the range [0,1] for
// matrix nodes
if constexpr (std::is_same_v<Node, mNode>) {
r.w = r.w / 2.0;
}

// Insert result into compute table if all lower-level qubits are
// eliminated as well
if (eliminateAll) {
getTraceComputeTable<Node>().insert(a.p, r);
}
r.w = r.w * aWeight;
return r;
}
Expand Down
124 changes: 120 additions & 4 deletions test/dd/test_package.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,21 +285,137 @@ TEST(DDPackageTest, IdentityTrace) {
auto dd = std::make_unique<dd::Package<>>(4);
auto fullTrace = dd->trace(dd->makeIdent(), 4);

ASSERT_EQ(fullTrace.r, 16.);
ASSERT_EQ(fullTrace.r, 1.);
}

TEST(DDPackageTest, CNotKronTrace) {
auto dd = std::make_unique<dd::Package<>>(4);
auto cxGate = dd->makeGateDD(dd::X_MAT, 1_pc, 0);
auto cxGateKron = dd->kronecker(cxGate, cxGate, 2);
auto fullTrace = dd->trace(cxGateKron, 4);
ASSERT_EQ(fullTrace, 0.25);
}

TEST(DDPackageTest, PartialIdentityTrace) {
auto dd = std::make_unique<dd::Package<>>(2);
auto tr = dd->partialTrace(dd->makeIdent(), {false, true});
auto mul = dd->multiply(tr, tr);
EXPECT_EQ(dd::RealNumber::val(mul.w.r), 4.0);
EXPECT_EQ(dd::RealNumber::val(mul.w.r), 1.);
}

TEST(DDPackageTest, PartialNonIdentityTrace) {
TEST(DDPackageTest, PartialSWapMatTrace) {
auto dd = std::make_unique<dd::Package<>>(2);
auto swapGate = dd->makeTwoQubitGateDD(dd::SWAP_MAT, 0, 1);
auto ptr = dd->partialTrace(swapGate, {true, false});
EXPECT_EQ(ptr.w * ptr.w, 1.);
auto fullTrace = dd->trace(ptr, 1);
auto fullTraceOriginal = dd->trace(swapGate, 2);
EXPECT_EQ(dd::RealNumber::val(ptr.w.r), 0.5);
// Check that successively tracing out subsystems is the same as computing the
// full trace from the beginning
EXPECT_EQ(fullTrace, fullTraceOriginal);
}

TEST(DDPackageTest, PartialTraceKeepInnerQubits) {
// Check that the partial trace computation is correct when tracing out the
// outer qubits only. This test shows that we should avoid storing
// non-eliminated nodes in the compute table, as this would prevent their
// proper elimination in subsequent trace calls.

const std::size_t numQubits = 8;
auto dd = std::make_unique<dd::Package<>>(numQubits);
const auto swapGate = dd->makeTwoQubitGateDD(dd::SWAP_MAT, 0, 1);
auto swapKron = swapGate;
for (std::size_t i = 0; i < 3; ++i) {
swapKron = dd->kronecker(swapKron, swapGate, 2);
}
auto fullTraceOriginal = dd->trace(swapKron, numQubits);
auto ptr = dd->partialTrace(
swapKron, {true, true, false, false, false, false, true, true});
auto fullTrace = dd->trace(ptr, 4);
EXPECT_EQ(dd::RealNumber::val(ptr.w.r), 0.25);
EXPECT_EQ(fullTrace.r, 0.0625);
// Check that successively tracing out subsystems is the same as computing the
// full trace from the beginning
EXPECT_EQ(fullTrace, fullTraceOriginal);
}

TEST(DDPackageTest, TraceComplexity) {
// Check that the full trace computation scales with the number of nodes
// instead of paths in the DD due to the usage of a compute table
for (std::size_t numQubits = 1; numQubits <= 10; ++numQubits) {
auto dd = std::make_unique<dd::Package<>>(numQubits);
auto& computeTable = dd->getTraceComputeTable<dd::mNode>();
const auto hGate = dd->makeGateDD(dd::H_MAT, 0);
auto hKron = hGate;
for (std::size_t i = 0; i < numQubits - 1; ++i) {
hKron = dd->kronecker(hKron, hGate, 1);
}
dd->trace(hKron, numQubits);
const auto& stats = computeTable.getStats();
ASSERT_EQ(stats.lookups, 2 * numQubits - 1);
ASSERT_EQ(stats.hits, numQubits - 1);
}
}

TEST(DDPackageTest, KeepBottomQubitsPartialTraceComplexity) {
// Check that during the trace computation, once a level is reached
// where the remaining qubits should not be eliminated, the function does not
// recurse further but immediately returns the current CachedEdge<Node>.
const std::size_t numQubits = 8;
auto dd = std::make_unique<dd::Package<>>(numQubits);
auto& uniqueTable = dd->getUniqueTable<dd::mNode>();
const auto hGate = dd->makeGateDD(dd::H_MAT, 0);
auto hKron = hGate;
for (std::size_t i = 0; i < numQubits - 1; ++i) {
hKron = dd->kronecker(hKron, hGate, 1);
}

const std::size_t maxNodeVal = 6;
std::array<std::size_t, maxNodeVal> lookupValues{};

for (std::size_t i = 0; i < maxNodeVal; ++i) {
// Store the number of lookups performed so far for the six bottom qubits
lookupValues[i] = uniqueTable.getStats(i).lookups;
}
dd->partialTrace(hKron,
{false, false, false, false, false, false, true, true});
for (std::size_t i = 0; i < maxNodeVal; ++i) {
// Check that the partial trace computation performs no additional lookups
// on the bottom qubits that are not eliminated
ASSERT_EQ(uniqueTable.getStats(i).lookups, lookupValues[i]);
}
}

TEST(DDPackageTest, PartialTraceComplexity) {
// In the worst case, the partial trace computation scales with the number of
// paths in the DD. This situation arises particularly when tracing out the
// bottom qubits.
const std::size_t numQubits = 9;
auto dd = std::make_unique<dd::Package<>>(numQubits);
auto& uniqueTable = dd->getUniqueTable<dd::mNode>();
const auto hGate = dd->makeGateDD(dd::H_MAT, 0);
auto hKron = hGate;
for (std::size_t i = 0; i < numQubits - 2; ++i) {
hKron = dd->kronecker(hKron, hGate, 1);
}
hKron = dd->kronecker(hKron, dd->makeIdent(), 1);

const std::size_t maxNodeVal = 6;
std::array<std::size_t, maxNodeVal + 1> lookupValues{};
for (std::size_t i = 1; i <= maxNodeVal; ++i) {
// Store the number of lookups performed so far for levels 1 through 6
lookupValues[i] = uniqueTable.getStats(i).lookups;
}

dd->partialTrace(
hKron, {true, false, false, false, false, false, false, true, true});
for (std::size_t i = 1; i < maxNodeVal; ++i) {
// Check that the number of lookups scales with the number of paths in the
// DD
ASSERT_EQ(uniqueTable.getStats(i).lookups,
lookupValues[i] +
static_cast<std::size_t>(std::pow(4, (maxNodeVal - i))));
}
}

TEST(DDPackageTest, StateGenerationManipulation) {
Expand Down

0 comments on commit 35e06ca

Please sign in to comment.