From ad328756d2d50b1ec83a21c2b5a8734e3e7da393 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 3 Oct 2022 19:14:03 -0400 Subject: [PATCH 1/2] improved hybrid bayes net pruning --- gtsam/hybrid/HybridBayesNet.cpp | 28 +++++++++++++++++++++-- gtsam/hybrid/HybridBayesNet.h | 14 +++++++++--- gtsam/hybrid/tests/testHybridBayesNet.cpp | 19 ++++++++++++++- 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index bca50a902e..7889707900 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -31,8 +31,32 @@ static std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { } /* ************************************************************************* */ -HybridBayesNet HybridBayesNet::prune( - const DecisionTreeFactor::shared_ptr &discreteFactor) const { +DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { + AlgebraicDecisionTree decisionTree; + + // The canonical decision tree factor which will get the discrete conditionals + // added to it. + DecisionTreeFactor dtFactor; + + for (size_t i = 0; i < this->size(); i++) { + HybridConditional::shared_ptr conditional = this->at(i); + if (conditional->isDiscrete()) { + // Convert to a DecisionTreeFactor and add it to the main factor. + DecisionTreeFactor f(*conditional->asDiscreteConditional()); + dtFactor = dtFactor * f; + } + } + return boost::make_shared(dtFactor); +} + +/* ************************************************************************* */ +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { + // Get the decision tree of only the discrete keys + auto discreteConditionals = this->discreteConditionals(); + const DecisionTreeFactor::shared_ptr discreteFactor = + boost::make_shared( + discreteConditionals->prune(maxNrLeaves)); + /* To Prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index f28224d373..b8234d70ab 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -111,9 +111,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ VectorValues optimize(const DiscreteValues &assignment) const; - /// Prune the Hybrid Bayes Net given the discrete decision tree. - HybridBayesNet prune( - const DecisionTreeFactor::shared_ptr &discreteFactor) const; + protected: + /** + * @brief Get all the discrete conditionals as a decision tree factor. + * + * @return DecisionTreeFactor::shared_ptr + */ + DecisionTreeFactor::shared_ptr discreteConditionals() const; + + public: + /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. + HybridBayesNet prune(size_t maxNrLeaves) const; /// @} diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index cc2ab1759c..5885fdcdcc 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -66,7 +66,6 @@ TEST(HybridBayesNet, Add) { EXPECT(bayesNet.equals(other)); } - /* ****************************************************************************/ // Test choosing an assignment of conditionals TEST(HybridBayesNet, Choose) { @@ -184,6 +183,24 @@ TEST(HybridBayesNet, OptimizeMultifrontal) { EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); } +/* ****************************************************************************/ +// Test bayes net pruning +TEST(HybridBayesNet, Prune) { + Switching s(4); + + Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + s.linearizedFactorGraph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + + auto prunedBayesNet = hybridBayesNet->prune(2); + HybridValues pruned_delta = prunedBayesNet.optimize(); + + EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete())); + EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); +} + /* ****************************************************************************/ // Test HybridBayesNet serialization. TEST(HybridBayesNet, Serialization) { From d6d44fc3b422051bab6d0858cc6190fbb3121610 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 3 Oct 2022 19:14:13 -0400 Subject: [PATCH 2/2] minor cleanup --- gtsam/hybrid/HybridConditional.h | 2 -- gtsam/inference/BayesTree.h | 1 - 2 files changed, 3 deletions(-) diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index b43bb9945d..93ce33bea1 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -34,8 +34,6 @@ namespace gtsam { -class HybridGaussianFactorGraph; - /** * Hybrid Conditional Density * diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 924a505a2d..b4b07a3573 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -33,7 +33,6 @@ namespace gtsam { // Forward declarations template class FactorGraph; template class EliminatableClusterTree; - class HybridBayesTreeClique; /* ************************************************************************* */ /** clique statistics */