From c4d388990d44294e6ade625917ff20f2076d307f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 15 Sep 2022 13:23:12 -0400 Subject: [PATCH 01/18] prune hybrid gaussian ISAM more efficiently and without the need to specify the root --- gtsam/hybrid/HybridBayesTree.cpp | 4 +- gtsam/hybrid/HybridGaussianISAM.cpp | 78 ++++++++++++------- gtsam/hybrid/HybridGaussianISAM.h | 7 +- gtsam/hybrid/HybridNonlinearISAM.h | 7 +- gtsam/hybrid/tests/testHybridIncremental.cpp | 10 +-- .../hybrid/tests/testHybridNonlinearISAM.cpp | 19 ++--- 6 files changed, 70 insertions(+), 55 deletions(-) diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 3fa4d6268b..d600714d43 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -89,12 +89,12 @@ struct HybridAssignmentData { gaussianbayesTree_(gbt) {} /** - * @brief A function used during tree traversal that operators on each node + * @brief A function used during tree traversal that operates on each node * before visiting the node's children. * * @param node The current node being visited. * @param parentData The HybridAssignmentData from the parent node. - * @return HybridAssignmentData + * @return HybridAssignmentData which is passed to the children. */ static HybridAssignmentData AssignmentPreOrderVisitor( const HybridBayesTree::sharedNode& node, diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 23a95c021d..bf2f60da66 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -14,9 +14,10 @@ * @date March 31, 2022 * @author Fan Jiang * @author Frank Dellaert - * @author Richard Roberts + * @author Varun Agrawal */ +#include #include #include #include @@ -121,38 +122,59 @@ bool IsSubset(KeyVector a, KeyVector b) { } /* ************************************************************************* */ -void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) { +void HybridGaussianISAM::prune(const size_t maxNrLeaves) { auto decisionTree = boost::dynamic_pointer_cast( - this->clique(root)->conditional()->inner()); + this->roots_.at(0)->conditional()->inner()); + DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); decisionTree->root_ = prunedDiscreteFactor.root_; - std::vector prunedKeys; - for (auto&& clique : nodes()) { - // The cliques can be repeated for each frontal so we record it in - // prunedKeys and check if we have already pruned a particular clique. - if (std::find(prunedKeys.begin(), prunedKeys.end(), clique.first) != - prunedKeys.end()) { - continue; - } - - // Add all the keys of the current clique to be pruned to prunedKeys - for (auto&& key : clique.second->conditional()->frontals()) { - prunedKeys.push_back(key); - } - - // Convert parents() to a KeyVector for comparison - KeyVector parents; - for (auto&& parent : clique.second->conditional()->parents()) { - parents.push_back(parent); - } - - if (IsSubset(parents, decisionTree->keys())) { - auto gaussianMixture = boost::dynamic_pointer_cast( - clique.second->conditional()->inner()); - - gaussianMixture->prune(prunedDiscreteFactor); + /// Helper struct for pruning the hybrid bayes tree. + struct HybridPrunerData { + /// The discrete decision tree after pruning. + DecisionTreeFactor prunedDiscreteFactor; + HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor, + const HybridBayesTree::sharedNode& parentClique) + : prunedDiscreteFactor(prunedDiscreteFactor) {} + + /** + * @brief A function used during tree traversal that operates on each node + * before visiting the node's children. + * + * @param node The current node being visited. + * @param parentData The data from the parent node. + * @return HybridPrunerData which is passed to the children. + */ + static HybridPrunerData AssignmentPreOrderVisitor( + const HybridBayesTree::sharedNode& clique, + HybridPrunerData& parentData) { + // Get the conditional + HybridConditional::shared_ptr conditional = clique->conditional(); + + // If conditional is hybrid, we prune it. + if (conditional->isHybrid()) { + auto gaussianMixture = conditional->asMixture(); + + // Check if the number of discrete keys match, + // else we get an assignment error. + // TODO(Varun) Update prune method to handle assignment subset? + if (gaussianMixture->discreteKeys() == + parentData.prunedDiscreteFactor.discreteKeys()) { + gaussianMixture->prune(parentData.prunedDiscreteFactor); + } + } + return parentData; } + }; + + HybridPrunerData rootData(prunedDiscreteFactor, 0); + { + treeTraversal::no_op visitorPost; + // Limits OpenMP threads since we're mixing TBB and OpenMP + TbbOpenMPMixedScope threadLimiter; + treeTraversal::DepthFirstForestParallel( + *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, + visitorPost); } } diff --git a/gtsam/hybrid/HybridGaussianISAM.h b/gtsam/hybrid/HybridGaussianISAM.h index ff09787293..59e2218076 100644 --- a/gtsam/hybrid/HybridGaussianISAM.h +++ b/gtsam/hybrid/HybridGaussianISAM.h @@ -66,11 +66,10 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM { /** * @brief Prune the underlying Bayes tree. - * - * @param root The root key in the discrete conditional decision tree. - * @param maxNumberLeaves + * + * @param maxNumberLeaves The max number of leaf nodes to keep. */ - void prune(const Key& root, const size_t maxNumberLeaves); + void prune(const size_t maxNumberLeaves); }; /// traits diff --git a/gtsam/hybrid/HybridNonlinearISAM.h b/gtsam/hybrid/HybridNonlinearISAM.h index d96485fff0..b1998fb300 100644 --- a/gtsam/hybrid/HybridNonlinearISAM.h +++ b/gtsam/hybrid/HybridNonlinearISAM.h @@ -82,12 +82,9 @@ class GTSAM_EXPORT HybridNonlinearISAM { /** * @brief Prune the underlying Bayes tree. * - * @param root The root key in the discrete conditional decision tree. - * @param maxNumberLeaves + * @param maxNumberLeaves The max number of leaf nodes to keep. */ - void prune(const Key& root, const size_t maxNumberLeaves) { - isam_.prune(root, maxNumberLeaves); - } + void prune(const size_t maxNumberLeaves) { isam_.prune(maxNumberLeaves); } /** Return the current linearization point */ const Values& getLinearizationPoint() const { return linPoint_; } diff --git a/gtsam/hybrid/tests/testHybridIncremental.cpp b/gtsam/hybrid/tests/testHybridIncremental.cpp index 8e16d02b93..a0a87933f4 100644 --- a/gtsam/hybrid/tests/testHybridIncremental.cpp +++ b/gtsam/hybrid/tests/testHybridIncremental.cpp @@ -235,7 +235,7 @@ TEST(HybridGaussianElimination, Approx_inference) { size_t maxNrLeaves = 5; incrementalHybrid.update(graph1); - incrementalHybrid.prune(M(3), maxNrLeaves); + incrementalHybrid.prune(maxNrLeaves); /* unpruned factor is: @@ -329,7 +329,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) { // Run update with pruning size_t maxComponents = 5; incrementalHybrid.update(graph1); - incrementalHybrid.prune(M(3), maxComponents); + incrementalHybrid.prune(maxComponents); // Check if we have a bayes tree with 4 hybrid nodes, // each with 2, 4, 8, and 5 (pruned) leaves respetively. @@ -350,7 +350,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) { // Run update with pruning a second time. incrementalHybrid.update(graph2); - incrementalHybrid.prune(M(4), maxComponents); + incrementalHybrid.prune(maxComponents); // Check if we have a bayes tree with pruned hybrid nodes, // with 5 (pruned) leaves. @@ -496,7 +496,7 @@ TEST(HybridGaussianISAM, NonTrivial) { // The MHS at this point should be a 2 level tree on (1, 2). // 1 has 2 choices, and 2 has 4 choices. inc.update(gfg); - inc.prune(M(2), 2); + inc.prune(2); /*************** Run Round 4 ***************/ // Add odometry factor with discrete modes. @@ -531,7 +531,7 @@ TEST(HybridGaussianISAM, NonTrivial) { // Keep pruning! inc.update(gfg); - inc.prune(M(3), 3); + inc.prune(3); // The final discrete graph should not be empty since we have eliminated // all continuous variables. diff --git a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp index 46076b3306..4e1710c424 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp @@ -256,7 +256,7 @@ TEST(HybridNonlinearISAM, Approx_inference) { incrementalHybrid.update(graph1, initial); HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree(); - bayesTree.prune(M(3), maxNrLeaves); + bayesTree.prune(maxNrLeaves); /* unpruned factor is: @@ -355,7 +355,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) { incrementalHybrid.update(graph1, initial); HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree(); - bayesTree.prune(M(3), maxComponents); + bayesTree.prune(maxComponents); // Check if we have a bayes tree with 4 hybrid nodes, // each with 2, 4, 8, and 5 (pruned) leaves respetively. @@ -380,7 +380,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) { incrementalHybrid.update(graph2, initial); bayesTree = incrementalHybrid.bayesTree(); - bayesTree.prune(M(4), maxComponents); + bayesTree.prune(maxComponents); // Check if we have a bayes tree with pruned hybrid nodes, // with 5 (pruned) leaves. @@ -482,8 +482,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { still = boost::make_shared(W(1), W(2), Pose2(0, 0, 0), noise_model); moving = - boost::make_shared(W(1), W(2), odometry, - noise_model); + boost::make_shared(W(1), W(2), odometry, noise_model); components = {moving, still}; mixtureFactor = boost::make_shared( contKeys, DiscreteKeys{gtsam::DiscreteKey(M(2), 2)}, components); @@ -515,7 +514,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { // The MHS at this point should be a 2 level tree on (1, 2). // 1 has 2 choices, and 2 has 4 choices. inc.update(fg, initial); - inc.prune(M(2), 2); + inc.prune(2); fg = HybridNonlinearFactorGraph(); initial = Values(); @@ -526,8 +525,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { still = boost::make_shared(W(2), W(3), Pose2(0, 0, 0), noise_model); moving = - boost::make_shared(W(2), W(3), odometry, - noise_model); + boost::make_shared(W(2), W(3), odometry, noise_model); components = {moving, still}; mixtureFactor = boost::make_shared( contKeys, DiscreteKeys{gtsam::DiscreteKey(M(3), 2)}, components); @@ -551,7 +549,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { // Keep pruning! inc.update(fg, initial); - inc.prune(M(3), 3); + inc.prune(3); fg = HybridNonlinearFactorGraph(); initial = Values(); @@ -560,8 +558,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { // The final discrete graph should not be empty since we have eliminated // all continuous variables. - auto discreteTree = - bayesTree[M(3)]->conditional()->asDiscreteConditional(); + auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional(); EXPECT_LONGS_EQUAL(3, discreteTree->size()); // Test if the optimal discrete mode assignment is (1, 1, 1). From aef1669a5072670d6db8dcdc64a94299c88c8421 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 15 Sep 2022 13:24:33 -0400 Subject: [PATCH 02/18] Add labelformatter to Assignment for convenience --- gtsam/discrete/Assignment.h | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/Assignment.h b/gtsam/discrete/Assignment.h index 90e2dbdd83..c6e426ca15 100644 --- a/gtsam/discrete/Assignment.h +++ b/gtsam/discrete/Assignment.h @@ -11,15 +11,17 @@ /** * @file Assignment.h - * @brief An assignment from labels to a discrete value index (size_t) + * @brief An assignment from labels to a discrete value index (size_t) * @author Frank Dellaert * @date Feb 5, 2012 */ #pragma once +#include #include #include +#include #include #include @@ -32,13 +34,30 @@ namespace gtsam { */ template class Assignment : public std::map { + /** + * @brief Default method used by `labelFormatter` or `valueFormatter` when + * printing. + * + * @param x The value passed to format. + * @return std::string + */ + static std::string DefaultFormatter(const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + } + public: using std::map::operator=; - void print(const std::string& s = "Assignment: ") const { + void print(const std::string& s = "Assignment: ", + const std::function& labelFormatter = + &DefaultFormatter) const { std::cout << s << ": "; - for (const typename Assignment::value_type& keyValue : *this) - std::cout << "(" << keyValue.first << ", " << keyValue.second << ")"; + for (const typename Assignment::value_type& keyValue : *this) { + std::cout << "(" << labelFormatter(keyValue.first) << ", " + << keyValue.second << ")"; + } std::cout << std::endl; } From 8b5586f3b61e02e08d3e2621f83dd33915225a71 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 16 Sep 2022 10:23:00 -0400 Subject: [PATCH 03/18] move prune method to HybridBayesTree class --- gtsam/hybrid/HybridBayesTree.cpp | 57 +++++++++++++++++++++++ gtsam/hybrid/HybridBayesTree.h | 7 +++ gtsam/hybrid/HybridGaussianISAM.cpp | 72 ----------------------------- 3 files changed, 64 insertions(+), 72 deletions(-) diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index d600714d43..8fb487ae20 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -144,4 +144,61 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { return result; } +/* ************************************************************************* */ +void HybridBayesTree::prune(const size_t maxNrLeaves) { + auto decisionTree = boost::dynamic_pointer_cast( + this->roots_.at(0)->conditional()->inner()); + + DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); + decisionTree->root_ = prunedDiscreteFactor.root_; + + /// Helper struct for pruning the hybrid bayes tree. + struct HybridPrunerData { + /// The discrete decision tree after pruning. + DecisionTreeFactor prunedDiscreteFactor; + HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor, + const HybridBayesTree::sharedNode& parentClique) + : prunedDiscreteFactor(prunedDiscreteFactor) {} + + /** + * @brief A function used during tree traversal that operates on each node + * before visiting the node's children. + * + * @param node The current node being visited. + * @param parentData The data from the parent node. + * @return HybridPrunerData which is passed to the children. + */ + static HybridPrunerData AssignmentPreOrderVisitor( + const HybridBayesTree::sharedNode& clique, + HybridPrunerData& parentData) { + // Get the conditional + HybridConditional::shared_ptr conditional = clique->conditional(); + + // If conditional is hybrid, we prune it. + if (conditional->isHybrid()) { + auto gaussianMixture = conditional->asMixture(); + + // Check if the number of discrete keys match, + // else we get an assignment error. + // TODO(Varun) Update prune method to handle assignment subset? + if (gaussianMixture->discreteKeys() == + parentData.prunedDiscreteFactor.discreteKeys()) { + gaussianMixture->prune(parentData.prunedDiscreteFactor); + } + } + return parentData; + } + }; + + HybridPrunerData rootData(prunedDiscreteFactor, 0); + { + treeTraversal::no_op visitorPost; + // Limits OpenMP threads since we're mixing TBB and OpenMP + TbbOpenMPMixedScope threadLimiter; + treeTraversal::DepthFirstForestParallel( + *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, + visitorPost); + } +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 3fa344d4dd..d443e33c4c 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -88,6 +88,13 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { */ VectorValues optimize(const DiscreteValues& assignment) const; + /** + * @brief Prune the underlying Bayes tree. + * + * @param maxNumberLeaves The max number of leaf nodes to keep. + */ + void prune(const size_t maxNumberLeaves); + /// @} private: diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index bf2f60da66..0a9d0b0de6 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -106,76 +106,4 @@ void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors, this->updateInternal(newFactors, &orphans, ordering, function); } -/* ************************************************************************* */ -/** - * @brief Check if `b` is a subset of `a`. - * Non-const since they need to be sorted. - * - * @param a KeyVector - * @param b KeyVector - * @return True if the keys of b is a subset of a, else false. - */ -bool IsSubset(KeyVector a, KeyVector b) { - std::sort(a.begin(), a.end()); - std::sort(b.begin(), b.end()); - return std::includes(a.begin(), a.end(), b.begin(), b.end()); -} - -/* ************************************************************************* */ -void HybridGaussianISAM::prune(const size_t maxNrLeaves) { - auto decisionTree = boost::dynamic_pointer_cast( - this->roots_.at(0)->conditional()->inner()); - - DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); - decisionTree->root_ = prunedDiscreteFactor.root_; - - /// Helper struct for pruning the hybrid bayes tree. - struct HybridPrunerData { - /// The discrete decision tree after pruning. - DecisionTreeFactor prunedDiscreteFactor; - HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor, - const HybridBayesTree::sharedNode& parentClique) - : prunedDiscreteFactor(prunedDiscreteFactor) {} - - /** - * @brief A function used during tree traversal that operates on each node - * before visiting the node's children. - * - * @param node The current node being visited. - * @param parentData The data from the parent node. - * @return HybridPrunerData which is passed to the children. - */ - static HybridPrunerData AssignmentPreOrderVisitor( - const HybridBayesTree::sharedNode& clique, - HybridPrunerData& parentData) { - // Get the conditional - HybridConditional::shared_ptr conditional = clique->conditional(); - - // If conditional is hybrid, we prune it. - if (conditional->isHybrid()) { - auto gaussianMixture = conditional->asMixture(); - - // Check if the number of discrete keys match, - // else we get an assignment error. - // TODO(Varun) Update prune method to handle assignment subset? - if (gaussianMixture->discreteKeys() == - parentData.prunedDiscreteFactor.discreteKeys()) { - gaussianMixture->prune(parentData.prunedDiscreteFactor); - } - } - return parentData; - } - }; - - HybridPrunerData rootData(prunedDiscreteFactor, 0); - { - treeTraversal::no_op visitorPost; - // Limits OpenMP threads since we're mixing TBB and OpenMP - TbbOpenMPMixedScope threadLimiter; - treeTraversal::DepthFirstForestParallel( - *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, - visitorPost); - } -} - } // namespace gtsam From dcad55c0322df722d4a58bb683879bc93bea4de6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 16 Sep 2022 10:26:25 -0400 Subject: [PATCH 04/18] optional maxNrLeaves for HybridGaussianISAM --- gtsam/hybrid/HybridGaussianISAM.cpp | 8 +++++++- gtsam/hybrid/HybridGaussianISAM.h | 11 ++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 0a9d0b0de6..6946775b94 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -42,6 +42,7 @@ HybridGaussianISAM::HybridGaussianISAM(const HybridBayesTree& bayesTree) void HybridGaussianISAM::updateInternal( const HybridGaussianFactorGraph& newFactors, HybridBayesTree::Cliques* orphans, + const boost::optional& maxNrLeaves, const boost::optional& ordering, const HybridBayesTree::Eliminate& function) { // Remove the contaminated part of the Bayes tree @@ -92,6 +93,10 @@ void HybridGaussianISAM::updateInternal( HybridBayesTree::shared_ptr bayesTree = factors.eliminateMultifrontal(elimination_ordering, function, index); + if (maxNrLeaves) { + bayesTree->prune(*maxNrLeaves); + } + // Re-add into Bayes tree data structures this->roots_.insert(this->roots_.end(), bayesTree->roots().begin(), bayesTree->roots().end()); @@ -100,10 +105,11 @@ void HybridGaussianISAM::updateInternal( /* ************************************************************************* */ void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors, + const boost::optional& maxNrLeaves, const boost::optional& ordering, const HybridBayesTree::Eliminate& function) { Cliques orphans; - this->updateInternal(newFactors, &orphans, ordering, function); + this->updateInternal(newFactors, &orphans, maxNrLeaves, ordering, function); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianISAM.h b/gtsam/hybrid/HybridGaussianISAM.h index 59e2218076..a6a82b3e89 100644 --- a/gtsam/hybrid/HybridGaussianISAM.h +++ b/gtsam/hybrid/HybridGaussianISAM.h @@ -48,6 +48,7 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM { void updateInternal( const HybridGaussianFactorGraph& newFactors, HybridBayesTree::Cliques* orphans, + const boost::optional& maxNrLeaves = boost::none, const boost::optional& ordering = boost::none, const HybridBayesTree::Eliminate& function = HybridBayesTree::EliminationTraitsType::DefaultEliminate); @@ -57,19 +58,15 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM { * @brief Perform update step with new factors. * * @param newFactors Factor graph of new factors to add and eliminate. + * @param maxNrLeaves The maximum number of leaves to keep after pruning. + * @param ordering Custom elimination ordering. * @param function Elimination function. */ void update(const HybridGaussianFactorGraph& newFactors, + const boost::optional& maxNrLeaves = boost::none, const boost::optional& ordering = boost::none, const HybridBayesTree::Eliminate& function = HybridBayesTree::EliminationTraitsType::DefaultEliminate); - - /** - * @brief Prune the underlying Bayes tree. - * - * @param maxNumberLeaves The max number of leaf nodes to keep. - */ - void prune(const size_t maxNumberLeaves); }; /// traits From 57659261227a07806ab766e88638dd7b0fbd667b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 16 Sep 2022 10:47:01 -0400 Subject: [PATCH 05/18] optional ordering argument for HybridNonlinearISAM::update --- gtsam/hybrid/HybridNonlinearISAM.cpp | 9 ++++++--- gtsam/hybrid/HybridNonlinearISAM.h | 4 +++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridNonlinearISAM.cpp b/gtsam/hybrid/HybridNonlinearISAM.cpp index 36cda4e807..d05e081dd4 100644 --- a/gtsam/hybrid/HybridNonlinearISAM.cpp +++ b/gtsam/hybrid/HybridNonlinearISAM.cpp @@ -33,7 +33,9 @@ void HybridNonlinearISAM::saveGraph(const string& s, /* ************************************************************************* */ void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors, - const Values& initialValues) { + const Values& initialValues, + const boost::optional& maxNrLeaves, + const boost::optional& ordering) { if (newFactors.size() > 0) { // Reorder and relinearize every reorderInterval updates if (reorderInterval_ > 0 && ++reorderCounter_ >= reorderInterval_) { @@ -51,7 +53,8 @@ void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors, newFactors.linearize(linPoint_); // Update ISAM - isam_.update(*linearizedNewFactors, boost::none, eliminationFunction_); + isam_.update(*linearizedNewFactors, maxNrLeaves, ordering, + eliminationFunction_); } } @@ -66,7 +69,7 @@ void HybridNonlinearISAM::reorder_relinearize() { // Just recreate the whole BayesTree // TODO: allow for constrained ordering here // TODO: decouple relinearization and reordering to avoid - isam_.update(*factors_.linearize(newLinPoint), boost::none, + isam_.update(*factors_.linearize(newLinPoint), boost::none, boost::none, eliminationFunction_); // Update linearization point diff --git a/gtsam/hybrid/HybridNonlinearISAM.h b/gtsam/hybrid/HybridNonlinearISAM.h index b1998fb300..47aa81c558 100644 --- a/gtsam/hybrid/HybridNonlinearISAM.h +++ b/gtsam/hybrid/HybridNonlinearISAM.h @@ -118,7 +118,9 @@ class GTSAM_EXPORT HybridNonlinearISAM { /** Add new factors along with their initial linearization points */ void update(const HybridNonlinearFactorGraph& newFactors, - const Values& initialValues); + const Values& initialValues, + const boost::optional& maxNrLeaves = boost::none, + const boost::optional& ordering = boost::none); /** Relinearization and reordering of variables */ void reorder_relinearize(); From a96c3db29e58041b128f382050ae802fefea2c3e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 16 Sep 2022 10:48:08 -0400 Subject: [PATCH 06/18] minor fix --- gtsam/hybrid/GaussianMixture.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 6816dfbf67..3a654ddade 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -119,11 +119,12 @@ void GaussianMixture::print(const std::string &s, "", [&](Key k) { return formatter(k); }, [&](const GaussianConditional::shared_ptr &gf) -> std::string { RedirectCout rd; - if (gf && !gf->empty()) + if (gf && !gf->empty()) { gf->print("", formatter); - else - return {"nullptr"}; - return rd.str(); + return rd.str(); + } else { + return "nullptr"; + } }); } From 93528c3d4f8ae5c87a063a3232fb2894e7899504 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 16 Sep 2022 12:19:24 -0400 Subject: [PATCH 07/18] Only eliminate variables that are in newFactors --- gtsam/hybrid/HybridGaussianISAM.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 6946775b94..c7811992e9 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -62,23 +62,29 @@ void HybridGaussianISAM::updateInternal( for (const sharedClique& orphan : *orphans) factors += boost::make_shared >(orphan); + // Get all the discrete keys from the new factors KeySet allDiscrete; - for (auto& factor : factors) { + for (auto& factor : newFactors) { for (auto& k : factor->discreteKeys()) { allDiscrete.insert(k.first); } } + + // Create KeyVector with continuous keys followed by discrete keys. KeyVector newKeysDiscreteLast; + // Insert continuous keys first. for (auto& k : newFactorKeys) { if (!allDiscrete.exists(k)) { newKeysDiscreteLast.push_back(k); } } + // Insert discrete keys at the end std::copy(allDiscrete.begin(), allDiscrete.end(), std::back_inserter(newKeysDiscreteLast)); // Get an ordering where the new keys are eliminated last - const VariableIndex index(factors); + const VariableIndex index(newFactors); + Ordering elimination_ordering; if (ordering) { elimination_ordering = *ordering; @@ -89,10 +95,14 @@ void HybridGaussianISAM::updateInternal( true); } + GTSAM_PRINT(elimination_ordering); + std::cout << "\n\n\n\neliminateMultifrontal" << std::endl; + GTSAM_PRINT(factors); // eliminate all factors (top, added, orphans) into a new Bayes tree HybridBayesTree::shared_ptr bayesTree = factors.eliminateMultifrontal(elimination_ordering, function, index); + std::cout << "optionally prune" << std::endl; if (maxNrLeaves) { bayesTree->prune(*maxNrLeaves); } From aebcde99e2ae949a1bc20dac3d1fb5853991a576 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 16 Sep 2022 18:13:59 -0400 Subject: [PATCH 08/18] add push_back to HybridBayesNet --- gtsam/hybrid/HybridBayesNet.h | 2 ++ gtsam/hybrid/tests/testHybridBayesNet.cpp | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index e84103a50c..f28224d373 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -73,6 +73,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { HybridConditional(boost::make_shared(key, table))); } + using Base::push_back; + /// Get a specific Gaussian mixture by index `i`. GaussianMixture::shared_ptr atMixture(size_t i) const; diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 0c15ee83da..cc2ab1759c 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -52,6 +52,21 @@ TEST(HybridBayesNet, Creation) { EXPECT(df.equals(expected)); } +/* ****************************************************************************/ +// Test adding a bayes net to another one. +TEST(HybridBayesNet, Add) { + HybridBayesNet bayesNet; + + bayesNet.add(Asia, "99/1"); + + DiscreteConditional expected(Asia, "99/1"); + + HybridBayesNet other; + other.push_back(bayesNet); + EXPECT(bayesNet.equals(other)); +} + + /* ****************************************************************************/ // Test choosing an assignment of conditionals TEST(HybridBayesNet, Choose) { From 9ef5c184ec379acb04c1bbd8619bda0e9ca25b5f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 17 Sep 2022 08:04:55 -0400 Subject: [PATCH 09/18] move renamed allDiscreteKeys and allContinuousKeys to HybridFactorGraph --- gtsam/hybrid/HybridFactorGraph.h | 22 ++++++++++++++++++++ gtsam/hybrid/HybridGaussianFactorGraph.cpp | 24 +--------------------- gtsam/hybrid/HybridGaussianFactorGraph.h | 6 ------ 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index fc730f0c97..ea071a0209 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -135,6 +135,28 @@ class HybridFactorGraph : public FactorGraph { push_hybrid(p); } } + + /// Get all the discrete keys in the factor graph. + const KeySet allDiscreteKeys() const { + KeySet discrete_keys; + for (auto& factor : factors_) { + for (const DiscreteKey& k : factor->discreteKeys()) { + discrete_keys.insert(k.first); + } + } + return discrete_keys; + } + + /// Get all the continuous keys in the factor graph. + const KeySet allContinuousKeys() const { + KeySet keys; + for (auto& factor : factors_) { + for (const Key& key : factor->continuousKeys()) { + keys.insert(key); + } + } + return keys; + } }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 09b592bd69..c08e774f24 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -404,31 +404,9 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { FactorGraph::add(boost::make_shared(factor)); } -/* ************************************************************************ */ -const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const { - KeySet discrete_keys; - for (auto &factor : factors_) { - for (const DiscreteKey &k : factor->discreteKeys()) { - discrete_keys.insert(k.first); - } - } - return discrete_keys; -} - -/* ************************************************************************ */ -const KeySet HybridGaussianFactorGraph::getContinuousKeys() const { - KeySet keys; - for (auto &factor : factors_) { - for (const Key &key : factor->continuousKeys()) { - keys.insert(key); - } - } - return keys; -} - /* ************************************************************************ */ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { - KeySet discrete_keys = getDiscreteKeys(); + KeySet discrete_keys = allDiscreteKeys(); for (auto &factor : factors_) { for (const DiscreteKey &k : factor->discreteKeys()) { discrete_keys.insert(k.first); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index ad5cde09b0..bd24cdeaa7 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -161,12 +161,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph } } - /// Get all the discrete keys in the factor graph. - const KeySet getDiscreteKeys() const; - - /// Get all the continuous keys in the factor graph. - const KeySet getContinuousKeys() const; - /** * @brief Return a Colamd constrained ordering where the discrete keys are * eliminated after the continuous keys. From 12db5dd947481bc1990dddae5f527ca4724be788 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 17 Sep 2022 08:05:07 -0400 Subject: [PATCH 10/18] undo changes --- gtsam/hybrid/HybridGaussianISAM.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index c7811992e9..1a95c0c932 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -62,9 +62,9 @@ void HybridGaussianISAM::updateInternal( for (const sharedClique& orphan : *orphans) factors += boost::make_shared >(orphan); - // Get all the discrete keys from the new factors + // Get all the discrete keys from the factors KeySet allDiscrete; - for (auto& factor : newFactors) { + for (auto& factor : factors) { for (auto& k : factor->discreteKeys()) { allDiscrete.insert(k.first); } @@ -83,7 +83,7 @@ void HybridGaussianISAM::updateInternal( std::back_inserter(newKeysDiscreteLast)); // Get an ordering where the new keys are eliminated last - const VariableIndex index(newFactors); + const VariableIndex index(factors); Ordering elimination_ordering; if (ordering) { @@ -95,14 +95,10 @@ void HybridGaussianISAM::updateInternal( true); } - GTSAM_PRINT(elimination_ordering); - std::cout << "\n\n\n\neliminateMultifrontal" << std::endl; - GTSAM_PRINT(factors); // eliminate all factors (top, added, orphans) into a new Bayes tree HybridBayesTree::shared_ptr bayesTree = factors.eliminateMultifrontal(elimination_ordering, function, index); - std::cout << "optionally prune" << std::endl; if (maxNrLeaves) { bayesTree->prune(*maxNrLeaves); } From 2f8a0f82e0337e80dca1b7d9c608b1f7cbce792c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 19 Sep 2022 18:23:18 -0400 Subject: [PATCH 11/18] rename testHybridIncremental to testHybridGaussianISAM --- .../{testHybridIncremental.cpp => testHybridGaussianISAM.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename gtsam/hybrid/tests/{testHybridIncremental.cpp => testHybridGaussianISAM.cpp} (100%) diff --git a/gtsam/hybrid/tests/testHybridIncremental.cpp b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp similarity index 100% rename from gtsam/hybrid/tests/testHybridIncremental.cpp rename to gtsam/hybrid/tests/testHybridGaussianISAM.cpp From c2ca426acce0bf60eb14f66267b08b029abc28fb Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 20 Sep 2022 05:16:26 -0400 Subject: [PATCH 12/18] rename allDiscreteKeys and allContinuousKeys to discreteKeys and continuousKeys respectively --- gtsam/hybrid/HybridFactorGraph.h | 4 ++-- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 2 +- gtsam/hybrid/HybridGaussianISAM.cpp | 7 +------ 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index ea071a0209..05a17b000f 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -137,7 +137,7 @@ class HybridFactorGraph : public FactorGraph { } /// Get all the discrete keys in the factor graph. - const KeySet allDiscreteKeys() const { + const KeySet discreteKeys() const { KeySet discrete_keys; for (auto& factor : factors_) { for (const DiscreteKey& k : factor->discreteKeys()) { @@ -148,7 +148,7 @@ class HybridFactorGraph : public FactorGraph { } /// Get all the continuous keys in the factor graph. - const KeySet allContinuousKeys() const { + const KeySet continuousKeys() const { KeySet keys; for (auto& factor : factors_) { for (const Key& key : factor->continuousKeys()) { diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index c08e774f24..ddb776ff4c 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -406,7 +406,7 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { /* ************************************************************************ */ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { - KeySet discrete_keys = allDiscreteKeys(); + KeySet discrete_keys = discreteKeys(); for (auto &factor : factors_) { for (const DiscreteKey &k : factor->discreteKeys()) { discrete_keys.insert(k.first); diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 1a95c0c932..57e50104d7 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -63,12 +63,7 @@ void HybridGaussianISAM::updateInternal( factors += boost::make_shared >(orphan); // Get all the discrete keys from the factors - KeySet allDiscrete; - for (auto& factor : factors) { - for (auto& k : factor->discreteKeys()) { - allDiscrete.insert(k.first); - } - } + KeySet allDiscrete = factors.discreteKeys(); // Create KeyVector with continuous keys followed by discrete keys. KeyVector newKeysDiscreteLast; From ad328756d2d50b1ec83a21c2b5a8734e3e7da393 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 3 Oct 2022 19:14:03 -0400 Subject: [PATCH 13/18] improved hybrid bayes net pruning --- gtsam/hybrid/HybridBayesNet.cpp | 28 +++++++++++++++++++++-- gtsam/hybrid/HybridBayesNet.h | 14 +++++++++--- gtsam/hybrid/tests/testHybridBayesNet.cpp | 19 ++++++++++++++- 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index bca50a902e..7889707900 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -31,8 +31,32 @@ static std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { } /* ************************************************************************* */ -HybridBayesNet HybridBayesNet::prune( - const DecisionTreeFactor::shared_ptr &discreteFactor) const { +DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { + AlgebraicDecisionTree decisionTree; + + // The canonical decision tree factor which will get the discrete conditionals + // added to it. + DecisionTreeFactor dtFactor; + + for (size_t i = 0; i < this->size(); i++) { + HybridConditional::shared_ptr conditional = this->at(i); + if (conditional->isDiscrete()) { + // Convert to a DecisionTreeFactor and add it to the main factor. + DecisionTreeFactor f(*conditional->asDiscreteConditional()); + dtFactor = dtFactor * f; + } + } + return boost::make_shared(dtFactor); +} + +/* ************************************************************************* */ +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { + // Get the decision tree of only the discrete keys + auto discreteConditionals = this->discreteConditionals(); + const DecisionTreeFactor::shared_ptr discreteFactor = + boost::make_shared( + discreteConditionals->prune(maxNrLeaves)); + /* To Prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index f28224d373..b8234d70ab 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -111,9 +111,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ VectorValues optimize(const DiscreteValues &assignment) const; - /// Prune the Hybrid Bayes Net given the discrete decision tree. - HybridBayesNet prune( - const DecisionTreeFactor::shared_ptr &discreteFactor) const; + protected: + /** + * @brief Get all the discrete conditionals as a decision tree factor. + * + * @return DecisionTreeFactor::shared_ptr + */ + DecisionTreeFactor::shared_ptr discreteConditionals() const; + + public: + /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. + HybridBayesNet prune(size_t maxNrLeaves) const; /// @} diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index cc2ab1759c..5885fdcdcc 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -66,7 +66,6 @@ TEST(HybridBayesNet, Add) { EXPECT(bayesNet.equals(other)); } - /* ****************************************************************************/ // Test choosing an assignment of conditionals TEST(HybridBayesNet, Choose) { @@ -184,6 +183,24 @@ TEST(HybridBayesNet, OptimizeMultifrontal) { EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); } +/* ****************************************************************************/ +// Test bayes net pruning +TEST(HybridBayesNet, Prune) { + Switching s(4); + + Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + s.linearizedFactorGraph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + + auto prunedBayesNet = hybridBayesNet->prune(2); + HybridValues pruned_delta = prunedBayesNet.optimize(); + + EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete())); + EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); +} + /* ****************************************************************************/ // Test HybridBayesNet serialization. TEST(HybridBayesNet, Serialization) { From d6d44fc3b422051bab6d0858cc6190fbb3121610 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 3 Oct 2022 19:14:13 -0400 Subject: [PATCH 14/18] minor cleanup --- gtsam/hybrid/HybridConditional.h | 2 -- gtsam/inference/BayesTree.h | 1 - 2 files changed, 3 deletions(-) diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index b43bb9945d..93ce33bea1 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -34,8 +34,6 @@ namespace gtsam { -class HybridGaussianFactorGraph; - /** * Hybrid Conditional Density * diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 924a505a2d..b4b07a3573 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -33,7 +33,6 @@ namespace gtsam { // Forward declarations template class FactorGraph; template class EliminatableClusterTree; - class HybridBayesTreeClique; /* ************************************************************************* */ /** clique statistics */ From bc8c77c54d36e55732dc60e0349e1815b607aa15 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 4 Oct 2022 10:12:02 -0400 Subject: [PATCH 15/18] rename test file to correct form --- ...ianHybridFactorGraph.cpp => testHybridGaussianFactorGraph.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename gtsam/hybrid/tests/{testGaussianHybridFactorGraph.cpp => testHybridGaussianFactorGraph.cpp} (100%) diff --git a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp similarity index 100% rename from gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp rename to gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp From 8820bf272c268ae298d99d8e13fe81052af80fc2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 4 Oct 2022 12:33:28 -0400 Subject: [PATCH 16/18] Add test to expose bug in elimination with gaussian conditionals --- .../tests/testHybridGaussianFactorGraph.cpp | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 40da42412d..d199d76113 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -500,6 +500,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) { } } +/* ************************************************************************* */ TEST(HybridGaussianFactorGraph, optimize) { HybridGaussianFactorGraph hfg; @@ -521,6 +522,46 @@ TEST(HybridGaussianFactorGraph, optimize) { EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0))); } + +/* ************************************************************************* */ +// Test adding of gaussian conditional and re-elimination. +TEST(HybridGaussianFactorGraph, Conditionals) { + Switching switching(4); + HybridGaussianFactorGraph hfg; + + hfg.push_back(switching.linearizedFactorGraph.at(0)); // P(X1) + Ordering ordering; + ordering.push_back(X(1)); + HybridBayesNet::shared_ptr bayes_net = hfg.eliminateSequential(ordering); + + hfg.push_back(switching.linearizedFactorGraph.at(1)); // P(X1, X2 | M1) + hfg.push_back(*bayes_net); + hfg.push_back(switching.linearizedFactorGraph.at(2)); // P(X2, X3 | M2) + hfg.push_back(switching.linearizedFactorGraph.at(5)); // P(M1) + ordering.push_back(X(2)); + ordering.push_back(X(3)); + ordering.push_back(M(1)); + ordering.push_back(M(2)); + + bayes_net = hfg.eliminateSequential(ordering); + + HybridValues result = bayes_net->optimize(); + + Values expected_continuous; + expected_continuous.insert(X(1), 0); + expected_continuous.insert(X(2), 1); + expected_continuous.insert(X(3), 2); + expected_continuous.insert(X(4), 4); + Values result_continuous = + switching.linearizationPoint.retract(result.continuous()); + EXPECT(assert_equal(expected_continuous, result_continuous)); + + DiscreteValues expected_discrete; + expected_discrete[M(1)] = 1; + expected_discrete[M(2)] = 1; + EXPECT(assert_equal(expected_discrete, result.discrete())); +} + /* ************************************************************************* */ int main() { TestResult tr; From 9002b6829183fdebb1ce65eb61bb261956c64307 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 4 Oct 2022 12:33:37 -0400 Subject: [PATCH 17/18] fix the bug --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index ddb776ff4c..0faf0e86e4 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -96,8 +96,12 @@ GaussianMixtureFactor::Sum sumFrontals( } } else if (f->isContinuous()) { - deferredFactors.push_back( - boost::dynamic_pointer_cast(f)->inner()); + if (auto gf = boost::dynamic_pointer_cast(f)) { + deferredFactors.push_back(gf->inner()); + } + if (auto cg = boost::dynamic_pointer_cast(f)) { + deferredFactors.push_back(cg->asGaussian()); + } } else if (f->isDiscrete()) { // Don't do anything for discrete-only factors From 6238a1f9017017c7d21ca2708b4dab46d3d3af2b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 4 Oct 2022 12:34:53 -0400 Subject: [PATCH 18/18] more docs for Switching example --- gtsam/hybrid/tests/Switching.h | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/tests/Switching.h b/gtsam/hybrid/tests/Switching.h index 8bcb26c927..3ae8f0bb1c 100644 --- a/gtsam/hybrid/tests/Switching.h +++ b/gtsam/hybrid/tests/Switching.h @@ -115,7 +115,6 @@ inline std::pair> makeBinaryOrdering( /* *************************************************************************** */ using MotionModel = BetweenFactor; -// using MotionMixture = MixtureFactor; // Test fixture with switching network. struct Switching { @@ -125,7 +124,13 @@ struct Switching { HybridGaussianFactorGraph linearizedFactorGraph; Values linearizationPoint; - /// Create with given number of time steps. + /** + * @brief Create with given number of time steps. + * + * @param K The total number of timesteps. + * @param between_sigma The stddev between poses. + * @param prior_sigma The stddev on priors (also used for measurements). + */ Switching(size_t K, double between_sigma = 1.0, double prior_sigma = 0.1) : K(K) { // Create DiscreteKeys for binary K modes, modes[0] will not be used. @@ -166,6 +171,8 @@ struct Switching { linearizationPoint.insert(X(k), static_cast(k)); } + // The ground truth is robot moving forward + // and one less than the linearization point linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint); }