From 22c758221ddd26c58923454f5f12428039163e2e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 30 Dec 2022 19:28:42 +0530 Subject: [PATCH 01/16] make GaussianMixtureFactor store the normalizing constant as well --- gtsam/hybrid/GaussianMixture.cpp | 3 +- gtsam/hybrid/GaussianMixtureFactor.cpp | 39 +++++++++++++------- gtsam/hybrid/GaussianMixtureFactor.h | 18 ++++++--- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 43 +++++++++++++--------- 4 files changed, 66 insertions(+), 37 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 155cae10b1..ddcfaf0e82 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -150,7 +150,8 @@ boost::shared_ptr GaussianMixture::likelihood( const KeyVector continuousParentKeys = continuousParents(); const GaussianMixtureFactor::Factors likelihoods( conditionals(), [&](const GaussianConditional::shared_ptr &conditional) { - return conditional->likelihood(frontals); + return std::make_pair(conditional->likelihood(frontals), + 0.5 * conditional->logDeterminant()); }); return boost::make_shared( continuousParentKeys, discreteParentKeys, likelihoods); diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 32ca1432cd..0759cf3be5 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -29,8 +29,11 @@ namespace gtsam { /* *******************************************************************************/ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Factors &factors) - : Base(continuousKeys, discreteKeys), factors_(factors) {} + const Mixture &factors) + : Base(continuousKeys, discreteKeys), + factors_(factors, [](const GaussianFactor::shared_ptr &gf) { + return std::make_pair(gf, 0.0); + }) {} /* *******************************************************************************/ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { @@ -44,9 +47,9 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { // Check the base and the factors: return Base::equals(*e, tol) && factors_.equals(e->factors_, - [tol](const GaussianFactor::shared_ptr &f1, - const GaussianFactor::shared_ptr &f2) { - return f1->equals(*f2, tol); + [tol](const GaussianMixtureFactor::FactorAndLogZ &f1, + const GaussianMixtureFactor::FactorAndLogZ &f2) { + return f1.first->equals(*(f2.first), tol); }); } @@ -60,7 +63,8 @@ void GaussianMixtureFactor::print(const std::string &s, } else { factors_.print( "", [&](Key k) { return formatter(k); }, - [&](const GaussianFactor::shared_ptr &gf) -> std::string { + [&](const GaussianMixtureFactor::FactorAndLogZ &gf_z) -> std::string { + auto gf = gf_z.first; RedirectCout rd; std::cout << ":\n"; if (gf && !gf->empty()) { @@ -75,8 +79,10 @@ void GaussianMixtureFactor::print(const std::string &s, } /* *******************************************************************************/ -const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() { - return factors_; +const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() { + // Unzip to tree of Gaussian factors and tree of log-constants, + // and return the first tree. + return unzip(factors_).first; } /* *******************************************************************************/ @@ -95,9 +101,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add( /* *******************************************************************************/ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() const { - auto wrap = [](const GaussianFactor::shared_ptr &factor) { + auto wrap = [](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { GaussianFactorGraph result; - result.push_back(factor); + result.push_back(factor_z.first); return result; }; return {factors_, wrap}; @@ -108,8 +114,11 @@ AlgebraicDecisionTree GaussianMixtureFactor::error( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. auto errorFunc = - [continuousValues](const GaussianFactor::shared_ptr &factor) { - return factor->error(continuousValues); + [continuousValues](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { + GaussianFactor::shared_ptr factor; + double log_z; + std::tie(factor, log_z) = factor_z; + return factor->error(continuousValues) + log_z; }; DecisionTree errorTree(factors_, errorFunc); return errorTree; @@ -120,8 +129,10 @@ double GaussianMixtureFactor::error( const VectorValues &continuousValues, const DiscreteValues &discreteValues) const { // Directly index to get the conditional, no need to build the whole tree. - auto factor = factors_(discreteValues); - return factor->error(continuousValues); + GaussianFactor::shared_ptr factor; + double log_z; + std::tie(factor, log_z) = factors_(discreteValues); + return factor->error(continuousValues) + log_z; } } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index b8f475de3f..b3e603bc33 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -54,8 +54,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { using Sum = DecisionTree; - /// typedef for Decision Tree of Gaussian Factors - using Factors = DecisionTree; + /// typedef of pair of Gaussian factor and log of normalizing constant. + using FactorAndLogZ = std::pair; + /// typedef for Decision Tree of Gaussian Factors and log-constant. + using Factors = DecisionTree; + using Mixture = DecisionTree; private: /// Decision tree of Gaussian factors indexed by discrete keys. @@ -87,7 +90,12 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const Factors &factors); + const Mixture &factors); + + GaussianMixtureFactor(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys, + const Factors &factors_and_z) + : Base(continuousKeys, discreteKeys), factors_(factors_and_z) {} /** * @brief Construct a new GaussianMixtureFactor object using a vector of @@ -101,7 +109,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { const DiscreteKeys &discreteKeys, const std::vector &factors) : GaussianMixtureFactor(continuousKeys, discreteKeys, - Factors(discreteKeys, factors)) {} + Mixture(discreteKeys, factors)) {} /// @} /// @name Testable @@ -115,7 +123,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { /// @} /// Getter for the underlying Gaussian Factor Decision Tree. - const Factors &factors(); + const Mixture factors(); /** * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index aac37bc247..15a84b27a4 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -204,16 +204,17 @@ hybridElimination(const HybridGaussianFactorGraph &factors, }; sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); - using EliminationPair = GaussianFactorGraph::EliminationResult; + using EliminationPair = + std::pair, + std::pair, double>>; KeyVector keysOfEliminated; // Not the ordering KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)? // This is the elimination method on the leaf nodes - auto eliminate = [&](const GaussianFactorGraph &graph) - -> GaussianFactorGraph::EliminationResult { + auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair { if (graph.empty()) { - return {nullptr, nullptr}; + return {nullptr, std::make_pair(nullptr, 0.0)}; } #ifdef HYBRID_TIMING @@ -222,17 +223,21 @@ hybridElimination(const HybridGaussianFactorGraph &factors, std::pair, boost::shared_ptr> - result = EliminatePreferCholesky(graph, frontalKeys); + conditional_factor = EliminatePreferCholesky(graph, frontalKeys); // Initialize the keysOfEliminated to be the keys of the // eliminated GaussianConditional - keysOfEliminated = result.first->keys(); - keysOfSeparator = result.second->keys(); + keysOfEliminated = conditional_factor.first->keys(); + keysOfSeparator = conditional_factor.second->keys(); #ifdef HYBRID_TIMING gttoc_(hybrid_eliminate); #endif + std::pair, + std::pair, double>> + result = std::make_pair(conditional_factor.first, + std::make_pair(conditional_factor.second, 0.0)); return result; }; @@ -257,16 +262,20 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // DiscreteFactor, with the error for each discrete choice. if (keysOfSeparator.empty()) { VectorValues empty_values; - auto factorProb = [&](const GaussianFactor::shared_ptr &factor) { - if (!factor) { - return 0.0; // If nullptr, return 0.0 probability - } else { - // This is the probability q(μ) at the MLE point. - double error = - 0.5 * std::abs(factor->augmentedInformation().determinant()); - return std::exp(-error); - } - }; + auto factorProb = + [&](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { + if (!factor_z.first) { + return 0.0; // If nullptr, return 0.0 probability + } else { + GaussianFactor::shared_ptr factor = factor_z.first; + double log_z = factor_z.second; + // This is the probability q(μ) at the MLE point. + double error = + 0.5 * std::abs(factor->augmentedInformation().determinant()) + + log_z; + return std::exp(-error); + } + }; DecisionTree fdt(separatorFactors, factorProb); auto discreteFactor = From 38a6154c5528002a08604b1367632d6a848021bb Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 30 Dec 2022 20:33:50 +0530 Subject: [PATCH 02/16] update test --- gtsam/hybrid/tests/testGaussianMixture.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 242c9ba412..56dc24cf1b 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -179,7 +179,8 @@ TEST(GaussianMixture, Likelihood) { const GaussianMixtureFactor::Factors factors( gm.conditionals(), [measurements](const GaussianConditional::shared_ptr& conditional) { - return conditional->likelihood(measurements); + return std::make_pair(conditional->likelihood(measurements), + 0.5 * conditional->logDeterminant()); }); const GaussianMixtureFactor expected({X(0)}, {mode}, factors); EXPECT(assert_equal(*factor, expected)); From b972be0b8f6f2a9f0f21fbfd42d89b57d93b5588 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 12:09:56 -0500 Subject: [PATCH 03/16] Change from pair to small struct --- gtsam/hybrid/GaussianMixture.cpp | 7 +-- gtsam/hybrid/GaussianMixtureFactor.cpp | 51 +++++++++------------ gtsam/hybrid/GaussianMixtureFactor.h | 53 +++++++++++++--------- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 9 ++-- 4 files changed, 62 insertions(+), 58 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index ddcfaf0e82..10521244fc 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -149,9 +149,10 @@ boost::shared_ptr GaussianMixture::likelihood( const DiscreteKeys discreteParentKeys = discreteKeys(); const KeyVector continuousParentKeys = continuousParents(); const GaussianMixtureFactor::Factors likelihoods( - conditionals(), [&](const GaussianConditional::shared_ptr &conditional) { - return std::make_pair(conditional->likelihood(frontals), - 0.5 * conditional->logDeterminant()); + conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { + return GaussianMixtureFactor::FactorAndConstant{ + conditional->likelihood(frontals), + 0.5 * conditional->logDeterminant()}; }); return boost::make_shared( continuousParentKeys, discreteParentKeys, likelihoods); diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 0759cf3be5..e07b300faa 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -22,6 +22,8 @@ #include #include #include +#include +#include #include namespace gtsam { @@ -32,7 +34,7 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, const Mixture &factors) : Base(continuousKeys, discreteKeys), factors_(factors, [](const GaussianFactor::shared_ptr &gf) { - return std::make_pair(gf, 0.0); + return FactorAndConstant{gf, 0.0}; }) {} /* *******************************************************************************/ @@ -46,11 +48,11 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { // Check the base and the factors: return Base::equals(*e, tol) && - factors_.equals(e->factors_, - [tol](const GaussianMixtureFactor::FactorAndLogZ &f1, - const GaussianMixtureFactor::FactorAndLogZ &f2) { - return f1.first->equals(*(f2.first), tol); - }); + factors_.equals(e->factors_, [tol](const FactorAndConstant &f1, + const FactorAndConstant &f2) { + return f1.factor->equals(*(f2.factor), tol) && + std::abs(f1.constant - f2.constant) < tol; + }); } /* *******************************************************************************/ @@ -63,8 +65,8 @@ void GaussianMixtureFactor::print(const std::string &s, } else { factors_.print( "", [&](Key k) { return formatter(k); }, - [&](const GaussianMixtureFactor::FactorAndLogZ &gf_z) -> std::string { - auto gf = gf_z.first; + [&](const FactorAndConstant &gf_z) -> std::string { + auto gf = gf_z.factor; RedirectCout rd; std::cout << ":\n"; if (gf && !gf->empty()) { @@ -79,10 +81,10 @@ void GaussianMixtureFactor::print(const std::string &s, } /* *******************************************************************************/ -const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() { - // Unzip to tree of Gaussian factors and tree of log-constants, - // and return the first tree. - return unzip(factors_).first; +const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const { + return Mixture(factors_, [](const FactorAndConstant &factor_z) { + return factor_z.factor; + }); } /* *******************************************************************************/ @@ -101,9 +103,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add( /* *******************************************************************************/ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() const { - auto wrap = [](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { + auto wrap = [](const FactorAndConstant &factor_z) { GaussianFactorGraph result; - result.push_back(factor_z.first); + result.push_back(factor_z.factor); return result; }; return {factors_, wrap}; @@ -113,26 +115,17 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() AlgebraicDecisionTree GaussianMixtureFactor::error( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = - [continuousValues](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { - GaussianFactor::shared_ptr factor; - double log_z; - std::tie(factor, log_z) = factor_z; - return factor->error(continuousValues) + log_z; - }; + auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) { + return factor_z.error(continuousValues); + }; DecisionTree errorTree(factors_, errorFunc); return errorTree; } /* *******************************************************************************/ -double GaussianMixtureFactor::error( - const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { - // Directly index to get the conditional, no need to build the whole tree. - GaussianFactor::shared_ptr factor; - double log_z; - std::tie(factor, log_z) = factors_(discreteValues); - return factor->error(continuousValues) + log_z; +double GaussianMixtureFactor::error(const HybridValues &values) const { + const FactorAndConstant factor_z = factors_(values.discrete()); + return factor_z.factor->error(values.continuous()) + factor_z.constant; } } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index b3e603bc33..aca4f365b3 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -23,17 +23,15 @@ #include #include #include -#include -#include +#include #include -#include namespace gtsam { class GaussianFactorGraph; - -// Needed for wrapper. -using GaussianFactorVector = std::vector; +class HybridValues; +class DiscreteValues; +class VectorValues; /** * @brief Implementation of a discrete conditional mixture factor. @@ -53,12 +51,27 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { using shared_ptr = boost::shared_ptr; using Sum = DecisionTree; + using sharedFactor = boost::shared_ptr; + + /// Gaussian factor and log of normalizing constant. + struct FactorAndConstant { + sharedFactor factor; + double constant; + + // Return error with constant added. + double error(const VectorValues &values) const { + return factor->error(values) + constant; + } + + // Check pointer equality. + bool operator==(const FactorAndConstant &other) const { + return factor == other.factor && constant == other.constant; + } + }; - /// typedef of pair of Gaussian factor and log of normalizing constant. - using FactorAndLogZ = std::pair; - /// typedef for Decision Tree of Gaussian Factors and log-constant. - using Factors = DecisionTree; - using Mixture = DecisionTree; + /// typedef for Decision Tree of Gaussian factors and log-constant. + using Factors = DecisionTree; + using Mixture = DecisionTree; private: /// Decision tree of Gaussian factors indexed by discrete keys. @@ -85,7 +98,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @param continuousKeys A vector of keys representing continuous variables. * @param discreteKeys A vector of keys representing discrete variables and * their cardinalities. - * @param factors The decision tree of Gaussian Factors stored as the mixture + * @param factors The decision tree of Gaussian factors stored as the mixture * density. */ GaussianMixtureFactor(const KeyVector &continuousKeys, @@ -107,7 +120,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { */ GaussianMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const std::vector &factors) + const std::vector &factors) : GaussianMixtureFactor(continuousKeys, discreteKeys, Mixture(discreteKeys, factors)) {} @@ -121,9 +134,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { const std::string &s = "GaussianMixtureFactor\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard API + /// @{ /// Getter for the underlying Gaussian Factor Decision Tree. - const Mixture factors(); + const Mixture factors() const; /** * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while @@ -145,21 +160,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { AlgebraicDecisionTree error(const VectorValues &continuousValues) const; /** - * @brief Compute the error of this Gaussian Mixture given the continuous - * values and a discrete assignment. - * - * @param continuousValues Continuous values at which to compute the error. - * @param discreteValues The discrete assignment for a specific mode sequence. + * @brief Compute the log-likelihood, including the log-normalizing constant. * @return double */ - double error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const; + double error(const HybridValues &values) const; /// Add MixtureFactor to a Sum, syntactic sugar. friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { sum = factor.add(sum); return sum; } + /// @} }; // traits diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 15a84b27a4..5c1c2daf3b 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -263,16 +263,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors, if (keysOfSeparator.empty()) { VectorValues empty_values; auto factorProb = - [&](const GaussianMixtureFactor::FactorAndLogZ &factor_z) { - if (!factor_z.first) { + [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { + GaussianFactor::shared_ptr factor = factor_z.factor; + if (!factor) { return 0.0; // If nullptr, return 0.0 probability } else { - GaussianFactor::shared_ptr factor = factor_z.first; - double log_z = factor_z.second; // This is the probability q(μ) at the MLE point. double error = 0.5 * std::abs(factor->augmentedInformation().determinant()) + - log_z; + factor_z.constant; return std::exp(-error); } }; From 9cf3e5c26aacaf412450a90fe9503f129a6842da Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 12:10:16 -0500 Subject: [PATCH 04/16] Switch to using HybridValues --- gtsam/hybrid/GaussianMixture.cpp | 14 ++++----- gtsam/hybrid/GaussianMixture.h | 12 ++++---- gtsam/hybrid/HybridBayesNet.cpp | 10 +++---- gtsam/hybrid/HybridBayesNet.h | 6 ++-- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 33 +++++++++++----------- gtsam/hybrid/HybridGaussianFactorGraph.h | 13 ++------- 6 files changed, 38 insertions(+), 50 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 10521244fc..05864a6e4d 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -159,9 +160,9 @@ boost::shared_ptr GaussianMixture::likelihood( } /* ************************************************************************* */ -std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { +std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { std::set s; - s.insert(dkeys.begin(), dkeys.end()); + s.insert(discreteKeys.begin(), discreteKeys.end()); return s; } @@ -186,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { const GaussianConditional::shared_ptr &conditional) -> GaussianConditional::shared_ptr { // typecast so we can use this to get probability value - DiscreteValues values(choices); + const DiscreteValues values(choices); // Case where the gaussian mixture has the same // discrete keys as the decision tree. @@ -256,11 +257,10 @@ AlgebraicDecisionTree GaussianMixture::error( } /* *******************************************************************************/ -double GaussianMixture::error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { +double GaussianMixture::error(const HybridValues &values) const { // Directly index to get the conditional, no need to build the whole tree. - auto conditional = conditionals_(discreteValues); - return conditional->error(continuousValues); + auto conditional = conditionals_(values.discrete()); + return conditional->error(values.continuous()); } } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 2cdc23b46d..4df1bd90ca 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -30,6 +30,7 @@ namespace gtsam { class GaussianMixtureFactor; +class HybridValues; /** * @brief A conditional of gaussian mixtures indexed by discrete variables, as @@ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture /// @name Constructors /// @{ - /// Defaut constructor, mainly for serialization. + /// Default constructor, mainly for serialization. GaussianMixture() = default; /** @@ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture /// @name Standard API /// @{ + /// @brief Return the conditional Gaussian for the given discrete assignment. GaussianConditional::shared_ptr operator()( const DiscreteValues &discreteValues) const; @@ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture * @brief Compute the error of this Gaussian Mixture given the continuous * values and a discrete assignment. * - * @param continuousValues Continuous values at which to compute the error. - * @param discreteValues The discrete assignment for a specific mode sequence. + * @param values Continuous values and discrete assignment. * @return double */ - double error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const; + double error(const HybridValues &values) const; /** * @brief Prune the decision tree of Gaussian factors as per the discrete @@ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture }; /// Return the DiscreteKey vector as a set. -std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys); +std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys); // traits template <> diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 8e01c0c76f..8be314c4e2 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -1,5 +1,5 @@ /* ---------------------------------------------------------------------------- - * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * GTSAM Copyright 2010-2022, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) @@ -12,6 +12,7 @@ * @author Fan Jiang * @author Varun Agrawal * @author Shangjie Xue + * @author Frank Dellaert * @date January 2022 */ @@ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const { } /* ************************************************************************* */ -double HybridBayesNet::error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { - GaussianBayesNet gbn = choose(discreteValues); - return gbn.error(continuousValues); +double HybridBayesNet::error(const HybridValues &values) const { + GaussianBayesNet gbn = choose(values.discrete()); + return gbn.error(values.continuous()); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index a64b3bb4f0..0d2c337b7f 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief 0.5 * sum of squared Mahalanobis distances * for a specific discrete assignment. * - * @param continuousValues Continuous values at which to compute the error. - * @param discreteValues Discrete assignment for a specific mode sequence. + * @param values Continuous values and discrete assignment. * @return double */ - double error(const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const; + double error(const HybridValues &values) const; /** * @brief Compute conditional error for each discrete assignment, diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 5c1c2daf3b..de55114b3c 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -55,13 +55,14 @@ namespace gtsam { +/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: template class EliminateableFactorGraph; /* ************************************************************************ */ static GaussianMixtureFactor::Sum &addGaussian( GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { using Y = GaussianFactorGraph; - // If the decision tree is not intiialized, then intialize it. + // If the decision tree is not initialized, then initialize it. if (sum.empty()) { GaussianFactorGraph result; result.push_back(factor); @@ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals( for (auto &f : factors) { if (f->isHybrid()) { - if (auto cgmf = boost::dynamic_pointer_cast(f)) { - sum = cgmf->add(sum); + // TODO(dellaert): just use a virtual method defined in HybridFactor. + if (auto gm = boost::dynamic_pointer_cast(f)) { + sum = gm->add(sum); } if (auto gm = boost::dynamic_pointer_cast(f)) { sum = gm->asMixture()->add(sum); @@ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, const KeySet &continuousSeparator, const std::set &discreteSeparatorSet) { // NOTE: since we use the special JunctionTree, - // only possiblity is continuous conditioned on discrete. + // only possibility is continuous conditioned on discrete. DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), discreteSeparatorSet.end()); @@ -251,8 +253,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // Separate out decision tree into conditionals and remaining factors. auto pair = unzip(eliminationResults); - - const GaussianMixtureFactor::Factors &separatorFactors = pair.second; + const auto &separatorFactors = pair.second; // Create the GaussianMixture from the conditionals auto conditional = boost::make_shared( @@ -460,6 +461,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( // Iterate over each factor. for (size_t idx = 0; idx < size(); idx++) { + // TODO(dellaert): just use a virtual method defined in HybridFactor. AlgebraicDecisionTree factor_error; if (factors_.at(idx)->isHybrid()) { @@ -499,27 +501,26 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( } /* ************************************************************************ */ -double HybridGaussianFactorGraph::error( - const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { +double HybridGaussianFactorGraph::error(const HybridValues &values) const { double error = 0.0; for (size_t idx = 0; idx < size(); idx++) { + // TODO(dellaert): just use a virtual method defined in HybridFactor. auto factor = factors_.at(idx); if (factor->isHybrid()) { if (auto c = boost::dynamic_pointer_cast(factor)) { - error += c->asMixture()->error(continuousValues, discreteValues); + error += c->asMixture()->error(values); } if (auto f = boost::dynamic_pointer_cast(factor)) { - error += f->error(continuousValues, discreteValues); + error += f->error(values); } } else if (factor->isContinuous()) { if (auto f = boost::dynamic_pointer_cast(factor)) { - error += f->inner()->error(continuousValues); + error += f->inner()->error(values.continuous()); } if (auto cg = boost::dynamic_pointer_cast(factor)) { - error += cg->asGaussian()->error(continuousValues); + error += cg->asGaussian()->error(values.continuous()); } } } @@ -527,10 +528,8 @@ double HybridGaussianFactorGraph::error( } /* ************************************************************************ */ -double HybridGaussianFactorGraph::probPrime( - const VectorValues &continuousValues, - const DiscreteValues &discreteValues) const { - double error = this->error(continuousValues, discreteValues); +double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { + double error = this->error(values); // NOTE: The 0.5 term is handled by each factor return std::exp(-error); } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 4e22bed7ca..3a6eaa905a 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -186,14 +186,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * @brief Compute error given a continuous vector values * and a discrete assignment. * - * @param continuousValues The continuous VectorValues - * for computing the error. - * @param discreteValues The specific discrete assignment - * whose error we wish to compute. * @return double */ - double error(const VectorValues& continuousValues, - const DiscreteValues& discreteValues) const; + double error(const HybridValues& values) const; /** * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ @@ -210,13 +205,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * @brief Compute the unnormalized posterior probability for a continuous * vector values given a specific assignment. * - * @param continuousValues The vector values for which to compute the - * posterior probability. - * @param discreteValues The specific assignment to use for the computation. * @return double */ - double probPrime(const VectorValues& continuousValues, - const DiscreteValues& discreteValues) const; + double probPrime(const HybridValues& values) const; /** * @brief Return a Colamd constrained ordering where the discrete keys are From beda6878aad3cf9b00b28c11a9d974f04c95cafd Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 13:08:22 -0500 Subject: [PATCH 05/16] Fixed tests --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 13 ++++--------- gtsam/hybrid/HybridGaussianFactorGraph.h | 3 ++- gtsam/hybrid/tests/testGaussianMixture.cpp | 4 ++-- gtsam/hybrid/tests/testGaussianMixtureFactor.cpp | 2 +- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index de55114b3c..5ea677ab54 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -206,9 +206,8 @@ hybridElimination(const HybridGaussianFactorGraph &factors, }; sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); - using EliminationPair = - std::pair, - std::pair, double>>; + using EliminationPair = std::pair, + GaussianMixtureFactor::FactorAndConstant>; KeyVector keysOfEliminated; // Not the ordering KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)? @@ -216,7 +215,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // This is the elimination method on the leaf nodes auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair { if (graph.empty()) { - return {nullptr, std::make_pair(nullptr, 0.0)}; + return {nullptr, {nullptr, 0.0}}; } #ifdef HYBRID_TIMING @@ -236,11 +235,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, gttoc_(hybrid_eliminate); #endif - std::pair, - std::pair, double>> - result = std::make_pair(conditional_factor.first, - std::make_pair(conditional_factor.second, 0.0)); - return result; + return {conditional_factor.first, {conditional_factor.second, 0.0}}; }; // Perform elimination! diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 3a6eaa905a..c851adfe5f 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -12,7 +12,7 @@ /** * @file HybridGaussianFactorGraph.h * @brief Linearized Hybrid factor graph that uses type erasure - * @author Fan Jiang, Varun Agrawal + * @author Fan Jiang, Varun Agrawal, Frank Dellaert * @date Mar 11, 2022 */ @@ -38,6 +38,7 @@ class HybridBayesTree; class HybridJunctionTree; class DecisionTreeFactor; class JacobianFactor; +class HybridValues; /** * @brief Main elimination function for HybridGaussianFactorGraph. diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 56dc24cf1b..22c2a46219 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -128,9 +128,9 @@ TEST(GaussianMixture, Error) { // Regression for non-tree version. DiscreteValues assignment; assignment[M(1)] = 0; - EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8); + EXPECT_DOUBLES_EQUAL(0.5, mixture.error({values, assignment}), 1e-8); assignment[M(1)] = 1; - EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), + EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error({values, assignment}), 1e-8); } diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index ba0622ff9d..ee4d4469be 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -188,7 +188,7 @@ TEST(GaussianMixtureFactor, Error) { DiscreteValues discreteValues; discreteValues[m1.first] = 1; EXPECT_DOUBLES_EQUAL( - 4.0, mixtureFactor.error(continuousValues, discreteValues), 1e-9); + 4.0, mixtureFactor.error({continuousValues, discreteValues}), 1e-9); } /* ************************************************************************* */ From 078e6b0b62d1f172b78b42b1d854d56197b214c4 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 13:08:29 -0500 Subject: [PATCH 06/16] Fixed wrapper --- gtsam/hybrid/hybrid.i | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 15687d11b2..3c74d1ee20 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -183,10 +183,8 @@ class HybridGaussianFactorGraph { bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const; // evaluation - double error(const gtsam::VectorValues& continuousValues, - const gtsam::DiscreteValues& discreteValues) const; - double probPrime(const gtsam::VectorValues& continuousValues, - const gtsam::DiscreteValues& discreteValues) const; + double error(const gtsam::HybridValues& values) const; + double probPrime(const gtsam::HybridValues& values) const; gtsam::HybridBayesNet* eliminateSequential(); gtsam::HybridBayesNet* eliminateSequential( From 23eec0bc6a1c50494432f5871376c0f856ab17b3 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 13:08:38 -0500 Subject: [PATCH 07/16] factor_graph_from_bayes_net --- python/gtsam/tests/test_HybridFactorGraph.py | 74 +++++++++++--------- 1 file changed, 41 insertions(+), 33 deletions(-) diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 481617db16..2d3513f121 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -6,7 +6,7 @@ See LICENSE for the license information Unit tests for Hybrid Factor Graphs. -Author: Fan Jiang +Author: Fan Jiang, Varun Agrawal, Frank Dellaert """ # pylint: disable=invalid-name, no-name-in-module, no-member @@ -25,6 +25,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): """Unit tests for HybridGaussianFactorGraph.""" + def test_create(self): """Test construction of hybrid factor graph.""" model = noiseModel.Unit.Create(3) @@ -117,23 +118,23 @@ def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet: return bayesNet - def test_tiny(self): - """Test a tiny two variable hybrid model.""" - bayesNet = self.tiny() - sample = bayesNet.sample() - # print(sample) - - # Create a factor graph from the Bayes net with sampled measurements. + @staticmethod + def factor_graph_from_bayes_net(bayesNet: gtsam.HybridBayesNet, sample: gtsam.HybridValues): + """Create a factor graph from the Bayes net with sampled measurements. + The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...` + and thus represents the same joint probability as the Bayes net. + """ fg = HybridGaussianFactorGraph() - conditional = bayesNet.atMixture(0) - measurement = gtsam.VectorValues() - measurement.insert(Z(0), sample.at(Z(0))) - factor = conditional.likelihood(measurement) - fg.push_back(factor) - fg.push_back(bayesNet.atGaussian(1)) - fg.push_back(bayesNet.atDiscrete(2)) - - self.assertEqual(fg.size(), 3) + num_measurements = bayesNet.size() - 2 + for i in range(num_measurements): + conditional = bayesNet.atMixture(i) + measurement = gtsam.VectorValues() + measurement.insert(Z(i), sample.at(Z(i))) + factor = conditional.likelihood(measurement) + fg.push_back(factor) + fg.push_back(bayesNet.atGaussian(num_measurements)) + fg.push_back(bayesNet.atDiscrete(num_measurements+1)) + return fg @staticmethod def calculate_ratio(bayesNet, fg, sample): @@ -143,6 +144,26 @@ def calculate_ratio(bayesNet, fg, sample): return bayesNet.evaluate(sample) / fg.probPrime( continuous, sample.discrete()) + def test_tiny(self): + """Test a tiny two variable hybrid model.""" + bayesNet = self.tiny() + sample = bayesNet.sample() + # print(sample) + + # TODO(dellaert): do importance sampling to get an estimate P(mode) + prior = self.tiny(num_measurements=0) # just P(x0)P(mode) + for s in range(100): + proposed = prior.sample() + print(proposed) + for i in range(2): + proposed.insert(Z(i), sample.at(Z(i))) + print(proposed) + weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed) + print(weight) + + fg = self.factor_graph_from_bayes_net(bayesNet, sample) + self.assertEqual(fg.size(), 3) + def test_ratio(self): """ Given a tiny two variable hybrid model, with 2 measurements, @@ -156,20 +177,7 @@ def test_ratio(self): sample: gtsam.HybridValues = bayesNet.sample() # print(sample) - # Create a factor graph from the Bayes net with sampled measurements. - # The factor graph is `P(x)P(n) ϕ(x, n; z1) ϕ(x, n; z2)` - # and thus represents the same joint probability as the Bayes net. - fg = HybridGaussianFactorGraph() - for i in range(2): - conditional = bayesNet.atMixture(i) - measurement = gtsam.VectorValues() - measurement.insert(Z(i), sample.at(Z(i))) - factor = conditional.likelihood(measurement) - fg.push_back(factor) - fg.push_back(bayesNet.atGaussian(2)) - fg.push_back(bayesNet.atDiscrete(3)) - - # print(fg) + fg = self.factor_graph_from_bayes_net(bayesNet, sample) self.assertEqual(fg.size(), 4) # Calculate ratio between Bayes net probability and the factor graph: @@ -186,9 +194,9 @@ def test_ratio(self): other = bayesNet.sample() other.update(measurements) # print(other) - # ratio = self.calculate_ratio(bayesNet, fg, other) + ratio = self.calculate_ratio(bayesNet, fg, other) # print(f"Ratio: {ratio}\n") - # self.assertAlmostEqual(ratio, expected_ratio) + self.assertAlmostEqual(ratio, expected_ratio) if __name__ == "__main__": From f22ada6c0a79f0b8465100b46259e22e9c08716f Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 13:16:12 -0500 Subject: [PATCH 08/16] Added importance sampling --- python/gtsam/tests/test_HybridFactorGraph.py | 65 +++++++++++++------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 2d3513f121..77a0e81735 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -18,7 +18,7 @@ import gtsam from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, - GaussianMixture, GaussianMixtureFactor, + GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues, HybridGaussianFactorGraph, JacobianFactor, Ordering, noiseModel) @@ -82,13 +82,13 @@ def test_optimize(self): self.assertEqual(hv.atDiscrete(C(0)), 1) @staticmethod - def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet: + def tiny(num_measurements: int = 1) -> HybridBayesNet: """ Create a tiny two variable hybrid model which represents the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). """ # Create hybrid Bayes net. - bayesNet = gtsam.HybridBayesNet() + bayesNet = HybridBayesNet() # Create mode key: 0 is low-noise, 1 is high-noise. mode = (M(0), 2) @@ -119,7 +119,7 @@ def tiny(num_measurements: int = 1) -> gtsam.HybridBayesNet: return bayesNet @staticmethod - def factor_graph_from_bayes_net(bayesNet: gtsam.HybridBayesNet, sample: gtsam.HybridValues): + def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues): """Create a factor graph from the Bayes net with sampled measurements. The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...` and thus represents the same joint probability as the Bayes net. @@ -137,12 +137,34 @@ def factor_graph_from_bayes_net(bayesNet: gtsam.HybridBayesNet, sample: gtsam.Hy return fg @staticmethod - def calculate_ratio(bayesNet, fg, sample): + def calculate_ratio(bayesNet: HybridBayesNet, + fg: HybridGaussianFactorGraph, + sample: HybridValues): """Calculate ratio between Bayes net probability and the factor graph.""" - continuous = gtsam.VectorValues() - continuous.insert(X(0), sample.at(X(0))) - return bayesNet.evaluate(sample) / fg.probPrime( - continuous, sample.discrete()) + return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0 + + @classmethod + def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=1000): + """Do importance sampling to get an estimate of the discrete marginal P(mode).""" + # Use prior on x0, mode as proposal density. + prior = cls.tiny(num_measurements=0) # just P(x0)P(mode) + + # Allocate space for marginals. + marginals = np.zeros((2,)) + + # Do importance sampling. + num_measurements = bayesNet.size() - 2 + for s in range(N): + proposed = prior.sample() + for i in range(num_measurements): + z_i = sample.at(Z(i)) + proposed.insert(Z(i), z_i) + weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed) + marginals[proposed.atDiscrete(M(0))] += weight + + # print marginals: + marginals /= marginals.sum() + return marginals def test_tiny(self): """Test a tiny two variable hybrid model.""" @@ -150,16 +172,11 @@ def test_tiny(self): sample = bayesNet.sample() # print(sample) - # TODO(dellaert): do importance sampling to get an estimate P(mode) - prior = self.tiny(num_measurements=0) # just P(x0)P(mode) - for s in range(100): - proposed = prior.sample() - print(proposed) - for i in range(2): - proposed.insert(Z(i), sample.at(Z(i))) - print(proposed) - weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed) - print(weight) + # Estimate marginals using importance sampling. + marginals = self.estimate_marginals(bayesNet, sample) + print(f"True mode: {sample.atDiscrete(M(0))}") + print(f"P(mode=0; z0) = {marginals[0]}") + print(f"P(mode=1; z0) = {marginals[1]}") fg = self.factor_graph_from_bayes_net(bayesNet, sample) self.assertEqual(fg.size(), 3) @@ -174,9 +191,15 @@ def test_ratio(self): # Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n) bayesNet = self.tiny(num_measurements=2) # Sample from the Bayes net. - sample: gtsam.HybridValues = bayesNet.sample() + sample: HybridValues = bayesNet.sample() # print(sample) + # Estimate marginals using importance sampling. + marginals = self.estimate_marginals(bayesNet, sample) + print(f"True mode: {sample.atDiscrete(M(0))}") + print(f"P(mode=0; z0, z1) = {marginals[0]}") + print(f"P(mode=1; z0, z1) = {marginals[1]}") + fg = self.factor_graph_from_bayes_net(bayesNet, sample) self.assertEqual(fg.size(), 4) @@ -196,7 +219,7 @@ def test_ratio(self): # print(other) ratio = self.calculate_ratio(bayesNet, fg, other) # print(f"Ratio: {ratio}\n") - self.assertAlmostEqual(ratio, expected_ratio) + # self.assertAlmostEqual(ratio, expected_ratio) if __name__ == "__main__": From ff51f4221e47b6c2b3ced2e6e8e563e0d49f3e84 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 13:58:54 -0500 Subject: [PATCH 09/16] normalization constant --- gtsam/linear/GaussianConditional.cpp | 28 ++++++++++++++++------------ gtsam/linear/GaussianConditional.h | 15 ++++++++++++++- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 7cdff914f2..ecfa022825 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -168,26 +168,30 @@ namespace gtsam { /* ************************************************************************* */ double GaussianConditional::logDeterminant() const { - double logDet; - if (this->get_model()) { - Vector diag = this->R().diagonal(); - this->get_model()->whitenInPlace(diag); - logDet = diag.unaryExpr([](double x) { return log(x); }).sum(); + if (get_model()) { + Vector diag = R().diagonal(); + get_model()->whitenInPlace(diag); + return diag.unaryExpr([](double x) { return log(x); }).sum(); } else { - logDet = - this->R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); + return R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); } - return logDet; } /* ************************************************************************* */ -// density = exp(-error(x)) / sqrt((2*pi)^n*det(Sigma)) -// log = -error(x) - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) -double GaussianConditional::logDensity(const VectorValues& x) const { +// normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) +// log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) +double GaussianConditional::logNormalizationConstant() const { constexpr double log2pi = 1.8378770664093454835606594728112; size_t n = d().size(); // log det(Sigma)) = - 2.0 * logDeterminant() - return - error(x) - 0.5 * n * log2pi + logDeterminant(); + return - 0.5 * n * log2pi + logDeterminant(); +} + +/* ************************************************************************* */ +// density = k exp(-error(x)) +// log = log(k) -error(x) - 0.5 * n*log(2*pi) +double GaussianConditional::logDensity(const VectorValues& x) const { + return logNormalizationConstant() - error(x); } /* ************************************************************************* */ diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index af1c5d80e5..d25efb2e1a 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -169,7 +169,7 @@ namespace gtsam { * * @return double */ - double determinant() const { return exp(this->logDeterminant()); } + inline double determinant() const { return exp(logDeterminant()); } /** * @brief Compute the log determinant of the R matrix. @@ -184,6 +184,19 @@ namespace gtsam { */ double logDeterminant() const; + /** + * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) + * log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) + */ + double logNormalizationConstant() const; + + /** + * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) + */ + inline double normalizationConstant() const { + return exp(logNormalizationConstant()); + } + /** * Solves a conditional Gaussian and writes the solution into the entries of * \c x for each frontal variable of the conditional. The parents are From 395ffad97970133beac9daa823c1fccf17562fe6 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 13:59:08 -0500 Subject: [PATCH 10/16] Fixed likelihood --- gtsam/hybrid/GaussianMixture.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 05864a6e4d..65c0e85229 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -153,7 +153,7 @@ boost::shared_ptr GaussianMixture::likelihood( conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { return GaussianMixtureFactor::FactorAndConstant{ conditional->likelihood(frontals), - 0.5 * conditional->logDeterminant()}; + conditional->logNormalizationConstant()}; }); return boost::make_shared( continuousParentKeys, discreteParentKeys, likelihoods); From 3a8220c26471daee80af1292e87529794564dd0e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 13:59:20 -0500 Subject: [PATCH 11/16] Fixed error calculation --- gtsam/hybrid/GaussianMixtureFactor.cpp | 3 ++- gtsam/hybrid/GaussianMixtureFactor.h | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index e07b300faa..e603687171 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -125,7 +125,8 @@ AlgebraicDecisionTree GaussianMixtureFactor::error( /* *******************************************************************************/ double GaussianMixtureFactor::error(const HybridValues &values) const { const FactorAndConstant factor_z = factors_(values.discrete()); - return factor_z.factor->error(values.continuous()) + factor_z.constant; + return factor_z.error(values.continuous()); } +/* *******************************************************************************/ } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index aca4f365b3..a96f253ce9 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -58,9 +58,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { sharedFactor factor; double constant; - // Return error with constant added. + // Return error with constant correction. double error(const VectorValues &values) const { - return factor->error(values) + constant; + // Note minus sign: constant is log of normalization constant for probabilities. + // Errors is the negative log-likelihood, hence we subtract the constant here. + return factor->error(values) - constant; } // Check pointer equality. From bcf8a9ddfdc771552e249229620faaab270d794a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 13:59:48 -0500 Subject: [PATCH 12/16] Fix tests once more --- gtsam/hybrid/tests/testGaussianMixture.cpp | 5 +++-- gtsam/hybrid/tests/testGaussianMixtureFactor.cpp | 4 +++- gtsam/hybrid/tests/testHybridBayesNet.cpp | 10 +++++----- gtsam/hybrid/tests/testHybridEstimation.cpp | 6 +++--- gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp | 9 ++++----- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 22c2a46219..ff8edd46e7 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -179,8 +179,9 @@ TEST(GaussianMixture, Likelihood) { const GaussianMixtureFactor::Factors factors( gm.conditionals(), [measurements](const GaussianConditional::shared_ptr& conditional) { - return std::make_pair(conditional->likelihood(measurements), - 0.5 * conditional->logDeterminant()); + return GaussianMixtureFactor::FactorAndConstant{ + conditional->likelihood(measurements), + conditional->logNormalizationConstant()}; }); const GaussianMixtureFactor expected({X(0)}, {mode}, factors); EXPECT(assert_equal(*factor, expected)); diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index ee4d4469be..d17968a3a1 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -188,7 +189,8 @@ TEST(GaussianMixtureFactor, Error) { DiscreteValues discreteValues; discreteValues[m1.first] = 1; EXPECT_DOUBLES_EQUAL( - 4.0, mixtureFactor.error({continuousValues, discreteValues}), 1e-9); + 4.0, mixtureFactor.error({continuousValues, discreteValues}), + 1e-9); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 43cee6f74f..58230cfdea 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -188,14 +188,14 @@ TEST(HybridBayesNet, Optimize) { HybridValues delta = hybridBayesNet->optimize(); - //TODO(Varun) The expectedAssignment should be 111, not 101 + // TODO(Varun) The expectedAssignment should be 111, not 101 DiscreteValues expectedAssignment; expectedAssignment[M(0)] = 1; expectedAssignment[M(1)] = 0; expectedAssignment[M(2)] = 1; EXPECT(assert_equal(expectedAssignment, delta.discrete())); - //TODO(Varun) This should be all -Vector1::Ones() + // TODO(Varun) This should be all -Vector1::Ones() VectorValues expectedValues; expectedValues.insert(X(0), -0.999904 * Vector1::Ones()); expectedValues.insert(X(1), -0.99029 * Vector1::Ones()); @@ -243,8 +243,8 @@ TEST(HybridBayesNet, Error) { double total_error = 0; for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { if (hybridBayesNet->at(idx)->isHybrid()) { - double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(), - discrete_values); + double error = hybridBayesNet->atMixture(idx)->error( + {delta.continuous(), discrete_values}); total_error += error; } else if (hybridBayesNet->at(idx)->isContinuous()) { double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous()); @@ -253,7 +253,7 @@ TEST(HybridBayesNet, Error) { } EXPECT_DOUBLES_EQUAL( - total_error, hybridBayesNet->error(delta.continuous(), discrete_values), + total_error, hybridBayesNet->error({delta.continuous(), discrete_values}), 1e-9); EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9); diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index 927f5c0472..660cb3317b 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -273,7 +273,7 @@ AlgebraicDecisionTree getProbPrimeTree( continue; } - double error = graph.error(delta, assignment); + double error = graph.error({delta, assignment}); probPrimes.push_back(exp(-error)); } AlgebraicDecisionTree probPrimeTree(discrete_keys, probPrimes); @@ -487,8 +487,8 @@ TEST(HybridEstimation, CorrectnessViaSampling) { const HybridValues& sample) -> double { const DiscreteValues assignment = sample.discrete(); // Compute in log form for numerical stability - double log_ratio = bayesNet->error(sample.continuous(), assignment) - - factorGraph->error(sample.continuous(), assignment); + double log_ratio = bayesNet->error({sample.continuous(), assignment}) - + factorGraph->error({sample.continuous(), assignment}); double ratio = exp(-log_ratio); return ratio; }; diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 565c7f0a02..f97883f61e 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -575,15 +575,14 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) { HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential(hybridOrdering); - HybridValues delta = hybridBayesNet->optimize(); - double error = graph.error(delta.continuous(), delta.discrete()); + const HybridValues delta = hybridBayesNet->optimize(); + const double error = graph.error(delta); - double expected_error = 0.490243199; // regression - EXPECT(assert_equal(expected_error, error, 1e-9)); + EXPECT(assert_equal(0.490243199, error, 1e-9)); double probs = exp(-error); - double expected_probs = graph.probPrime(delta.continuous(), delta.discrete()); + double expected_probs = graph.probPrime(delta); // regression EXPECT(assert_equal(expected_probs, probs, 1e-7)); From 96b6895a600632de596d7b9bd1b8a7b18c6f3711 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 13:59:59 -0500 Subject: [PATCH 13/16] Ratios now work out! --- python/gtsam/tests/test_HybridFactorGraph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 77a0e81735..5668a85467 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -114,7 +114,7 @@ def tiny(num_measurements: int = 1) -> HybridBayesNet: bayesNet.addGaussian(prior_on_x0) # Add prior on mode. - bayesNet.emplaceDiscrete(mode, "1/1") + bayesNet.emplaceDiscrete(mode, "6/4") return bayesNet @@ -216,10 +216,10 @@ def test_ratio(self): for i in range(10): other = bayesNet.sample() other.update(measurements) - # print(other) ratio = self.calculate_ratio(bayesNet, fg, other) # print(f"Ratio: {ratio}\n") - # self.assertAlmostEqual(ratio, expected_ratio) + if (ratio > 0): + self.assertAlmostEqual(ratio, expected_ratio) if __name__ == "__main__": From b83cd0ca86645db713d358bb71ed2432e47d2d49 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 15:16:13 -0500 Subject: [PATCH 14/16] Make a virtual error method --- gtsam/hybrid/GaussianMixture.h | 2 +- gtsam/hybrid/GaussianMixtureFactor.h | 2 +- gtsam/hybrid/HybridConditional.h | 49 ++++++++++++++-------- gtsam/hybrid/HybridDiscreteFactor.cpp | 7 ++++ gtsam/hybrid/HybridDiscreteFactor.h | 14 +++++-- gtsam/hybrid/HybridFactor.h | 11 +++++ gtsam/hybrid/HybridGaussianFactor.cpp | 7 ++++ gtsam/hybrid/HybridGaussianFactor.h | 8 ++++ gtsam/hybrid/HybridGaussianFactorGraph.cpp | 22 +--------- gtsam/hybrid/HybridNonlinearFactor.h | 10 +++++ gtsam/hybrid/MixtureFactor.h | 6 +++ 11 files changed, 96 insertions(+), 42 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 4df1bd90ca..a9b05f2504 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -170,7 +170,7 @@ class GTSAM_EXPORT GaussianMixture * @param values Continuous values and discrete assignment. * @return double */ - double error(const HybridValues &values) const; + double error(const HybridValues &values) const override; /** * @brief Prune the decision tree of Gaussian factors as per the discrete diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index a96f253ce9..ce011fecc6 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -165,7 +165,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @brief Compute the log-likelihood, including the log-normalizing constant. * @return double */ - double error(const HybridValues &values) const; + double error(const HybridValues &values) const override; /// Add MixtureFactor to a Sum, syntactic sugar. friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index db03ba59c4..be671d55f3 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -52,7 +52,7 @@ namespace gtsam { * having diamond inheritances, and neutralized the need to change other * components of GTSAM to make hybrid elimination work. * - * A great reference to the type-erasure pattern is Eduaado Madrid's CppCon + * A great reference to the type-erasure pattern is Eduardo Madrid's CppCon * talk (https://www.youtube.com/watch?v=s082Qmd_nHs). * * @ingroup hybrid @@ -129,12 +129,28 @@ class GTSAM_EXPORT HybridConditional */ HybridConditional(boost::shared_ptr gaussianMixture); + /// @} + /// @name Testable + /// @{ + + /// GTSAM-style print + void print( + const std::string& s = "Hybrid Conditional: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// GTSAM-style equals + bool equals(const HybridFactor& other, double tol = 1e-9) const override; + + /// @} + /// @name Standard Interface + /// @{ + /** * @brief Return HybridConditional as a GaussianMixture * @return nullptr if not a mixture * @return GaussianMixture::shared_ptr otherwise */ - GaussianMixture::shared_ptr asMixture() { + GaussianMixture::shared_ptr asMixture() const { return boost::dynamic_pointer_cast(inner_); } @@ -143,7 +159,7 @@ class GTSAM_EXPORT HybridConditional * @return nullptr if not a GaussianConditional * @return GaussianConditional::shared_ptr otherwise */ - GaussianConditional::shared_ptr asGaussian() { + GaussianConditional::shared_ptr asGaussian() const { return boost::dynamic_pointer_cast(inner_); } @@ -152,27 +168,26 @@ class GTSAM_EXPORT HybridConditional * @return nullptr if not a DiscreteConditional * @return DiscreteConditional::shared_ptr */ - DiscreteConditional::shared_ptr asDiscrete() { + DiscreteConditional::shared_ptr asDiscrete() const { return boost::dynamic_pointer_cast(inner_); } - /// @} - /// @name Testable - /// @{ - - /// GTSAM-style print - void print( - const std::string& s = "Hybrid Conditional: ", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; + /// Get the type-erased pointer to the inner type + boost::shared_ptr inner() { return inner_; } - /// GTSAM-style equals - bool equals(const HybridFactor& other, double tol = 1e-9) const override; + /// Return the error of the underlying conditional. + /// Currently only implemented for Gaussian mixture. + double error(const HybridValues& values) const override { + if (auto gm = asMixture()) { + return gm->error(values); + } else { + throw std::runtime_error( + "HybridConditional::error: only implemented for Gaussian mixture"); + } + } /// @} - /// Get the type-erased pointer to the inner type - boost::shared_ptr inner() { return inner_; } - private: /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/hybrid/HybridDiscreteFactor.cpp b/gtsam/hybrid/HybridDiscreteFactor.cpp index 0455e1e904..605ea5738b 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.cpp +++ b/gtsam/hybrid/HybridDiscreteFactor.cpp @@ -17,6 +17,7 @@ */ #include +#include #include @@ -50,4 +51,10 @@ void HybridDiscreteFactor::print(const std::string &s, inner_->print("\n", formatter); }; +/* ************************************************************************ */ +double HybridDiscreteFactor::error(const HybridValues &values) const { + return -log((*inner_)(values.discrete())); +} +/* ************************************************************************ */ + } // namespace gtsam diff --git a/gtsam/hybrid/HybridDiscreteFactor.h b/gtsam/hybrid/HybridDiscreteFactor.h index 015dc46f8c..6e914d38b6 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.h +++ b/gtsam/hybrid/HybridDiscreteFactor.h @@ -24,10 +24,12 @@ namespace gtsam { +class HybridValues; + /** - * A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows - * us to hide the implementation of DiscreteFactor and thus avoid diamond - * inheritance. + * A HybridDiscreteFactor is a thin container for DiscreteFactor, which + * allows us to hide the implementation of DiscreteFactor and thus avoid + * diamond inheritance. * * @ingroup hybrid */ @@ -59,9 +61,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor { const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard Interface + /// @{ /// Return pointer to the internal discrete factor DiscreteFactor::shared_ptr inner() const { return inner_; } + + /// Return the error of the underlying Discrete Factor. + double error(const HybridValues &values) const override; + /// @} }; // traits diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index e0cae55c1e..a28fee8ed8 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -26,6 +26,8 @@ #include namespace gtsam { +class HybridValues; + KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys); KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); @@ -110,6 +112,15 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// @name Standard Interface /// @{ + /** + * @brief Compute the error of this Gaussian Mixture given the continuous + * values and a discrete assignment. + * + * @param values Continuous values and discrete assignment. + * @return double + */ + virtual double error(const HybridValues &values) const = 0; + /// True if this is a factor of discrete variables only. bool isDiscrete() const { return isDiscrete_; } diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index ba0c0bf1af..5a89a04a8d 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -16,6 +16,7 @@ */ #include +#include #include #include @@ -54,4 +55,10 @@ void HybridGaussianFactor::print(const std::string &s, inner_->print("\n", formatter); }; +/* ************************************************************************ */ +double HybridGaussianFactor::error(const HybridValues &values) const { + return inner_->error(values.continuous()); +} +/* ************************************************************************ */ + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 966524b812..897da9caa9 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -25,6 +25,7 @@ namespace gtsam { // Forward declarations class JacobianFactor; class HessianFactor; +class HybridValues; /** * A HybridGaussianFactor is a layer over GaussianFactor so that we do not have @@ -92,8 +93,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard Interface + /// @{ + /// Return pointer to the internal discrete factor GaussianFactor::shared_ptr inner() const { return inner_; } + + /// Return the error of the underlying Discrete Factor. + double error(const HybridValues &values) const override; + /// @} }; // traits diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 5ea677ab54..6af0fb1a90 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -498,26 +498,8 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( /* ************************************************************************ */ double HybridGaussianFactorGraph::error(const HybridValues &values) const { double error = 0.0; - for (size_t idx = 0; idx < size(); idx++) { - // TODO(dellaert): just use a virtual method defined in HybridFactor. - auto factor = factors_.at(idx); - - if (factor->isHybrid()) { - if (auto c = boost::dynamic_pointer_cast(factor)) { - error += c->asMixture()->error(values); - } - if (auto f = boost::dynamic_pointer_cast(factor)) { - error += f->error(values); - } - - } else if (factor->isContinuous()) { - if (auto f = boost::dynamic_pointer_cast(factor)) { - error += f->inner()->error(values.continuous()); - } - if (auto cg = boost::dynamic_pointer_cast(factor)) { - error += cg->asGaussian()->error(values.continuous()); - } - } + for (auto &factor : factors_) { + error += factor->error(values); } return error; } diff --git a/gtsam/hybrid/HybridNonlinearFactor.h b/gtsam/hybrid/HybridNonlinearFactor.h index 7776347b36..9b3e780ef1 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.h +++ b/gtsam/hybrid/HybridNonlinearFactor.h @@ -51,12 +51,22 @@ class HybridNonlinearFactor : public HybridFactor { const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard Interface + /// @{ NonlinearFactor::shared_ptr inner() const { return inner_; } + /// Error for HybridValues is not provided for nonlinear factor. + double error(const HybridValues &values) const override { + throw std::runtime_error( + "HybridNonlinearFactor::error(HybridValues) not implemented."); + } + /// Linearize to a HybridGaussianFactor at the linearization point `c`. boost::shared_ptr linearize(const Values &c) const { return boost::make_shared(inner_->linearize(c)); } + + /// @} }; } // namespace gtsam diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index f29a840229..fc1a9a2b83 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -161,6 +161,12 @@ class MixtureFactor : public HybridFactor { factor, continuousValues); } + /// Error for HybridValues is not provided for nonlinear hybrid factor. + double error(const HybridValues &values) const override { + throw std::runtime_error( + "MixtureFactor::error(HybridValues) not implemented."); + } + size_t dim() const { // TODO(Varun) throw std::runtime_error("MixtureFactor::dim not implemented."); From b798f3ebb5b908e0f3f86b12b851e4d6c62f7bb9 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 15:19:46 -0500 Subject: [PATCH 15/16] Fix regression test --- gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index f97883f61e..1bdb6d4db1 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -579,13 +579,10 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) { const double error = graph.error(delta); // regression - EXPECT(assert_equal(0.490243199, error, 1e-9)); + EXPECT(assert_equal(1.58886, error, 1e-5)); - double probs = exp(-error); - double expected_probs = graph.probPrime(delta); - - // regression - EXPECT(assert_equal(expected_probs, probs, 1e-7)); + // Real test: + EXPECT(assert_equal(graph.probPrime(delta), exp(-error), 1e-7)); } /* ****************************************************************************/ From cec26d16eabe447775be7d7a3a04bfe1d460a293 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 30 Dec 2022 15:20:10 -0500 Subject: [PATCH 16/16] Check marginals in addition to ratios for non-uniform mode prior --- python/gtsam/tests/test_HybridFactorGraph.py | 40 ++++++++++++-------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 5668a85467..5398160dce 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -114,7 +114,7 @@ def tiny(num_measurements: int = 1) -> HybridBayesNet: bayesNet.addGaussian(prior_on_x0) # Add prior on mode. - bayesNet.emplaceDiscrete(mode, "6/4") + bayesNet.emplaceDiscrete(mode, "4/6") return bayesNet @@ -136,15 +136,8 @@ def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues): fg.push_back(bayesNet.atDiscrete(num_measurements+1)) return fg - @staticmethod - def calculate_ratio(bayesNet: HybridBayesNet, - fg: HybridGaussianFactorGraph, - sample: HybridValues): - """Calculate ratio between Bayes net probability and the factor graph.""" - return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0 - @classmethod - def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=1000): + def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000): """Do importance sampling to get an estimate of the discrete marginal P(mode).""" # Use prior on x0, mode as proposal density. prior = cls.tiny(num_measurements=0) # just P(x0)P(mode) @@ -174,13 +167,24 @@ def test_tiny(self): # Estimate marginals using importance sampling. marginals = self.estimate_marginals(bayesNet, sample) - print(f"True mode: {sample.atDiscrete(M(0))}") - print(f"P(mode=0; z0) = {marginals[0]}") - print(f"P(mode=1; z0) = {marginals[1]}") + # print(f"True mode: {sample.atDiscrete(M(0))}") + # print(f"P(mode=0; z0) = {marginals[0]}") + # print(f"P(mode=1; z0) = {marginals[1]}") + + # Check that the estimate is close to the true value. + self.assertAlmostEqual(marginals[0], 0.4, delta=0.1) + self.assertAlmostEqual(marginals[1], 0.6, delta=0.1) fg = self.factor_graph_from_bayes_net(bayesNet, sample) self.assertEqual(fg.size(), 3) + @staticmethod + def calculate_ratio(bayesNet: HybridBayesNet, + fg: HybridGaussianFactorGraph, + sample: HybridValues): + """Calculate ratio between Bayes net probability and the factor graph.""" + return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0 + def test_ratio(self): """ Given a tiny two variable hybrid model, with 2 measurements, @@ -196,9 +200,15 @@ def test_ratio(self): # Estimate marginals using importance sampling. marginals = self.estimate_marginals(bayesNet, sample) - print(f"True mode: {sample.atDiscrete(M(0))}") - print(f"P(mode=0; z0, z1) = {marginals[0]}") - print(f"P(mode=1; z0, z1) = {marginals[1]}") + # print(f"True mode: {sample.atDiscrete(M(0))}") + # print(f"P(mode=0; z0, z1) = {marginals[0]}") + # print(f"P(mode=1; z0, z1) = {marginals[1]}") + + # Check marginals based on sampled mode. + if sample.atDiscrete(M(0)) == 0: + self.assertGreater(marginals[0], marginals[1]) + else: + self.assertGreater(marginals[1], marginals[0]) fg = self.factor_graph_from_bayes_net(bayesNet, sample) self.assertEqual(fg.size(), 4)