diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 25e18cfaf2..b3f0d69b0e 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -71,7 +71,7 @@ namespace gtsam { static inline double id(const double& x) { return x; } }; - AlgebraicDecisionTree() : Base(1.0) {} + AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {} // Explicitly non-explicit constructor AlgebraicDecisionTree(const Base& add) : Base(add) {} @@ -158,9 +158,9 @@ namespace gtsam { } /// print method customized to value type `double`. - void print(const std::string& s, - const typename Base::LabelFormatter& labelFormatter = - &DefaultFormatter) const { + void print(const std::string& s = "", + const typename Base::LabelFormatter& labelFormatter = + &DefaultFormatter) const { auto valueFormatter = [](const double& v) { return (boost::format("%4.8g") % v).str(); }; diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index fa0cc189b6..c0815b2d7a 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -85,8 +85,8 @@ size_t GaussianMixture::nrComponents() const { /* *******************************************************************************/ GaussianConditional::shared_ptr GaussianMixture::operator()( - const DiscreteValues &discreteVals) const { - auto &ptr = conditionals_(discreteVals); + const DiscreteValues &discreteValues) const { + auto &ptr = conditionals_(discreteValues); if (!ptr) return nullptr; auto conditional = boost::dynamic_pointer_cast(ptr); if (conditional) @@ -207,4 +207,30 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { conditionals_.root_ = pruned_conditionals.root_; } +/* *******************************************************************************/ +AlgebraicDecisionTree GaussianMixture::error( + const VectorValues &continuousValues) const { + // functor to calculate to double error value from GaussianConditional. + auto errorFunc = + [continuousValues](const GaussianConditional::shared_ptr &conditional) { + if (conditional) { + return conditional->error(continuousValues); + } else { + // Return arbitrarily large error if conditional is null + // Conditional is null if it is pruned out. + return 1e50; + } + }; + DecisionTree errorTree(conditionals_, errorFunc); + return errorTree; +} + +/* *******************************************************************************/ +double GaussianMixture::error(const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const { + // Directly index to get the conditional, no need to build the whole tree. + auto conditional = conditionals_(discreteValues); + return conditional->error(continuousValues); +} + } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index dada35ae4b..88d5a02c0c 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -122,7 +122,7 @@ class GTSAM_EXPORT GaussianMixture /// @{ GaussianConditional::shared_ptr operator()( - const DiscreteValues &discreteVals) const; + const DiscreteValues &discreteValues) const; /// Returns the total number of continuous components size_t nrComponents() const; @@ -144,6 +144,26 @@ class GTSAM_EXPORT GaussianMixture /// Getter for the underlying Conditionals DecisionTree const Conditionals &conditionals(); + /** + * @brief Compute error of the GaussianMixture as a tree. + * + * @param continuousValues The continuous VectorValues. + * @return AlgebraicDecisionTree A decision tree with the same keys + * as the conditionals, and leaf values as the error. + */ + AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + + /** + * @brief Compute the error of this Gaussian Mixture given the continuous + * values and a discrete assignment. + * + * @param continuousValues Continuous values at which to compute the error. + * @param discreteValues The discrete assignment for a specific mode sequence. + * @return double + */ + double error(const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const; + /** * @brief Prune the decision tree of Gaussian factors as per the discrete * `decisionTree`. diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 181b1e6a5a..fd437f52c0 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -95,4 +95,26 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() }; return {factors_, wrap}; } + +/* *******************************************************************************/ +AlgebraicDecisionTree GaussianMixtureFactor::error( + const VectorValues &continuousValues) const { + // functor to convert from sharedFactor to double error value. + auto errorFunc = + [continuousValues](const GaussianFactor::shared_ptr &factor) { + return factor->error(continuousValues); + }; + DecisionTree errorTree(factors_, errorFunc); + return errorTree; +} + +/* *******************************************************************************/ +double GaussianMixtureFactor::error( + const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const { + // Directly index to get the conditional, no need to build the whole tree. + auto factor = factors_(discreteValues); + return factor->error(continuousValues); +} + } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 0d649f8576..0b65b5aa93 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -20,15 +20,19 @@ #pragma once +#include #include #include +#include #include #include +#include namespace gtsam { class GaussianFactorGraph; +// Needed for wrapper. using GaussianFactorVector = std::vector; /** @@ -126,6 +130,26 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { */ Sum add(const Sum &sum) const; + /** + * @brief Compute error of the GaussianMixtureFactor as a tree. + * + * @param continuousValues The continuous VectorValues. + * @return AlgebraicDecisionTree A decision tree with the same keys + * as the factors involved, and leaf values as the error. + */ + AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + + /** + * @brief Compute the error of this Gaussian Mixture given the continuous + * values and a discrete assignment. + * + * @param continuousValues Continuous values at which to compute the error. + * @param discreteValues The discrete assignment for a specific mode sequence. + * @return double + */ + double error(const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const; + /// Add MixtureFactor to a Sum, syntactic sugar. friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { sum = factor.add(sum); diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index a641363846..48c4b6d508 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -232,4 +232,56 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { return gbn.optimize(); } +/* ************************************************************************* */ +double HybridBayesNet::error(const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const { + GaussianBayesNet gbn = this->choose(discreteValues); + return gbn.error(continuousValues); +} + +/* ************************************************************************* */ +AlgebraicDecisionTree HybridBayesNet::error( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree error_tree; + + // Iterate over each factor. + for (size_t idx = 0; idx < size(); idx++) { + AlgebraicDecisionTree conditional_error; + + if (factors_.at(idx)->isHybrid()) { + // If factor is hybrid, select based on assignment and compute error. + GaussianMixture::shared_ptr gm = this->atMixture(idx); + conditional_error = gm->error(continuousValues); + + // Assign for the first index, add error for subsequent ones. + if (idx == 0) { + error_tree = conditional_error; + } else { + error_tree = error_tree + conditional_error; + } + + } else if (factors_.at(idx)->isContinuous()) { + // If continuous only, get the (double) error + // and add it to the error_tree + double error = this->atGaussian(idx)->error(continuousValues); + // Add the computed error to every leaf of the error tree. + error_tree = error_tree.apply( + [error](double leaf_value) { return leaf_value + error; }); + + } else if (factors_.at(idx)->isDiscrete()) { + // If factor at `idx` is discrete-only, we skip. + continue; + } + } + + return error_tree; +} + +/* ************************************************************************* */ +AlgebraicDecisionTree HybridBayesNet::probPrime( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree error_tree = this->error(continuousValues); + return error_tree.apply([](double error) { return exp(-error); }); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 84c1da6ea8..f8ec609119 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -124,6 +124,39 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. HybridBayesNet prune(size_t maxNrLeaves); + /** + * @brief 0.5 * sum of squared Mahalanobis distances + * for a specific discrete assignment. + * + * @param continuousValues Continuous values at which to compute the error. + * @param discreteValues Discrete assignment for a specific mode sequence. + * @return double + */ + double error(const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const; + + /** + * @brief Compute conditional error for each discrete assignment, + * and return as a tree. + * + * @param continuousValues Continuous values at which to compute the error. + * @return AlgebraicDecisionTree + */ + AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + + /** + * @brief Compute unnormalized probability q(μ|M), + * for each discrete assignment, and return as a tree. + * q(μ|M) is the unnormalized probability at the MLE point μ, + * conditioned on the discrete variables. + * + * @param continuousValues Continuous values at which to compute the + * probability. + * @return AlgebraicDecisionTree + */ + AlgebraicDecisionTree probPrime( + const VectorValues &continuousValues) const; + /// @} private: diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 041603fbd5..32653bdecf 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -423,4 +423,58 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { return ordering; } +/* ************************************************************************ */ +AlgebraicDecisionTree HybridGaussianFactorGraph::error( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree error_tree(0.0); + + // Iterate over each factor. + for (size_t idx = 0; idx < size(); idx++) { + AlgebraicDecisionTree factor_error; + + if (factors_.at(idx)->isHybrid()) { + // If factor is hybrid, select based on assignment. + GaussianMixtureFactor::shared_ptr gaussianMixture = + boost::static_pointer_cast(factors_.at(idx)); + // Compute factor error. + factor_error = gaussianMixture->error(continuousValues); + + // If first factor, assign error, else add it. + if (idx == 0) { + error_tree = factor_error; + } else { + error_tree = error_tree + factor_error; + } + + } else if (factors_.at(idx)->isContinuous()) { + // If continuous only, get the (double) error + // and add it to the error_tree + auto hybridGaussianFactor = + boost::static_pointer_cast(factors_.at(idx)); + GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner(); + + // Compute the error of the gaussian factor. + double error = gaussian->error(continuousValues); + // Add the gaussian factor error to every leaf of the error tree. + error_tree = error_tree.apply( + [error](double leaf_value) { return leaf_value + error; }); + + } else if (factors_.at(idx)->isDiscrete()) { + // If factor at `idx` is discrete-only, we skip. + continue; + } + } + + return error_tree; +} + +/* ************************************************************************ */ +AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree error_tree = this->error(continuousValues); + AlgebraicDecisionTree prob_tree = + error_tree.apply([](double error) { return exp(-error); }); + return prob_tree; +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 6a03625001..ac9ae1a462 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -41,7 +41,7 @@ class JacobianFactor; /** * @brief Main elimination function for HybridGaussianFactorGraph. - * + * * @param factors The factor graph to eliminate. * @param keys The elimination ordering. * @return The conditional on the ordering keys and the remaining factors. @@ -99,11 +99,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph using shared_ptr = boost::shared_ptr; ///< shared_ptr to This using Values = gtsam::Values; ///< backwards compatibility - using Indices = KeyVector; ///> map from keys to values + using Indices = KeyVector; ///< map from keys to values /// @name Constructors /// @{ + /// @brief Default constructor. HybridGaussianFactorGraph() = default; /** @@ -170,6 +171,28 @@ class GTSAM_EXPORT HybridGaussianFactorGraph } } + /** + * @brief Compute error for each discrete assignment, + * and return as a tree. + * + * Error \f$ e = \Vert x - \mu \Vert_{\Sigma} \f$. + * + * @param continuousValues Continuous values at which to compute the error. + * @return AlgebraicDecisionTree + */ + AlgebraicDecisionTree error(const VectorValues& continuousValues) const; + + /** + * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ + * for each discrete assignment, and return as a tree. + * + * @param continuousValues Continuous values at which to compute the + * probability. + * @return AlgebraicDecisionTree + */ + AlgebraicDecisionTree probPrime( + const VectorValues& continuousValues) const; + /** * @brief Return a Colamd constrained ordering where the discrete keys are * eliminated after the continuous keys. diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index 5e7337d0cc..f29a840229 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -86,11 +87,11 @@ class MixtureFactor : public HybridFactor { * elements based on the number of discrete keys and the cardinality of the * keys, so that the decision tree is constructed appropriately. * - * @tparam FACTOR The type of the factor shared pointers being passed in. Will - * be typecast to NonlinearFactor shared pointers. + * @tparam FACTOR The type of the factor shared pointers being passed in. + * Will be typecast to NonlinearFactor shared pointers. * @param keys Vector of keys for continuous factors. * @param discreteKeys Vector of discrete keys. - * @param factors Vector of shared pointers to factors. + * @param factors Vector of nonlinear factors. * @param normalized Flag indicating if the factor error is already * normalized. */ @@ -107,8 +108,12 @@ class MixtureFactor : public HybridFactor { std::copy(f->keys().begin(), f->keys().end(), std::inserter(factor_keys_set, factor_keys_set.end())); - nonlinear_factors.push_back( - boost::dynamic_pointer_cast(f)); + if (auto nf = boost::dynamic_pointer_cast(f)) { + nonlinear_factors.push_back(nf); + } else { + throw std::runtime_error( + "Factors passed into MixtureFactor need to be nonlinear!"); + } } factors_ = Factors(discreteKeys, nonlinear_factors); @@ -121,22 +126,39 @@ class MixtureFactor : public HybridFactor { ~MixtureFactor() = default; + /** + * @brief Compute error of the MixtureFactor as a tree. + * + * @param continuousValues The continuous values for which to compute the + * error. + * @return AlgebraicDecisionTree A decision tree with the same keys + * as the factor, and leaf values as the error. + */ + AlgebraicDecisionTree error(const Values& continuousValues) const { + // functor to convert from sharedFactor to double error value. + auto errorFunc = [continuousValues](const sharedFactor& factor) { + return factor->error(continuousValues); + }; + DecisionTree errorTree(factors_, errorFunc); + return errorTree; + } + /** * @brief Compute error of factor given both continuous and discrete values. * - * @param continuousVals The continuous Values. - * @param discreteVals The discrete Values. + * @param continuousValues The continuous Values. + * @param discreteValues The discrete Values. * @return double The error of this factor. */ - double error(const Values& continuousVals, - const DiscreteValues& discreteVals) const { - // Retrieve the factor corresponding to the assignment in discreteVals. - auto factor = factors_(discreteVals); + double error(const Values& continuousValues, + const DiscreteValues& discreteValues) const { + // Retrieve the factor corresponding to the assignment in discreteValues. + auto factor = factors_(discreteValues); // Compute the error for the selected factor - const double factorError = factor->error(continuousVals); + const double factorError = factor->error(continuousValues); if (normalized_) return factorError; - return factorError + - this->nonlinearFactorLogNormalizingConstant(factor, continuousVals); + return factorError + this->nonlinearFactorLogNormalizingConstant( + factor, continuousValues); } size_t dim() const { @@ -149,7 +171,7 @@ class MixtureFactor : public HybridFactor { /// print to stdout void print( - const std::string& s = "MixtureFactor", + const std::string& s = "", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { std::cout << (s.empty() ? "" : s + " "); Base::print("", keyFormatter); @@ -192,17 +214,18 @@ class MixtureFactor : public HybridFactor { /// Linearize specific nonlinear factors based on the assignment in /// discreteValues. GaussianFactor::shared_ptr linearize( - const Values& continuousVals, const DiscreteValues& discreteVals) const { - auto factor = factors_(discreteVals); - return factor->linearize(continuousVals); + const Values& continuousValues, + const DiscreteValues& discreteValues) const { + auto factor = factors_(discreteValues); + return factor->linearize(continuousValues); } /// Linearize all the continuous factors to get a GaussianMixtureFactor. boost::shared_ptr linearize( - const Values& continuousVals) const { + const Values& continuousValues) const { // functional to linearize each factor in the decision tree - auto linearizeDT = [continuousVals](const sharedFactor& factor) { - return factor->linearize(continuousVals); + auto linearizeDT = [continuousValues](const sharedFactor& factor) { + return factor->linearize(continuousValues); }; DecisionTree linearized_factors( diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 86029a48aa..899c129e00 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -196,22 +196,24 @@ class HybridNonlinearFactorGraph { #include class MixtureFactor : gtsam::HybridFactor { - MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, - const gtsam::DecisionTree& factors, bool normalized = false); + MixtureFactor( + const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, + const gtsam::DecisionTree& factors, + bool normalized = false); template MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, const std::vector& factors, bool normalized = false); - double error(const gtsam::Values& continuousVals, - const gtsam::DiscreteValues& discreteVals) const; + double error(const gtsam::Values& continuousValues, + const gtsam::DiscreteValues& discreteValues) const; double nonlinearFactorLogNormalizingConstant(const gtsam::NonlinearFactor* factor, const gtsam::Values& values) const; GaussianMixtureFactor* linearize( - const gtsam::Values& continuousVals) const; + const gtsam::Values& continuousValues) const; void print(string s = "MixtureFactor\n", const gtsam::KeyFormatter& keyFormatter = diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 420e22315b..310081f028 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -78,15 +78,58 @@ TEST(GaussianMixture, Equals) { GaussianMixture::Conditionals conditionals( {m1}, vector{conditional0, conditional1}); - GaussianMixture mixtureFactor({X(1)}, {X(2)}, {m1}, conditionals); + GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals); // Let's check that this worked: DiscreteValues mode; mode[m1.first] = 1; - auto actual = mixtureFactor(mode); + auto actual = mixture(mode); EXPECT(actual == conditional1); } +/* ************************************************************************* */ +/// Test error method of GaussianMixture. +TEST(GaussianMixture, Error) { + Matrix22 S1 = Matrix22::Identity(); + Matrix22 S2 = Matrix22::Identity() * 2; + Matrix22 R1 = Matrix22::Ones(); + Matrix22 R2 = Matrix22::Ones(); + Vector2 d1(1, 2), d2(2, 1); + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + auto conditional0 = boost::make_shared(X(1), d1, R1, + X(2), S1, model), + conditional1 = boost::make_shared(X(1), d2, R2, + X(2), S2, model); + + // Create decision tree + DiscreteKey m1(M(1), 2); + GaussianMixture::Conditionals conditionals( + {m1}, + vector{conditional0, conditional1}); + GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals); + + VectorValues values; + values.insert(X(1), Vector2::Ones()); + values.insert(X(2), Vector2::Zero()); + auto error_tree = mixture.error(values); + + // regression + std::vector discrete_keys = {m1}; + std::vector leaves = {0.5, 4.3252595}; + AlgebraicDecisionTree expected_error(discrete_keys, leaves); + + EXPECT(assert_equal(expected_error, error_tree, 1e-6)); + + // Regression for non-tree version. + DiscreteValues assignment; + assignment[M(1)] = 0; + EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8); + assignment[M(1)] = 1; + EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), 1e-8); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index cb9068c303..ba0622ff9d 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file GaussianMixtureFactor.cpp + * @file testGaussianMixtureFactor.cpp * @brief Unit tests for GaussianMixtureFactor * @author Varun Agrawal * @author Fan Jiang @@ -135,7 +135,7 @@ TEST(GaussianMixtureFactor, Printing) { EXPECT(assert_print_equal(expected, mixtureFactor)); } -TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) { +TEST(GaussianMixtureFactor, GaussianMixture) { KeyVector keys; keys.push_back(X(0)); keys.push_back(X(1)); @@ -151,6 +151,46 @@ TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) { EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size()); } +/* ************************************************************************* */ +// Test the error of the GaussianMixtureFactor +TEST(GaussianMixtureFactor, Error) { + DiscreteKey m1(1, 2); + + auto A01 = Matrix2::Identity(); + auto A02 = Matrix2::Identity(); + + auto A11 = Matrix2::Identity(); + auto A12 = Matrix2::Identity() * 2; + + auto b = Vector2::Zero(); + + auto f0 = boost::make_shared(X(1), A01, X(2), A02, b); + auto f1 = boost::make_shared(X(1), A11, X(2), A12, b); + std::vector factors{f0, f1}; + + GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); + + VectorValues continuousValues; + continuousValues.insert(X(1), Vector2(0, 0)); + continuousValues.insert(X(2), Vector2(1, 1)); + + // error should return a tree of errors, with nodes for each discrete value. + AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousValues); + + std::vector discrete_keys = {m1}; + // Error values for regression test + std::vector errors = {1, 4}; + AlgebraicDecisionTree expected_error(discrete_keys, errors); + + EXPECT(assert_equal(expected_error, error_tree)); + + // Test for single leaf given discrete assignment P(X|M,Z). + DiscreteValues discreteValues; + discreteValues[m1.first] = 1; + EXPECT_DOUBLES_EQUAL( + 4.0, mixtureFactor.error(continuousValues, discreteValues), 1e-9); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index e1fe724695..050c01aeda 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -183,6 +183,60 @@ TEST(HybridBayesNet, OptimizeMultifrontal) { EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); } +/* ****************************************************************************/ +// Test bayes net error +TEST(HybridBayesNet, Error) { + Switching s(3); + + Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + s.linearizedFactorGraph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + auto error_tree = hybridBayesNet->error(delta.continuous()); + + std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; + std::vector leaves = {0.0097568009, 3.3973404e-31, 0.029126214, + 0.0097568009}; + AlgebraicDecisionTree expected_error(discrete_keys, leaves); + + // regression + EXPECT(assert_equal(expected_error, error_tree, 1e-9)); + + // Error on pruned bayes net + auto prunedBayesNet = hybridBayesNet->prune(2); + auto pruned_error_tree = prunedBayesNet.error(delta.continuous()); + + std::vector pruned_leaves = {2e50, 3.3973404e-31, 2e50, 0.0097568009}; + AlgebraicDecisionTree expected_pruned_error(discrete_keys, + pruned_leaves); + + // regression + EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9)); + + // Verify error computation and check for specific error value + DiscreteValues discrete_values; + boost::assign::insert(discrete_values)(M(0), 1)(M(1), 1); + + double total_error = 0; + for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { + if (hybridBayesNet->at(idx)->isHybrid()) { + double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(), + discrete_values); + total_error += error; + } else if (hybridBayesNet->at(idx)->isContinuous()) { + double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous()); + total_error += error; + } + } + + EXPECT_DOUBLES_EQUAL( + total_error, hybridBayesNet->error(delta.continuous(), discrete_values), + 1e-9); + EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9); + EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9); +} + /* ****************************************************************************/ // Test bayes net pruning TEST(HybridBayesNet, Prune) { diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index ed6b97ab04..7877461b67 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -562,6 +562,36 @@ TEST(HybridGaussianFactorGraph, Conditionals) { EXPECT(assert_equal(expected_discrete, result.discrete())); } +/* ****************************************************************************/ +// Test hybrid gaussian factor graph error and unnormalized probabilities +TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) { + Switching s(3); + + HybridGaussianFactorGraph graph = s.linearizedFactorGraph; + + Ordering hybridOrdering = graph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + graph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + auto error_tree = graph.error(delta.continuous()); + + std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; + std::vector leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568}; + AlgebraicDecisionTree expected_error(discrete_keys, leaves); + + // regression + EXPECT(assert_equal(expected_error, error_tree, 1e-7)); + + auto probs = graph.probPrime(delta.continuous()); + std::vector prob_leaves = {0.36793249, 0.61247742, 0.59489556, + 0.99029064}; + AlgebraicDecisionTree expected_probs(discrete_keys, prob_leaves); + + // regression + EXPECT(assert_equal(expected_probs, probs, 1e-7)); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp new file mode 100644 index 0000000000..fe3212eda5 --- /dev/null +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -0,0 +1,109 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testMixtureFactor.cpp + * @brief Unit tests for MixtureFactor + * @author Varun Agrawal + * @date October 2022 + */ + +#include +#include +#include +#include +#include + +// Include for test suite +#include + +using namespace std; +using namespace gtsam; +using noiseModel::Isotropic; +using symbol_shorthand::M; +using symbol_shorthand::X; + +/* ************************************************************************* */ +// Check iterators of empty mixture. +TEST(MixtureFactor, Constructor) { + MixtureFactor factor; + MixtureFactor::const_iterator const_it = factor.begin(); + CHECK(const_it == factor.end()); + MixtureFactor::iterator it = factor.begin(); + CHECK(it == factor.end()); +} + +/* ************************************************************************* */ +// Test .print() output. +TEST(MixtureFactor, Printing) { + DiscreteKey m1(1, 2); + double between0 = 0.0; + double between1 = 1.0; + + Vector1 sigmas = Vector1(1.0); + auto model = noiseModel::Diagonal::Sigmas(sigmas, false); + + auto f0 = + boost::make_shared>(X(1), X(2), between0, model); + auto f1 = + boost::make_shared>(X(1), X(2), between1, model); + std::vector factors{f0, f1}; + + MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); + + std::string expected = + R"(Hybrid [x1 x2; 1] +MixtureFactor + Choice(1) + 0 Leaf Nonlinear factor on 2 keys + 1 Leaf Nonlinear factor on 2 keys +)"; + EXPECT(assert_print_equal(expected, mixtureFactor)); +} + +/* ************************************************************************* */ +// Test the error of the MixtureFactor +TEST(MixtureFactor, Error) { + DiscreteKey m1(1, 2); + + double between0 = 0.0; + double between1 = 1.0; + + Vector1 sigmas = Vector1(1.0); + auto model = noiseModel::Diagonal::Sigmas(sigmas, false); + + auto f0 = + boost::make_shared>(X(1), X(2), between0, model); + auto f1 = + boost::make_shared>(X(1), X(2), between1, model); + std::vector factors{f0, f1}; + + MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); + + Values continuousValues; + continuousValues.insert(X(1), 0); + continuousValues.insert(X(2), 1); + + AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousValues); + + std::vector discrete_keys = {m1}; + std::vector errors = {0.5, 0}; + AlgebraicDecisionTree expected_error(discrete_keys, errors); + + EXPECT(assert_equal(expected_error, error_tree)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */