Skip to content

Commit

Permalink
🐛 fix compute table lookup
Browse files Browse the repository at this point in the history
Due to terminals no longer being represented by separate entities, some of the checks whether a compute table lookup was successful did not work anymore as they checked for `nullptr`. This commit adjusts the `lookup` function to return a pointer to an entry. This makes it easier to distinguish a successful lookup from one that failed.

Signed-off-by: Lukas Burgholzer <lukas.burgholzer@jku.at>
  • Loading branch information
burgholzer committed Jul 27, 2023
1 parent b1b16b1 commit 702413d
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 64 deletions.
22 changes: 9 additions & 13 deletions include/dd/ComputeTable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,13 @@ class ComputeTable {
++count;
}

ResultType lookup(const LeftOperandType& leftOperand,
const RightOperandType& rightOperand,
[[maybe_unused]] const bool useDensityMatrix = false) {
ResultType result{};
ResultType* lookup(const LeftOperandType& leftOperand,
const RightOperandType& rightOperand,
[[maybe_unused]] const bool useDensityMatrix = false) {
ResultType* result = nullptr;
lookups++;
const auto key = hash(leftOperand, rightOperand);
auto& entry = table[key];
if (entry.result.p == nullptr) {
return result;
}
if (entry.leftOperand != leftOperand) {
return result;
}
Expand All @@ -68,20 +65,19 @@ class ComputeTable {
// Since density matrices are reduced representations of matrices, a
// density matrix may not be returned when a matrix is required and vice
// versa
if (dNode::isDensityMatrixNode(entry.result.p->flags) !=
useDensityMatrix) {
if (!dNode::isTerminal(entry.result.p) &&
dNode::isDensityMatrixNode(entry.result.p->flags) !=
useDensityMatrix) {
return result;
}
}
hits++;
return entry.result;
return &entry.result;
}

void clear() {
if (count > 0) {
for (auto& entry : table) {
entry.result.p = nullptr;
}
std::fill(table.begin(), table.end(), Entry{});
count = 0;
}
}
Expand Down
58 changes: 27 additions & 31 deletions include/dd/Package.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1607,12 +1607,12 @@ template <class Config> class Package {
}

auto& computeTable = getAddComputeTable<Node>();
auto r = computeTable.lookup({x.p, x.w}, {y.p, y.w});
if (!Node::isTerminal(r.p)) {
if (r.w.approximatelyZero()) {
if (const auto* r = computeTable.lookup({x.p, x.w}, {y.p, y.w});
r != nullptr) {
if (r->w.approximatelyZero()) {
return Edge<Node>::zero;
}
return {r.p, cn.getCached(r.w)};
return {r->p, cn.getCached(r->w)};
}

const Qubit w = (x.isTerminal() || (!y.isTerminal() && y.p->v > x.p->v))
Expand Down Expand Up @@ -1686,9 +1686,8 @@ template <class Config> class Package {
}

// check in compute table
auto r = matrixTranspose.lookup(a);
if (!r.isTerminal()) {
return r;
if (const auto* r = matrixTranspose.lookup(a); r != nullptr) {
return *r;
}

std::array<mEdge, NEDGE> e{};
Expand All @@ -1699,23 +1698,22 @@ template <class Config> class Package {
}
}
// create new top node
r = makeDDNode(a.p->v, e);
auto res = makeDDNode(a.p->v, e);
// adjust top weight
r.w = cn.lookup(cn.mulTemp(r.w, a.w));
res.w = cn.lookup(cn.mulTemp(res.w, a.w));

// put in compute table
matrixTranspose.insert(a, r);
return r;
matrixTranspose.insert(a, res);
return res;
}
mEdge conjugateTranspose(const mEdge& a) {
if (a.isTerminal()) { // terminal case
return {a.p, ComplexNumbers::conj(a.w)};
}

// check if in compute table
auto r = conjugateMatrixTranspose.lookup(a);
if (!r.isTerminal()) {
return r;
if (const auto* r = conjugateMatrixTranspose.lookup(a); r != nullptr) {
return *r;
}

std::array<mEdge, NEDGE> e{};
Expand All @@ -1726,14 +1724,14 @@ template <class Config> class Package {
}
}
// create new top node
r = makeDDNode(a.p->v, e);
auto res = makeDDNode(a.p->v, e);

// adjust top weight including conjugate
r.w = cn.lookup(cn.mulTemp(r.w, ComplexNumbers::conj(a.w)));
res.w = cn.lookup(cn.mulTemp(res.w, ComplexNumbers::conj(a.w)));

// put it in the compute table
conjugateMatrixTranspose.insert(a, r);
return r;
conjugateMatrixTranspose.insert(a, res);
return res;
}

///
Expand Down Expand Up @@ -1850,12 +1848,13 @@ template <class Config> class Package {

auto& computeTable =
getMultiplicationComputeTable<LeftOperandNode, RightOperandNode>();
auto r = computeTable.lookup(xCopy, yCopy, generateDensityMatrix);
if (!RightOperandNode::isTerminal(r.p)) {
if (r.w.approximatelyZero()) {
if (const auto* r =
computeTable.lookup(xCopy, yCopy, generateDensityMatrix);
r != nullptr) {
if (r->w.approximatelyZero()) {
return ResultEdge::zero;
}
auto e = ResultEdge{r.p, cn.getCached(r.w)};
auto e = ResultEdge{r->p, cn.getCached(r->w)};
ComplexNumbers::mul(e.w, e.w, x.w);
ComplexNumbers::mul(e.w, e.w, y.w);
if (e.w.approximatelyZero()) {
Expand Down Expand Up @@ -2102,9 +2101,8 @@ template <class Config> class Package {
// Set to one to generate more lookup hits
auto xCopy = vEdge{x.p, Complex::one};
auto yCopy = vEdge{y.p, Complex::one};
auto r = vectorInnerProduct.lookup(xCopy, yCopy);
if (!vNode::isTerminal(r.p)) {
auto c = cn.getTemporary(r.w);
if (const auto* r = vectorInnerProduct.lookup(xCopy, yCopy); r != nullptr) {
auto c = cn.getTemporary(r->w);
ComplexNumbers::mul(c, c, x.w);
ComplexNumbers::mul(c, c, y.w);
return {c.r->value, c.i->value};
Expand All @@ -2131,9 +2129,8 @@ template <class Config> class Package {
sum.r += cv.r;
sum.i += cv.i;
}
r.w = sum;

vectorInnerProduct.insert(xCopy, yCopy, r);
vectorInnerProduct.insert(xCopy, yCopy, {vNode::getTerminal(), sum});
auto c = cn.getTemporary(sum);
ComplexNumbers::mul(c, c, x.w);
ComplexNumbers::mul(c, c, y.w);
Expand Down Expand Up @@ -2225,12 +2222,11 @@ template <class Config> class Package {
}

auto& computeTable = getKroneckerComputeTable<Node>();
auto r = computeTable.lookup(x, y);
if (!Node::isTerminal(r.p)) {
if (r.w.approximatelyZero()) {
if (const auto* r = computeTable.lookup(x, y); r != nullptr) {
if (r->w.approximatelyZero()) {
return Edge<Node>::zero;
}
return {r.p, cn.getCached(r.w)};
return {r->p, cn.getCached(r->w)};
}

constexpr std::size_t n = std::tuple_size_v<decltype(x.p->e)>;
Expand Down
13 changes: 4 additions & 9 deletions include/dd/UnaryComputeTable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,22 @@ class UnaryComputeTable {
++count;
}

ResultType lookup(const OperandType& operand) {
ResultType result{};
ResultType* lookup(const OperandType& operand) {
ResultType* result = nullptr;
lookups++;
const auto key = hash(operand);
auto& entry = table[key];
if (entry.result.p == nullptr) {
return result;
}
if (entry.operand != operand) {
return result;
}

hits++;
return entry.result;
return &entry.result;
}

void clear() {
if (count > 0) {
for (auto& entry : table) {
entry.result.p = nullptr;
}
std::fill(table.begin(), table.end(), Entry{});
count = 0;
}
hits = 0;
Expand Down
25 changes: 14 additions & 11 deletions test/dd/test_package.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1241,29 +1241,32 @@ TEST(DDPackageTest, dNodeMulCache1) {

const auto xCopy = dd::dEdge{state.p, dd::Complex::one};
const auto yCopy = dd::dEdge{densityMatrix0.p, dd::Complex::one};
const auto cachedResult = computeTable.lookup(xCopy, yCopy, false);
ASSERT_NE(cachedResult.p, nullptr);
const auto* cachedResult = computeTable.lookup(xCopy, yCopy, false);
ASSERT_NE(cachedResult, nullptr);
ASSERT_NE(cachedResult->p, nullptr);
state = dd->multiply(state, densityMatrix0, 0, false);
ASSERT_NE(state.p, nullptr);
ASSERT_EQ(state.p, cachedResult.p);
ASSERT_EQ(state.p, cachedResult->p);

const auto densityMatrix1 = dd::densityFromMatrixEdge(operation);
const auto xCopy1 = dd::dEdge{densityMatrix1.p, dd::Complex::one};
const auto yCopy1 = dd::dEdge{state.p, dd::Complex::one};
const auto cachedResult1 = computeTable.lookup(xCopy1, yCopy1, true);
ASSERT_NE(cachedResult1.p, nullptr);
const auto* cachedResult1 = computeTable.lookup(xCopy1, yCopy1, true);
ASSERT_NE(cachedResult1, nullptr);
ASSERT_NE(cachedResult1->p, nullptr);
state = dd->multiply(densityMatrix1, state, 0, true);
ASSERT_NE(state.p, nullptr);
ASSERT_EQ(state.p, cachedResult1.p);
ASSERT_EQ(state.p, cachedResult1->p);

// try a repeated lookup
const auto cachedResult2 = computeTable.lookup(xCopy1, yCopy1, true);
ASSERT_NE(cachedResult2.p, nullptr);
ASSERT_EQ(cachedResult2.p, cachedResult1.p);
const auto* cachedResult2 = computeTable.lookup(xCopy1, yCopy1, true);
ASSERT_NE(cachedResult2, nullptr);
ASSERT_NE(cachedResult2->p, nullptr);
ASSERT_EQ(cachedResult2->p, cachedResult1->p);

computeTable.clear();
const auto cachedResult3 = computeTable.lookup(xCopy1, yCopy1, true);
ASSERT_EQ(cachedResult3.p, nullptr);
const auto* cachedResult3 = computeTable.lookup(xCopy1, yCopy1, true);
ASSERT_EQ(cachedResult3, nullptr);
}

TEST(DDPackageTest, dNoiseCache) {
Expand Down

0 comments on commit 702413d

Please sign in to comment.