From 5e9020237f8a821a641080911adebeab8da84426 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 18 Mar 2022 00:11:05 -0400 Subject: [PATCH 1/3] add test for visitWith with pruned tree --- gtsam/discrete/tests/testDecisionTree.cpp | 44 ++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 967023eebd..ca5daf2014 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -20,13 +20,11 @@ // #define DT_DEBUG_MEMORY // #define DT_NO_PRUNING #define DISABLE_DOT -#include - +#include #include +#include #include -#include - #include using namespace boost::assign; @@ -346,6 +344,44 @@ TEST(DecisionTree, visitWith) { EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); } +/* ************************************************************************** */ +// Test visit, with Choices argument. +TEST(DecisionTree, VisitWithPruned) { + // Create pruned tree + std::pair A("A", 2), B("B", 2), C("C", 2); + std::vector> labels = {C, B, A}; + std::vector nodes = {0, 0, 2, 3, 4, 4, 6, 7}; + DT tree(labels, nodes); + + std::vector> choices; + auto func = [&](const Assignment& choice, const int& d) { + choices.push_back(choice); + }; + tree.visitWith(func); + + EXPECT_LONGS_EQUAL(6, choices.size()); + + Assignment expectedAssignment; + + expectedAssignment = {{"B", 0}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(0)); + + expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(1)); + + expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}}; + EXPECT(expectedAssignment == choices.at(2)); + + expectedAssignment = {{"B", 0}, {"C", 1}}; + EXPECT(expectedAssignment == choices.at(3)); + + expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 1}}; + EXPECT(expectedAssignment == choices.at(4)); + + expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}}; + EXPECT(expectedAssignment == choices.at(5)); +} + /* ************************************************************************** */ // Test fold. TEST(DecisionTree, fold) { From 13c60990f7336d320cad3ff5a692d25c32db9398 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 18 Mar 2022 00:11:38 -0400 Subject: [PATCH 2/3] fix visitWith operator() to be more functional --- gtsam/discrete/Assignment.h | 2 ++ gtsam/discrete/DecisionTree-inl.h | 7 +++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/Assignment.h b/gtsam/discrete/Assignment.h index cdbf0a2e96..90e2dbdd83 100644 --- a/gtsam/discrete/Assignment.h +++ b/gtsam/discrete/Assignment.h @@ -33,6 +33,8 @@ namespace gtsam { template class Assignment : public std::map { public: + using std::map::operator=; + void print(const std::string& s = "Assignment: ") const { std::cout << s << ": "; for (const typename Assignment::value_type& keyValue : *this) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 01c7b689c1..c616c0269c 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -666,8 +666,11 @@ namespace gtsam { if (!choice) throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr"); for (size_t i = 0; i < choice->nrChoices(); i++) { - choices[choice->label()] = i; // Set assignment for label to i - (*this)(choice->branches()[i]); // recurse! + choices[choice->label()] = i; // Set assignment for label to i + + VisitWith visit(f); + visit.choices = choices; + visit(choice->branches()[i]); // recurse! } } }; From fa542a20384435c88eaf3cf2b460a2c645908812 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 18 Mar 2022 08:19:01 -0400 Subject: [PATCH 3/3] address review comments --- gtsam/discrete/DecisionTree-inl.h | 8 +++++--- gtsam/discrete/tests/testDecisionTree.cpp | 6 ++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index c616c0269c..627c1a5aac 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -668,9 +668,11 @@ namespace gtsam { for (size_t i = 0; i < choice->nrChoices(); i++) { choices[choice->label()] = i; // Set assignment for label to i - VisitWith visit(f); - visit.choices = choices; - visit(choice->branches()[i]); // recurse! + (*this)(choice->branches()[i]); // recurse! + + // Remove the choice so we are backtracking + auto choice_it = choices.find(choice->label()); + choices.erase(choice_it); } } }; diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index ca5daf2014..91deed6253 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -20,11 +20,13 @@ // #define DT_DEBUG_MEMORY // #define DT_NO_PRUNING #define DISABLE_DOT -#include -#include #include + +#include #include +#include + #include using namespace boost::assign;