From 64744b057e174684eb56d4618d0484fe3e4c236e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 26 Oct 2022 10:14:09 -0400 Subject: [PATCH 01/18] Hybrid Mixture error calculation --- gtsam/hybrid/GaussianMixtureFactor.h | 18 ++++++++++++++++++ gtsam/hybrid/MixtureFactor.h | 16 ++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index a0a51af55e..f27c491808 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -24,11 +24,13 @@ #include #include #include +#include namespace gtsam { class GaussianFactorGraph; +// Needed for wrapper. using GaussianFactorVector = std::vector; /** @@ -125,6 +127,22 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { */ Sum add(const Sum &sum) const; + /** + * @brief Compute error of the GaussianMixtureFactor as a tree. + * + * @param continuousVals The continuous VectorValues. + * @return DecisionTree A decision tree with corresponding keys + * as the factor but leaf values as the error. + */ + DecisionTree error(const VectorValues &c) const { + // functor to convert from sharedFactor to double error value. + auto errorFunc = [c](const GaussianFactor::shared_ptr &factor) { + return factor->error(c); + }; + DecisionTree errorTree(factors_, errorFunc); + return errorTree; + } + /// Add MixtureFactor to a Sum, syntactic sugar. friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { sum = factor.add(sum); diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index 5e7337d0cc..5a2383221a 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -121,6 +121,22 @@ class MixtureFactor : public HybridFactor { ~MixtureFactor() = default; + /** + * @brief Compute error of the MixtureFactor as a tree. + * + * @param continuousVals The continuous values for which to compute the error. + * @return DecisionTree A decision tree with corresponding keys + * as the factor but leaf values as the error. + */ + DecisionTree error(const Values& continuousVals) const { + // functor to convert from sharedFactor to double error value. + auto errorFunc = [continuousVals](const sharedFactor& factor) { + return factor->error(continuousVals); + }; + DecisionTree errorTree(factors_, errorFunc); + return errorTree; + } + /** * @brief Compute error of factor given both continuous and discrete values. * From c41b58fc986b110c8d873d8695cf3053c3242304 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 31 Oct 2022 15:36:48 -0400 Subject: [PATCH 02/18] Add GaussianMixtureFactor::error method and unit test --- gtsam/hybrid/GaussianMixtureFactor.cpp | 12 ++++++ gtsam/hybrid/GaussianMixtureFactor.h | 12 ++---- .../tests/testGaussianMixtureFactor.cpp | 37 ++++++++++++++++++- 3 files changed, 50 insertions(+), 11 deletions(-) diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 181b1e6a5a..a8500911a5 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -95,4 +95,16 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() }; return {factors_, wrap}; } + +/* *******************************************************************************/ +AlgebraicDecisionTree GaussianMixtureFactor::error( + const VectorValues &continuousVals) const { + // functor to convert from sharedFactor to double error value. + auto errorFunc = [continuousVals](const GaussianFactor::shared_ptr &factor) { + return factor->error(continuousVals); + }; + DecisionTree errorTree(factors_, errorFunc); + return errorTree; +} + } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index f27c491808..31ec3c1a08 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -20,6 +20,7 @@ #pragma once +#include #include #include #include @@ -131,17 +132,10 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @brief Compute error of the GaussianMixtureFactor as a tree. * * @param continuousVals The continuous VectorValues. - * @return DecisionTree A decision tree with corresponding keys + * @return AlgebraicDecisionTree A decision tree with corresponding keys * as the factor but leaf values as the error. */ - DecisionTree error(const VectorValues &c) const { - // functor to convert from sharedFactor to double error value. - auto errorFunc = [c](const GaussianFactor::shared_ptr &factor) { - return factor->error(c); - }; - DecisionTree errorTree(factors_, errorFunc); - return errorTree; - } + AlgebraicDecisionTree error(const VectorValues &continuousVals) const; /// Add MixtureFactor to a Sum, syntactic sugar. friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index cb9068c303..e6248f5c93 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file GaussianMixtureFactor.cpp + * @file testGaussianMixtureFactor.cpp * @brief Unit tests for GaussianMixtureFactor * @author Varun Agrawal * @author Fan Jiang @@ -135,7 +135,7 @@ TEST(GaussianMixtureFactor, Printing) { EXPECT(assert_print_equal(expected, mixtureFactor)); } -TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) { +TEST(GaussianMixtureFactor, GaussianMixture) { KeyVector keys; keys.push_back(X(0)); keys.push_back(X(1)); @@ -151,6 +151,39 @@ TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) { EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size()); } +/* ************************************************************************* */ +// Test the error of the GaussianMixtureFactor +TEST(GaussianMixtureFactor, Error) { + DiscreteKey m1(1, 2); + + auto A01 = Matrix2::Identity(); + auto A02 = Matrix2::Identity(); + + auto A11 = Matrix2::Identity(); + auto A12 = Matrix2::Identity() * 2; + + auto b = Vector2::Zero(); + + auto f0 = boost::make_shared(X(1), A01, X(2), A02, b); + auto f1 = boost::make_shared(X(1), A11, X(2), A12, b); + std::vector factors{f0, f1}; + + GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); + + VectorValues continuousVals; + continuousVals.insert(X(1), Vector2(0, 0)); + continuousVals.insert(X(2), Vector2(1, 1)); + + // error should return a tree of errors, with nodes for each discrete value. + AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousVals); + + std::vector discrete_keys = {m1}; + std::vector errors = {1, 4}; + AlgebraicDecisionTree expected_error(discrete_keys, errors); + + EXPECT(assert_equal(expected_error, error_tree)); +} + /* ************************************************************************* */ int main() { TestResult tr; From aa1c65d0dcf4da65459aa84a00c55d772f936660 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 31 Oct 2022 15:37:06 -0400 Subject: [PATCH 03/18] default string for AlgebraicDecisionTree::print --- gtsam/discrete/AlgebraicDecisionTree.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 9769715a17..2c749a4718 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -156,9 +156,9 @@ namespace gtsam { } /// print method customized to value type `double`. - void print(const std::string& s, - const typename Base::LabelFormatter& labelFormatter = - &DefaultFormatter) const { + void print(const std::string& s = "", + const typename Base::LabelFormatter& labelFormatter = + &DefaultFormatter) const { auto valueFormatter = [](const double& v) { return (boost::format("%4.8g") % v).str(); }; From 5c375f6d03d81a273c8dcaa4541883579ed71a56 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 31 Oct 2022 15:38:12 -0400 Subject: [PATCH 04/18] add unit tests for MixtureFactor --- gtsam/hybrid/tests/testMixtureFactor.cpp | 108 +++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 gtsam/hybrid/tests/testMixtureFactor.cpp diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp new file mode 100644 index 0000000000..17ada8c890 --- /dev/null +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -0,0 +1,108 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testMixtureFactor.cpp + * @brief Unit tests for MixtureFactor + * @author Varun Agrawal + * @date October 2022 + */ + +#include +#include +#include +#include +#include + +// Include for test suite +#include + +using namespace std; +using namespace gtsam; +using noiseModel::Isotropic; +using symbol_shorthand::M; +using symbol_shorthand::X; + +/* ************************************************************************* */ +// Check iterators of empty mixture. +TEST(MixtureFactor, Constructor) { + MixtureFactor factor; + MixtureFactor::const_iterator const_it = factor.begin(); + CHECK(const_it == factor.end()); + MixtureFactor::iterator it = factor.begin(); + CHECK(it == factor.end()); +} + + +TEST(MixtureFactor, Printing) { + DiscreteKey m1(1, 2); + double between0 = 0.0; + double between1 = 1.0; + + Vector1 sigmas = Vector1(1.0); + auto model = noiseModel::Diagonal::Sigmas(sigmas, false); + + auto f0 = + boost::make_shared>(X(1), X(2), between0, model); + auto f1 = + boost::make_shared>(X(1), X(2), between1, model); + std::vector factors{f0, f1}; + + MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); + + std::string expected = + R"(Hybrid [x1 x2; 1] +MixtureFactor + Choice(1) + 0 Leaf Nonlinear factor on 2 keys + 1 Leaf Nonlinear factor on 2 keys +)"; + EXPECT(assert_print_equal(expected, mixtureFactor)); +} + +/* ************************************************************************* */ +// Test the error of the MixtureFactor +TEST(MixtureFactor, Error) { + DiscreteKey m1(1, 2); + + double between0 = 0.0; + double between1 = 1.0; + + Vector1 sigmas = Vector1(1.0); + auto model = noiseModel::Diagonal::Sigmas(sigmas, false); + + auto f0 = + boost::make_shared>(X(1), X(2), between0, model); + auto f1 = + boost::make_shared>(X(1), X(2), between1, model); + std::vector factors{f0, f1}; + + MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); + + Values continuousVals; + continuousVals.insert(X(1), 0); + continuousVals.insert(X(2), 1); + + AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousVals); + + std::vector discrete_keys = {m1}; + std::vector errors = {0.5, 0}; + AlgebraicDecisionTree expected_error(discrete_keys, errors); + + EXPECT(assert_equal(expected_error, error_tree)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ \ No newline at end of file From d834897b14f134c98a78307f894d4c805094c881 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 31 Oct 2022 15:38:23 -0400 Subject: [PATCH 05/18] update MixtureFactor so that all tests pass --- gtsam/hybrid/MixtureFactor.h | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index 5a2383221a..511705cf3c 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -107,8 +107,12 @@ class MixtureFactor : public HybridFactor { std::copy(f->keys().begin(), f->keys().end(), std::inserter(factor_keys_set, factor_keys_set.end())); - nonlinear_factors.push_back( - boost::dynamic_pointer_cast(f)); + if (auto nf = boost::dynamic_pointer_cast(f)) { + nonlinear_factors.push_back(nf); + } else { + throw std::runtime_error( + "Factors passed into MixtureFactor need to be nonlinear!"); + } } factors_ = Factors(discreteKeys, nonlinear_factors); @@ -125,10 +129,10 @@ class MixtureFactor : public HybridFactor { * @brief Compute error of the MixtureFactor as a tree. * * @param continuousVals The continuous values for which to compute the error. - * @return DecisionTree A decision tree with corresponding keys + * @return AlgebraicDecisionTree A decision tree with corresponding keys * as the factor but leaf values as the error. */ - DecisionTree error(const Values& continuousVals) const { + AlgebraicDecisionTree error(const Values& continuousVals) const { // functor to convert from sharedFactor to double error value. auto errorFunc = [continuousVals](const sharedFactor& factor) { return factor->error(continuousVals); @@ -165,7 +169,7 @@ class MixtureFactor : public HybridFactor { /// print to stdout void print( - const std::string& s = "MixtureFactor", + const std::string& s = "", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { std::cout << (s.empty() ? "" : s + " "); Base::print("", keyFormatter); From c0eeb0cfcd07e0b0906174c306e8f61bbcea13c5 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 31 Oct 2022 19:22:53 -0400 Subject: [PATCH 06/18] add newline --- gtsam/hybrid/tests/testMixtureFactor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index 17ada8c890..c00d70e5ab 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -105,4 +105,4 @@ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } -/* ************************************************************************* */ \ No newline at end of file +/* ************************************************************************* */ From 9365a02bdb8cbdbab5aa978f0c52d736df5a4de9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 1 Nov 2022 14:01:20 -0400 Subject: [PATCH 07/18] add specific assignment error for GaussianMixtureFactor --- gtsam/hybrid/GaussianMixtureFactor.cpp | 8 ++++++++ gtsam/hybrid/GaussianMixtureFactor.h | 12 ++++++++++++ gtsam/hybrid/tests/testGaussianMixtureFactor.cpp | 6 ++++++ 3 files changed, 26 insertions(+) diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index a8500911a5..16802516e3 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -107,4 +107,12 @@ AlgebraicDecisionTree GaussianMixtureFactor::error( return errorTree; } +/* *******************************************************************************/ +double GaussianMixtureFactor::error( + const VectorValues &continuousVals, + const DiscreteValues &discreteValues) const { + auto factor = factors_(discreteValues); + return factor->error(continuousVals); +} + } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 31ec3c1a08..b6552c0785 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -137,6 +138,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { */ AlgebraicDecisionTree error(const VectorValues &continuousVals) const; + /** + * @brief Compute the error of this Gaussian Mixture given the continuous + * values and a discrete assignment. + * + * @param continuousVals The continuous values at which to compute the error. + * @param discreteValues The discrete assignment for a specific mode sequence. + * @return double + */ + double error(const VectorValues &continuousVals, + const DiscreteValues &discreteValues) const; + /// Add MixtureFactor to a Sum, syntactic sugar. friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { sum = factor.add(sum); diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index e6248f5c93..5c25a09313 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -182,6 +182,12 @@ TEST(GaussianMixtureFactor, Error) { AlgebraicDecisionTree expected_error(discrete_keys, errors); EXPECT(assert_equal(expected_error, error_tree)); + + // Test for single leaf given discrete assignment P(X|M,Z). + DiscreteValues discreteVals; + discreteVals[m1.first] = 1; + EXPECT_DOUBLES_EQUAL(4.0, mixtureFactor.error(continuousVals, discreteVals), + 1e-9); } /* ************************************************************************* */ From ca14b7e6ece6bfb5dbb21ce7f9024de6c7567a76 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 1 Nov 2022 20:19:36 -0400 Subject: [PATCH 08/18] GaussianMixture error methods --- gtsam/hybrid/GaussianMixture.cpp | 19 ++++++++++ gtsam/hybrid/GaussianMixture.h | 20 +++++++++++ gtsam/hybrid/tests/testGaussianMixture.cpp | 40 ++++++++++++++++++++-- 3 files changed, 77 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 5172a97983..c1194d2010 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -208,4 +208,23 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { conditionals_.root_ = pruned_conditionals.root_; } +/* *******************************************************************************/ +AlgebraicDecisionTree GaussianMixture::error( + const VectorValues &continuousVals) const { + // functor to convert from GaussianConditional to double error value. + auto errorFunc = + [continuousVals](const GaussianConditional::shared_ptr &conditional) { + return conditional->error(continuousVals); + }; + DecisionTree errorTree(conditionals_, errorFunc); + return errorTree; +} + +/* *******************************************************************************/ +double GaussianMixture::error(const VectorValues &continuousVals, + const DiscreteValues &discreteValues) const { + auto conditional = conditionals_(discreteValues); + return conditional->error(continuousVals); +} + } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 9792a85323..b3b47fc87a 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -143,6 +143,26 @@ class GTSAM_EXPORT GaussianMixture /// Getter for the underlying Conditionals DecisionTree const Conditionals &conditionals(); + /** + * @brief Compute error of the GaussianMixture as a tree. + * + * @param continuousVals The continuous VectorValues. + * @return AlgebraicDecisionTree A decision tree with corresponding keys + * as the factor but leaf values as the error. + */ + AlgebraicDecisionTree error(const VectorValues &continuousVals) const; + + /** + * @brief Compute the error of this Gaussian Mixture given the continuous + * values and a discrete assignment. + * + * @param continuousVals The continuous values at which to compute the error. + * @param discreteValues The discrete assignment for a specific mode sequence. + * @return double + */ + double error(const VectorValues &continuousVals, + const DiscreteValues &discreteValues) const; + /** * @brief Prune the decision tree of Gaussian factors as per the discrete * `decisionTree`. diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 420e22315b..556a5f16a9 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -78,15 +78,51 @@ TEST(GaussianMixture, Equals) { GaussianMixture::Conditionals conditionals( {m1}, vector{conditional0, conditional1}); - GaussianMixture mixtureFactor({X(1)}, {X(2)}, {m1}, conditionals); + GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals); // Let's check that this worked: DiscreteValues mode; mode[m1.first] = 1; - auto actual = mixtureFactor(mode); + auto actual = mixture(mode); EXPECT(actual == conditional1); } +/* ************************************************************************* */ +/// Test error method of GaussianMixture. +TEST(GaussianMixture, Error) { + Matrix22 S1 = Matrix22::Identity(); + Matrix22 S2 = Matrix22::Identity() * 2; + Matrix22 R1 = Matrix22::Ones(); + Matrix22 R2 = Matrix22::Ones(); + Vector2 d1(1, 2), d2(2, 1); + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + auto conditional0 = boost::make_shared(X(1), d1, R1, + X(2), S1, model), + conditional1 = boost::make_shared(X(1), d2, R2, + X(2), S2, model); + + // Create decision tree + DiscreteKey m1(1, 2); + GaussianMixture::Conditionals conditionals( + {m1}, + vector{conditional0, conditional1}); + GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals); + + VectorValues values; + values.insert(X(1), Vector2::Ones()); + values.insert(X(2), Vector2::Zero()); + auto error_tree = mixture.error(values); + + std::vector discrete_keys = {m1}; + std::vector leaves = {0.5, 4.3252595}; + AlgebraicDecisionTree expected_error(discrete_keys, leaves); + + // regression + EXPECT(assert_equal(expected_error, error_tree, 1e-6)); +} + /* ************************************************************************* */ int main() { TestResult tr; From 281ad3167ef5a82f7d63743114bc79b319a1fb71 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 2 Nov 2022 02:53:51 -0400 Subject: [PATCH 09/18] error method for HybridBayesNet --- gtsam/hybrid/GaussianMixture.cpp | 7 ++- gtsam/hybrid/HybridBayesNet.cpp | 41 +++++++++++++++++ gtsam/hybrid/HybridBayesNet.h | 13 ++++++ gtsam/hybrid/tests/testHybridBayesNet.cpp | 55 +++++++++++++++++++++++ 4 files changed, 115 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index c1194d2010..2c5aabf552 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -214,7 +214,12 @@ AlgebraicDecisionTree GaussianMixture::error( // functor to convert from GaussianConditional to double error value. auto errorFunc = [continuousVals](const GaussianConditional::shared_ptr &conditional) { - return conditional->error(continuousVals); + if (conditional) { + return conditional->error(continuousVals); + } else { + // return arbitrarily large error + return 1e50; + } }; DecisionTree errorTree(conditionals_, errorFunc); return errorTree; diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index cc27600f09..91d92ab0e2 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -145,4 +145,45 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { return gbn.optimize(); } +/* ************************************************************************* */ +double HybridBayesNet::error(const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const { + GaussianBayesNet gbn = this->choose(discreteValues); + return gbn.error(continuousValues); +} + +/* ************************************************************************* */ +AlgebraicDecisionTree HybridBayesNet::error( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree error_tree; + + for (size_t idx = 0; idx < size(); idx++) { + AlgebraicDecisionTree conditional_error; + if (factors_.at(idx)->isHybrid()) { + // If factor is hybrid, select based on assignment. + GaussianMixture::shared_ptr gm = this->atMixture(idx); + conditional_error = gm->error(continuousValues); + + if (idx == 0) { + error_tree = conditional_error; + } else { + error_tree = error_tree + conditional_error; + } + + } else if (factors_.at(idx)->isContinuous()) { + // If continuous only, get the (double) error + // and add it to the error_tree + double error = this->atGaussian(idx)->error(continuousValues); + error_tree = error_tree.apply( + [error](double leaf_value) { return leaf_value + error; }); + + } else if (factors_.at(idx)->isDiscrete()) { + // If factor at `idx` is discrete-only, we skip. + continue; + } + } + + return error_tree; +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index b8234d70ab..82e890cc4f 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -123,6 +123,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. HybridBayesNet prune(size_t maxNrLeaves) const; + /** + * @brief 0.5 * sum of squared Mahalanobis distances + * for a specific discrete assignment. + * + * @param continuousValues Continuous values at which to compute the error. + * @param discreteValues Discrete assignment for a specific mode sequence. + * @return double + */ + double error(const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const; + + AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + /// @} private: diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 5885fdcdcc..4ca760f882 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -183,6 +183,61 @@ TEST(HybridBayesNet, OptimizeMultifrontal) { EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); } +/* ****************************************************************************/ +// Test bayes net error +TEST(HybridBayesNet, Error) { + Switching s(3); + + Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + s.linearizedFactorGraph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + auto error_tree = hybridBayesNet->error(delta.continuous()); + + std::vector discrete_keys = {{M(1), 2}, {M(2), 2}}; + std::vector leaves = {0.0097568009, 3.3973404e-31, 0.029126214, + 0.0097568009}; + AlgebraicDecisionTree expected_error(discrete_keys, leaves); + + // regression + EXPECT(assert_equal(expected_error, error_tree, 1e-9)); + + // Error on pruned bayes net + auto prunedBayesNet = hybridBayesNet->prune(2); + auto pruned_error_tree = prunedBayesNet.error(delta.continuous()); + + std::vector pruned_leaves = {2e50, 3.3973404e-31, 2e50, 0.0097568009}; + AlgebraicDecisionTree expected_pruned_error(discrete_keys, + pruned_leaves); + + // regression + EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9)); + + // Verify error computation and check for specific error value + DiscreteValues discrete_values; + discrete_values[M(1)] = 1; + discrete_values[M(2)] = 1; + + double total_error = 0; + for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { + if (hybridBayesNet->at(idx)->isHybrid()) { + double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(), + discrete_values); + total_error += error; + } else if (hybridBayesNet->at(idx)->isContinuous()) { + double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous()); + total_error += error; + } + } + + EXPECT_DOUBLES_EQUAL( + total_error, hybridBayesNet->error(delta.continuous(), discrete_values), + 1e-9); + EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9); + EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9); +} + /* ****************************************************************************/ // Test bayes net pruning TEST(HybridBayesNet, Prune) { From cff6505423929bcc771e0cf990916d8bf981a64f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 3 Nov 2022 06:33:51 -0400 Subject: [PATCH 10/18] add docstring for HybridBayesNet::error(VectorValues) --- gtsam/hybrid/HybridBayesNet.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 82e890cc4f..d64e561f8e 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -134,6 +134,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { double error(const VectorValues &continuousValues, const DiscreteValues &discreteValues) const; + /** + * @brief Compute conditional error for each discrete assignment, + * and return as a tree. + * + * @param continuousValues Continuous values at which to compute the error. + * @return AlgebraicDecisionTree + */ AlgebraicDecisionTree error(const VectorValues &continuousValues) const; /// @} From 8fa7f443617d0a560c0800449bb28b7136f29577 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 3 Nov 2022 11:44:41 -0400 Subject: [PATCH 11/18] fix discrete keys --- gtsam/hybrid/tests/testHybridBayesNet.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 3c229476ed..8b8ca976b0 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -195,7 +195,7 @@ TEST(HybridBayesNet, Error) { HybridValues delta = hybridBayesNet->optimize(); auto error_tree = hybridBayesNet->error(delta.continuous()); - std::vector discrete_keys = {{M(1), 2}, {M(2), 2}}; + std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; std::vector leaves = {0.0097568009, 3.3973404e-31, 0.029126214, 0.0097568009}; AlgebraicDecisionTree expected_error(discrete_keys, leaves); @@ -216,8 +216,8 @@ TEST(HybridBayesNet, Error) { // Verify error computation and check for specific error value DiscreteValues discrete_values; + discrete_values[M(0)] = 1; discrete_values[M(1)] = 1; - discrete_values[M(2)] = 1; double total_error = 0; for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { From 9cb225ac20e64670d340bd89abd217a49eaff64d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 7 Nov 2022 15:42:24 -0500 Subject: [PATCH 12/18] Allow setting custom leaf value for AlgebraicDecisionTree --- gtsam/discrete/AlgebraicDecisionTree.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index cc3c8b8d77..b3f0d69b0e 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -71,7 +71,7 @@ namespace gtsam { static inline double id(const double& x) { return x; } }; - AlgebraicDecisionTree() : Base(1.0) {} + AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {} // Explicitly non-explicit constructor AlgebraicDecisionTree(const Base& add) : Base(add) {} From 551cc0d32b9948ee0b97543f5bdf4ef44314daa8 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 7 Nov 2022 15:42:47 -0500 Subject: [PATCH 13/18] add error and probPrime methods to HybridGaussianFactorGraph --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 49 +++++++++++++++++++ gtsam/hybrid/HybridGaussianFactorGraph.h | 22 ++++++++- .../tests/testHybridGaussianFactorGraph.cpp | 30 ++++++++++++ 3 files changed, 100 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 041603fbd5..d6937957f2 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -423,4 +423,53 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { return ordering; } +/* ************************************************************************ */ +AlgebraicDecisionTree HybridGaussianFactorGraph::error( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree error_tree(0.0); + + for (size_t idx = 0; idx < size(); idx++) { + AlgebraicDecisionTree factor_error; + + if (factors_.at(idx)->isHybrid()) { + // If factor is hybrid, select based on assignment. + GaussianMixtureFactor::shared_ptr gaussianMixture = + boost::static_pointer_cast(factors_.at(idx)); + factor_error = gaussianMixture->error(continuousValues); + + if (idx == 0) { + error_tree = factor_error; + } else { + error_tree = error_tree + factor_error; + } + + } else if (factors_.at(idx)->isContinuous()) { + // If continuous only, get the (double) error + // and add it to the error_tree + auto hybridGaussianFactor = + boost::static_pointer_cast(factors_.at(idx)); + GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner(); + + double error = gaussian->error(continuousValues); + error_tree = error_tree.apply( + [error](double leaf_value) { return leaf_value + error; }); + + } else if (factors_.at(idx)->isDiscrete()) { + // If factor at `idx` is discrete-only, we skip. + continue; + } + } + + return error_tree; +} + +/* ************************************************************************ */ +AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree error_tree = this->error(continuousValues); + AlgebraicDecisionTree prob_tree = + error_tree.apply([](double error) { return exp(-error); }); + return prob_tree; +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 6a03625001..c7e9aa60da 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -41,7 +41,7 @@ class JacobianFactor; /** * @brief Main elimination function for HybridGaussianFactorGraph. - * + * * @param factors The factor graph to eliminate. * @param keys The elimination ordering. * @return The conditional on the ordering keys and the remaining factors. @@ -170,6 +170,26 @@ class GTSAM_EXPORT HybridGaussianFactorGraph } } + /** + * @brief Compute error for each discrete assignment, + * and return as a tree. + * + * @param continuousValues Continuous values at which to compute the error. + * @return AlgebraicDecisionTree + */ + AlgebraicDecisionTree error(const VectorValues& continuousValues) const; + + /** + * @brief Compute unnormalized probability for each discrete assignment, + * and return as a tree. + * + * @param continuousValues Continuous values at which to compute the + * probability. + * @return AlgebraicDecisionTree + */ + AlgebraicDecisionTree probPrime( + const VectorValues& continuousValues) const; + /** * @brief Return a Colamd constrained ordering where the discrete keys are * eliminated after the continuous keys. diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index ed6b97ab04..7877461b67 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -562,6 +562,36 @@ TEST(HybridGaussianFactorGraph, Conditionals) { EXPECT(assert_equal(expected_discrete, result.discrete())); } +/* ****************************************************************************/ +// Test hybrid gaussian factor graph error and unnormalized probabilities +TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) { + Switching s(3); + + HybridGaussianFactorGraph graph = s.linearizedFactorGraph; + + Ordering hybridOrdering = graph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + graph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + auto error_tree = graph.error(delta.continuous()); + + std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; + std::vector leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568}; + AlgebraicDecisionTree expected_error(discrete_keys, leaves); + + // regression + EXPECT(assert_equal(expected_error, error_tree, 1e-7)); + + auto probs = graph.probPrime(delta.continuous()); + std::vector prob_leaves = {0.36793249, 0.61247742, 0.59489556, + 0.99029064}; + AlgebraicDecisionTree expected_probs(discrete_keys, prob_leaves); + + // regression + EXPECT(assert_equal(expected_probs, probs, 1e-7)); +} + /* ************************************************************************* */ int main() { TestResult tr; From 07a616dcdae69013045af724fec36234307fe420 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 7 Nov 2022 15:42:59 -0500 Subject: [PATCH 14/18] add probPrime to HybridBayesNet --- gtsam/hybrid/HybridBayesNet.cpp | 6 ++++++ gtsam/hybrid/HybridBayesNet.h | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index fe87795fe0..f0d53c4165 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -273,4 +273,10 @@ AlgebraicDecisionTree HybridBayesNet::error( return error_tree; } +AlgebraicDecisionTree HybridBayesNet::probPrime( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree error_tree = this->error(continuousValues); + return error_tree.apply([](double error) { return exp(-error); }); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index f296ba6449..c6ac6dcec7 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -144,6 +144,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + /** + * @brief Compute unnormalized probability for each discrete assignment, + * and return as a tree. + * + * @param continuousValues Continuous values at which to compute the + * probability. + * @return AlgebraicDecisionTree + */ + AlgebraicDecisionTree probPrime( + const VectorValues &continuousValues) const; + /// @} private: From 0e1c3b8cb6755577f8c8cf4891013d8931716510 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 22 Dec 2022 07:33:51 +0530 Subject: [PATCH 15/18] rename all *Vals to *Values --- gtsam/hybrid/GaussianMixture.cpp | 14 +++--- gtsam/hybrid/GaussianMixture.h | 14 +++--- gtsam/hybrid/GaussianMixtureFactor.cpp | 13 +++--- gtsam/hybrid/GaussianMixtureFactor.h | 12 ++--- gtsam/hybrid/MixtureFactor.h | 44 ++++++++++--------- gtsam/hybrid/hybrid.i | 6 +-- .../tests/testGaussianMixtureFactor.cpp | 16 +++---- gtsam/hybrid/tests/testMixtureFactor.cpp | 8 ++-- 8 files changed, 65 insertions(+), 62 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 4819eda657..314d4fc635 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -85,8 +85,8 @@ size_t GaussianMixture::nrComponents() const { /* *******************************************************************************/ GaussianConditional::shared_ptr GaussianMixture::operator()( - const DiscreteValues &discreteVals) const { - auto &ptr = conditionals_(discreteVals); + const DiscreteValues &discreteValues) const { + auto &ptr = conditionals_(discreteValues); if (!ptr) return nullptr; auto conditional = boost::dynamic_pointer_cast(ptr); if (conditional) @@ -209,12 +209,12 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { /* *******************************************************************************/ AlgebraicDecisionTree GaussianMixture::error( - const VectorValues &continuousVals) const { + const VectorValues &continuousValues) const { // functor to convert from GaussianConditional to double error value. auto errorFunc = - [continuousVals](const GaussianConditional::shared_ptr &conditional) { + [continuousValues](const GaussianConditional::shared_ptr &conditional) { if (conditional) { - return conditional->error(continuousVals); + return conditional->error(continuousValues); } else { // return arbitrarily large error return 1e50; @@ -225,10 +225,10 @@ AlgebraicDecisionTree GaussianMixture::error( } /* *******************************************************************************/ -double GaussianMixture::error(const VectorValues &continuousVals, +double GaussianMixture::error(const VectorValues &continuousValues, const DiscreteValues &discreteValues) const { auto conditional = conditionals_(discreteValues); - return conditional->error(continuousVals); + return conditional->error(continuousValues); } } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 4bb13d298f..88d5a02c0c 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -122,7 +122,7 @@ class GTSAM_EXPORT GaussianMixture /// @{ GaussianConditional::shared_ptr operator()( - const DiscreteValues &discreteVals) const; + const DiscreteValues &discreteValues) const; /// Returns the total number of continuous components size_t nrComponents() const; @@ -147,21 +147,21 @@ class GTSAM_EXPORT GaussianMixture /** * @brief Compute error of the GaussianMixture as a tree. * - * @param continuousVals The continuous VectorValues. - * @return AlgebraicDecisionTree A decision tree with corresponding keys - * as the factor but leaf values as the error. + * @param continuousValues The continuous VectorValues. + * @return AlgebraicDecisionTree A decision tree with the same keys + * as the conditionals, and leaf values as the error. */ - AlgebraicDecisionTree error(const VectorValues &continuousVals) const; + AlgebraicDecisionTree error(const VectorValues &continuousValues) const; /** * @brief Compute the error of this Gaussian Mixture given the continuous * values and a discrete assignment. * - * @param continuousVals The continuous values at which to compute the error. + * @param continuousValues Continuous values at which to compute the error. * @param discreteValues The discrete assignment for a specific mode sequence. * @return double */ - double error(const VectorValues &continuousVals, + double error(const VectorValues &continuousValues, const DiscreteValues &discreteValues) const; /** diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 16802516e3..f070fe07aa 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -98,21 +98,22 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() /* *******************************************************************************/ AlgebraicDecisionTree GaussianMixtureFactor::error( - const VectorValues &continuousVals) const { + const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = [continuousVals](const GaussianFactor::shared_ptr &factor) { - return factor->error(continuousVals); - }; + auto errorFunc = + [continuousValues](const GaussianFactor::shared_ptr &factor) { + return factor->error(continuousValues); + }; DecisionTree errorTree(factors_, errorFunc); return errorTree; } /* *******************************************************************************/ double GaussianMixtureFactor::error( - const VectorValues &continuousVals, + const VectorValues &continuousValues, const DiscreteValues &discreteValues) const { auto factor = factors_(discreteValues); - return factor->error(continuousVals); + return factor->error(continuousValues); } } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 2808ec49f7..0b65b5aa93 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -133,21 +133,21 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { /** * @brief Compute error of the GaussianMixtureFactor as a tree. * - * @param continuousVals The continuous VectorValues. - * @return AlgebraicDecisionTree A decision tree with corresponding keys - * as the factor but leaf values as the error. + * @param continuousValues The continuous VectorValues. + * @return AlgebraicDecisionTree A decision tree with the same keys + * as the factors involved, and leaf values as the error. */ - AlgebraicDecisionTree error(const VectorValues &continuousVals) const; + AlgebraicDecisionTree error(const VectorValues &continuousValues) const; /** * @brief Compute the error of this Gaussian Mixture given the continuous * values and a discrete assignment. * - * @param continuousVals The continuous values at which to compute the error. + * @param continuousValues Continuous values at which to compute the error. * @param discreteValues The discrete assignment for a specific mode sequence. * @return double */ - double error(const VectorValues &continuousVals, + double error(const VectorValues &continuousValues, const DiscreteValues &discreteValues) const; /// Add MixtureFactor to a Sum, syntactic sugar. diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index 511705cf3c..58a915d57b 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -128,14 +128,15 @@ class MixtureFactor : public HybridFactor { /** * @brief Compute error of the MixtureFactor as a tree. * - * @param continuousVals The continuous values for which to compute the error. - * @return AlgebraicDecisionTree A decision tree with corresponding keys - * as the factor but leaf values as the error. + * @param continuousValues The continuous values for which to compute the + * error. + * @return AlgebraicDecisionTree A decision tree with the same keys + * as the factor, and leaf values as the error. */ - AlgebraicDecisionTree error(const Values& continuousVals) const { + AlgebraicDecisionTree error(const Values& continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = [continuousVals](const sharedFactor& factor) { - return factor->error(continuousVals); + auto errorFunc = [continuousValues](const sharedFactor& factor) { + return factor->error(continuousValues); }; DecisionTree errorTree(factors_, errorFunc); return errorTree; @@ -144,19 +145,19 @@ class MixtureFactor : public HybridFactor { /** * @brief Compute error of factor given both continuous and discrete values. * - * @param continuousVals The continuous Values. - * @param discreteVals The discrete Values. + * @param continuousValues The continuous Values. + * @param discreteValues The discrete Values. * @return double The error of this factor. */ - double error(const Values& continuousVals, - const DiscreteValues& discreteVals) const { - // Retrieve the factor corresponding to the assignment in discreteVals. - auto factor = factors_(discreteVals); + double error(const Values& continuousValues, + const DiscreteValues& discreteValues) const { + // Retrieve the factor corresponding to the assignment in discreteValues. + auto factor = factors_(discreteValues); // Compute the error for the selected factor - const double factorError = factor->error(continuousVals); + const double factorError = factor->error(continuousValues); if (normalized_) return factorError; - return factorError + - this->nonlinearFactorLogNormalizingConstant(factor, continuousVals); + return factorError + this->nonlinearFactorLogNormalizingConstant( + factor, continuousValues); } size_t dim() const { @@ -212,17 +213,18 @@ class MixtureFactor : public HybridFactor { /// Linearize specific nonlinear factors based on the assignment in /// discreteValues. GaussianFactor::shared_ptr linearize( - const Values& continuousVals, const DiscreteValues& discreteVals) const { - auto factor = factors_(discreteVals); - return factor->linearize(continuousVals); + const Values& continuousValues, + const DiscreteValues& discreteValues) const { + auto factor = factors_(discreteValues); + return factor->linearize(continuousValues); } /// Linearize all the continuous factors to get a GaussianMixtureFactor. boost::shared_ptr linearize( - const Values& continuousVals) const { + const Values& continuousValues) const { // functional to linearize each factor in the decision tree - auto linearizeDT = [continuousVals](const sharedFactor& factor) { - return factor->linearize(continuousVals); + auto linearizeDT = [continuousValues](const sharedFactor& factor) { + return factor->linearize(continuousValues); }; DecisionTree linearized_factors( diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 86029a48aa..90c76593ef 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -204,14 +204,14 @@ class MixtureFactor : gtsam::HybridFactor { const std::vector& factors, bool normalized = false); - double error(const gtsam::Values& continuousVals, - const gtsam::DiscreteValues& discreteVals) const; + double error(const gtsam::Values& continuousValues, + const gtsam::DiscreteValues& discreteValues) const; double nonlinearFactorLogNormalizingConstant(const gtsam::NonlinearFactor* factor, const gtsam::Values& values) const; GaussianMixtureFactor* linearize( - const gtsam::Values& continuousVals) const; + const gtsam::Values& continuousValues) const; void print(string s = "MixtureFactor\n", const gtsam::KeyFormatter& keyFormatter = diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 5c25a09313..14e1b8dadd 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -170,12 +170,12 @@ TEST(GaussianMixtureFactor, Error) { GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); - VectorValues continuousVals; - continuousVals.insert(X(1), Vector2(0, 0)); - continuousVals.insert(X(2), Vector2(1, 1)); + VectorValues continuousValues; + continuousValues.insert(X(1), Vector2(0, 0)); + continuousValues.insert(X(2), Vector2(1, 1)); // error should return a tree of errors, with nodes for each discrete value. - AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousVals); + AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousValues); std::vector discrete_keys = {m1}; std::vector errors = {1, 4}; @@ -184,10 +184,10 @@ TEST(GaussianMixtureFactor, Error) { EXPECT(assert_equal(expected_error, error_tree)); // Test for single leaf given discrete assignment P(X|M,Z). - DiscreteValues discreteVals; - discreteVals[m1.first] = 1; - EXPECT_DOUBLES_EQUAL(4.0, mixtureFactor.error(continuousVals, discreteVals), - 1e-9); + DiscreteValues discreteValues; + discreteValues[m1.first] = 1; + EXPECT_DOUBLES_EQUAL( + 4.0, mixtureFactor.error(continuousValues, discreteValues), 1e-9); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index c00d70e5ab..5167f6ff6a 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -87,11 +87,11 @@ TEST(MixtureFactor, Error) { MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); - Values continuousVals; - continuousVals.insert(X(1), 0); - continuousVals.insert(X(2), 1); + Values continuousValues; + continuousValues.insert(X(1), 0); + continuousValues.insert(X(2), 1); - AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousVals); + AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousValues); std::vector discrete_keys = {m1}; std::vector errors = {0.5, 0}; From 098d2ce4a4120a418d0987405394989db07a85ac Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 22 Dec 2022 08:26:08 +0530 Subject: [PATCH 16/18] Update docstrings --- gtsam/hybrid/HybridBayesNet.h | 6 ++++-- gtsam/hybrid/HybridGaussianFactorGraph.h | 9 ++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index c6ac6dcec7..f8ec609119 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -145,8 +145,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { AlgebraicDecisionTree error(const VectorValues &continuousValues) const; /** - * @brief Compute unnormalized probability for each discrete assignment, - * and return as a tree. + * @brief Compute unnormalized probability q(μ|M), + * for each discrete assignment, and return as a tree. + * q(μ|M) is the unnormalized probability at the MLE point μ, + * conditioned on the discrete variables. * * @param continuousValues Continuous values at which to compute the * probability. diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index c7e9aa60da..ac9ae1a462 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -99,11 +99,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph using shared_ptr = boost::shared_ptr; ///< shared_ptr to This using Values = gtsam::Values; ///< backwards compatibility - using Indices = KeyVector; ///> map from keys to values + using Indices = KeyVector; ///< map from keys to values /// @name Constructors /// @{ + /// @brief Default constructor. HybridGaussianFactorGraph() = default; /** @@ -174,14 +175,16 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * @brief Compute error for each discrete assignment, * and return as a tree. * + * Error \f$ e = \Vert x - \mu \Vert_{\Sigma} \f$. + * * @param continuousValues Continuous values at which to compute the error. * @return AlgebraicDecisionTree */ AlgebraicDecisionTree error(const VectorValues& continuousValues) const; /** - * @brief Compute unnormalized probability for each discrete assignment, - * and return as a tree. + * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ + * for each discrete assignment, and return as a tree. * * @param continuousValues Continuous values at which to compute the * probability. From d94b3199a05419020eb5a363e6c43bd573eb85f2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 22 Dec 2022 09:22:34 +0530 Subject: [PATCH 17/18] address review comments --- gtsam/hybrid/GaussianMixture.cpp | 6 ++++-- gtsam/hybrid/GaussianMixtureFactor.cpp | 1 + gtsam/hybrid/HybridBayesNet.cpp | 7 ++++++- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 5 +++++ gtsam/hybrid/MixtureFactor.h | 7 ++++--- gtsam/hybrid/hybrid.i | 6 ++++-- gtsam/hybrid/tests/testGaussianMixture.cpp | 11 +++++++++-- gtsam/hybrid/tests/testGaussianMixtureFactor.cpp | 1 + gtsam/hybrid/tests/testHybridBayesNet.cpp | 3 +-- gtsam/hybrid/tests/testMixtureFactor.cpp | 3 ++- 10 files changed, 37 insertions(+), 13 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 314d4fc635..c0815b2d7a 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -210,13 +210,14 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { /* *******************************************************************************/ AlgebraicDecisionTree GaussianMixture::error( const VectorValues &continuousValues) const { - // functor to convert from GaussianConditional to double error value. + // functor to calculate to double error value from GaussianConditional. auto errorFunc = [continuousValues](const GaussianConditional::shared_ptr &conditional) { if (conditional) { return conditional->error(continuousValues); } else { - // return arbitrarily large error + // Return arbitrarily large error if conditional is null + // Conditional is null if it is pruned out. return 1e50; } }; @@ -227,6 +228,7 @@ AlgebraicDecisionTree GaussianMixture::error( /* *******************************************************************************/ double GaussianMixture::error(const VectorValues &continuousValues, const DiscreteValues &discreteValues) const { + // Directly index to get the conditional, no need to build the whole tree. auto conditional = conditionals_(discreteValues); return conditional->error(continuousValues); } diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index f070fe07aa..fd437f52c0 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -112,6 +112,7 @@ AlgebraicDecisionTree GaussianMixtureFactor::error( double GaussianMixtureFactor::error( const VectorValues &continuousValues, const DiscreteValues &discreteValues) const { + // Directly index to get the conditional, no need to build the whole tree. auto factor = factors_(discreteValues); return factor->error(continuousValues); } diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index f0d53c4165..48c4b6d508 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -244,13 +244,16 @@ AlgebraicDecisionTree HybridBayesNet::error( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree; + // Iterate over each factor. for (size_t idx = 0; idx < size(); idx++) { AlgebraicDecisionTree conditional_error; + if (factors_.at(idx)->isHybrid()) { - // If factor is hybrid, select based on assignment. + // If factor is hybrid, select based on assignment and compute error. GaussianMixture::shared_ptr gm = this->atMixture(idx); conditional_error = gm->error(continuousValues); + // Assign for the first index, add error for subsequent ones. if (idx == 0) { error_tree = conditional_error; } else { @@ -261,6 +264,7 @@ AlgebraicDecisionTree HybridBayesNet::error( // If continuous only, get the (double) error // and add it to the error_tree double error = this->atGaussian(idx)->error(continuousValues); + // Add the computed error to every leaf of the error tree. error_tree = error_tree.apply( [error](double leaf_value) { return leaf_value + error; }); @@ -273,6 +277,7 @@ AlgebraicDecisionTree HybridBayesNet::error( return error_tree; } +/* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::probPrime( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree = this->error(continuousValues); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index d6937957f2..32653bdecf 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -428,6 +428,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree(0.0); + // Iterate over each factor. for (size_t idx = 0; idx < size(); idx++) { AlgebraicDecisionTree factor_error; @@ -435,8 +436,10 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( // If factor is hybrid, select based on assignment. GaussianMixtureFactor::shared_ptr gaussianMixture = boost::static_pointer_cast(factors_.at(idx)); + // Compute factor error. factor_error = gaussianMixture->error(continuousValues); + // If first factor, assign error, else add it. if (idx == 0) { error_tree = factor_error; } else { @@ -450,7 +453,9 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( boost::static_pointer_cast(factors_.at(idx)); GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner(); + // Compute the error of the gaussian factor. double error = gaussian->error(continuousValues); + // Add the gaussian factor error to every leaf of the error tree. error_tree = error_tree.apply( [error](double leaf_value) { return leaf_value + error; }); diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index 58a915d57b..f29a840229 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -86,11 +87,11 @@ class MixtureFactor : public HybridFactor { * elements based on the number of discrete keys and the cardinality of the * keys, so that the decision tree is constructed appropriately. * - * @tparam FACTOR The type of the factor shared pointers being passed in. Will - * be typecast to NonlinearFactor shared pointers. + * @tparam FACTOR The type of the factor shared pointers being passed in. + * Will be typecast to NonlinearFactor shared pointers. * @param keys Vector of keys for continuous factors. * @param discreteKeys Vector of discrete keys. - * @param factors Vector of shared pointers to factors. + * @param factors Vector of nonlinear factors. * @param normalized Flag indicating if the factor error is already * normalized. */ diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 90c76593ef..899c129e00 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -196,8 +196,10 @@ class HybridNonlinearFactorGraph { #include class MixtureFactor : gtsam::HybridFactor { - MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, - const gtsam::DecisionTree& factors, bool normalized = false); + MixtureFactor( + const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, + const gtsam::DecisionTree& factors, + bool normalized = false); template MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 556a5f16a9..310081f028 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -104,7 +104,7 @@ TEST(GaussianMixture, Error) { X(2), S2, model); // Create decision tree - DiscreteKey m1(1, 2); + DiscreteKey m1(M(1), 2); GaussianMixture::Conditionals conditionals( {m1}, vector{conditional0, conditional1}); @@ -115,12 +115,19 @@ TEST(GaussianMixture, Error) { values.insert(X(2), Vector2::Zero()); auto error_tree = mixture.error(values); + // regression std::vector discrete_keys = {m1}; std::vector leaves = {0.5, 4.3252595}; AlgebraicDecisionTree expected_error(discrete_keys, leaves); - // regression EXPECT(assert_equal(expected_error, error_tree, 1e-6)); + + // Regression for non-tree version. + DiscreteValues assignment; + assignment[M(1)] = 0; + EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8); + assignment[M(1)] = 1; + EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), 1e-8); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 14e1b8dadd..ba0622ff9d 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -178,6 +178,7 @@ TEST(GaussianMixtureFactor, Error) { AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousValues); std::vector discrete_keys = {m1}; + // Error values for regression test std::vector errors = {1, 4}; AlgebraicDecisionTree expected_error(discrete_keys, errors); diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 8b8ca976b0..3593e1952c 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -216,8 +216,7 @@ TEST(HybridBayesNet, Error) { // Verify error computation and check for specific error value DiscreteValues discrete_values; - discrete_values[M(0)] = 1; - discrete_values[M(1)] = 1; + insert(discrete_values)(M(0), 1)(M(1), 1); double total_error = 0; for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index 5167f6ff6a..fe3212eda5 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -41,7 +41,8 @@ TEST(MixtureFactor, Constructor) { CHECK(it == factor.end()); } - +/* ************************************************************************* */ +// Test .print() output. TEST(MixtureFactor, Printing) { DiscreteKey m1(1, 2); double between0 = 0.0; From 23ec7eddb01c09922d38d8bc70d2a8da1763924e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 22 Dec 2022 09:53:22 +0530 Subject: [PATCH 18/18] cleaner way of assigning discrete values --- gtsam/hybrid/tests/testHybridBayesNet.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 3593e1952c..050c01aeda 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -216,7 +216,7 @@ TEST(HybridBayesNet, Error) { // Verify error computation and check for specific error value DiscreteValues discrete_values; - insert(discrete_values)(M(0), 1)(M(1), 1); + boost::assign::insert(discrete_values)(M(0), 1)(M(1), 1); double total_error = 0; for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {