diff --git a/gtsam/discrete/Assignment.h b/gtsam/discrete/Assignment.h index 90e2dbdd83..c6e426ca15 100644 --- a/gtsam/discrete/Assignment.h +++ b/gtsam/discrete/Assignment.h @@ -11,15 +11,17 @@ /** * @file Assignment.h - * @brief An assignment from labels to a discrete value index (size_t) + * @brief An assignment from labels to a discrete value index (size_t) * @author Frank Dellaert * @date Feb 5, 2012 */ #pragma once +#include #include #include +#include #include #include @@ -32,13 +34,30 @@ namespace gtsam { */ template class Assignment : public std::map { + /** + * @brief Default method used by `labelFormatter` or `valueFormatter` when + * printing. + * + * @param x The value passed to format. + * @return std::string + */ + static std::string DefaultFormatter(const L& x) { + std::stringstream ss; + ss << x; + return ss.str(); + } + public: using std::map::operator=; - void print(const std::string& s = "Assignment: ") const { + void print(const std::string& s = "Assignment: ", + const std::function& labelFormatter = + &DefaultFormatter) const { std::cout << s << ": "; - for (const typename Assignment::value_type& keyValue : *this) - std::cout << "(" << keyValue.first << ", " << keyValue.second << ")"; + for (const typename Assignment::value_type& keyValue : *this) { + std::cout << "(" << labelFormatter(keyValue.first) << ", " + << keyValue.second << ")"; + } std::cout << std::endl; } diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 6816dfbf67..3a654ddade 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -119,11 +119,12 @@ void GaussianMixture::print(const std::string &s, "", [&](Key k) { return formatter(k); }, [&](const GaussianConditional::shared_ptr &gf) -> std::string { RedirectCout rd; - if (gf && !gf->empty()) + if (gf && !gf->empty()) { gf->print("", formatter); - else - return {"nullptr"}; - return rd.str(); + return rd.str(); + } else { + return "nullptr"; + } }); } diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index bca50a902e..7889707900 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -31,8 +31,32 @@ static std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { } /* ************************************************************************* */ -HybridBayesNet HybridBayesNet::prune( - const DecisionTreeFactor::shared_ptr &discreteFactor) const { +DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { + AlgebraicDecisionTree decisionTree; + + // The canonical decision tree factor which will get the discrete conditionals + // added to it. + DecisionTreeFactor dtFactor; + + for (size_t i = 0; i < this->size(); i++) { + HybridConditional::shared_ptr conditional = this->at(i); + if (conditional->isDiscrete()) { + // Convert to a DecisionTreeFactor and add it to the main factor. + DecisionTreeFactor f(*conditional->asDiscreteConditional()); + dtFactor = dtFactor * f; + } + } + return boost::make_shared(dtFactor); +} + +/* ************************************************************************* */ +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { + // Get the decision tree of only the discrete keys + auto discreteConditionals = this->discreteConditionals(); + const DecisionTreeFactor::shared_ptr discreteFactor = + boost::make_shared( + discreteConditionals->prune(maxNrLeaves)); + /* To Prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index e84103a50c..b8234d70ab 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -73,6 +73,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { HybridConditional(boost::make_shared(key, table))); } + using Base::push_back; + /// Get a specific Gaussian mixture by index `i`. GaussianMixture::shared_ptr atMixture(size_t i) const; @@ -109,9 +111,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ VectorValues optimize(const DiscreteValues &assignment) const; - /// Prune the Hybrid Bayes Net given the discrete decision tree. - HybridBayesNet prune( - const DecisionTreeFactor::shared_ptr &discreteFactor) const; + protected: + /** + * @brief Get all the discrete conditionals as a decision tree factor. + * + * @return DecisionTreeFactor::shared_ptr + */ + DecisionTreeFactor::shared_ptr discreteConditionals() const; + + public: + /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. + HybridBayesNet prune(size_t maxNrLeaves) const; /// @} diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 3fa4d6268b..8fb487ae20 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -89,12 +89,12 @@ struct HybridAssignmentData { gaussianbayesTree_(gbt) {} /** - * @brief A function used during tree traversal that operators on each node + * @brief A function used during tree traversal that operates on each node * before visiting the node's children. * * @param node The current node being visited. * @param parentData The HybridAssignmentData from the parent node. - * @return HybridAssignmentData + * @return HybridAssignmentData which is passed to the children. */ static HybridAssignmentData AssignmentPreOrderVisitor( const HybridBayesTree::sharedNode& node, @@ -144,4 +144,61 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { return result; } +/* ************************************************************************* */ +void HybridBayesTree::prune(const size_t maxNrLeaves) { + auto decisionTree = boost::dynamic_pointer_cast( + this->roots_.at(0)->conditional()->inner()); + + DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); + decisionTree->root_ = prunedDiscreteFactor.root_; + + /// Helper struct for pruning the hybrid bayes tree. + struct HybridPrunerData { + /// The discrete decision tree after pruning. + DecisionTreeFactor prunedDiscreteFactor; + HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor, + const HybridBayesTree::sharedNode& parentClique) + : prunedDiscreteFactor(prunedDiscreteFactor) {} + + /** + * @brief A function used during tree traversal that operates on each node + * before visiting the node's children. + * + * @param node The current node being visited. + * @param parentData The data from the parent node. + * @return HybridPrunerData which is passed to the children. + */ + static HybridPrunerData AssignmentPreOrderVisitor( + const HybridBayesTree::sharedNode& clique, + HybridPrunerData& parentData) { + // Get the conditional + HybridConditional::shared_ptr conditional = clique->conditional(); + + // If conditional is hybrid, we prune it. + if (conditional->isHybrid()) { + auto gaussianMixture = conditional->asMixture(); + + // Check if the number of discrete keys match, + // else we get an assignment error. + // TODO(Varun) Update prune method to handle assignment subset? + if (gaussianMixture->discreteKeys() == + parentData.prunedDiscreteFactor.discreteKeys()) { + gaussianMixture->prune(parentData.prunedDiscreteFactor); + } + } + return parentData; + } + }; + + HybridPrunerData rootData(prunedDiscreteFactor, 0); + { + treeTraversal::no_op visitorPost; + // Limits OpenMP threads since we're mixing TBB and OpenMP + TbbOpenMPMixedScope threadLimiter; + treeTraversal::DepthFirstForestParallel( + *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, + visitorPost); + } +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 3fa344d4dd..d443e33c4c 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -88,6 +88,13 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { */ VectorValues optimize(const DiscreteValues& assignment) const; + /** + * @brief Prune the underlying Bayes tree. + * + * @param maxNumberLeaves The max number of leaf nodes to keep. + */ + void prune(const size_t maxNumberLeaves); + /// @} private: diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index b43bb9945d..93ce33bea1 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -34,8 +34,6 @@ namespace gtsam { -class HybridGaussianFactorGraph; - /** * Hybrid Conditional Density * diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index fc730f0c97..05a17b000f 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -135,6 +135,28 @@ class HybridFactorGraph : public FactorGraph { push_hybrid(p); } } + + /// Get all the discrete keys in the factor graph. + const KeySet discreteKeys() const { + KeySet discrete_keys; + for (auto& factor : factors_) { + for (const DiscreteKey& k : factor->discreteKeys()) { + discrete_keys.insert(k.first); + } + } + return discrete_keys; + } + + /// Get all the continuous keys in the factor graph. + const KeySet continuousKeys() const { + KeySet keys; + for (auto& factor : factors_) { + for (const Key& key : factor->continuousKeys()) { + keys.insert(key); + } + } + return keys; + } }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 09b592bd69..0faf0e86e4 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -96,8 +96,12 @@ GaussianMixtureFactor::Sum sumFrontals( } } else if (f->isContinuous()) { - deferredFactors.push_back( - boost::dynamic_pointer_cast(f)->inner()); + if (auto gf = boost::dynamic_pointer_cast(f)) { + deferredFactors.push_back(gf->inner()); + } + if (auto cg = boost::dynamic_pointer_cast(f)) { + deferredFactors.push_back(cg->asGaussian()); + } } else if (f->isDiscrete()) { // Don't do anything for discrete-only factors @@ -404,31 +408,9 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { FactorGraph::add(boost::make_shared(factor)); } -/* ************************************************************************ */ -const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const { - KeySet discrete_keys; - for (auto &factor : factors_) { - for (const DiscreteKey &k : factor->discreteKeys()) { - discrete_keys.insert(k.first); - } - } - return discrete_keys; -} - -/* ************************************************************************ */ -const KeySet HybridGaussianFactorGraph::getContinuousKeys() const { - KeySet keys; - for (auto &factor : factors_) { - for (const Key &key : factor->continuousKeys()) { - keys.insert(key); - } - } - return keys; -} - /* ************************************************************************ */ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { - KeySet discrete_keys = getDiscreteKeys(); + KeySet discrete_keys = discreteKeys(); for (auto &factor : factors_) { for (const DiscreteKey &k : factor->discreteKeys()) { discrete_keys.insert(k.first); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index ad5cde09b0..bd24cdeaa7 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -161,12 +161,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph } } - /// Get all the discrete keys in the factor graph. - const KeySet getDiscreteKeys() const; - - /// Get all the continuous keys in the factor graph. - const KeySet getContinuousKeys() const; - /** * @brief Return a Colamd constrained ordering where the discrete keys are * eliminated after the continuous keys. diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 23a95c021d..57e50104d7 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -14,9 +14,10 @@ * @date March 31, 2022 * @author Fan Jiang * @author Frank Dellaert - * @author Richard Roberts + * @author Varun Agrawal */ +#include #include #include #include @@ -41,6 +42,7 @@ HybridGaussianISAM::HybridGaussianISAM(const HybridBayesTree& bayesTree) void HybridGaussianISAM::updateInternal( const HybridGaussianFactorGraph& newFactors, HybridBayesTree::Cliques* orphans, + const boost::optional& maxNrLeaves, const boost::optional& ordering, const HybridBayesTree::Eliminate& function) { // Remove the contaminated part of the Bayes tree @@ -60,23 +62,24 @@ void HybridGaussianISAM::updateInternal( for (const sharedClique& orphan : *orphans) factors += boost::make_shared >(orphan); - KeySet allDiscrete; - for (auto& factor : factors) { - for (auto& k : factor->discreteKeys()) { - allDiscrete.insert(k.first); - } - } + // Get all the discrete keys from the factors + KeySet allDiscrete = factors.discreteKeys(); + + // Create KeyVector with continuous keys followed by discrete keys. KeyVector newKeysDiscreteLast; + // Insert continuous keys first. for (auto& k : newFactorKeys) { if (!allDiscrete.exists(k)) { newKeysDiscreteLast.push_back(k); } } + // Insert discrete keys at the end std::copy(allDiscrete.begin(), allDiscrete.end(), std::back_inserter(newKeysDiscreteLast)); // Get an ordering where the new keys are eliminated last const VariableIndex index(factors); + Ordering elimination_ordering; if (ordering) { elimination_ordering = *ordering; @@ -91,6 +94,10 @@ void HybridGaussianISAM::updateInternal( HybridBayesTree::shared_ptr bayesTree = factors.eliminateMultifrontal(elimination_ordering, function, index); + if (maxNrLeaves) { + bayesTree->prune(*maxNrLeaves); + } + // Re-add into Bayes tree data structures this->roots_.insert(this->roots_.end(), bayesTree->roots().begin(), bayesTree->roots().end()); @@ -99,61 +106,11 @@ void HybridGaussianISAM::updateInternal( /* ************************************************************************* */ void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors, + const boost::optional& maxNrLeaves, const boost::optional& ordering, const HybridBayesTree::Eliminate& function) { Cliques orphans; - this->updateInternal(newFactors, &orphans, ordering, function); -} - -/* ************************************************************************* */ -/** - * @brief Check if `b` is a subset of `a`. - * Non-const since they need to be sorted. - * - * @param a KeyVector - * @param b KeyVector - * @return True if the keys of b is a subset of a, else false. - */ -bool IsSubset(KeyVector a, KeyVector b) { - std::sort(a.begin(), a.end()); - std::sort(b.begin(), b.end()); - return std::includes(a.begin(), a.end(), b.begin(), b.end()); -} - -/* ************************************************************************* */ -void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) { - auto decisionTree = boost::dynamic_pointer_cast( - this->clique(root)->conditional()->inner()); - DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); - decisionTree->root_ = prunedDiscreteFactor.root_; - - std::vector prunedKeys; - for (auto&& clique : nodes()) { - // The cliques can be repeated for each frontal so we record it in - // prunedKeys and check if we have already pruned a particular clique. - if (std::find(prunedKeys.begin(), prunedKeys.end(), clique.first) != - prunedKeys.end()) { - continue; - } - - // Add all the keys of the current clique to be pruned to prunedKeys - for (auto&& key : clique.second->conditional()->frontals()) { - prunedKeys.push_back(key); - } - - // Convert parents() to a KeyVector for comparison - KeyVector parents; - for (auto&& parent : clique.second->conditional()->parents()) { - parents.push_back(parent); - } - - if (IsSubset(parents, decisionTree->keys())) { - auto gaussianMixture = boost::dynamic_pointer_cast( - clique.second->conditional()->inner()); - - gaussianMixture->prune(prunedDiscreteFactor); - } - } + this->updateInternal(newFactors, &orphans, maxNrLeaves, ordering, function); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianISAM.h b/gtsam/hybrid/HybridGaussianISAM.h index ff09787293..a6a82b3e89 100644 --- a/gtsam/hybrid/HybridGaussianISAM.h +++ b/gtsam/hybrid/HybridGaussianISAM.h @@ -48,6 +48,7 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM { void updateInternal( const HybridGaussianFactorGraph& newFactors, HybridBayesTree::Cliques* orphans, + const boost::optional& maxNrLeaves = boost::none, const boost::optional& ordering = boost::none, const HybridBayesTree::Eliminate& function = HybridBayesTree::EliminationTraitsType::DefaultEliminate); @@ -57,20 +58,15 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM { * @brief Perform update step with new factors. * * @param newFactors Factor graph of new factors to add and eliminate. + * @param maxNrLeaves The maximum number of leaves to keep after pruning. + * @param ordering Custom elimination ordering. * @param function Elimination function. */ void update(const HybridGaussianFactorGraph& newFactors, + const boost::optional& maxNrLeaves = boost::none, const boost::optional& ordering = boost::none, const HybridBayesTree::Eliminate& function = HybridBayesTree::EliminationTraitsType::DefaultEliminate); - - /** - * @brief Prune the underlying Bayes tree. - * - * @param root The root key in the discrete conditional decision tree. - * @param maxNumberLeaves - */ - void prune(const Key& root, const size_t maxNumberLeaves); }; /// traits diff --git a/gtsam/hybrid/HybridNonlinearISAM.cpp b/gtsam/hybrid/HybridNonlinearISAM.cpp index 36cda4e807..d05e081dd4 100644 --- a/gtsam/hybrid/HybridNonlinearISAM.cpp +++ b/gtsam/hybrid/HybridNonlinearISAM.cpp @@ -33,7 +33,9 @@ void HybridNonlinearISAM::saveGraph(const string& s, /* ************************************************************************* */ void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors, - const Values& initialValues) { + const Values& initialValues, + const boost::optional& maxNrLeaves, + const boost::optional& ordering) { if (newFactors.size() > 0) { // Reorder and relinearize every reorderInterval updates if (reorderInterval_ > 0 && ++reorderCounter_ >= reorderInterval_) { @@ -51,7 +53,8 @@ void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors, newFactors.linearize(linPoint_); // Update ISAM - isam_.update(*linearizedNewFactors, boost::none, eliminationFunction_); + isam_.update(*linearizedNewFactors, maxNrLeaves, ordering, + eliminationFunction_); } } @@ -66,7 +69,7 @@ void HybridNonlinearISAM::reorder_relinearize() { // Just recreate the whole BayesTree // TODO: allow for constrained ordering here // TODO: decouple relinearization and reordering to avoid - isam_.update(*factors_.linearize(newLinPoint), boost::none, + isam_.update(*factors_.linearize(newLinPoint), boost::none, boost::none, eliminationFunction_); // Update linearization point diff --git a/gtsam/hybrid/HybridNonlinearISAM.h b/gtsam/hybrid/HybridNonlinearISAM.h index d96485fff0..47aa81c558 100644 --- a/gtsam/hybrid/HybridNonlinearISAM.h +++ b/gtsam/hybrid/HybridNonlinearISAM.h @@ -82,12 +82,9 @@ class GTSAM_EXPORT HybridNonlinearISAM { /** * @brief Prune the underlying Bayes tree. * - * @param root The root key in the discrete conditional decision tree. - * @param maxNumberLeaves + * @param maxNumberLeaves The max number of leaf nodes to keep. */ - void prune(const Key& root, const size_t maxNumberLeaves) { - isam_.prune(root, maxNumberLeaves); - } + void prune(const size_t maxNumberLeaves) { isam_.prune(maxNumberLeaves); } /** Return the current linearization point */ const Values& getLinearizationPoint() const { return linPoint_; } @@ -121,7 +118,9 @@ class GTSAM_EXPORT HybridNonlinearISAM { /** Add new factors along with their initial linearization points */ void update(const HybridNonlinearFactorGraph& newFactors, - const Values& initialValues); + const Values& initialValues, + const boost::optional& maxNrLeaves = boost::none, + const boost::optional& ordering = boost::none); /** Relinearization and reordering of variables */ void reorder_relinearize(); diff --git a/gtsam/hybrid/tests/Switching.h b/gtsam/hybrid/tests/Switching.h index 8bcb26c927..3ae8f0bb1c 100644 --- a/gtsam/hybrid/tests/Switching.h +++ b/gtsam/hybrid/tests/Switching.h @@ -115,7 +115,6 @@ inline std::pair> makeBinaryOrdering( /* *************************************************************************** */ using MotionModel = BetweenFactor; -// using MotionMixture = MixtureFactor; // Test fixture with switching network. struct Switching { @@ -125,7 +124,13 @@ struct Switching { HybridGaussianFactorGraph linearizedFactorGraph; Values linearizationPoint; - /// Create with given number of time steps. + /** + * @brief Create with given number of time steps. + * + * @param K The total number of timesteps. + * @param between_sigma The stddev between poses. + * @param prior_sigma The stddev on priors (also used for measurements). + */ Switching(size_t K, double between_sigma = 1.0, double prior_sigma = 0.1) : K(K) { // Create DiscreteKeys for binary K modes, modes[0] will not be used. @@ -166,6 +171,8 @@ struct Switching { linearizationPoint.insert(X(k), static_cast(k)); } + // The ground truth is robot moving forward + // and one less than the linearization point linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint); } diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 0c15ee83da..5885fdcdcc 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -52,6 +52,20 @@ TEST(HybridBayesNet, Creation) { EXPECT(df.equals(expected)); } +/* ****************************************************************************/ +// Test adding a bayes net to another one. +TEST(HybridBayesNet, Add) { + HybridBayesNet bayesNet; + + bayesNet.add(Asia, "99/1"); + + DiscreteConditional expected(Asia, "99/1"); + + HybridBayesNet other; + other.push_back(bayesNet); + EXPECT(bayesNet.equals(other)); +} + /* ****************************************************************************/ // Test choosing an assignment of conditionals TEST(HybridBayesNet, Choose) { @@ -169,6 +183,24 @@ TEST(HybridBayesNet, OptimizeMultifrontal) { EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); } +/* ****************************************************************************/ +// Test bayes net pruning +TEST(HybridBayesNet, Prune) { + Switching s(4); + + Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + s.linearizedFactorGraph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + + auto prunedBayesNet = hybridBayesNet->prune(2); + HybridValues pruned_delta = prunedBayesNet.optimize(); + + EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete())); + EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); +} + /* ****************************************************************************/ // Test HybridBayesNet serialization. TEST(HybridBayesNet, Serialization) { diff --git a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp similarity index 92% rename from gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp rename to gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 40da42412d..d199d76113 100644 --- a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -500,6 +500,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) { } } +/* ************************************************************************* */ TEST(HybridGaussianFactorGraph, optimize) { HybridGaussianFactorGraph hfg; @@ -521,6 +522,46 @@ TEST(HybridGaussianFactorGraph, optimize) { EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0))); } + +/* ************************************************************************* */ +// Test adding of gaussian conditional and re-elimination. +TEST(HybridGaussianFactorGraph, Conditionals) { + Switching switching(4); + HybridGaussianFactorGraph hfg; + + hfg.push_back(switching.linearizedFactorGraph.at(0)); // P(X1) + Ordering ordering; + ordering.push_back(X(1)); + HybridBayesNet::shared_ptr bayes_net = hfg.eliminateSequential(ordering); + + hfg.push_back(switching.linearizedFactorGraph.at(1)); // P(X1, X2 | M1) + hfg.push_back(*bayes_net); + hfg.push_back(switching.linearizedFactorGraph.at(2)); // P(X2, X3 | M2) + hfg.push_back(switching.linearizedFactorGraph.at(5)); // P(M1) + ordering.push_back(X(2)); + ordering.push_back(X(3)); + ordering.push_back(M(1)); + ordering.push_back(M(2)); + + bayes_net = hfg.eliminateSequential(ordering); + + HybridValues result = bayes_net->optimize(); + + Values expected_continuous; + expected_continuous.insert(X(1), 0); + expected_continuous.insert(X(2), 1); + expected_continuous.insert(X(3), 2); + expected_continuous.insert(X(4), 4); + Values result_continuous = + switching.linearizationPoint.retract(result.continuous()); + EXPECT(assert_equal(expected_continuous, result_continuous)); + + DiscreteValues expected_discrete; + expected_discrete[M(1)] = 1; + expected_discrete[M(2)] = 1; + EXPECT(assert_equal(expected_discrete, result.discrete())); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridIncremental.cpp b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp similarity index 99% rename from gtsam/hybrid/tests/testHybridIncremental.cpp rename to gtsam/hybrid/tests/testHybridGaussianISAM.cpp index 8e16d02b93..a0a87933f4 100644 --- a/gtsam/hybrid/tests/testHybridIncremental.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp @@ -235,7 +235,7 @@ TEST(HybridGaussianElimination, Approx_inference) { size_t maxNrLeaves = 5; incrementalHybrid.update(graph1); - incrementalHybrid.prune(M(3), maxNrLeaves); + incrementalHybrid.prune(maxNrLeaves); /* unpruned factor is: @@ -329,7 +329,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) { // Run update with pruning size_t maxComponents = 5; incrementalHybrid.update(graph1); - incrementalHybrid.prune(M(3), maxComponents); + incrementalHybrid.prune(maxComponents); // Check if we have a bayes tree with 4 hybrid nodes, // each with 2, 4, 8, and 5 (pruned) leaves respetively. @@ -350,7 +350,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) { // Run update with pruning a second time. incrementalHybrid.update(graph2); - incrementalHybrid.prune(M(4), maxComponents); + incrementalHybrid.prune(maxComponents); // Check if we have a bayes tree with pruned hybrid nodes, // with 5 (pruned) leaves. @@ -496,7 +496,7 @@ TEST(HybridGaussianISAM, NonTrivial) { // The MHS at this point should be a 2 level tree on (1, 2). // 1 has 2 choices, and 2 has 4 choices. inc.update(gfg); - inc.prune(M(2), 2); + inc.prune(2); /*************** Run Round 4 ***************/ // Add odometry factor with discrete modes. @@ -531,7 +531,7 @@ TEST(HybridGaussianISAM, NonTrivial) { // Keep pruning! inc.update(gfg); - inc.prune(M(3), 3); + inc.prune(3); // The final discrete graph should not be empty since we have eliminated // all continuous variables. diff --git a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp index 46076b3306..4e1710c424 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp @@ -256,7 +256,7 @@ TEST(HybridNonlinearISAM, Approx_inference) { incrementalHybrid.update(graph1, initial); HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree(); - bayesTree.prune(M(3), maxNrLeaves); + bayesTree.prune(maxNrLeaves); /* unpruned factor is: @@ -355,7 +355,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) { incrementalHybrid.update(graph1, initial); HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree(); - bayesTree.prune(M(3), maxComponents); + bayesTree.prune(maxComponents); // Check if we have a bayes tree with 4 hybrid nodes, // each with 2, 4, 8, and 5 (pruned) leaves respetively. @@ -380,7 +380,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) { incrementalHybrid.update(graph2, initial); bayesTree = incrementalHybrid.bayesTree(); - bayesTree.prune(M(4), maxComponents); + bayesTree.prune(maxComponents); // Check if we have a bayes tree with pruned hybrid nodes, // with 5 (pruned) leaves. @@ -482,8 +482,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { still = boost::make_shared(W(1), W(2), Pose2(0, 0, 0), noise_model); moving = - boost::make_shared(W(1), W(2), odometry, - noise_model); + boost::make_shared(W(1), W(2), odometry, noise_model); components = {moving, still}; mixtureFactor = boost::make_shared( contKeys, DiscreteKeys{gtsam::DiscreteKey(M(2), 2)}, components); @@ -515,7 +514,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { // The MHS at this point should be a 2 level tree on (1, 2). // 1 has 2 choices, and 2 has 4 choices. inc.update(fg, initial); - inc.prune(M(2), 2); + inc.prune(2); fg = HybridNonlinearFactorGraph(); initial = Values(); @@ -526,8 +525,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { still = boost::make_shared(W(2), W(3), Pose2(0, 0, 0), noise_model); moving = - boost::make_shared(W(2), W(3), odometry, - noise_model); + boost::make_shared(W(2), W(3), odometry, noise_model); components = {moving, still}; mixtureFactor = boost::make_shared( contKeys, DiscreteKeys{gtsam::DiscreteKey(M(3), 2)}, components); @@ -551,7 +549,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { // Keep pruning! inc.update(fg, initial); - inc.prune(M(3), 3); + inc.prune(3); fg = HybridNonlinearFactorGraph(); initial = Values(); @@ -560,8 +558,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { // The final discrete graph should not be empty since we have eliminated // all continuous variables. - auto discreteTree = - bayesTree[M(3)]->conditional()->asDiscreteConditional(); + auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional(); EXPECT_LONGS_EQUAL(3, discreteTree->size()); // Test if the optimal discrete mode assignment is (1, 1, 1). diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 924a505a2d..b4b07a3573 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -33,7 +33,6 @@ namespace gtsam { // Forward declarations template class FactorGraph; template class EliminatableClusterTree; - class HybridBayesTreeClique; /* ************************************************************************* */ /** clique statistics */