diff --git a/include/mqt-core/dd/DDpackageConfig.hpp b/include/mqt-core/dd/DDpackageConfig.hpp index ccd511996..717fc7064 100644 --- a/include/mqt-core/dd/DDpackageConfig.hpp +++ b/include/mqt-core/dd/DDpackageConfig.hpp @@ -6,8 +6,6 @@ namespace dd { struct DDPackageConfig { - // Note the order of parameters here must be the *same* as in the template - // definition. static constexpr std::size_t UT_VEC_NBUCKET = 32768U; static constexpr std::size_t UT_VEC_INITIAL_ALLOCATION_SIZE = 2048U; static constexpr std::size_t UT_MAT_NBUCKET = 32768U; @@ -22,6 +20,8 @@ struct DDPackageConfig { static constexpr std::size_t CT_MAT_MAT_MULT_NBUCKET = 16384U; static constexpr std::size_t CT_VEC_KRON_NBUCKET = 4096U; static constexpr std::size_t CT_MAT_KRON_NBUCKET = 4096U; + static constexpr std::size_t CT_DM_TRACE_NBUCKET = 1U; + static constexpr std::size_t CT_MAT_TRACE_NBUCKET = 4096U; static constexpr std::size_t CT_VEC_INNER_PROD_NBUCKET = 4096U; static constexpr std::size_t CT_DM_NOISE_NBUCKET = 1U; static constexpr std::size_t UT_DM_NBUCKET = 1U; @@ -63,6 +63,8 @@ struct DensityMatrixSimulatorDDPackageConfig : public dd::DDPackageConfig { static constexpr std::size_t UT_MAT_INITIAL_ALLOCATION_SIZE = 1U; static constexpr std::size_t CT_VEC_KRON_NBUCKET = 1U; static constexpr std::size_t CT_MAT_KRON_NBUCKET = 1U; + static constexpr std::size_t CT_DM_TRACE_NBUCKET = 4096U; + static constexpr std::size_t CT_MAT_TRACE_NBUCKET = 1U; static constexpr std::size_t CT_VEC_INNER_PROD_NBUCKET = 1U; static constexpr std::size_t STOCHASTIC_CACHE_OPS = 1U; static constexpr std::size_t CT_VEC_ADD_MAG_NBUCKET = 1U; diff --git a/include/mqt-core/dd/Package.hpp b/include/mqt-core/dd/Package.hpp index 3f541a0d1..bd0e99471 100644 --- a/include/mqt-core/dd/Package.hpp +++ b/include/mqt-core/dd/Package.hpp @@ -5,6 +5,7 @@ #include "dd/CachedEdge.hpp" #include "dd/Complex.hpp" #include "dd/ComplexNumbers.hpp" +#include "dd/ComplexValue.hpp" #include "dd/ComputeTable.hpp" #include "dd/DDDefinitions.hpp" #include "dd/DDpackageConfig.hpp" @@ -256,16 +257,22 @@ template class Package { } // invalidate all compute tables involving matrices if any matrix node has // been collected - if (mCollect > 0 || dCollect > 0) { + if (mCollect > 0) { matrixAdd.clear(); conjugateMatrixTranspose.clear(); matrixKronecker.clear(); + matrixTrace.clear(); matrixVectorMultiplication.clear(); matrixMatrixMultiplication.clear(); stochasticNoiseOperationCache.clear(); + } + // invalidate all compute tables involving density matrices if any density + // matrix node has been collected + if (dCollect > 0) { densityAdd.clear(); densityDensityMultiplication.clear(); densityNoise.clear(); + densityTrace.clear(); } // invalidate all compute tables where any component of the entry contains // numbers from the complex table if any complex numbers were collected @@ -276,10 +283,12 @@ template class Package { vectorInnerProduct.clear(); vectorKronecker.clear(); matrixKronecker.clear(); + matrixTrace.clear(); stochasticNoiseOperationCache.clear(); densityAdd.clear(); densityDensityMultiplication.clear(); densityNoise.clear(); + densityTrace.clear(); } return vCollect > 0 || mCollect > 0 || cCollect > 0; } @@ -885,11 +894,13 @@ template class Package { vectorInnerProduct.clear(); vectorKronecker.clear(); matrixKronecker.clear(); + matrixTrace.clear(); stochasticNoiseOperationCache.clear(); densityAdd.clear(); densityDensityMultiplication.clear(); densityNoise.clear(); + densityTrace.clear(); } /// @@ -1939,6 +1950,19 @@ template class Package { /// (Partial) trace /// public: + UnaryComputeTable + densityTrace{}; + UnaryComputeTable + matrixTrace{}; + + template [[nodiscard]] auto& getTraceComputeTable() { + if constexpr (std::is_same_v) { + return matrixTrace; + } else { + return densityTrace; + } + } + mEdge partialTrace(const mEdge& a, const std::vector& eliminate) { auto r = trace(a, eliminate, eliminate.size()); return {r.p, cn.lookup(r.w)}; @@ -1947,7 +1971,7 @@ template class Package { template ComplexValue trace(const Edge& a, const std::size_t numQubits) { if (a.isIdentity()) { - return a.w * std::pow(2, numQubits); + return static_cast(a.w); } const auto eliminate = std::vector(numQubits, true); return trace(a, eliminate, numQubits).w; @@ -1975,7 +1999,24 @@ template class Package { } private: - /// TODO: introduce a compute table for the trace? + /** + * @brief Computes the normalized (partial) trace using a compute table to + * store results for eliminated nodes. + * @details At each level, perform a lookup and store results in the compute + * table only if all lower-level qubits are eliminated as well. + * + * This optimization allows the full trace + * computation to scale linearly with respect to the number of nodes. + * However, the partial trace computation still scales with the number of + * paths to the lowest level in the DD that should be traced out. + * + * For matrices, normalization is continuously applied, dividing by two at + * each level marked for elimination, thereby ensuring that the result is + * mapped to the interval [0,1] (as opposed to the interval [0,2^N]). + * + * For density matrices, such normalization is not applied as the trace of + * density matrices is always 1 by definition. + */ template CachedEdge trace(const Edge& a, const std::vector& eliminate, std::size_t level, @@ -1985,24 +2026,48 @@ template class Package { return CachedEdge::zero(); } - if (std::none_of(eliminate.begin(), eliminate.end(), + // If `a` is the identity matrix or there is nothing left to eliminate, + // then simply return `a` + if (a.isIdentity() || + std::none_of(eliminate.begin(), + eliminate.begin() + + static_cast::difference_type>(level), [](bool v) { return v; })) { return CachedEdge{a.p, aWeight}; } - if (a.isIdentity()) { - const auto elims = - std::count(eliminate.begin(), - eliminate.begin() + static_cast(level), true); - return CachedEdge{a.p, aWeight * std::pow(2, elims)}; - } - const auto v = a.p->v; if (eliminate[v]) { + // Lookup nodes marked for elimination in the compute table if all + // lower-level qubits are eliminated as well: if the trace has already + // been computed, return the result + const auto eliminateAll = std::all_of( + eliminate.begin(), + eliminate.begin() + + static_cast::difference_type>(level), + [](bool e) { return e; }); + if (eliminateAll) { + if (const auto* r = getTraceComputeTable().lookup(a.p); + r != nullptr) { + return {r->p, r->w * aWeight}; + } + } + const auto elims = alreadyEliminated + 1; auto r = add2(trace(a.p->e[0], eliminate, level - 1, elims), trace(a.p->e[3], eliminate, level - 1, elims), v - 1); + // The resulting weight is continuously normalized to the range [0,1] for + // matrix nodes + if constexpr (std::is_same_v) { + r.w = r.w / 2.0; + } + + // Insert result into compute table if all lower-level qubits are + // eliminated as well + if (eliminateAll) { + getTraceComputeTable().insert(a.p, r); + } r.w = r.w * aWeight; return r; } diff --git a/test/dd/test_package.cpp b/test/dd/test_package.cpp index 531773cc8..dbe38c1f1 100644 --- a/test/dd/test_package.cpp +++ b/test/dd/test_package.cpp @@ -285,21 +285,137 @@ TEST(DDPackageTest, IdentityTrace) { auto dd = std::make_unique>(4); auto fullTrace = dd->trace(dd->makeIdent(), 4); - ASSERT_EQ(fullTrace.r, 16.); + ASSERT_EQ(fullTrace.r, 1.); +} + +TEST(DDPackageTest, CNotKronTrace) { + auto dd = std::make_unique>(4); + auto cxGate = dd->makeGateDD(dd::X_MAT, 1_pc, 0); + auto cxGateKron = dd->kronecker(cxGate, cxGate, 2); + auto fullTrace = dd->trace(cxGateKron, 4); + ASSERT_EQ(fullTrace, 0.25); } TEST(DDPackageTest, PartialIdentityTrace) { auto dd = std::make_unique>(2); auto tr = dd->partialTrace(dd->makeIdent(), {false, true}); auto mul = dd->multiply(tr, tr); - EXPECT_EQ(dd::RealNumber::val(mul.w.r), 4.0); + EXPECT_EQ(dd::RealNumber::val(mul.w.r), 1.); } -TEST(DDPackageTest, PartialNonIdentityTrace) { +TEST(DDPackageTest, PartialSWapMatTrace) { auto dd = std::make_unique>(2); auto swapGate = dd->makeTwoQubitGateDD(dd::SWAP_MAT, 0, 1); auto ptr = dd->partialTrace(swapGate, {true, false}); - EXPECT_EQ(ptr.w * ptr.w, 1.); + auto fullTrace = dd->trace(ptr, 1); + auto fullTraceOriginal = dd->trace(swapGate, 2); + EXPECT_EQ(dd::RealNumber::val(ptr.w.r), 0.5); + // Check that successively tracing out subsystems is the same as computing the + // full trace from the beginning + EXPECT_EQ(fullTrace, fullTraceOriginal); +} + +TEST(DDPackageTest, PartialTraceKeepInnerQubits) { + // Check that the partial trace computation is correct when tracing out the + // outer qubits only. This test shows that we should avoid storing + // non-eliminated nodes in the compute table, as this would prevent their + // proper elimination in subsequent trace calls. + + const std::size_t numQubits = 8; + auto dd = std::make_unique>(numQubits); + const auto swapGate = dd->makeTwoQubitGateDD(dd::SWAP_MAT, 0, 1); + auto swapKron = swapGate; + for (std::size_t i = 0; i < 3; ++i) { + swapKron = dd->kronecker(swapKron, swapGate, 2); + } + auto fullTraceOriginal = dd->trace(swapKron, numQubits); + auto ptr = dd->partialTrace( + swapKron, {true, true, false, false, false, false, true, true}); + auto fullTrace = dd->trace(ptr, 4); + EXPECT_EQ(dd::RealNumber::val(ptr.w.r), 0.25); + EXPECT_EQ(fullTrace.r, 0.0625); + // Check that successively tracing out subsystems is the same as computing the + // full trace from the beginning + EXPECT_EQ(fullTrace, fullTraceOriginal); +} + +TEST(DDPackageTest, TraceComplexity) { + // Check that the full trace computation scales with the number of nodes + // instead of paths in the DD due to the usage of a compute table + for (std::size_t numQubits = 1; numQubits <= 10; ++numQubits) { + auto dd = std::make_unique>(numQubits); + auto& computeTable = dd->getTraceComputeTable(); + const auto hGate = dd->makeGateDD(dd::H_MAT, 0); + auto hKron = hGate; + for (std::size_t i = 0; i < numQubits - 1; ++i) { + hKron = dd->kronecker(hKron, hGate, 1); + } + dd->trace(hKron, numQubits); + const auto& stats = computeTable.getStats(); + ASSERT_EQ(stats.lookups, 2 * numQubits - 1); + ASSERT_EQ(stats.hits, numQubits - 1); + } +} + +TEST(DDPackageTest, KeepBottomQubitsPartialTraceComplexity) { + // Check that during the trace computation, once a level is reached + // where the remaining qubits should not be eliminated, the function does not + // recurse further but immediately returns the current CachedEdge. + const std::size_t numQubits = 8; + auto dd = std::make_unique>(numQubits); + auto& uniqueTable = dd->getUniqueTable(); + const auto hGate = dd->makeGateDD(dd::H_MAT, 0); + auto hKron = hGate; + for (std::size_t i = 0; i < numQubits - 1; ++i) { + hKron = dd->kronecker(hKron, hGate, 1); + } + + const std::size_t maxNodeVal = 6; + std::array lookupValues{}; + + for (std::size_t i = 0; i < maxNodeVal; ++i) { + // Store the number of lookups performed so far for the six bottom qubits + lookupValues[i] = uniqueTable.getStats(i).lookups; + } + dd->partialTrace(hKron, + {false, false, false, false, false, false, true, true}); + for (std::size_t i = 0; i < maxNodeVal; ++i) { + // Check that the partial trace computation performs no additional lookups + // on the bottom qubits that are not eliminated + ASSERT_EQ(uniqueTable.getStats(i).lookups, lookupValues[i]); + } +} + +TEST(DDPackageTest, PartialTraceComplexity) { + // In the worst case, the partial trace computation scales with the number of + // paths in the DD. This situation arises particularly when tracing out the + // bottom qubits. + const std::size_t numQubits = 9; + auto dd = std::make_unique>(numQubits); + auto& uniqueTable = dd->getUniqueTable(); + const auto hGate = dd->makeGateDD(dd::H_MAT, 0); + auto hKron = hGate; + for (std::size_t i = 0; i < numQubits - 2; ++i) { + hKron = dd->kronecker(hKron, hGate, 1); + } + hKron = dd->kronecker(hKron, dd->makeIdent(), 1); + + const std::size_t maxNodeVal = 6; + std::array lookupValues{}; + for (std::size_t i = 1; i <= maxNodeVal; ++i) { + // Store the number of lookups performed so far for levels 1 through 6 + lookupValues[i] = uniqueTable.getStats(i).lookups; + } + + dd->partialTrace( + hKron, {true, false, false, false, false, false, false, true, true}); + for (std::size_t i = 1; i < maxNodeVal; ++i) { + // Check that the number of lookups scales with the number of paths in the + // DD + ASSERT_EQ(uniqueTable.getStats(i).lookups, + lookupValues[i] + + static_cast(std::pow(4, (maxNodeVal - i)))); + } } TEST(DDPackageTest, StateGenerationManipulation) {